1. 前言

近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现。在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题、无线领域机型适配问题转换为“特定场景下的正常图片和异常图片的二分类问题”,并借助Goolge开源的Inception V3网络进行迁移学习,重训练出对应场景下的图片分类模型,问题图片的准确率达到95%以上。

过去一年,我们在图片智能识别做的主要工作包括:

  • 模型的落地和参数调优
  • 模型的服务化
  • 模型服务的优化(包括数据库连接池的引入、gunicorn容器的引入、docker化等)

本篇文章主要是对模型重训练的源码进行学习和分析,加深对模型训练过程的理解,以便后续在对模型训练过程进行调整时有的放矢。

这边对迁移学习做个简单解释:图像识别往往包含数以百万计的参数,从头训练需要大量打好标签的图片,还需要大量的计算力(往往数百小时的GPU时间)。对此,迁移学习是一个捷径,它可以在已经训练好的相似工作模型基础上,继续训练新的模型。

2. retrain.py源码分析

目前我们使用的图像智能服务,对于迁移学习的代码,是参考的开源代码 github: tensorflow/hub/image_retraining/retrain.py

下面是对源码的学习和解读:

2.1 执行主入口main:

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_dir',
type=str,
default='',
help='Path to folders of labeled images.'
)
parser.add_argument(
'--output_graph',
type=str,
default='/tmp/output_graph.pb',
help='Where to save the trained graph.'
)
......省略......
parser.add_argument(
'--logging_verbosity',
type=str,
default='INFO',
choices=['DEBUG', 'INFO', 'WARN', 'ERROR', 'FATAL'],
help='How much logging output should be produced.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

可以看到,程序main入口主要是对输入参数的声明和解析,实际执行时传入的参数会存入到FLAGS变量中,然后执行 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)开始正式训练。

2.2 main(_)方法

def main(_):
# Needed to make sure the logging output is visible.
# See https://github.com/tensorflow/tensorflow/issues/3047 ## 设置log级别
logging_verbosity = logging_level_verbosity(FLAGS.logging_verbosity)
tf.logging.set_verbosity(logging_verbosity) ## 判断image_dir参数是否传入,该参数表示用于训练的图片集路径
if not FLAGS.image_dir:
tf.logging.error('Must set flag --image_dir.')
return -1 # Prepare necessary directories that can be used during training
## 重建summaries_dir,并确保intermediate_output_graphs_dir存在
prepare_file_system() # Look at the folder structure, and create lists of all the images.
## 根据输入的图片集路径、测试图片占比、验证图片占比来划分输入的图集,将图集划分为训练集、测试集、验证集
image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
FLAGS.validation_percentage) ## 根据image_dir下的子目录个数,判断要分类的数量。每个子目录为一个类别,每个类别会各自分为训练集、测试集、验证集。如果类别数为0或1,则返回错误,因为分类问题至少要有2个类。
class_count = len(image_lists.keys())
if class_count == 0:
tf.logging.error('No valid folders of images found at ' + FLAGS.image_dir)
return -1
if class_count == 1:
tf.logging.error('Only one valid folder of images found at ' +
FLAGS.image_dir +
' - multiple classes are needed for classification.')
return -1 # See if the command-line flags mean we're applying any distortions.
## 根据传入的参数判断是否要对图片进行一些调整
do_distort_images = should_distort_images(
FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
FLAGS.random_brightness) # Set up the pre-trained graph.
## 载入module,默认使用inception v3,可以用参数--tfhub_module调整为使用其他已训练的模型
module_spec = hub.load_module_spec(FLAGS.tfhub_module)
## 创建模型图graph
graph, bottleneck_tensor, resized_image_tensor, wants_quantization = (
create_module_graph(module_spec)) # Add the new layer that we'll be training.
## 调用add_final_retrain_ops方法获得训练步骤、交叉熵、瓶颈输入、真实的输入、最终的tensor
with graph.as_default():
(train_step, cross_entropy, bottleneck_input,
ground_truth_input, final_tensor) = add_final_retrain_ops(
class_count, FLAGS.final_tensor_name, bottleneck_tensor,
wants_quantization, is_training=True) with tf.Session(graph=graph) as sess:
# Initialize all weights: for the module to their pretrained values,
# and for the newly added retraining layer to random initial values.
## 初始化变量
init = tf.global_variables_initializer()
sess.run(init) # Set up the image decoding sub-graph.
## 调用图片解码操作的函数获得输入的图片tensor和解码后的图片tensor
jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec) if do_distort_images:
# We will be applying distortions, so set up the operations we'll need.
(distorted_jpeg_data_tensor,
distorted_image_tensor) = add_input_distortions(
FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
FLAGS.random_brightness, module_spec)
else:
# We'll make sure we've calculated the 'bottleneck' image summaries and
# cached them on disk.
## 创建各个image的bottlenecks,并缓存到磁盘disk
cache_bottlenecks(sess, image_lists, FLAGS.image_dir,
FLAGS.bottleneck_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, FLAGS.tfhub_module) # Create the operations we need to evaluate the accuracy of our new layer.
## 创建评估的operation
evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input) # Merge all the summaries and write them out to the summaries_dir
## 将summary merge并写到summaries_dir目录下
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
sess.graph) validation_writer = tf.summary.FileWriter(
FLAGS.summaries_dir + '/validation') # Create a train saver that is used to restore values into an eval graph
# when exporting models.
train_saver = tf.train.Saver() # Run the training for as many cycles as requested on the command line.
## 根据传入的迭代次数,开始训练
for i in range(FLAGS.how_many_training_steps):
# Get a batch of input bottleneck values, either calculated fresh every
# time with distortions applied, or from the cache stored on disk.
if do_distort_images:
(train_bottlenecks,
train_ground_truth) = get_random_distorted_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.image_dir, distorted_jpeg_data_tensor,
distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
else:
## 获取用于training的图片bottlenecks值,默认train_batch_size=100,即每次迭代会批量取100张图片进行训练
(train_bottlenecks,
train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, FLAGS.train_batch_size, 'training',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
FLAGS.tfhub_module)
# Feed the bottlenecks and ground truth into the graph, and run a training
# step. Capture training summaries for TensorBoard with the `merged` op.
## 执行merge操作,并用feed_dict的内容填充placeholder
train_summary, _ = sess.run(
[merged, train_step],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
train_writer.add_summary(train_summary, i) # Every so often, print out how well the graph is training.
## 判断是否最后一步训练
is_last_step = (i + 1 == FLAGS.how_many_training_steps) ## 默认eval_step_interval=10,即每训练10次或训练全部完成,打印一下当前的训练结果
if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
## 打印训练精确度和交叉熵
train_accuracy, cross_entropy_value = sess.run(
[evaluation_step, cross_entropy],
feed_dict={bottleneck_input: train_bottlenecks,
ground_truth_input: train_ground_truth})
tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %
(datetime.now(), i, train_accuracy * 100))
tf.logging.info('%s: Step %d: Cross entropy = %f' %
(datetime.now(), i, cross_entropy_value))
# TODO: Make this use an eval graph, to avoid quantization
# moving averages being updated by the validation set, though in
# practice this makes a negligable difference.
## 获取验证集的图片的bottleneck值,也是每批次取100
validation_bottlenecks, validation_ground_truth, _ = (
get_random_cached_bottlenecks(
sess, image_lists, FLAGS.validation_batch_size, 'validation',
FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
FLAGS.tfhub_module))
# Run a validation step and capture training summaries for TensorBoard
# with the `merged` op.
validation_summary, validation_accuracy = sess.run(
[merged, evaluation_step],
feed_dict={bottleneck_input: validation_bottlenecks,
ground_truth_input: validation_ground_truth})
validation_writer.add_summary(validation_summary, i) ## 打印验证集的测试精确度和测试的图片数
tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
(datetime.now(), i, validation_accuracy * 100,
len(validation_bottlenecks))) # Store intermediate results
## 存储瞬时结果
intermediate_frequency = FLAGS.intermediate_store_frequency if (intermediate_frequency > 0 and (i % intermediate_frequency == 0)
and i > 0):
# If we want to do an intermediate save, save a checkpoint of the train
# graph, to restore into the eval graph.
train_saver.save(sess, CHECKPOINT_NAME)
intermediate_file_name = (FLAGS.intermediate_output_graphs_dir +
'intermediate_' + str(i) + '.pb')
tf.logging.info('Save intermediate result to : ' +
intermediate_file_name)
save_graph_to_file(intermediate_file_name, module_spec,
class_count) # After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME) # We've completed all our training, so run a final test evaluation on
# some new images we haven't used before.
## 执行最终的评估
run_final_eval(sess, module_spec, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor) # Write out the trained graph and labels with the weights stored as
# constants.
tf.logging.info('Save final result to : ' + FLAGS.output_graph)
if wants_quantization:
tf.logging.info('The model is instrumented for quantization with TF-Lite')
save_graph_to_file(FLAGS.output_graph, module_spec, class_count)
with tf.gfile.GFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n') ## 保存训练的graph
if FLAGS.saved_model_dir:
export_model(module_spec, class_count, FLAGS.saved_model_dir)

main方法中的一些细节解释已经用中文备注在上述代码(使用“##”开头)中,它的主要步骤是:

  • 设置log级别
  • 准备workspace
  • 从image_dir载入输入图片集,并创建image_lists,该image_lists是一个字段,key为各个类别,value为对应类别的图集(包含训练集、测试集、验证集,划分比例默认为0.8、0.1、0.1)
  • 载入在ImageNet上已经训练好的Inception V3网络的特征张量
  • 针对每个图片,调用图片解码操作获得图片的原始张量和解码后张量
  • 针对每个图片的jpeg_data_tensor和decoded_image_tensor,创建其对应的bottlenects(实际上是1*2048维的张量),并缓存到磁盘
  • 获取训练步骤、交叉熵
  • 开始迭代训练
  • 每迭代10次,打印训练的精度和交叉熵,打印验证集的测试结果。默认情况下训练集和测试集都是取100张图
  • 训练完成后,使用测试集进行最后的评估
  • 结果的打印和保存

2.3 其它方法

分析完代码的主要执行路径,下面解读下其它方法。因为总的代码非常的长,篇幅有限,下面按照顺序简单介绍下其它方法的内容。

2.3.1 create_image_lists

def create_image_lists(image_dir, testing_percentage, validation_percentage):
...... 省略......
result[label_name] = {
'dir': dir_name,
'training': training_images,
'testing': testing_images,
'validation': validation_images,
}
return result

根据image_dir的地址,testing_percentage和testing_percentage的比例划分图集,返回的格式类似如下:

{
'correct': {
'dir': correct_image_dir,
'training': correct_training_images,
'testing': correct_testing_images,
'validation': correct_validation_images
},
'error': {
'dir': error_image_dir,
'training': error_training_images,
'testing': error_testing_images,
'validation': error_validation_images
}
}

每个training/testing/validation对应的value为image的file_name list。

2.3.2 get_image_path

获取图片的全路径

2.3.3 get_bottleneck_path

获得不同类别(training、testing、validation)的bottleneck路径

2.3.4 create_module_graph

根据给定的已训练好的模型Hub Module,创建模型的图

2.3.5 run_bottleneck_on_image

def run_bottleneck_on_image(sess, image_data, image_data_tensor,
decoded_image_tensor, resized_input_tensor,
bottleneck_tensor):
"""Runs inference on an image to extract the 'bottleneck' summary layer.
Args:
sess: Current active TensorFlow Session.
image_data: String of raw JPEG data.
image_data_tensor: Input data layer in the graph.
decoded_image_tensor: Output of initial image resizing and preprocessing.
resized_input_tensor: The input node of the recognition graph.
bottleneck_tensor: Layer before the final softmax.
Returns:
Numpy array of bottleneck values.
"""
# First decode the JPEG image, resize it, and rescale the pixel values.
resized_input_values = sess.run(decoded_image_tensor,
{image_data_tensor: image_data})
# Then run it through the recognition network.
bottleneck_values = sess.run(bottleneck_tensor,
{resized_input_tensor: resized_input_values})
bottleneck_values = np.squeeze(bottleneck_values)
return bottleneck_values

根据给定的输入图片解码后的tensor,计算bottleneck_values,并执行squeeze操作(删除单维度条目,把shape中为1的维度去掉)

2.3.6 ensure_dir_exists

确保目录存在:如果目录不存在,则创建目录

2.3.7 create_bottleneck_file

调run_bottleneck_on_image方法计算bottleneck值,并缓存到磁盘文件

2.3.8 get_or_create_bottleneck

批量获取一组图片的bottleneck值

2.3.9 cache_bottlenecks

批量缓存bottleneck

2.3.10 get_random_cached_bottlenecks

随机获取一批缓存的bottlenecks,以及其对应的真实标ground_truths和文件名filenames

2.3.11 add_final_retrain_ops

def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
quantize_layer, is_training): batch_size, bottleneck_tensor_size = bottleneck_tensor.get_shape().as_list()
assert batch_size is None, 'We want to work with arbitrary batch size.'
with tf.name_scope('input'):
bottleneck_input = tf.placeholder_with_default(
bottleneck_tensor,
shape=[batch_size, bottleneck_tensor_size],
name='BottleneckInputPlaceholder') ground_truth_input = tf.placeholder(
tf.int64, [batch_size], name='GroundTruthInput') # Organizing the following ops so they are easier to see in TensorBoard.
layer_name = 'final_retrain_ops'
with tf.name_scope(layer_name):
with tf.name_scope('weights'):
initial_value = tf.truncated_normal(
[bottleneck_tensor_size, class_count], stddev=0.001)
layer_weights = tf.Variable(initial_value, name='final_weights')
variable_summaries(layer_weights) with tf.name_scope('biases'):
layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
variable_summaries(layer_biases) with tf.name_scope('Wx_plus_b'):
logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) # The tf.contrib.quantize functions rewrite the graph in place for
# quantization. The imported model graph has already been rewritten, so upon
# calling these rewrites, only the newly added final layer will be
# transformed.
if quantize_layer:
if is_training:
tf.contrib.quantize.create_training_graph()
else:
tf.contrib.quantize.create_eval_graph() tf.summary.histogram('activations', final_tensor) # If this is an eval graph, we don't need to add loss ops or an optimizer.
if not is_training:
return None, None, bottleneck_input, ground_truth_input, final_tensor with tf.name_scope('cross_entropy'):
cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy(
labels=ground_truth_input, logits=logits) tf.summary.scalar('cross_entropy', cross_entropy_mean) with tf.name_scope('train'):
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
train_step = optimizer.minimize(cross_entropy_mean) return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
final_tensor)

在结尾处添加一个新的softmax层和全连接层(y=WX+b),用于训练和评估。此处与logistic模型是一样的,采用梯度下降的方式来最小化交叉熵进行迭代训练。

2.3.12 add_evaluation_step

def add_evaluation_step(result_tensor, ground_truth_tensor):
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
## 对每组向量按列找到最大值的index
prediction = tf.argmax(result_tensor, 1)
## 将每组张量比较预测的结果和实际的结果的一致性,一致则为True,否则为False
correct_prediction = tf.equal(prediction, ground_truth_tensor)
with tf.name_scope('accuracy'):
## 将True或False转为float格式,并计算平均值
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
return evaluation_step, prediction

注解见上述代码,返回最终的accuracy和预测的值list。

2.3.13 run_final_eval

执行最终的评估,使用测试集进行结果评估。如果传入参数print_misclassified_test_images,则会打印评估出错的图片的名字和识别结果。

2.3.14 save_graph_to_file

将graph保存到文件

2.3.15 prepare_file_system

准备workspace

2.3.16 add_jpeg_decoding

将输入图片解析为张量,并进行解码

2.3.17 export_model

输出模型

源码分析——迁移学习Inception V3网络重训练实现图片分类的更多相关文章

  1. druid 源码分析与学习(含详细监控设计思路的彩蛋)(转)

    原文路径:http://herman-liu76.iteye.com/blog/2308563  Druid是阿里巴巴公司的数据库连接池工具,昨天突然想学习一下阿里的druid源码,于是下载下来分析了 ...

  2. .NET 云原生架构师训练营(KestrelServer源码分析)--学习笔记

    目录 目标 源码 目标 理解 KestrelServer 如何接收网络请求,网络请求如何转换成 http request context(C# 可识别) 源码 https://github.com/d ...

  3. Netty源码分析第5章(ByteBuf)---->第2节: ByteBuf的分类

    Netty源码分析第五章: ByteBuf 第二节: ByteBuf的分类 上一小节简单介绍了AbstractByteBuf这个抽象类, 这一小节对其子类的分类做一个简单的介绍 ByteBuf根据不同 ...

  4. 从requests源码分析中学习python(一)

    v2ex同步更新:https://www.v2ex.com/t/500081 微信公众号:python学习开发 分析源码,看大神的代码是一种学习的好方法,让我从中学到很多以前不知道的知识,这次打算从大 ...

  5. rocketmq源码分析1-benchmark学习

    benchmark 分析 组成部分 三个java类,都含有main方法,可选的传递一些参数,诸如测试线程数量,消息体积大小.三个类分别用于测试普通生产者,事务生产者,消费者.生产者 默认64个测试线程 ...

  6. Netty源码分析一<序一Unix网络I/O模型简介>

    Unix网络 I/O 模型   我们都知道,为了操作系统的安全性考虑,进程是无法直接操作I/O设备的,其必须通过系统调用请求内核来协助完成I/O动作,而内核会为每个I/O设备维护一个buffer.以下 ...

  7. 微调Inception V3网络-对Satellite分类

    目录 1. 流程概述 2. 准备数据集 2.1 Satellite数据集介绍 3. Inception V3网络 4. 训练 4.1 基于Keras微调Inception V3网络 4.2 Keras ...

  8. JDK1.8 HashMap 源码分析

    一.概述 以键值对的形式存储,是基于Map接口的实现,可以接收null的键值,不保证有序(比如插入顺序),存储着Entry(hash, key, value, next)对象. 二.示例 public ...

  9. 安卓MonkeyRunner源码分析之工作原理架构图及系列集合

    花了点时间整理了下MonkeyRunner的工作原理图,请配合本人博客里面MonkeyRunner其他源码分析文章进行阅读.下面整理成相应系列列表方便大家阅读: MonkeyRunner源码分析之-谁 ...

随机推荐

  1. hadoop ha集群搭建

    集群配置: jdk1.8.0_161 hadoop-2.6.1 zookeeper-3.4.8 linux系统环境:Centos6.5 3台主机:master.slave01.slave02 Hado ...

  2. 洛谷P1079 Vigenère 密码

    题目链接:https://www.luogu.org/problemnew/show/P1079

  3. easyui-numberbox限定仅输入数字

    许多必填项都涉及到数字,比如电话号码,身份证号这些要求用户在输入时只能输入数字.Easyui提供了数字框控件,允许用户只输入数字, <td> <input id="ssd& ...

  4. 试水新的Angular4 HTTP API

    本文来自网易云社区 作者:梁月康 原文:https://netbasal.com/a-taste-from-the-new-angular-http-client-38fcdc6b359b Angul ...

  5. C# Params的应用

    为了将方法声明为可以接受可变数量参数的方法,我们可以使用params关键字来声明数组,如下所示: public static Int32Add(params Int32[] values) { Int ...

  6. Oracle 学习笔记(Windows 环境下安装 + PL/SQL)

    Oracle 安装.PL/SQL 配置使用  前言:因更换机械硬盘为 SSD 固态硬盘装了新 Windows 7 系统,需要重新搭建开发环境,把 Oracle 安装过程和 PL/SQL 配置使用做下笔 ...

  7. jmeter快捷键

    快捷键 功能 备注 Ctrl + C 复制 可复制组件 Ctrl + V 粘贴 可粘贴组件 Ctrl + Shift + C 复制粘贴当前组件到下一行   Ctrl + R 运行测试计划   Ctrl ...

  8. 重做LVM文件系统之减小PV

    命令 缩小PV空间到120G,即PV上的vg00已将缩小到120G pvresize --setphysicalvolumesize 120g /dev/sda2 背景 机器上有一块900G本地的本地 ...

  9. JSP与JavaBeans

    JavaBeans简介 JavaBeans是一种符合一定标准的普通java类,需要满足下面几点: 1 类是public 2 属性私有 3 空的public构造方法 4 通过getter setter操 ...

  10. 【bzoj2789】[Poi2012]Letters 树状数组求逆序对

    题目描述 给出两个长度相同且由大写英文字母组成的字符串A.B,保证A和B中每种字母出现的次数相同. 现在每次可以交换A中相邻两个字符,求最少需要交换多少次可以使得A变成B. 输入 第一行一个正整数n ...