Tensorflow的slim框架可以写出像keras一样简单的代码来实现网络结构(虽然现在keras也已经集成在tf.contrib中了),而且models/slim提供了类似之前说过的object detection接口类似的image classification接口,可以很方便的进行fine-tuning利用自己的数据集训练自己所需的模型。

官方文档提供了比较详细的从数据准备,预训练模型的model zoo,fine-tuning,freeze model等一系列流程的步骤,但是缺少了inference的文档,不过tf所有模型的加载方式是通用的,所以调用方法和调用其他pb模型是一样的。

根据TF开发人员是说法Tensorflow对于模型读写的保存和调用的步骤一般如下:Build your graph --> write your graph --> import from written graph --> run compute etc

以下我们使用slim提供的网络inception-resnet-v2作为例子:

1. export inference graph

import tensorflow as tf
import nets.inception_resnet_v2 as net slim = tf.contrib.slim # checkpoint path
checkpoint_path = "/your/path/to/inception_resnet_v2.ckpt" # ckpt file obtained during model training or fine-tuning # set up and load session
sess = tf.Session()
arg_scope = net.inception_resnet_v2_arg_scope()
# initialize tensor suitable for model input
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
logits, end_points = net.inception_resnet_v2(inputs=input_tensor) # set up model saver
saver = tf.train.Saver()
saver.restore(sess, checkpoint_path)
with tf.gfile.GFile('/your/path/to/model_graph.pb', 'w') as f: # save model to given pb file
f.write(sess.graph_def.SerializeToString())
f.close()

2. freeze model

这里用tf提供的tensorflow/python/tools下的freeze_graph工具:

$ bazel build tensorflow/python/tools:freeze_graph
$ bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/your/path/to/model_graph.pb \ # obtained above
--input_checkpoint=/your/path/to/inception_resnet_v2.ckpt \
--input_binary=true
--output_graph=/your/path/to/frozen_graph.pb \
--output_node_names=InceptionResnetV2/Logits/Predictions # output node name defined in inception resnet v2 net

(Optional) visualize frozen graph

LOG_DIR = ‘/tmp/graphdeflogdir’
model_filename = '/your/path/to/frozen_graph.pb' with tf.Session() as sess:
with tf.gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
writer = tf.summary.FileWriter(LOG_DIR, graph_def)
writer.close()

然后用tensorborad --logdir=LOG_DIR选择graph就可以查看到frozen后的网络结构。

3. inference

import cv2
import numpy as np def preprocess_inception(image_np, central_fraction=0.875):
image_height, image_width, image_channel = image_np.shape
if central_fraction:
bbox_start_h = int(image_height * (1 - central_fraction) / 2)
bbox_end_h = int(image_height - bbox_start_h)
bbox_start_w = int(image_width * (1 - central_fraction) / 2)
bbox_end_w = int(image_width - bbox_start_w)
image_np = image_np[bbox_start_h:bbox_end_h, bbox_start_w:bbox_end_w]
# normalize
image_np = 2 * (image_np / 255.) - 1
return image_np image_np = cv2.imread("test.jpg")
# preprocess image as inception resnet v2 does
image_np = preprcess_inception(image_np)
# resize to model input image size
image_np = cv2.resize(image_np, (299, 299))
# expand dims to shape [None, 299, 299, 3]
image_np = np.expand_dims(image_np, 0)
# load model
with tf.gfile.GFile('/your/path/to/frozen_graph.pb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.import_graph_def(graph_def, name='')
with tf.Session(graph=graph) as sess:
input tensor = sess.graph.get_tensor_by_name("input:0") # get input tensor
output_tensor = sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0") # get output tensor
logits = sess.run(output_tensor, feed_dict={input_tensor: image_np})
print "Prediciton label index:", np.argmax(logits[0], 1)
print "Top 3 Prediciton label index:", np.argsort(logits[0], 3)

参考:

  1. https://stackoverflow.com/questions/42961243/using-pre-trained-inception-v4-model
  2. https://gist.github.com/cchadowitz-pf/f1c3e781c125813f9976f6e69c06fec2
  3. https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
  4. https://github.com/tensorflow/models/blob/master/slim/README.md
  5. https://gist.github.com/tokestermw/795cc1fd6d0c9069b20204cbd133e36b

Tensorflow 使用slim框架下的分类模型进行分类的更多相关文章

  1. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  2. Windows下mnist数据集caffemodel分类模型训练及测试

    1. MNIST数据集介绍 MNIST是一个手写数字数据库,样本收集的是美国中学生手写样本,比较符合实际情况,大体上样本是这样的: MNIST数据库有以下特性: 包含了60000个训练样本集和1000 ...

  3. TensorFlow(十八):从零开始训练图片分类模型

    (一):进入GitHub下载模型-->下载地址 因为我们需要slim模块,所以将包中的slim文件夹复制出来使用. (1):在slim中新建images文件夹存放图片集 (2):新建model文 ...

  4. 三分钟快速上手TensorFlow 2.0 (下)——模型的部署 、大规模训练、加速

    前文:三分钟快速上手TensorFlow 2.0 (中)——常用模块和模型的部署 TensorFlow 模型导出 使用 SavedModel 完整导出模型 不仅包含参数的权值,还包含计算的流程(即计算 ...

  5. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  6. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  7. tensorflow中slim模块api介绍

    tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35   http://blog.csdn.net/guvcolie/article/details/77686 ...

  8. tensorflow实现基于LSTM的文本分类方法

    tensorflow实现基于LSTM的文本分类方法 作者:u010223750 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实 ...

  9. Tensorflow object detection API 搭建物体识别模型(二)

    二.数据准备 1)下载图片 图片来源于ImageNet中的鲤鱼分类,下载地址:https://pan.baidu.com/s/1Ry0ywIXVInGxeHi3uu608g 提取码: wib3 在桌面 ...

随机推荐

  1. 走进java

    Java 技术体系 1.java技术语言 2.各种硬件平台上的java虚拟机 3.Class文件格式 4.Java API类库 5.来自商业机构和开源社区的第三方Java类库 我们把Java程序设计语 ...

  2. zabbix 添加被监控主机

    点击 configured > host > create host 主机名:输入主机名,允许使用字母数字,空格,点,破折号和下划线 组:从右侧选择框中选择一个或多个组,然后单击 « 将其 ...

  3. PHP 利用QQ邮箱发送邮件「PHPMailer」

    在 PHP 应用开发中,往往需要验证用户邮箱.发送消息通知,而使用 PHP 内置的 mail() 函数,则需要邮件系统的支持. 如果熟悉 IMAP/SMTP 协议,结合 Socket 功能就可以编写邮 ...

  4. WTL中最简单的实现窗口拖动的方法(转)

    目前,很多基于对话框的应用程序中对话框都是不带框架的,也就是说对话框没有标题栏.众所周知,窗口的移动都是通过鼠标拖动窗口的标题栏来实现的,那么现在应用程序中的对话框没有了标题栏,用户如何移动对话框呢? ...

  5. VS2008中捕获内存泄露(转)

    内存泄露十分讨厌,捕获内存泄露更加令人厌烦…… 其实,VS本身就有内存泄露的检测机制.只需做以下操作即可开启.(同时必须在debug模式 下运行程序并且以 正常流程退出 ) // 在入口函数cpp中添 ...

  6. extern字符串常量,宏定义字符串常量,怎么选

    在使用常量的时候,我看到主要有两种写法: #define RKLICURegexEnumerationOptionsErrorKey @"RKLICURegexEnumerationOpti ...

  7. WCF:又是枚举惹的祸

    在WCF中使用枚举不便于服务的演化,因为增加一个枚举值,需要更新所有客户端.某种程度上说这也带来了好处,即:防止了新增枚举值带来的意外(宁可失败,也不意外). 鉴于枚举的这种表现,以后尽可能的采用in ...

  8. How to chain a command after sudo su?

    The idea is simple, for example: alias foo='sudo su foo && cd /tmp' However, it does not exe ...

  9. 微软BI 之SSRS 系列 - 实现 Excel 中图表结合的报表设计

    来自群里面讨论的一个问题,EXCEL 中有类似于这样的图形,上面是 Chart, Chart X轴上的值正好就是下方 Table 的列头,这个在 SSRS 中应该如何实现?   SSRS 2008.2 ...

  10. Long polling failed, will retry in 16 seconds. appId: zeus-guard, cluster: default, namespaces: application, long polling url: null, reason: Get config services failed from···

    当dubbo应用启动之前, 如果apollo 未启动好,那么我们dubbo应用会一直等待,直到apollo准备就绪,注意其中轮询时间是从1,2,3,4,8,14,32, 方式一直增长,单位是s.