tensorflow用pretrained-model做retrain
最近工作里需要用到tensorflow的pretrained-model去做retrain. 记录一下.
为什么可以用pretrained-model去做retrain
这个就要引出CNN的本质了.CNN的本质就是求出合适的卷积核,提取出合理的底层特征.进而为不同的特征赋以权重.从而表达图像.
通俗点讲,比如有一张猫的图片,你怎么判断是猫不是狗?你可能会看到图里有猫的头,猫的爪子,猫的尾巴. 头/爪子/尾巴 就是CNN中比较靠前的层所提取出来的特征,我们称之为高级特征,这时候的特征我们人类还是能理解的. 继续对这些头/爪子/尾巴继续做特征提取,...,最终得到的特征已经非常细节非常抽象了,可能是一个点,一条线等等. 最终我们的image=这些低级特征乘以不同权重,求和.
假设现在你有一个基于公开数据集的trained-model.这个数据集里没有你想识别的图片,比如红绿灯吧. 但是,没关系!!,虽然你之前的模型不认识红绿灯,但是它也抽象出来了很多底层的抽象的细节特征啊,点啊,线啊之类的. 我们依然可以使用这些特征去表示红绿灯图片,只是每个特征的权重要改变而已! 这就是所谓的增强学习.
tensorflow里存储"很多底层的抽象的细节特征啊,点啊,线啊之类的"文件,称之为module.更多详细的见https://www.tensorflow.org/hub/tutorials/image_retraining
环境准备
- conda activate venv_python3.6
- pip install "tensorflow>=1.7.0"
- pip install tensorflow-hub
数据准备
- cd ~
- curl -LO http://download.tensorflow.org/example_images/flower_photos.tgz
- tar xzf flower_photos.tgz
示例代码下载
- mkdir ~/example_code
- cd ~/example_code
- curl -LO https://github.com/tensorflow/hub/raw/master/examples/image_retraining/retrain.py
重训练
- python retrain.py --image_dir ~/flower_photos
训练相关的文件模型等存储于/tmp
- /tmp/bottleneck 可以理解为每一个图片的feature map 存储的是新的class的image的抽象特征
- /tmp/output_graph.pb 新的模型
- /tmp/output_labels.txt 新识别出的label
bottleneck可以理解为image feature vector.可以理解为各种抽象的特征,点啊直线啊折线啊,利用这些特征,模型可以去做分类.
The script can take thirty minutes or more to complete, depending on the speed of your machine. The first phase analyzes all the images on disk and calculates and caches the bottleneck values for each of them. 'Bottleneck' is an informal term we often use for the layer just before the final output layer that actually does the classification. (TensorFlow Hub calls this an "image feature vector".) This penultimate layer has been trained to output a set of values that's good enough for the classifier to use to distinguish between all the classes it's been asked to recognize. That means it has to be a meaningful and compact summary of the images, since it has to contain enough information for the classifier to make a good choice in a very small set of values. The reason our final layer retraining can work on new classes is that it turns out the kind of information needed to distinguish between all the 1,000 classes in ImageNet is often also useful to distinguish between new kinds of objects.
- training accuracy 训练集精度
- validation accuracy 验证集精度
- Cross entropy 交叉熵
Cross entropy is a loss function which gives a glimpse into how well the learning process is progressing
整体而言,cross entropy应该是不断减小的,中间可能会有小的波动
train.py
python retrain.py \
--image_dir ~/flower_photos \
--tfhub_module https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/2
- 会从url + '?tf-hub-format=compressed'下载module包.默认会下载到/tmp/tfhub_modules
tar -xvf ../module.tar ./
./
./saved_model.pb
./variables/
./variables/variables.index
./variables/variables.data-00000-of-00001
./assets/
./tfhub_module.pb
这里面就包含了抽象的底层特征.
ssd module下载
https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1
数据集结构

每个目录下是相应类别的jpg文件
数据集的搜集应当注意的几点问题
The first place to start is by looking at the images you've gathered, since the most common issues we see with training come from the data that's being fed in.
For training to work well, you should gather at least a hundred photos of each kind of object you want to recognize. The more you can gather, the better the accuracy of your trained model is likely to be. You also need to make sure that the photos are a good representation of what your application will actually encounter. For example, if you take all your photos indoors against a blank wall and your users are trying to recognize objects outdoors, you probably won't see good results when you deploy.
Another pitfall to avoid is that the learning process will pick up on anything that the labeled images have in common with each other, and if you're not careful that might be something that's not useful. For example if you photograph one kind of object in a blue room, and another in a green one, then the model will end up basing its prediction on the background color, not the features of the object you actually care about. To avoid this, try to take pictures in as wide a variety of situations as you can, at different times, and with different devices.
You may also want to think about the categories you use. It might be worth splitting big categories that cover a lot of different physical forms into smaller ones that are more visually distinct. For example instead of 'vehicle' you might use 'car', 'motorbike', and 'truck'. It's also worth thinking about whether you have a 'closed world' or an 'open world' problem. In a closed world, the only things you'll ever be asked to categorize are the classes of object you know about. This might apply to a plant recognition app where you know the user is likely to be taking a picture of a flower, so all you have to do is decide which species. By contrast a roaming robot might see all sorts of different things through its camera as it wanders around the world. In that case you'd want the classifier to report if it wasn't sure what it was seeing. This can be hard to do well, but often if you collect a large number of typical 'background' photos with no relevant objects in them, you can add them to an extra 'unknown' class in your image folders.
It's also worth checking to make sure that all of your images are labeled correctly. Often user-generated tags are unreliable for our purposes. For example: pictures tagged #daisy might also include people and characters named Daisy. If you go through your images and weed out any mistakes it can do wonders for your overall accuracy.
如何使用本地model做retrain
这一步还没成功,因为我的需求比较特殊,我需要在jetson nano上跑模型,而tensorrt目前还是有Bug的,不是什么model都能推理,有的model里的算子不支持.而从tensorflow的官网download的ssd model的module,做retrain后得到的model无法在jetson nano上推理,
目前我需要ssd_inception_v2_coco_2017_11_17这个model对应的module,很不幸,并没有,只能自己写代码去做转换,使用了官方的create_module_spec_from_saved_model api还是有问题
与此问题相关的link
https://github.com/tensorflow/hub/issues/37
https://github.com/tensorflow/hub/blob/52d5066e925d345fbd54ddf98b7cadf027b69d99/examples/image_retraining/retrain.py 对应分支
https://www.tensorflow.org/hub/creating
python retrain.py
--image_dir ~/flower_photos
--tfhub_module ./ssd_inception_v2_coco_2017_11_17
tensorflow文件含义
- .pb文件 存储了完整的模型的结构信息,变量信息等.
- checkpoint文件 记录模型路径信息
cat checkpoint
model_checkpoint_path: "/tmp/_retrain_checkpoint"
all_model_checkpoint_paths: "/tmp/_retrain_checkpoint"
- .meta文件存储了运算图的结构
- .index文件存储了tensor结构的信息,ensorname<-->BundleEntryProto
- .data文件存储所有变量的值
meta file: describes the saved graph structure, includes GraphDef, SaverDef, and so on; then apply tf.train.import_meta_graph('/tmp/model.ckpt.meta'), will restore Saver and Graph.
index file: it is a string-string immutable table(tensorflow::table::Table). Each key is a name of a tensor and its value is a serialized BundleEntryProto. Each BundleEntryProto describes the metadata of a tensor: which of the "data" files contains the content of a tensor, the offset into that file, checksum, some auxiliary data, etc.
data file: it is TensorBundle collection, save the values of all variables.
tensorflow用pretrained-model做retrain的更多相关文章
- 查询的model里面 一般都要有一个要返回的model做属性 ;查询前要传入得参数,查询后返回的参数 都要集合在一个model中
查询的model里面 一般都要有一个要返回的model做属性
- Python学习笔记:Flask-Migrate基于model做upgrade的基本原理
1)flask-migrate的官网:https://flask-migrate.readthedocs.io/en/latest/ 2)获取帮助,在pycharm的控制台中输入 flask d ...
- tensorflow world language model
上文提到了pytorch里的world language model,那么怎么能不说tensorflow的实现呢,还是以tensorflow ptb的代码为例说说. 地址: https://githu ...
- VGG16 pre-trained model 实现 image classification
站在巨人的肩膀上!使用VGG预先训练好的weight来,进行自己的分类. 下一阶段是在这上面进行自己的修改,完成自己想要的功能. Github源码 Github上有我全部的工程代码. 环境配置 Pyt ...
- anaconda安装的TensorFlow版本没有model这个模块
一.采用git bash来安装,确认已经安装了git 二.手动找到TensorFlow的模块文件夹地址,若不知道,输入以下两行代码: import tensorflow as tf tf.__path ...
- 人体姿势识别,Convolutional pose machines文献阅读笔记。
开源实现 https://github.com/shihenw/convolutional-pose-machines-release(caffe版本) https://github.com/psyc ...
- [Tensorflow] Object Detection API - retrain mobileNet
前言 一.专注话题 重点话题 Retrain mobileNet (transfer learning). Train your own Object Detector. 这部分讲理论,下一篇讲实践. ...
- [Tensorflow] Object Detection API - predict through your exclusive model
开始预测 一.训练结果 From: Testing Custom Object Detector - TensorFlow Object Detection API Tutorial p.6 训练结果 ...
- TensorFlow Lite demo——就是为嵌入式设备而存在的,底层调用NDK神经网络API,注意其使用的tf model需要转换下,同时提供java和C++ API,无法使用tflite的见后
Introduction to TensorFlow Lite TensorFlow Lite is TensorFlow’s lightweight solution for mobile and ...
随机推荐
- Python初学者必看(1)
python介绍 python的创始人为吉多·范罗苏姆(Guido van Rossum).1989年的圣诞节期间,吉多·范罗苏姆为了在阿姆斯特丹打发时间,决心开发一个新的脚本解释程序,作为ABC语言 ...
- 关于linux下部署JavaWeb项目,nginx负责静态资源访问,tomcat负责处理动态请求的nginx配置
1.项目的运行环境 linux版本 [root@localhost ~]# cat /proc/version Linux version -.el6.x86_64 (mockbuild@x86-.b ...
- Security - 轻量级Java身份认证、访问控制安全框架
前言 此框架由小菜独立开发,并且已经在生产环境中运行大约一年时间. 也就是说,Security 框架写出来有一段时间了,但是一直没有公布.开源,经过不断迭代完善,终于算是拿得出手啦~ Security ...
- 并发的核心:CAS 是什么?Java8是如何优化 CAS 的?
大家可能都听说说 Java 中的并发包,如果想要读懂 Java 中的并发包,其核心就是要先读懂 CAS 机制,因为 CAS 可以说是并发包的底层实现原理. 今天就带大家读懂 CAS 是如何保证操作的原 ...
- java游戏开发杂谈 - 游戏物体
现实生活中,有很多物体,每个物体的长相.行为都不同. 物体存在于不同的空间内,它只在这个空间内发生作用. 物体没用了,空间就把它剔除,不然既占地方,又需要花精力管理. 需要它的时候,就把它造出来,不需 ...
- sun.misc.Unsafe 详解
原文地址 译者:许巧辉 校对:梁海舰 Java是一门安全的编程语言,防止程序员犯很多愚蠢的错误,它们大部分是基于内存管理的.但是,有一种方式可以有意的执行一些不安全.容易犯错的操作,那就是使用Unsa ...
- SSL,TLS
今天突然收到邮件说SSL不能用了,基于SSL的HTTPS协议不通了,怎么办? java/android 的网络编程简直一窍不通,平时都是用到了问百度.只能恶补有关网络的知识了. 传输协议: 传输协议中 ...
- ES 14 - (底层原理) Elasticsearch内部如何处理不同type的数据
目录 1 type的作用 2 type的底层数据结构 3 探究type的存储结构 3.1 创建索引并配置映射 3.2 添加数据 3.3 查看存储结构 4 关于type的最佳实践 1 type的作用 在 ...
- Python与家国天下
导读:Python猫是一只喵星来客,它爱地球的一切,特别爱优雅而无所不能的 Python.我是它的人类朋友豌豆花下猫,被授权润色与发表它的文章.如果你是第一次看到这个系列文章,那我强烈建议,请先看看它 ...
- 已实现乐观锁功能,FreeSql.DbContext 准备起航
上回说到 FreeSql.DbContext 的规则,以及演示它的执行过程,可惜当时还不支持"乐观锁",对于更新数据来讲并不安全. FreeSql 核心库 v0.3.27 已提供乐 ...