#xilerihua
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys #objectlocation
import six.moves.urllib as urllib
import tarfile
import matplotlib
matplotlib.use('Agg')
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from utils import label_map_util
from utils import visualization_utils as vis_util
import time class multi():
"""初始化所有模型"""
def __init__(self):
# 加载faster_rcnn 计算图
self.faster_graph = tf.Graph()
with self.faster_graph.as_default():
self.od_graph_def2 = tf.GraphDef()
with tf.gfile.GFile(r'E:/Project/TaoBaoLocation_new/research/object_detection/faster_rcnn_inception_v2_coco_2018_01_28/frozen_inference_graph.pb', 'rb') as fid:
self.serialized_graph = fid.read()
self.od_graph_def2.ParseFromString(self.serialized_graph)
tf.import_graph_def(self.od_graph_def2, name='')
self.faster_sess = tf.Session(graph=self.faster_graph) # 加载inception_v3计算图
self.inception_graph = tf.Graph()
with self.inception_graph.as_default():
self.od_graph_def2 = tf.GraphDef()
with tf.gfile.GFile(r'E:/Project/XiLeRiHuaReg/inception_v3_model/output_graph.pb', 'rb') as fid:
self.serialized_graph = fid.read()
self.od_graph_def2.ParseFromString(self.serialized_graph)
tf.import_graph_def(self.od_graph_def2, name='')
self.inception_sess = tf.Session(graph=self.inception_graph) def get_result(self, type, image_path):
if type == '2':
#xilerihua
lines = tf.gfile.GFile('E:/Project/XiLeRiHuaReg/inception_v3_model/output_labels.txt').readlines()
uid_to_human = {}
for uid, line in enumerate(lines):
line = line.strip('\n')
uid_to_human[uid] = line def id_to_string(node_id):
if node_id not in uid_to_human:
return ''
return uid_to_human[node_id] softmax_tensor = self.inception_sess.graph.get_tensor_by_name('final_result:0') image_data = tf.gfile.GFile(image_path, 'rb').read()
predictions = self.inception_sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions) # image_path = os.path.join(sys.argv[2]) top_k = predictions.argsort()[::-1][:1] # 取前k个,此处取最相似的那个 for node_id in top_k: # 只取第一个
human_string = id_to_string(node_id)
score = predictions[node_id] human_kanji = {
'baby wipes': '婴儿湿巾',
'bath towel': '洗澡巾',
'convenient toothpick box': '便捷牙具盒',
'dish rack': '沥水架',
'hooks4': '挂钩粘钩4个装',
'kitchen towel': '厨房方巾',
'towel': '毛巾',
'macaron basin': '马卡龙家用多用盆',
'multi functional dental box': '多功能牙具盒',
'paring knife': '削皮刀',
'pineapple towel set': '菠萝纹毛巾浴巾套装',
'rubbish bag': '垃圾袋',
'sponge': '清洁海绵',
'stainless hook': '不锈钢多用挂钩',
'storage boxes': '三格储物盒',
'towel set': '毛巾浴巾套装',
'usb cable': '数据线',
'liquor': '劲酒'
}
thres = 0.6
if score < thres:
print('不在17个范围之内')
elif human_kanji[human_string] == '劲酒':
print('不在17个范围之内')
else:
print(human_kanji[human_string]) if type == '1': # List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt') NUM_CLASSES = 90 ##################### Loading label map
# print('Loading label map...')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES,
use_display_name=True)
category_index = label_map_util.create_category_index(categories) ##################### Helper code
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8) ##################### Detection
# 测试图片的路径,可以根据自己的实际情况修改
# TEST_IMAGE_PATH = 'test_images/image1.jpg'
TEST_IMAGE_PATH = image_path
# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8) # with tf.Session(graph=self.faster_graph) as self.faster_sess:
# print(TEST_IMAGE_PATH)
image = Image.open(TEST_IMAGE_PATH)
image_np = load_image_into_numpy_array(image)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = self.faster_graph.get_tensor_by_name('image_tensor:0')
boxes = self.faster_graph.get_tensor_by_name('detection_boxes:0')
scores = self.faster_graph.get_tensor_by_name('detection_scores:0')
classes = self.faster_graph.get_tensor_by_name('detection_classes:0')
num_detections = self.faster_graph.get_tensor_by_name('num_detections:0') # Actual detection.
(boxes, scores, classes, num_detections) = self.faster_sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded}) scores = np.squeeze(scores)
scores = scores.reshape((100, 1))
boxes = np.squeeze(boxes)
res = np.hstack((boxes, scores)) # 筛选>thres的box
thres = 0.55
reserve_boxes_0 = []
for b in res:
if b[-1]>thres:
reserve_boxes_0.append(b.tolist()) # print('reserve_boxes_0:',reserve_boxes_0) #转换坐标
reserve_boxes=[]
w = image_np.shape[1] # 1,3乘 1024
h = image_np.shape[0] # 0,2乘 636
# print('h:',h,'w:',w) for box in reserve_boxes_0:
# print([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))],'tran')
# reserve_boxes.append([int(float(box[0]*h)),int(float(box[2]*h)),int(float(box[1]*w)),int(float(box[3]*w))])
reserve_boxes.append([int(float(box[1]*w)),int(float(box[0]*h)),int(float(box[3]*w)),int(float(box[2]*h))]) # print('reserve_boxes:',reserve_boxes) #没有找到一个框的情况
if len(reserve_boxes)==0:#为0的情况,裁剪返回图片坐标
w_subtract = int(image_np.shape[1] / 10)
h_subtract = int(image_np.shape[0] / 10)
print(w_subtract, h_subtract, image_np.shape[1] - w_subtract, image_np.shape[0] - h_subtract)
else:
# 保留最靠近中间的那个框的情况
# print('w:',image_np.shape[1],'h:',image_np.shape[0])
# 1.计算图片的中心点
# y:im.shape[0],x:im.shape[1]
x_center, y_center = image_np.shape[1] / 2, image_np.shape[0] / 2
# print(x_center,y_center) # 2 计算找出来的框到中心点的距离
dis_l = []
for b in reserve_boxes:
b_xcenter, b_ycenter = int((b[0] + b[2]) / 2), int((b[1] + b[3]) / 2)
distance = np.sqrt((x_center - b_xcenter) ** 2 + (y_center - b_ycenter) ** 2)
dis_l.append(distance)
# print('b_xcenter,b_ycenter:',b_xcenter,b_ycenter,distance) # 拿到最靠中心的box的index
center_index = dis_l.index(min(dis_l))
det = reserve_boxes[center_index]
print(det[0],det[1],det[2],det[3]) #可视化1
# cv2.rectangle(image_np, (det[0], det[1]), (det[2], det[3]), thickness=2, color=(0, 0, 255))
# cv2.imshow('res',image_np)
# cv2.waitKey(0)
# cv2.destroyAllWindows() #初始化
multi = multi() for i in range(5):
start_t=time.time()
multi.get_result("1","1.jpg")
end_t=time.time()
print('t1:',end_t-start_t)
multi.get_result("2","1.jpg")
start_t3=time.time()
print('t2:',start_t3-end_t)

  

通过类来实现多session 运行的更多相关文章

  1. C++ //多态 //静态多态:函数重载 和 运算符重载 属于静态多态 ,复用函数名 //动态多态:派生类和虚函数实现运行时多态

    1 //多态 2 //静态多态:函数重载 和 运算符重载 属于静态多态 ,复用函数名 3 //动态多态:派生类和虚函数实现运行时多态 4 5 //静态多态和动态多态的区别 6 //静态多态的函数地址早 ...

  2. 教你在Java的普通类中轻松获取Session以及request中保存的值

    曾经有多少人因为不知如何在业务类中获取自己在Action或页面上保存在Session中值,当然也包括我,但是本人已经学到一种办法可以解决这个问题,来分享下,希望对你有多多少少的帮助! 如何在Java的 ...

  3. [转载]tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定

    tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...

  4. 深度解剖session运行原理

    已经大半年没有更新博客了,一方面有比博客更重要的事情要做,另外一方面也没有时间来整理知识,所以希望在接下来的日子里面能够多多的写博客来与大家交流 什么是session session的官方定义是:Se ...

  5. [转] spring的普通类中如何取session和request对像

    在使用spring时,经常需要在普通类中获取session,request等对像.比如一些AOP拦截器类,在有使用struts2时,因为struts2有一个接口使用org.apache.struts2 ...

  6. jeecg中的一个上下文工具类获取request,session

    通过调用其中的方法可以获取到request和session,调用方式如下: HttpServletRequest request = ContextHolderUtils.getRequest();H ...

  7. spring的普通类中如何取session和request对像

    在使用spring时,经常需要在普通类中获取session,request等对像. 比如一些AOP拦截器类,在有使用struts2时,因为struts2有一个接口使用org.apache.struts ...

  8. tensorflow中使用tf.ConfigProto()配置Session运行参数&&GPU设备指定

    tf.ConfigProto()函数用在创建session的时候,用来对session进行参数配置: config = tf.ConfigProto(allow_soft_placement=True ...

  9. 使用tf.ConfigProto()配置Session运行参数和GPU设备指定

    参考链接:https://blog.csdn.net/dcrmg/article/details/79091941 tf.ConfigProto()函数用在创建session的时候,用来对sessio ...

随机推荐

  1. Excel催化剂开源第11波-动态数组函数技术开源及要点讲述

    在Excel催化剂中,大量的自定义函数使用了动态数组函数效果,虽然不是原生的Excel365版效果(听说Excel2019版取消了支持动态数组函数,还没求证到位,Excel365是可以用,但也仅限于部 ...

  2. 原 docker 安装使用 solr

    1.安装solr 7.5 docker solr 官网:https://hub.docker.com/_/solr/ docker pull solr:7.5.0 2.启动solr服务 docker ...

  3. Linux 安装MySql——apt-get版

    0)apt-get update 1)通过apt-get安装 更新设置到最新系统:    sudo apt-get update    sudo apt-get upgrade sudo apt-ge ...

  4. linux初学者-磁盘加密篇

    linux初学者-磁盘加密篇 因为保密需要,一般系统中会在文件和磁盘中进行加密,但是文件的加密比较容易破解,不安全.所以在特殊需要下,会对磁盘进行加密,磁盘加密后在磁盘损坏的同时,其中的数据也会损坏, ...

  5. zabbix3.4汉化

    1.管理员用户登入zabbix页面,更改语言为Chinese(zh_CN),点击Update 2.解决zabbix页面中文乱码 2.1在windows的C:\Windows\Fonts找到字体文件si ...

  6. 我是这样一步步理解--主题模型(Topic Model)、LDA

    1. LDA模型是什么 LDA可以分为以下5个步骤: 一个函数:gamma函数. 四个分布:二项分布.多项分布.beta分布.Dirichlet分布. 一个概念和一个理念:共轭先验和贝叶斯框架. 两个 ...

  7. 动态规划_Apple Catching_POJ-2385

    It and ) in his field, each full of apples. Bessie cannot reach the apples when they are on the tree ...

  8. JDBC连接mysql数据库操作详解

    1.什么是JDBC JDBC(Java Data Base Connectivity,java数据库连接)是一种用于执行SQL语句的Java API,可以为多种关系数据库提供统一访问,它由一组用Jav ...

  9. Django使用本机IP无法访问,使用127.0.0.1能正常访问

    使用Django搭建web站点后,使用127.0.0.1能访问,但是用自己本机IP却无法访问. 我们先到Django项目中找到setting文件 找到——> ALLOWED_HOSTS = [] ...

  10. PyQt4 在Windows下安装

    快来加入群[python爬虫交流群](群号570070796),发现精彩内容.     首先在网上下载sip文件下载完之后解压, 在Windows的开始菜单栏中进入sip的解压目录下:   在目录下面 ...