import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import keras import utils
import model as modellib
import visualize
from model import log %matplotlib inline # Root directory of the project
ROOT_DIR = os.getcwd() # Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH) # Path to Shapes trained weights
SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_shapes.h5")
# Run one of the code blocks

# Shapes toy dataset
# import shapes
# config = shapes.ShapesConfig() # MS COCO Dataset
import coco
config = coco.CocoConfig()
# Device to load the neural network on.
# Useful if you're training a model on the same
# machine, in which case use CPU and leave the
# GPU for training.
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
def get_ax(rows=1, cols=1, size=16):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes. Adjust the size attribute to control how big to render images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
# Create model in inference mode
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
config=config) # Set weights file path
if config.NAME == "shapes":
weights_path = SHAPES_MODEL_PATH
elif config.NAME == "coco":
weights_path = COCO_MODEL_PATH
# Or, uncomment to load the last model you trained
# weights_path = model.find_last()[1] # Load weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
# Show stats of all trainable weights
visualize.display_weight_stats(model)

# Pick layer types to display
LAYER_TYPES = ['Conv2D', 'Dense', 'Conv2DTranspose']
# Get layers
layers = model.get_trainable_layers()
layers = list(filter(lambda l: l.__class__.__name__ in LAYER_TYPES,
layers))
# Display Histograms
fig, ax = plt.subplots(len(layers), 2, figsize=(10, 3*len(layers)),
gridspec_kw={"hspace":1})
for l, layer in enumerate(layers):
weights = layer.get_weights()
for w, weight in enumerate(weights):
tensor = layer.weights[w]
ax[l, w].set_title(tensor.name)
_ = ax[l, w].hist(weight[w].flatten(), 50)

吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(5)的更多相关文章

  1. 吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(4)

    import os import sys import random import math import re import time import numpy as np import tenso ...

  2. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(3)

    import os import sys import random import math import re import time import numpy as np import cv2 i ...

  3. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(2)

    import os import sys import itertools import math import logging import json import re import random ...

  4. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(1)

    import os import sys import random import math import numpy as np import skimage.io import matplotli ...

  5. 吴裕雄 python 人工智能——基于神经网络算法在智能医疗诊断中的应用探索代码简要展示

    #K-NN分类 import os import sys import time import operator import cx_Oracle import numpy as np import ...

  6. 吴裕雄 PYTHON 人工智能——智能医疗系统后台智能分诊模块及系统健康养生公告简约版代码展示

    #coding:utf-8 import sys import cx_Oracle import numpy as np import pandas as pd import tensorflow a ...

  7. 吴裕雄 python 人工智能——智能医疗系统后台用户复诊模块简约版代码展示

    #复诊 import sys import os import time import operator import cx_Oracle import numpy as np import pand ...

  8. 吴裕雄 python 人工智能——智能医疗系统后台用户注册、登录和初诊简约版代码展示

    #用户注册.登录模块 #数据库脚本 CREATE TABLE usertable( userid number(8) primary key not null , username varchar(5 ...

  9. TF项目实战(基于SSD目标检测)——人脸检测1

    SSD实战——人脸检测 Tensorflow 一 .人脸检测的困难: 1. 姿态问题 2.不同种族人, 3.光照 遮挡 带眼睛 4.视角不同 5. 不同尺度 二. 数据集介绍以及转化VOC: 1. F ...

随机推荐

  1. Spring整合Mybatis错误解决方案

    ERROR:java.lang.AbstractMethodError: org.mybatis.spring.transaction.SpringManagedTransactionFactory. ...

  2. 简述python(threading)多线程

    一.概述 import threading 调用 t1 = threading.Thread(target=function , args=(,)) Thread类的实例方法 # join():在子线 ...

  3. org.apache.catalina.connector.ClientAbortException: java.io.IOException: 您的主机中的软件中止了一个已建立的连接。

    日志文件中有“java.io.IOException: 您的主机中的软件中止了一个已建立的连接.”错误 org.apache.catalina.connector.ClientAbortExcepti ...

  4. spring(三):BeanDefiniton

  5. C++-POJ3321-Apple Tree[数据结构][树状数组]

    树上的单点修改+子树查询 用dfn[u]和num[u]可以把任意子树表示成一段连续区间,此时结合树状数组就好了 #include <set> #include <map> #i ...

  6. [lua]紫猫lua教程-命令宝典-L1-01-09. string字符串函数库

    L1[string]01. ASCII码互转 小知识:字符串处理的几个共同的几点 1.字符串处理函数 字符串索引可以为负数 表示从字符串末尾开始算起 所有字符串处理函数的 字符串索引参数都使用 2.所 ...

  7. Python的深拷贝、浅拷贝

    浅拷贝 定义:浅拷贝只是对另外一个变量的内存地址的拷贝,这两个变量指向同一个内存地址的变量值. 浅拷贝的特点: 公用一个值: 这两个变量的内存地址一样: 对其中一个变量的值改变,另外一个变量的值也会改 ...

  8. Systemd 学习

    转:http://www.ruanyifeng.com/blog/2016/03/systemd-tutorial-commands.html 原文链接:https://www.jianshu.com ...

  9. 使用maven构建项目时,SSM和springboot项目的打包与云服务器部署

    下面讲讲如何打包SSM和springboot项目,并部署到云服务器上. 由于使用的IDE不同,有的使用eclipse,有的使用idea,所以如果在IDE中按照 maven clean 再 maven ...

  10. android .9背景图作为TextView背景时文字无法居中问题

    问题产生原因: .9图黑色边框绘制伸缩区域有问题,重叠的最大区域是TextView文字所能显示的区域 如下图所示,横向和纵向最大重叠部分就是文字可显示部分,这个图作为背景后文字整体偏下,无法上下居中对 ...