深度学习应用系列(四)| 使用 TFLite Android构建自己的图像识别App
深度学习要想落地实践,一个少不了的路径即是朝着智能终端、嵌入式设备等方向发展。但终端设备没有GPU服务器那样的强大性能,那如何使得终端设备应用上深度学习呢?
所幸谷歌已经推出了TFMobile,去年又更进一步,推出了TFLite,其应用思路为在GPU服务器上利用迁移学习训练自己的模型,然后将定制化模型移植到TFLite上,
终端设备仅利用模型做前向推理,预测结果。本文基于以下三篇文章而成:
- 理论篇:https://www.tensorflow.org/hub/tutorials/image_retraining#other_architectures
- 实践篇一:https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/index.html#0
- 实践篇二:https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/#0
- 谷歌提供的预编译模型:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md
- TOCO官网:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco
相信大家掌握后,也能轻松定制化自己的图像识别应用。
第一步. 准备数据
数据下载地址为:http://download.tensorflow.org/example_images/flower_photos.tgz
这是一个关于花分类的图片集合,下载解压后,可以看出有5个品种分类:daisy(雏菊)、dandelion(蒲公英)、rose(玫瑰)、sunflower(向日葵)、tulip(郁金香)。
我们的目的即是通过重新训练预编译模型,得到一个花类识别的模型。
第二步. 重新训练
1. 挑选预编译模型
从上述“谷歌提供的预编译模型”列表中,我们大体可以看出分为两类模型,一种是Float Models(浮点数模型),一种是Quantized Models(量化模型),什么区别呢?
其实Float Models表示为一种高精度值的模型,该模型意味着模型size较大,识别精度更高、识别时长更长,适合高性能终端设备;而Quantized Models则反之,是低精度值的模型,其精度采取固定的8位大小,故其模型size较小,识别精度低、识别时长较短,适合低性能终端设备,更细的说明可以参见 https://www.tensorflow.org/performance/quantization 。
我们的手机设备更新换代很快,一般可以使用Float Models。在这个模型下,有不少预编译模型可选,对于本文来说,主要集中为Inception 和Mobilenet两种架构。
注意Mobilenet其实也分为很多种类,如Mobilenet_V1_0.50_224,其中第三个参数为模型大小比例值(只能算是近似,不准确),分为0.25/0.50/0.75/1.0四个比例值,第四个参数为图片大小,其值有128/160/192/224四种值。
有兴趣想观察各模型层次结构的可通过以下代码查看:
import tensorflow as tf
import tensorflow.gfile as gfile MODEL_PATH = '/home/yourname/Documents/mobilenet_v1_1.0_224/frozen_graph.pb' def main(unusedArgv):
with tf.Graph().as_default() as graph:
with gfile.FastGFile(MODEL_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
for op in graph.get_operations():
for tensor in op.values():
print(tensor) if __name__ == '__main__':
tf.app.run()
考虑到测试手机性能还不赖,我们选择mobilenet_v1_1.0_224这个版本作为我们的预编译模型。
2. 下载训练代码
需要下载训练模型代码和android相关代码,如下:
git clone https://github.com/googlecodelabs/tensorflow-for-poets-2 cd tensorflow-for-poets-
其中,scripts目录下的retrain.py是我们需要关注的,这个代码目前仅支持Inception_v3和Mobilenet两种预编译模型,默认的训练模型为Inception_v3。
3. 重新训练模型
两种模型的训练命令不同,若走默认的Inception_v3模型,可通过如下命令:
python -m scripts.retrain \
--learning_rate=0.01 \
--bottleneck_dir=tf_files/bottlenecks \
--how_many_training_steps= \
--model_dir=tf_files/models/ \
--output_graph=tf_files/retrained_graph.pb \
--output_labels=tf_files/retrained_labels.txt \
--image_dir=tf_files/flower_photos \
若走Mobilenet模型,可通过如下命令:
python -m scripts.retrain \
--learning_rate=0.01 \
--bottleneck_dir=tf_files/bottlenecks \
--how_many_training_steps= \
--model_dir=tf_files/models/ \
--output_graph=tf_files/retrained_graph.pb \
--output_labels=tf_files/retrained_labels.txt \
--image_dir=tf_files/flower_photos \
--architecture=mobilenet_1.0_224
模型命令解释如下:
--architecture 为架构类型,支持mobilenet和Inception_v3两种
--image_dir 为数据地址,假定你在tensorflow-for-poets-2目录下建立了tflite目录,把花图片集放入其中
--output_labels 最后训练生成模型的标签,由于花图片集合已经按照子目录进行了分类,故retrained_labels.txt最后包含了上述五种花的分类名称
--output_graph 最后训练生成的模型
--model_dir 命令启动后,预编译模型的下载地址
--how_many_training_steps 训练步数,不指定的话默认为4000
--bottleneck_dir用来把top层的训练数据缓存成文件
--learning_rate 学习率
此外,还有些参数可以根据需要进行调整:
--testing_percentage 把图片按多少比例划分出来当做test数据,默认为10
--validation_percentage 把图片按多少比例划分出来当做validation数据,默认为10,这两个值设置完后,training数据占比80%
--eval_step_interval 多少步训练后进行一次评估,默认为10
--train_batch_size 一次训练的图片数,默认为100
--validation_batch_size 一次验证的图片数,默认为100
--random_scale 给定一个比例值,然后随机扩大训练图片的大小,默认为0
--random_brightness 给定一个比例值,然后随机增强或减弱训练图片的明亮程度,默认为0
--random_crop 给定一个比例值,然后随机裁剪训练图片的边缘值,默认为0
4. 检验训练效果
我们用Mobilenet_1.0_224进行训练,完成后找一张图片看看是否能正确识别:
python -m scripts.label_image \
--graph=tf_files/retrained_graph.pb \
--image=tf_files/flower_photos/daisy/3475870145_685a19116d.jpg
结果为:
Evaluation time (-image): .010s daisy (score=0.62305)
tulips (score=0.22490)
dandelion (score=0.14169)
roses (score=0.00966)
sunflowers (score=0.00071)
还是准确地识别了daisy出来。
5. 转换模型格式
pb格式是不能运行在TFLite上的,TFLite吸收了谷歌的protobuffer优点,创造了FlatBuffer格式,具体表现就是后缀名为.tflite的文件。
上述TOCO的官网已经介绍了如何通过命令行把pb格式转成为tflite文件,或者在代码里也可以转换格式。不仅支持pb格式,也支持HDF5文件格式转换成tflite,实现了与其他框架的模型共享。
那如何转呢?本例通过命令行方式转换。若训练模型为Inception_v3,命令行方式如下:
toco \
--graph_def_file=tf_files/retrained_graph.pb \
--output_file=tf_files/optimized_graph.lite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=,,, \
--input_array=Mul \
--output_array=final_result \
--inference_type=FLOAT \
--input_data_type=FLOAT
若训练模型为mobilenet,命令行方式则如下:
toco \
--graph_def_file=tf_files/retrained_graph.pb \
--output_file=tf_files/optimized_graph.lite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--input_shape=,,, \
--input_array=input \
--output_array=final_result \
--inference_type=FLOAT \
--input_data_type=FLOAT
需要说明几点:
--input_array 参数表示模型图结构的入口tensor op名称,mobilenet的入口名称为input,Inception_v3的入口名称为Mul,为什么这样?可查看scripts/retrain.py代码里内容:
if architecture == 'inception_v3':
# pylint: disable=line-too-long
data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long
bottleneck_tensor_name = 'pool_3/_reshape:0'
bottleneck_tensor_size = 2048
input_width = 299
input_height = 299
input_depth = 3
resized_input_tensor_name = 'Mul:0'
model_file_name = 'classify_image_graph_def.pb'
input_mean = 128
input_std = 128
elif architecture.startswith('mobilenet_'):
...
data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
data_url += version_string + '_' + size_string + '_frozen.tgz'
bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
bottleneck_tensor_size = 1001
input_width = int(size_string)
input_height = int(size_string)
input_depth = 3
resized_input_tensor_name = 'input:0'
其中的resized_input_tensor_name即是新生成模型的入口名称,大家也可以通过上面“1.挑选预编译模型”的代码可视化查看新生成的模型层次结构。所以名称必须正确写对,否则运行该命令会抛出“ValueError: Invalid tensors 'input' were found” 的异常。
--output_array则是模型的出口名称。为什么是final_result这个名称,因为在scripts/retrain.py里有:
parser.add_argument(
'--final_tensor_name',
type=str,
default='final_result',
help="""\
The name of the output classification layer in the retrained graph.\
"""
)
即出口名称默认为final_result。
--input_shape 需要注意的是mobilenet的训练图片大小为224,而Inception_v3的训练图片大小为299。
最后optimized_graph.lite即是我们要移植到android上的模型文件啦。
第三步. Android TFLite
1. 下载Android Studio
这一步骤不是本文重点,请大家自行在 https://developer.android.com/studio/ 进行下载安装,安装最新的SDK和NDK。
2. 引入工程
从android studio上引入 tensorflow-for-poets-2/android/tflite 下的代码,共有四个类,有三个类是跟布局打交道,而我们只需要关注ImageClassifier.java类。
3. 导入模型
可通过命令行方式把生成的模型导入上述工程的资源目录下:
cp tf_files/optimized_graph.lite android/tflite/app/src/main/assets/mobilenet.lite
cp tf_files/retrained_labels.txt android/tflite/app/src/main/assets/mobilenet.txt
4. 修改ImageClassifier.java类
注意修改四个地方即可:
/** Name of the model file stored in Assets. */
private static final String MODEL_PATH = "mobilenet.lite"; /** Name of the label file stored in Assets. */
private static final String LABEL_PATH = "mobilenet.txt"; static final int DIM_IMG_SIZE_X = 224; //若是inception,改成299
static final int DIM_IMG_SIZE_Y = 224; //若是inception,改成299
5. 运行观看效果
连上手机后,点击“Run”->"Run app"即会部署app到手机上,此时任何被摄像头捕获的图片都会按照标签里的5个分类进行识别排名。
我们可以通过百度搜一些这五种类别的花进行识别,以看看其识别的正确率。
后记:根据我的测试结果,在花的图片集上,mobilenet_1.0_244模型生成的新模型识别率较高,而inception_v3模型生成的新模型识别率较低或不准。
建议大家新的数据集可在两种模型间进行比较,以找到最适合自己的模型。
深度学习应用系列(四)| 使用 TFLite Android构建自己的图像识别App的更多相关文章
- 深度学习基础系列(九)| Dropout VS Batch Normalization? 是时候放弃Dropout了
Dropout是过去几年非常流行的正则化技术,可有效防止过拟合的发生.但从深度学习的发展趋势看,Batch Normalizaton(简称BN)正在逐步取代Dropout技术,特别是在卷积层.本文将首 ...
- 深度学习基础系列(五)| 深入理解交叉熵函数及其在tensorflow和keras中的实现
在统计学中,损失函数是一种衡量损失和错误(这种损失与“错误地”估计有关,如费用或者设备的损失)程度的函数.假设某样本的实际输出为a,而预计的输出为y,则y与a之间存在偏差,深度学习的目的即是通过不断地 ...
- 深度学习实践系列(2)- 搭建notMNIST的深度神经网络
如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...
- 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络
前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...
- UFLDL深度学习笔记 (四)用于分类的深度网络
UFLDL深度学习笔记 (四)用于分类的深度网络 1. 主要思路 本文要讨论的"UFLDL 建立分类用深度网络"基本原理基于前2节的softmax回归和 无监督特征学习,区别在于使 ...
- 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型
目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...
- C# Lambda 表达式学习之(四):动态构建类似于 c => c.Age == 2 || c.Age == 5 || c => c.Age == 17 等等一个或多个 OrElse 的表达式
可能你还感兴趣: 1. C# Lambda 表达式学习之(一):得到一个类的字段(Field)或属性(Property)名,强类型得到 2. C# Lambda 表达式学习之(二):LambdaExp ...
- 深度学习基础系列(四)| 理解softmax函数
深度学习最终目的表现为解决分类或回归问题.在现实应用中,输出层我们大多采用softmax或sigmoid函数来输出分类概率值,其中二元分类可以应用sigmoid函数. 而在多元分类的问题中,我们默认采 ...
- 基于深度学习的安卓恶意应用检测----------android manfest.xml + run time opcode, use 深度置信网络(DBN)
基于深度学习的安卓恶意应用检测 from:http://www.xml-data.org/JSJYY/2017-6-1650.htm 苏志达, 祝跃飞, 刘龙 摘要: 针对传统安卓恶意程序检测 ...
随机推荐
- 条件转化,2-sat BZOJ 1997
http://www.lydsy.com/JudgeOnline/problem.php?id=1997 1997: [Hnoi2010]Planar Time Limit: 10 Sec Memo ...
- codeblocks快捷键(转)
==日常编辑== • 按住Ctrl滚滚轮,代码的字体会随你心意变大变小.• 在编辑区按住右键可拖动代码,省去拉(尤其是横向)滚动条之麻烦:相关设置:Mouse Drag Scrolling.• Ctr ...
- linux中操作数据库的使用命令记录
1,mysql 查看数据库表编码格式: show create table widget; 修改数据库表编码格式: alter table widget default character set u ...
- 【BZOJ】3991: [SDOI2015]寻宝游戏 虚树+DFS序+set
[题意]给定n个点的带边权树,对于树上存在的若干特殊点,要求任选一个点开始将所有特殊点走遍后返回.现在初始没有特殊点,m次操作每次增加或减少一个特殊点,求每次操作后的总代价.n,m<=10^5. ...
- WordPress404页面自定义
不知道大家是怎么设计404页面,个性的404可以为网站增色不少,wordpress设置404是在主题里面的404.php页面上,当然比如你用Apache.nginx等服务器,你可以自己建一个单页,内容 ...
- 宋牧春: Linux设备树文件结构与解析深度分析(2) 【转】
转自:https://mp.weixin.qq.com/s/WPZSElF3OQPMGqdoldm07A 作者简介 宋牧春,linux内核爱好者,喜欢阅读各种开源代码(uboot.linux.ucos ...
- 64_r2
ruby-gnomecanvas2-0.90.4-7.fc26.3.x86_64.rpm 13-Feb-2017 08:00 75794 ruby-gnomecanvas2-devel-0.90.4- ...
- 【转载】selenium之 定位以及切换frame(iframe)
更多关于python selenium的文章,请关注我的专栏:Python Selenium自动化测试详解 总有人看不明白,以防万一,先在开头大写加粗说明一下: frameset不用切,frame需层 ...
- Codeforces Round #441 (Div. 2)
Codeforces Round #441 (Div. 2) A. Trip For Meal 题目描述:给出\(3\)个点,以及任意两个点之间的距离,求从\(1\)个点出发,再走\(n-1\)个点的 ...
- 千万不要运行的 Linux 命令
本文中列出的命令绝对不可以运行,即使你觉得很好奇也不行,除非你是在虚拟机上运行(出现问题你可以还原),因为它们会实实在在的破坏你的系统.所以不在root等高级管理权限下执行命令是很好的习惯. 本文的目 ...