训练自己的数据集(以bottle为例):

 

1、准备数据

文件夹结构:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 训练文件夹
└── val_logs 日志文件夹

1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md 
ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内

2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始) 
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同 
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:

bottle_1 1 1 1
bottle_2 1 1 1

4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下

item {
id: 1
name: 'bottle'
}
 

2、生成数据

# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`

得到 pet_train.record 和 pet_val.record 移动至bottle文件夹

 

3、准备conf文件

复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config 
对ssd_mobilenet_v1_bottle.config文件进行一下修改:

修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行为 input_path: "bottle/pet_train.record"
修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行为 input_path: "bottle/pet_val.record"
 

4、训练

# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &
 

5、验证

# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &
 

6、可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

浏览器访问 ip:6006,可看到趋势以及具体image的预测结果

 

7、导出模型

# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle

生成 bottle/frozen_inference_graph.pb 文件

 

8、测试图片

运行object_detection_tutorial.ipynb并修改其中的各种路径即可 
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:

import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)

运行infer.py test_images/image1.jpg即可

使用Tensorflow训练自己的数据的更多相关文章

  1. TensorFlow.训练_资料(有视频)

    ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...

  2. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  3. smallcorgi/Faster-RCNN_TF训练自己的数据

    熟悉了github项目提供的训练测试后,可以来训练自己的数据了.本文只介绍改动最少的方法,只训练2个类, 即自己添加的类(如person)和 background,使用的数据格式为pascal_voc ...

  4. 通过TensorFlow训练神经网络模型

    神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...

  5. TensorFlow.js之根据数据拟合曲线

    这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...

  6. tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了

    tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了 感觉它能找到词与词之间的关系,应该可以用来做推荐系统.自动摘要.相关搜索.联想什么的 tensorflow1.1.0 ...

  7. 人脸检测及识别python实现系列(3)——为模型训练准备人脸数据

    人脸检测及识别python实现系列(3)——为模型训练准备人脸数据 机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为.举一个简单的例子,成年人并没有主动 ...

  8. 2、TensorFlow训练MNIST

    装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...

  9. tensorflow训练验证码识别模型

    tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...

随机推荐

  1. vue项目webpack打包后图片路径错误

    首先项目是vue-cli搭建的,项目结构如下: 然后发现在css里写的图片引用地址在开发时正常显示,但在打包扔上服务器之后报错 报的是404,路径前面多了/static/css,不知道为啥. 在自己慢 ...

  2. [日常] Go语言圣经-匿名函数习题2

    练习5.13: 修改crawl,使其能保存发现的页面,必要时,可以创建目录来保存这些页面.只保存来自原始域名下的页面.假设初始页面在golang.org下,就不 要保存vimeo.com下的页面. p ...

  3. 使用spring的JavaMail发送邮件

    以前我们使用JavaMail发送邮件,步骤挺多的.现在的项目跟Spring整合的比较多.所以这里主要谈谈SpringMail发送. 导入jar包. 配置applicationContext-email ...

  4. JavaScript的三种对话框是通过调用window对象的三个方法alert(),confirm()和prompt()

    第一种:alert()方法 alert()方法是这三种对话框中最容易使用的一种,她可以用来简单而明了地将alert()括号内的文本信息显示在对话框中,我们将它称为警示对话框,要显示的信息放置在括号内, ...

  5. hightcharts 如何修改legend图例的样式

    正常情况下hightcharts 的legend图形是根据他本身默认的样式来显示,如下图 这几个图形的形状一般也是改不了的,只能根据图表的类型显示默认的.但是我们可以通过修改默认的样式来展示一些可以实 ...

  6. scikit-learn画ROC图

    1.使用sklearn库和matplotlib.pyplot库 import sklearn import matplotlib.pyplot as plt 2.准备绘图函数的传入参数1.预测的概率值 ...

  7. js-ES6学习笔记-Symbol

    1.ES6引入了一种新的原始数据类型Symbol,表示独一无二的值.它是JavaScript语言的第七种数据类型,前六种是:Undefined.Null.布尔值(Boolean).字符串(String ...

  8. 【代码笔记】iOS-UIActionSheet字体的修改

    一,效果图. 二,代码. RootViewController.h #import <UIKit/UIKit.h> @interface RootViewController : UIVi ...

  9. JS 解决 IOS 中拍照图片预览旋转 90度 BUG

    上篇博文[ Js利用Canvas实现图片压缩 ]中做了图片压缩上传,但是在IOS真机测试的时候,发现图片预览的时候自动逆时针旋转了90度.对于这个bug,我完全不知道问题出在哪里,接下来就是面向百度编 ...

  10. Linux服务器ftp+httpd部署

    一.ftp安装 1.安装vsftpd 命令:yum -y install vsftpd 2.修改ftp配置文件 命令:vim /etc/vsftpd/vsftpd.conf 3.按i进入insert模 ...