使用Tensorflow训练自己的数据
训练自己的数据集(以bottle为例):
1、准备数据
文件夹结构:
models
├── images
├── annotations
│ ├── xmls
│ └── trainval.txt
└── bottle
├── train_logs 训练文件夹
└── val_logs 日志文件夹
1)、下载官方预训练模型: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
以ssd_mobilenet_v1_coco为例,将压缩包内model.ckpt*的三个文件复制到bottle内
2)、准备jpg图片数据,放入images文件夹(图片文件命名要求“名字+下划线+编号.jpg”,必须使用下划线,编号从1开始)
使用https://github.com/tzutalin/labelImg工具对图片进行标注,生成xml文件放置xmls文件夹,并保持xml和jpg命名相同
3)、新建 bottle/trainval.txt 文件,内容为(图片名 1 1 1),每行一个文件,如:
bottle_1 1 1 1
bottle_2 1 1 1
4)、新建object_detection/data/bottle_label_map.pbtxt,内容如下
item {
id: 1
name: 'bottle'
}
2、生成数据
# From tensorflow/models
python object_detection/create_pet_tf_record.py \
--label_map_path=object_detection/data/bottle_label_map.pbtxt \
--data_dir=`pwd` \
--output_dir=`pwd`
得到 pet_train.record 和 pet_val.record 移动至bottle文件夹
3、准备conf文件
复制object_detection/samples/configs/ssd_mobilenet_v1_pets.config到 /bottle/ssd_mobilenet_v1_bottle.config
对ssd_mobilenet_v1_bottle.config文件进行一下修改:
修改第9行为 num_classes: 1,此数值代表bottle_label_map.pbtxt文件配置item的数量
修改第158行为 fine_tune_checkpoint: "bottle/model.ckpt"
修改第177行为 input_path: "bottle/pet_train.record"
修改第179行和193行为 label_map_path: "object_detection/data/bottle_label_map.pbtxt"
修改第191行为 input_path: "bottle/pet_val.record"
4、训练
# From tensorflow/models
python object_detection/train.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--train_dir=bottle/train_logs \
2>&1 | tee bottle/train_logs.txt &
5、验证
# From tensorflow/models
python object_detection/eval.py \
--logtostderr \
--pipeline_config_path=bottle/ssd_mobilenet_v1_bottle.config \
--checkpoint_dir=bottle/train_logs \
--eval_dir=bottle/val_logs &
6、可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard --logdir train_logs/
浏览器访问 ip:6006,可看到趋势以及具体image的预测结果
7、导出模型
# From tensorflow/models
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path bottle/ssd_mobilenet_v1_bottle.config \
--trained_checkpoint_prefix bottle/train_logs/model.ckpt-8 \
--output_directory bottle
生成 bottle/frozen_inference_graph.pb 文件
8、测试图片
运行object_detection_tutorial.ipynb并修改其中的各种路径即可
或自写编译inference脚本,如tensorflow/models/object_detection/infer.py:
import sys
sys.path.append('..')
import os
import time
import tensorflow as tf
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TEST_IMAGE = sys.argv[1]
PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
NUM_CLASSES = 21
IMAGE_SIZE = (18, 12)
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(
label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with detection_graph.as_default():
with tf.Session(graph=detection_graph, config=config) as sess:
start_time = time.time()
print(time.ctime())
image = Image.open(PATH_TEST_IMAGE)
image_np = np.array(image).astype(np.uint8)
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
scores = detection_graph.get_tensor_by_name('detection_scores:0')
classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
vis_util.visualize_boxes_and_labels_on_image_array(
image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
category_index, use_normalized_coordinates=True, line_thickness=8)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)
运行infer.py test_images/image1.jpg即可
使用Tensorflow训练自己的数据的更多相关文章
- TensorFlow.训练_资料(有视频)
ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...
- 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练
将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...
- smallcorgi/Faster-RCNN_TF训练自己的数据
熟悉了github项目提供的训练测试后,可以来训练自己的数据了.本文只介绍改动最少的方法,只训练2个类, 即自己添加的类(如person)和 background,使用的数据格式为pascal_voc ...
- 通过TensorFlow训练神经网络模型
神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...
- TensorFlow.js之根据数据拟合曲线
这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线.即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数.简单的说,就是给一些在二维 ...
- tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了
tensorflow训练了10万次,运行完毕,对这个word2vec终于有点感觉了 感觉它能找到词与词之间的关系,应该可以用来做推荐系统.自动摘要.相关搜索.联想什么的 tensorflow1.1.0 ...
- 人脸检测及识别python实现系列(3)——为模型训练准备人脸数据
人脸检测及识别python实现系列(3)——为模型训练准备人脸数据 机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为.举一个简单的例子,成年人并没有主动 ...
- 2、TensorFlow训练MNIST
装载自:http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html TensorFlow训练MNIST 这个教程的目标读者是对机器学习和T ...
- tensorflow训练验证码识别模型
tensorflow训练验证码识别模型的样本可以使用captcha生成,captcha在linux中的安装也很简单: pip install captcha 生成验证码: # -*- coding: ...
随机推荐
- VS2017 IIS 部署.net core web项目
1.点击IIS,查看模块 查看是否安装了 AspNetCoreModule 模块,如果没有安装可下载:https://dotnet.microsoft.com/download 下载安装后,即可部署项 ...
- [日常] Go语言圣经--Map习题
练习 4.8: 修改charcount程序,使用unicode.IsLetter等相关的函数,统计字母.数字等Unicode中不同的字符类别. 练习 4.9: 编写一个程序wordfreq程序,报告输 ...
- SpringBoot整合Druid数据连接池
SpringBoot整合Druid数据连接池 Druid是什么? Druid是Alibaba开源的的数据库连接池.Druid能够提供强大的监控和扩展功能. 在哪里下载druid maven中央仓库: ...
- python的变量以及常量介绍
变量概念: 把程序运行过程中产生的中间值保存在内存. 方便后面使用. 命名规范: 1. 数字, 字母, 下划线组成 2. 不能数字开头, 更不能是纯数字 3. 不能用关键字 4. 不要用中文 5. 要 ...
- BZOJ2388: 旅行规划(分块 凸包)
题意 题目链接 Sol 直接挂队爷的题解了 分块题好难调啊qwq #include<bits/stdc++.h> #define LL long long using namespace ...
- 我最喜欢用的css3之2D转换之translate用法
CSS3 2D 转换 div { transform: rotate(30deg); -ms-transform: rotate(30deg); /* IE 9 */ -webkit-transfor ...
- SSM 实训笔记 -11- 使用 Spring MVC + JDBC Template 实现筛选、检索功能(maven)
SSM 实训笔记 -11- 使用 Spring MVC + JDBC Template 实现筛选.检索功能(maven) 本篇是新建的一个数据库,新建的一个完整项目. 本篇内容: (1)使用 Spri ...
- Java语言的特点以及Java与C/C++的异同
Java语言的特点 1. Java为纯面向对象的语言,能够直接反应现实生活中的对象,容易理解,编程更容易. 2.跨平台,java是解释性语言,编译器会把java代码变成中间代码,然后在JVM上解释执行 ...
- windows 远程连接
* 方法1:windows自带的远程工具 缺点:如果操作系统是家庭版,会一致连接不上:尽管想办法把这个功能打开: 步骤: * 打开允许远程连接: 点进去自己设置就行,没有什么好说的 设置完之后,需要允 ...
- CSS揭秘(一)引言
借了一本CSS揭秘,国外的一本书,应该是目前相关书目里最好的了,内容非常扎实,不得不说图灵教育出的书真的不错,不然不是很厚的一本书卖到99也是.... 国外的这类书总是以问题开始,然后通过解决问题引出 ...