import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import keras import utils
import model as modellib
import visualize
from model import log %matplotlib inline # Root directory of the project
ROOT_DIR = os.getcwd() # Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH) # Path to Shapes trained weights
SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_shapes.h5")
# Run one of the code blocks

# Shapes toy dataset
# import shapes
# config = shapes.ShapesConfig() # MS COCO Dataset
import coco
config = coco.CocoConfig()
# Device to load the neural network on.
# Useful if you're training a model on the same
# machine, in which case use CPU and leave the
# GPU for training.
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
def get_ax(rows=1, cols=1, size=16):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes. Adjust the size attribute to control how big to render images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
# Create model in inference mode
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
config=config) # Set weights file path
if config.NAME == "shapes":
weights_path = SHAPES_MODEL_PATH
elif config.NAME == "coco":
weights_path = COCO_MODEL_PATH
# Or, uncomment to load the last model you trained
# weights_path = model.find_last()[1] # Load weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
# Show stats of all trainable weights
visualize.display_weight_stats(model)

# Pick layer types to display
LAYER_TYPES = ['Conv2D', 'Dense', 'Conv2DTranspose']
# Get layers
layers = model.get_trainable_layers()
layers = list(filter(lambda l: l.__class__.__name__ in LAYER_TYPES,
layers))
# Display Histograms
fig, ax = plt.subplots(len(layers), 2, figsize=(10, 3*len(layers)),
gridspec_kw={"hspace":1})
for l, layer in enumerate(layers):
weights = layer.get_weights()
for w, weight in enumerate(weights):
tensor = layer.weights[w]
ax[l, w].set_title(tensor.name)
_ = ax[l, w].hist(weight[w].flatten(), 50)

吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(5)的更多相关文章

  1. 吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(4)

    import os import sys import random import math import re import time import numpy as np import tenso ...

  2. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(3)

    import os import sys import random import math import re import time import numpy as np import cv2 i ...

  3. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(2)

    import os import sys import itertools import math import logging import json import re import random ...

  4. 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(1)

    import os import sys import random import math import numpy as np import skimage.io import matplotli ...

  5. 吴裕雄 python 人工智能——基于神经网络算法在智能医疗诊断中的应用探索代码简要展示

    #K-NN分类 import os import sys import time import operator import cx_Oracle import numpy as np import ...

  6. 吴裕雄 PYTHON 人工智能——智能医疗系统后台智能分诊模块及系统健康养生公告简约版代码展示

    #coding:utf-8 import sys import cx_Oracle import numpy as np import pandas as pd import tensorflow a ...

  7. 吴裕雄 python 人工智能——智能医疗系统后台用户复诊模块简约版代码展示

    #复诊 import sys import os import time import operator import cx_Oracle import numpy as np import pand ...

  8. 吴裕雄 python 人工智能——智能医疗系统后台用户注册、登录和初诊简约版代码展示

    #用户注册.登录模块 #数据库脚本 CREATE TABLE usertable( userid number(8) primary key not null , username varchar(5 ...

  9. TF项目实战(基于SSD目标检测)——人脸检测1

    SSD实战——人脸检测 Tensorflow 一 .人脸检测的困难: 1. 姿态问题 2.不同种族人, 3.光照 遮挡 带眼睛 4.视角不同 5. 不同尺度 二. 数据集介绍以及转化VOC: 1. F ...

随机推荐

  1. CQYZOJ P1392 拔河问题

    题目\(1\) Description 一个学校举行拔河比赛,所有的人被分成了两组,每个人必须(且只能够)在其中的一组,且两个组内的所有人体重加起来尽可能地接近. Input 第\(1\)行是一个\( ...

  2. 原来window 也可以使用pthreads

    https://blog.csdn.net/clever101/article/details/101029850

  3. java多线程--死锁

    1. Java中导致死锁的原因 Java中死锁最简单的情况是,一个线程T1持有锁L1并且申请获得锁L2,而另一个线程T2持有锁L2并且申请获得锁L1,因为默认的锁申请操作都是阻塞的,所以线程T1和T2 ...

  4. <context:component-scan>标签

    在spring-mvc的配置文件Springmvc-servlet.xml中,要扫描Controller注解的类,用<context:include-filter>标签 <conte ...

  5. 查看gcc编译器版本

    我们在windows下DS5中编译时使用GCC交叉编译器,但是在ubuntu时也需要使用GCC编译器,这时最好时保持版本一致,所以就需要查看windows下版本,如下图,在按装的文件夹中找到对应得文件 ...

  6. Makefile文件(DE1-soc软件实验”hello_word")

    DE1-soc软件实验”hello_word"中,hello_word此程序很好理解,那Makefile文件又如何理解呢? 所要完成的Makefile 文件描述了整个工程的编译.连接等规则. ...

  7. python的os库

    os库(operating system,提供操作系统函数) 1. __file__是什么? ans:当前文件的名字. 例如r.py内容如下 import os if __name__ == &quo ...

  8. 生产环境设置mysql主从复制

    Slave服务器的版本要等于或者高于master服务器 现在的实例是在mysql 5.7上的主从配置 a) master服务器的my.cnf配置,server_id 推荐用IP的后两位数字 [mysq ...

  9. Javaweb项目的命名规范

    项目名称:一般是英文 包名:公司域名的倒写,例如com.baidu 数据访问层:dao,persist,mapper 实体:entity,model,bean,javabean,pojo 业务逻辑:s ...

  10. 软件工程2020第一次作业(by cybersa)

    1 作业描述 作业属于哪个课程 2020春福大软工实践W班 这个作业要求在哪里 寒假作业(1/2) 这个作业的目标 建立博客.掌握markdown语法,学习写博客,回顾,总结,展望自己的学习历程 作业 ...