训练自己的数据集(以bottle为例):

 

1、准备数据

文件夹结构:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 训练文件夹
└── val_logs 日志文件夹

1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内

2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始) 
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同 
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:

bottle_1 1 1 1
bottle_2 1 1 1

4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下

item {
id: 1
name: 'bottle'
}
 

2、生成数据

# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`

得到 pet_train.record 和 pet_val.record 移动至bottle文件夹

 

3、准备conf文件

复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
对ssd_mobilenet_v1_bottle.config文件进行一下修改:

修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行为 input_path: "bottle/pet_train.record"
修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行为 input_path: "bottle/pet_val.record"
 

4、训练

# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &
 

5、验证

# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &
 

6、可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

浏览器访问 ip:6006,可看到趋势以及具体image的预测结果

 

7、导出模型

# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle

生成 bottle/frozen_inference_graph.pb 文件

 

8、测试图片

运行object_detection_tutorial.ipynb并修改其中的各种路径即可 
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:

import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)

运行infer.py test_images/image1.jpg即可

使用Tensorflow训练自己的数据的更多相关文章

  1. TensorFlow.训练_资料(有视频)

    ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...

  2. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  3. smallcorgi/Faster-RCNN_TF训练自己的数据

    熟悉了github项目提供的训练测试后,可以来训练自己的数据了.本文只介绍改动最少的方法,只训练2个类, 即自己添加的类(如person)和 background,使用的数据格式为pascal_voc ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow.js之根据数据拟合曲线

    这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...

  6. tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了

    tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了 感觉它能找到词与词之间的关系,应该可以用来做推荐系统.自动摘要.相关搜索.联想什么的 tensorflow1.1.0 ...

  7. 人脸检测及识别python实现系列(3)——为模型训练准备人脸数据

    人脸检测及识别python实现系列(3)——为模型训练准备人脸数据 机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为.举一个简单的例子,成年人并没有主动 ...

  8. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  9. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

随机推荐

  1. 关于eclipse的项目前有感叹号和errors exist in required project相关问题

    一般来说 项目运行中 各个类的信息中并没有报错 但在运行中会出现errors exist in required project 且有时候运行也会成功.这种情况是由于项目中其他的类存在问题未解决  导 ...

  2. 关于__int64的使用!

    关于__int64的使用! 类型  long long __int64 intmax_t 格式 %lld %I64d %I64d 在Dev C++中,三种类型均需用%I64d格式输出 ,c语言中int ...

  3. EF框架的三种模式

    Database First就是先建数据库或使用已有的数据库.然后在vs中添加ADO.Net实体数据模型,设置连接并且选择需要的数据库和表.它是以数据库设计为基础的,并根据数据库自动生成实体数据模型, ...

  4. POJ3090(SummerTrainingDay04-M 欧拉函数)

    Visible Lattice Points Time Limit: 1000MS   Memory Limit: 65536K Total Submissions: 7450   Accepted: ...

  5. python学习之老男孩python全栈第九期_day011知识点总结

    # 装饰器的形成的过程:最简单的装饰器:有返回值的:有一个参数的:万能参数# 装饰器的作用# 原则:开放封闭原则# 语法糖# 装饰器的固定模式:# def wrapper(f): # 装饰器函数,f是 ...

  6. ThinkPHP中create()方法自动验证表单信息

    自动验证是ThinkPHP模型层提供的一种数据验证方法,可以在使用create创建数据对象的时候自动进行数据验证. 原理: create()方法收集表单($_POST)信息并返回,同时触发表单自动验证 ...

  7. 活字格Web应用平台学习笔记2-基础教程-开始

    今天先学活字格第一个教程,开始. 目标是能够用活字格创建一个简单的Web页面. 哈哈,简单,跟Excel一样,做单元格输入.合并.文字居中.加底色.加图片,然后发布. 然后就真的生成了一个Web页面! ...

  8. PGIS下载离线地图 SQLite+WPF

    项目是超高分辨率屏幕墙,实时在线加载PGIS地图速度会比较慢,造成屏幕大量留白.于是使用地图缓存,事先把这个区块的地图全部down下来,使用Sqlite数据库保存.留存. //Task taskDow ...

  9. atitit.网络文件访问协议.unc smb nfs ftp http的区别

    atitit.网络文件访问协议.unc smb nfs ftp http的区别 1. 网络文件访问协议1 2. NETBios协议  2 3. SMB(Server Message Block)2 3 ...

  10. JMeter初体验

    Meter是开源软件Apache基金会下的一个性能测试工具,用来测试部署在服务器端的应用程序的性能. 1.JMeter下载和安装 JMeter可以在JMeter的官方网站下载,目前能下载的是JMete ...