训练自己的数据集(以bottle为例):

 

1、准备数据

文件夹结构:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 训练文件夹
└── val_logs 日志文件夹

1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内

2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始) 
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同 
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:

bottle_1 1 1 1
bottle_2 1 1 1

4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下

item {
id: 1
name: 'bottle'
}
 

2、生成数据

# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`

得到 pet_train.record 和 pet_val.record 移动至bottle文件夹

 

3、准备conf文件

复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
对ssd_mobilenet_v1_bottle.config文件进行一下修改:

修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行为 input_path: "bottle/pet_train.record"
修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行为 input_path: "bottle/pet_val.record"
 

4、训练

# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &
 

5、验证

# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &
 

6、可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

浏览器访问 ip:6006,可看到趋势以及具体image的预测结果

 

7、导出模型

# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle

生成 bottle/frozen_inference_graph.pb 文件

 

8、测试图片

运行object_detection_tutorial.ipynb并修改其中的各种路径即可 
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:

import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)

运行infer.py test_images/image1.jpg即可

使用Tensorflow训练自己的数据的更多相关文章

  1. TensorFlow.训练_资料(有视频)

    ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...

  2. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  3. smallcorgi/Faster-RCNN_TF训练自己的数据

    熟悉了github项目提供的训练测试后,可以来训练自己的数据了.本文只介绍改动最少的方法,只训练2个类, 即自己添加的类(如person)和 background,使用的数据格式为pascal_voc ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow.js之根据数据拟合曲线

    这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...

  6. tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了

    tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了 感觉它能找到词与词之间的关系,应该可以用来做推荐系统.自动摘要.相关搜索.联想什么的 tensorflow1.1.0 ...

  7. 人脸检测及识别python实现系列(3)——为模型训练准备人脸数据

    人脸检测及识别python实现系列(3)——为模型训练准备人脸数据 机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为.举一个简单的例子,成年人并没有主动 ...

  8. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  9. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

随机推荐

  1. c# List< int>和List< string>互相转换

    c# List< int>和List< string>互相转换 定义一个list< t> List<int> list = new List<in ...

  2. [日常] Go语言圣经--并发的web爬虫

    两种: crawler.go package main import ( "fmt" "links" //"log" "os&qu ...

  3. 胜利大逃亡(杭电hdu1253)bfs简单题

    胜利大逃亡 Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others) Total Subm ...

  4. 小tip:CSS vw让overflow:auto页面滚动条出现时不跳动——张鑫旭

    小tip:CSS vw让overflow:auto页面滚动条出现时不跳动 这篇文章发布于 2015年01月25日,星期日,23:08,归类于 css相关. 阅读 46274 次, 今日 91 次 by ...

  5. PHP 协程最简洁的讲解

    协程,又称微线程,纤程.英文名Coroutine.协程的概念很早就提出来了,但直到最近几年才在某些语言(如Lua)中得到广泛应用. 子程序,或者称为函数,在所有语言中都是层级调用,比如A调用B,B在执 ...

  6. switch的用法

    switch case 语句有如下规则: switch 语句中的变量类型可以是: byte.short.int 或者 char.从 Java SE 7 开始,switch 支持字符串 String 类 ...

  7. HTTP协议学习随笔

    一 HTTP概述 HTTP简单说其实就是一套语言交流规则!Http使用的是可靠的数据传输协议,因此即使数据来自地球的另一端,也能够确保数据在传输过程中不会被损坏或产生混乱. B/S结构 用户在浏览器, ...

  8. [可能不知道]什么是PeopleSoft的JOLT以及相关进程

    PeopleSoft applecation server依赖于Jolt,Jolt是与Tuxedo配套的产品,可以处理所有web请求.换句话说,Jolt是application server与web ...

  9. 在JavaScript文件中用jQuery方法实现日期时间选择功能

    JavaScript Document $(document).ready(function(e) { 在文本框里面显示当前日期 var date = new Date(); var nian = d ...

  10. Python Python-MySQLdb中的DictCursor使用方法简介

    Python-MySQLdb中的DictCursor使用方法简介 by:授客 QQ:1033553122     DictCursor的这个功能是继承于CursorDictRowsMixIn,这个Mi ...