tensorflow高效地推导pb模型,完整代码
from matplotlib import pyplot as plt
import numpy as np
import os import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
import glob
from collections import defaultdict
from io import StringIO from PIL import Image
import DrawBox
import cv2
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util SaveFilename='D:/FasterR-CNNImageTest/11/'
if not os.path.exists(SaveFilename):
os.mkdir(SaveFilename)
with tf.device('/cpu:0'):
cap = cv2.VideoCapture(0)
PATH_TO_CKPT = 'pb/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('dataset', 'pascal_label_map.pbtxt')
NUM_CLASSES = 2
# Load a (frozen) Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories) def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8) # # Detection
cnt = 0
PATH_TO_TEST_IMAGES_DIR = 'E:/PythonOpenCVCode/BalanceGeneratePicture/TestSetSaveImage/Test/JPEGImages' with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# while True: # for image_path in TEST_IMAGE_PATHS: #changed 20170825
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image, together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
for pidImage in glob.glob(PATH_TO_TEST_IMAGES_DIR + "/*.jpg"):
TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, pidImage)]
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# Visualization of the results of a detection.
# print(boxes)
DrawBox.visualize_boxes_and_labels_on_image_array(
image_np,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
max_boxes_to_draw=400,
use_normalized_coordinates=True,
groundtruth_box_visualization_color='red',
line_thickness=8)
#plt.figure(figsize=IMAGE_SIZE)
cv2.imwrite(SaveFilename + os.path.basename(image_path), image_np)
# plt.imshow(image_np)
cnt = cnt + 1
print(image_path)
tensorflow高效地推导pb模型,完整代码的更多相关文章
- 使用redis的zset实现高效分页查询(附完整代码)
一.需求 移动端系统里有用户和文章,文章可设置权限对部分用户开放.现要实现的功能是,用户浏览自己能看的最新文章,并可以上滑分页查看. 二.数据库表设计 涉及到的数据库表有:用户表TbUser.文章表T ...
- tensorflow学习笔记——模型持久化的原理,将CKPT转为pb文件,使用pb模型预测
由题目就可以看出,本节内容分为三部分,第一部分就是如何将训练好的模型持久化,并学习模型持久化的原理,第二部分就是如何将CKPT转化为pb文件,第三部分就是如何使用pb模型进行预测. 一,模型持久化 为 ...
- 【5】TensorFlow光速入门-图片分类完整代码
本文地址:https://www.cnblogs.com/tujia/p/13862364.html 系列文章: [0]TensorFlow光速入门-序 [1]TensorFlow光速入门-tenso ...
- tensorflow C++接口调用图像分类pb模型代码
#include <fstream> #include <utility> #include <Eigen/Core> #include <Eigen/Den ...
- tensorflow C++接口调用目标检测pb模型代码
#include <iostream> #include "tensorflow/cc/ops/const_op.h" #include "tensorflo ...
- 导出pb模型之后测试的python代码
链接:https://blog.csdn.net/thriving_fcl/article/details/75213361 saved_model模块主要用于TensorFlow Serving.T ...
- 查看tensorflow pb模型文件的节点信息
查看tensorflow pb模型文件的节点信息: import tensorflow as tf with tf.Session() as sess: with open('./quantized_ ...
- tensorflow c++ API加载.pb模型文件并预测图片
tensorflow python创建模型,训练模型,得到.pb模型文件后,用c++ api进行预测 #include <iostream> #include <map> # ...
- 将keras的h5模型转换为tensorflow的pb模型
h5_to_pb.py from keras.models import load_model import tensorflow as tf import os import os.path as ...
随机推荐
- Python对象和类
Python 里的所有数据都是以对象形式存在的,对象是类的实例. 定义类(class) 使用class来定义一个类. 比如,定义一个cat类,如下: class Cat(): def __init__ ...
- (转)常用的三种修改mysql最大连接数的方法
MYSQL数据库安装完成后,默认最大连接数是100,一般流量稍微大一点的论坛或网站这个连接数是远远不够的,增加默认MYSQL连接数的方法有两个 方法一:进入MYSQL安装目录 打开MYSQL配置文件 ...
- IPMI 远程配置
#重启ipmi服务 #重启ipmi服务 #将 channel 1 设置为静态 IP #设置 IP #设置 channel 1 掩码 #设置 channel 1 网关 #查看用户名及 ID #设置ID号 ...
- 异常处理记录: Servlet class X is not a javax.servlet.Servlet
使用Maven的tomcat插件启动Web项目, 访问资源时, 发生如下异常: https://stackoverflow.com/questions/1036702/my-class-is-not- ...
- leetcode-337-打家劫舍三*
题目描述: 方法一:递归 # Definition for a binary tree node. # class TreeNode(object): # def __init__(self, x): ...
- webpack 清理旧打包资源插件
当我们修改带hash的文件并进行打包时,每打包一次就会生成一个新的文件,而旧的文件并 没有删除.为了解决这种情况,我们可以使用clean-webpack-plugin 在打包之前将文件先清除,之后再打 ...
- vue2.0 使用webpack搭建项目遇到的最搞笑的坑
报错如下: 源码: 然后找了半天没搞明白... 无意中翻看了一下ES6语法规则.. 然后我发现:源代码最后一行要空一行,我心想这什么狗屁规定?MMP
- Bean容器的初始化
Bean容器的初始化
- Jdk8 Hashmap ConcurrentHashMap
JDK1.8 Hashmap JDK1.8 ConcurrentHashMap 不采用segment而采用 synchronized (f) f = table[i]; 减小锁的力度 设计了MOVE ...
- Python学习day23-面向对象的编程
figure:last-child { margin-bottom: 0.5rem; } #write ol, #write ul { position: relative; } img { max- ...