Tensorflow之实现物体检测
目录
- 项目背景
- TensorFlow介绍
- 环境搭建
- 模型选用
- Api使用说明
- 运行路由
- 小结
项目背景
产品看到竞品可以标记物体的功能,秉承一贯的他有我也要有,他没有我更要有的作风,丢过来一网站,说这个功能很简单,一定可以实现

这时候万能的谷歌发挥了作用,在茫茫的数据大海中发现了Tensorflow机器学习框架,也就是目前非常火爆的的深度学习(人工智能),既然方案已有,就差一个程序员了
Tensorflow介绍
百科介绍:TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,可被用于语音识别或图像识别等多项机器学习和深度学习领域。

翻译成大白话:是一个深度学习和神经网络的框架,底层C++,通过Python进行控制,当然,也是支持Go、Java等语言
环境搭建
- Linux/Unix(笔者使用Mac)
- Python3.6
- protoc 3.5.1
- tensorflow 1.7.0
1、克隆文件
文件目录格式如下
└── tensorflow
├── Dockerfile
├── README.md
├── data
│ ├── models
│ ├── pbtxt
│ └── tf_models
├── object_detection_api.py
├── server.py
├── sh
│ ├── download_data.sh
│ └── ods.sh
├── static
├── templates
└── upload
- data/models 存放
- data/pbtxt 物体标识名称
- data/tf_models 存放tensorflow/models数据
2、安装依赖库
pip3 install -r requirements.txt
3、下载模型
sh sh/download_data.sh
4、添加环境变量PYTHONPATH
echo 'export PYTHONPATH=$PYTHONPATH:
pwd/data/tf_models/models/research'>> ~/.bashrc && source ~/.bashrc
5、启动服务
python3 server.py
没有报错,说明你已成功搭建环境,使用过程是不是非常简单,下面介绍代码调用逻辑过程
模型选用
我从谷歌提供几种模型选出来对比

- Speed 是识别物体速度,值越小,识别越快
- mAP(平均准确率)是精度和检测边界盒的乘积,值越高神经网络的识别精确度越高,对应Speed越大
为了测试方便,笔者选用轻量级(ssd_mobilenet)作为本次识别物体模型
引入Python库
import numpy as np
import os
import tensorflow as tf
import json
import time
from PIL import Image
# 兼容Python2.7版本
try:
import urllib.request as ulib
except Exception as e:
import urllib as ulib
import re
from object_detection.utils import label_map_util
载入模型
MODEL_NAME = 'data/models/ssd_mobilenet_v2_coco_2018_03_29'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('data/pbtxt','mscoco_label_map.pbtxt') # CWH: Add object_detection path
# data/pbtxt下mscoco_label_map.pbtxt最大item.id
NUM_CLASSES = 90
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
# 加载模型
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
载入标签映射,内置函数返回整数会映射到pbtxt字符标签
mscoco_label_map.pbtxt格式如下
item {
name: "/m/01g317"
id: 1
display_name: "person"
}
item {
name: "/m/0199g"
id: 2
display_name: "bicycle"
}
# 加载标签
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
with detection_graph.as_default():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(graph=detection_graph,config=config) as sess:
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# 物体坐标
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# 检测到物体的准确度
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
def get_objects(file_name, threshold=0.5):
image = Image.open(file_name)
# 判断文件是否是jpeg格式
if not image.format=='JPEG':
result['status'] = 0
result['msg'] = file_name+ ' is ' + image.format + ' ods system allow jpeg or jpg'
return result
image_np = load_image_into_numpy_array(image)
# 扩展维度
image_np_expanded = np.expand_dims(image_np, axis=0)
output = []
# 获取运算结果
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
# 去掉纬度为1的数组
classes = np.squeeze(classes).astype(np.int32)
scores = np.squeeze(scores)
boxes = np.squeeze(boxes)
for c in range(0, len(classes)):
if scores[c] >= threshold:
item = Object()
item.class_name = category_index[classes[c]]['name'] # 物体名称
item.score = float(scores[c]) # 准确率
# 物体坐标轴百分比
item.y1 = float(boxes[c][0])
item.x1 = float(boxes[c][1])
item.y2 = float(boxes[c][2])
item.x2 = float(boxes[c][3])
output.append(item)
# 返回JSON格式
outputJson = json.dumps([ob.__dict__ for ob in output])
return outputJson
运行路由
server.py下的逻辑
def image():
startTime = time.time()
if request.method=='POST':
image_file = request.files['file']
base_path = os.path.abspath(os.path.dirname(__file__))
upload_path = os.path.join(base_path,'static/upload/')
# 保存上传图片文件
file_name = upload_path + image_file.filename
image_file.save(file_name)
# 准确率过滤值
threshold = request.form.get('threshold',0.5)
# 调用Api服务
objects = object_detection_api.get_objects(file_name, threshold)
# 模板显示
return render_template('index.html',json_data = objects,img=image_file.filename)
curl http://localhost:5000 | python -m json.tool
[
{
"y2": 0.9886252284049988,
"class_name": "bed",
"x2": 0.4297400414943695,
"score": 0.9562674164772034,
"y1": 0.5202791094779968,
"x1": 0
},
{
"y2": 0.9805927872657776,
"class_name": "couch",
"x2": 0.4395904541015625,
"score": 0.6422878503799438,
"y1": 0.5051193833351135,
"x1": 0.00021047890186309814
}
]
- class_name表示物体标签名称
- score 可信度值
- x1,y1表示对象所在最左上点位置
- x2,y2表示对象最右下点位置
在浏览器访问网址体验
小结
- Tensorflow使用GPU效率提升几个数量级
- 可以尝试不同的模型比较速度和准确度
- 本案例也是支持python2,为了跟上时代步伐,建议使用python3
- 案例有个摄像头演示,需要https支持,且使用安卓系统
大家肯定很好奇,怎么训练自己需要检测的物体,可以期待下一篇文章
Tensorflow之实现物体检测的更多相关文章
- Tensorflow物体检测(Object Detection)API的使用
Tensorflow在更新1.2版本之后多了很多新功能,其中放出了很多用tf框架写的深度网络结构(看这里),大大降低了吾等调包侠的开发难度,无论是fine-tuning还是该网络结构都方便了不少.这里 ...
- Tensorflow 之物体检测
1)安装Protobuf TensorFlow内部使用Protocol Buffers,物体检测需要特别安装一下. # yum info protobuf protobuf-compiler 2.5. ...
- 物体检测之FPN及Mask R-CNN
对比目前科研届普遍喜欢把问题搞复杂,通过复杂的算法尽量把审稿人搞蒙从而提高论文的接受率的思想,无论是著名的残差网络还是这篇Mask R-CNN,大神的论文尽量遵循著名的奥卡姆剃刀原理:即在所有能解决问 ...
- 物体检测丨Faster R-CNN详解
这篇文章把Faster R-CNN的原理和实现阐述得非常清楚,于是我在读的时候顺便把他翻译成了中文,如果有错误的地方请大家指出. 原文:http://www.telesens.co/2018/03/1 ...
- OpenCV学习 物体检测 人脸识别 填充颜色
介绍 OpenCV是开源计算机视觉和机器学习库.包含成千上万优化过的算法.项目地址:http://opencv.org/about.html.官方文档:http://docs.opencv.org/m ...
- opencv,关于物体检测
关于物体检测 环境:opencv 2.4.11+vs2013 参考: http://www.cnblogs.com/tornadomeet/archive/2012/06/02/2531705.htm ...
- 『计算机视觉』物体检测之RefineDet系列
Two Stage 的精度优势 二阶段的分类:二步法的第一步在分类时,正负样本是极不平衡的,导致分类器训练比较困难,这也是一步法效果不如二步法的原因之一,也是focal loss的motivation ...
- 后RCNN时代的物体检测及实例分割进展
https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650736740&idx=3&sn=cdce446703e69b ...
- 利用opencv进行移动物体检测
进行运动物体检测就是将动态的前景从静态的背景中分离出来.将当前画面与假设是静态背景进行比较发现有明显的变化的区域,就可以认为该区域出现移动的物体.在实际情况中由于光照阴影等因素干扰比较大,通过像素直接 ...
随机推荐
- Codeforces Round #455 (Div. 2) 909E. Coprocessor
题 OvO http://codeforces.com/contest/909/problem/E CF455 div2 E CF 909E 解 类似于拓扑排序地进行贪心, 对于 Ei=0 并且入度为 ...
- jquery实现ajax提交表单数据或json数据
- 建造者模式(Builder)---创建型
1 定义域特征 定义:将一个复杂的对象构建与其表示分离,使得同样的构建过程可以创建不同的表示.特征:用户只需要指定需要建造的类型即可,对于中间的细节不考虑. 本质:分离整体构建算法和部件构造.构建一个 ...
- 自己总结:汇编CALL和RET指令
ret指令,相当于 pop IP:修改IP的内容,从而实现近转移 retf指令,相当于 pop IP pop CS:修改CS和IP的内容,从而实现远转移 -------------- CPU执行cal ...
- Java面向对象5(V~Z)
计算各种图形的周长(接口与多态)(SDUT 3338) import java.util.Scanner; public class Main { public static void main(St ...
- Python面试题:使用栈处理括号匹配问题
括号匹配是栈应用的一个经典问题, 题目 判断一个文本中的括号是否闭合, 如: text = "({[({{abc}})][{1}]})2([]){({[]})}[]", 判断所有括 ...
- javascript数组的增删改和查询
数组的增删改操作 对数组的增删改操作进行总结,下面(一,二,三)是对数组的增加,修改,删除操作都会改变原来的数组. (一)增加 向末尾增加 push() 返回新增后的数组长度 arr[arr.leng ...
- LGU67496 小$s$的玻璃弹珠
题意 在一幢\(m\)层建筑你将获得\(n\)个一样的鸡蛋,从高于\(x\)的楼层落下的鸡蛋都会碎.如果一个蛋碎了,你就不能再把它掉下去. 你的目标是确切地知道\(x\)的值.问至少要扔几次才能确定. ...
- 微信小程序之状态管理A
其实这个标题 不是很对 主要是最近小程序项目中 有这么一个状态 所有商品都共用一个商品详情页面 大概就是这样子 为了公司 保险起见,一些展示的内容已经处理 但是无伤大雅 就是这么两个按钮 左侧粉色 ...
- 因OpenCV版本不一致所引发的报错
目录 一 因OpenCV版本不一致所引发的报错 注:原创不易,转载请务必注明原作者和出处,感谢支持! 一 因OpenCV版本不一致所引发的报错 今天遇到了一个很有意思的报错. 事情是这样的, 在编译& ...