吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(5)
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import keras import utils
import model as modellib
import visualize
from model import log %matplotlib inline # Root directory of the project
ROOT_DIR = os.getcwd() # Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
utils.download_trained_weights(COCO_MODEL_PATH) # Path to Shapes trained weights
SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_shapes.h5")
# Run one of the code blocks # Shapes toy dataset
# import shapes
# config = shapes.ShapesConfig() # MS COCO Dataset
import coco
config = coco.CocoConfig()
# Device to load the neural network on.
# Useful if you're training a model on the same
# machine, in which case use CPU and leave the
# GPU for training.
DEVICE = "/cpu:0" # /cpu:0 or /gpu:0
def get_ax(rows=1, cols=1, size=16):
"""Return a Matplotlib Axes array to be used in
all visualizations in the notebook. Provide a
central point to control graph sizes. Adjust the size attribute to control how big to render images
"""
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
# Create model in inference mode
with tf.device(DEVICE):
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
config=config) # Set weights file path
if config.NAME == "shapes":
weights_path = SHAPES_MODEL_PATH
elif config.NAME == "coco":
weights_path = COCO_MODEL_PATH
# Or, uncomment to load the last model you trained
# weights_path = model.find_last()[1] # Load weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)
# Show stats of all trainable weights
visualize.display_weight_stats(model)

# Pick layer types to display
LAYER_TYPES = ['Conv2D', 'Dense', 'Conv2DTranspose']
# Get layers
layers = model.get_trainable_layers()
layers = list(filter(lambda l: l.__class__.__name__ in LAYER_TYPES,
layers))
# Display Histograms
fig, ax = plt.subplots(len(layers), 2, figsize=(10, 3*len(layers)),
gridspec_kw={"hspace":1})
for l, layer in enumerate(layers):
weights = layer.get_weights()
for w, weight in enumerate(weights):
tensor = layer.weights[w]
ax[l, w].set_title(tensor.name)
_ = ax[l, w].hist(weight[w].flatten(), 50)

吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(5)的更多相关文章
- 吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(4)
import os import sys import random import math import re import time import numpy as np import tenso ...
- 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(3)
import os import sys import random import math import re import time import numpy as np import cv2 i ...
- 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(2)
import os import sys import itertools import math import logging import json import re import random ...
- 吴裕雄 python 人工智能——基于Mask_RCNN目标检测(1)
import os import sys import random import math import numpy as np import skimage.io import matplotli ...
- 吴裕雄 python 人工智能——基于神经网络算法在智能医疗诊断中的应用探索代码简要展示
#K-NN分类 import os import sys import time import operator import cx_Oracle import numpy as np import ...
- 吴裕雄 PYTHON 人工智能——智能医疗系统后台智能分诊模块及系统健康养生公告简约版代码展示
#coding:utf-8 import sys import cx_Oracle import numpy as np import pandas as pd import tensorflow a ...
- 吴裕雄 python 人工智能——智能医疗系统后台用户复诊模块简约版代码展示
#复诊 import sys import os import time import operator import cx_Oracle import numpy as np import pand ...
- 吴裕雄 python 人工智能——智能医疗系统后台用户注册、登录和初诊简约版代码展示
#用户注册.登录模块 #数据库脚本 CREATE TABLE usertable( userid number(8) primary key not null , username varchar(5 ...
- TF项目实战(基于SSD目标检测)——人脸检测1
SSD实战——人脸检测 Tensorflow 一 .人脸检测的困难: 1. 姿态问题 2.不同种族人, 3.光照 遮挡 带眼睛 4.视角不同 5. 不同尺度 二. 数据集介绍以及转化VOC: 1. F ...
随机推荐
- bootstrap图片上传控件 fileinput
前端 1.要引用的js fileinput.js fileinput.css <link type="text/css" rel="stylesheet& ...
- Safari 导航栏
目录 引子 隐藏 Safari 导航栏 显示 Safari 导航栏 iPhone 系统占比 参考资料 引子 最近在 iPhone 的 Safari 查看 h5 页面时,发现有些平台的页面向下滚动时,顶 ...
- Java-POJ1013-Counterfeit Dollar
在13枚硬币中找出fake的那一个 输入:三次天平称量结果 package poj.ProblemSet; import java.util.Scanner; /* 我怎么觉得是贪心算法呢? 起初对所 ...
- 1.EntityManaget的persist和merge方法的区别
1.persist和merge的区别: Persist:添加 Merge : 分两种情况,当对象存在id,则修改:当对象不存在id则添加. 看个例子: 1 public class Account { ...
- python UI自动化生成BeautifulReport测试报告并保存截图
前面已经写过利用BeautifulReport生成测试报告,那么接下来讲讲如何在测试报告里面保存截图 首先需要在测试用例中定义一个截图的方法: # 截图方法 """ os ...
- HTML学习(11)表格
HTML表格由<table>标签定义,下面是一个2行3列的表格: <table> <tr> <td>11</td> <td>12 ...
- 服务端捡起或丢弃指定物品ID触发详解
传奇服务端捡起或丢弃指定物品ID触发详解: @PickUpItemsX X是物品数据库中对应的IDX@DropItemsX X是物品数据库中对应的IDX@H.PickUpItemsX X是物品数据库中 ...
- 13.56Mhz下直接阻抗匹配调试步骤
直接匹配阻抗,天线与射频芯片在同一块板子,调试步骤与50欧姆阻抗匹配调试天线参数差不多,多了一部分射频芯片端的滤波部分的参数计算.下面介绍调试过程. 1.首先看一下射频芯片发射部分原理图:分析原理图时 ...
- 安装k8s出现 Failed to list *api.Node: Get http://192.168.144.131:8080...: dial tcp 192.168.144.131:8080: getsockopt: no route to
原因是master主机的防火墙没关,导致无法访问主机的8080端口,解决方法暂时关闭主机上的防火墙. * centos6 : service iptables stop * centos7 : sys ...
- pandas 进行excel绘图
python主流绘图工具:matplotlib ,seaborn,pandas ,openpyxl ,xslwriter openpyxl :首先说下这个官网的demo,看的有点懵,没有具体说明多个图 ...