『TensorFlow』SSD源码学习_其八:网络训练
Fork版本项目地址:SSD
作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释。我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含优化器选择之类的。
一、滑动平均
# =================================================================== #
# Configure the moving averages.
# =================================================================== #
if FLAGS.moving_average_decay:
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
else:
moving_average_variables, variable_averages = None, None
二、学习率衰减
with tf.device(deploy_config.optimizer_device()):
learning_rate = tf_utils.configure_learning_rate(FLAGS,
dataset.num_samples,
global_step)
细节实现函数,有三种形式,一种是常数学习率,两种不同的衰减方式(默认参数:exponential):
def configure_learning_rate(flags, num_samples_per_epoch, global_step):
"""Configures the learning rate. Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
decay_steps = int(num_samples_per_epoch / flags.batch_size *
flags.num_epochs_per_decay) if flags.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(flags.learning_rate,
global_step,
decay_steps,
flags.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif flags.learning_rate_decay_type == 'fixed':
return tf.constant(flags.learning_rate, name='fixed_learning_rate')
elif flags.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(flags.learning_rate,
global_step,
decay_steps,
flags.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
三、优化器选择
optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
选择很丰富(默认参数:adam):
def configure_optimizer(flags, learning_rate):
"""Configures the optimizer used for training. Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
"""
if flags.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=flags.adadelta_rho,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=flags.adagrad_initial_accumulator_value)
elif flags.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=flags.adam_beta1,
beta2=flags.adam_beta2,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=flags.ftrl_learning_rate_power,
initial_accumulator_value=flags.ftrl_initial_accumulator_value,
l1_regularization_strength=flags.ftrl_l1,
l2_regularization_strength=flags.ftrl_l2)
elif flags.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=flags.momentum,
name='Momentum')
elif flags.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=flags.rmsprop_decay,
momentum=flags.rmsprop_momentum,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
return optimizer
四、训练
实际上中间有好一段分布式梯度计算过程,这里不多介绍,大概就是在各个clone上计算出梯度,汇总梯度,再优化各个clone网络,将优化节点提出作为train_tensor等等。
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
keep_checkpoint_every_n_hours=1.0,
write_version=2,
pad_step_number=False)
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master='',
is_chief=True,
init_fn=tf_utils.get_init_fn(FLAGS), # 看函数实现就明白了,assign变量用
summary_op=summary_op, # tf.summary.merge节点
number_of_steps=FLAGS.max_number_of_steps, # 训练step
log_every_n_steps=FLAGS.log_every_n_steps, # 输出训练信息间隔
save_summaries_secs=FLAGS.save_summaries_secs, # 每次summary时间间隔
saver=saver, # tf.train.Saver节点
save_interval_secs=FLAGS.save_interval_secs, # 每次model保存step间隔
session_config=config, # sess参数
sync_optimizer=None)
其中调用的初始化函数如下:
def get_init_fn(flags):
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step. Returns:
An init function run by the supervisor.
"""
if flags.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then ignore.
if tf.train.latest_checkpoint(flags.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% flags.train_dir)
return None exclusions = []
if flags.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in flags.checkpoint_exclude_scopes.split(',')] # TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# Change model scope if necessary.
if flags.checkpoint_model_scope is not None:
variables_to_restore = \
{var.op.name.replace(flags.model_name,
flags.checkpoint_model_scope): var
for var in variables_to_restore} if tf.gfile.IsDirectory(flags.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
else:
checkpoint_path = flags.checkpoint_path
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=flags.ignore_missing_vars)
至此,SSD项目介绍完毕,训练命令如下,不过默认训练step是无限的,不手动终止会一直训练下去,所以要关注一下训练的指标,够用了就关了吧,
DATASET_DIR=./tfrecords
TRAIN_DIR=./logs/
CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
python train_ssd_network.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=pascalvoc_2012 \
--dataset_split_name=train \
--model_name=ssd_300_vgg \
--checkpoint_path=${CHECKPOINT_PATH} \
--save_summaries_secs=60 \
--save_interval_secs=600 \
--weight_decay=0.0005 \
--optimizer=adam \
--learning_rate=0.001 \
--batch_size=32
如何使用训练好模型见集智专栏的文章最后一部分。
『TensorFlow』SSD源码学习_其八:网络训练的更多相关文章
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理
Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...
- 『TensorFlow』SSD源码学习_其四:数据介绍及TFR文件生成
Fork版本项目地址:SSD 一.数据格式介绍 数据文件夹命名为VOC2012,内部有5个子文件夹,如下, 我们的检测任务中使用JPEGImages文件夹和Annotations文件夹. JPEGIm ...
- 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构
Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...
- 『TensorFlow』SSD源码学习_其七:损失函数
Fork版本项目地址:SSD 一.损失函数介绍 SSD损失函数分为两个部分:对应搜索框的位置loss(loc)和类别置信度loss(conf).(搜索框指网络生成的网格) 详细的说明如下: i指代搜索 ...
- 『TensorFlow』SSD源码学习_其六:标签整理
Fork版本项目地址:SSD 一.输入标签生成 在数据预处理之后,图片.类别.真实框格式较为原始,不能够直接作为损失函数的输入标签(ssd向前网络只需要图像就行,这里的处理主要需要满足loss的计算) ...
- 『TensorFlow』SSD源码学习_其三:锚框生成
Fork版本项目地址:SSD 上一节中我们定义了vgg_300的网络结构,实际使用中还需要匹配SSD另一关键组件:被选取特征层的搜索网格.在项目中,vgg_300网络和网格生成都被统一进一个class ...
- nginx源码学习_源码结构
nginx的优秀除了体现在程序结构以及代码风格上,nginx的源码组织也同样简洁明了,目录结构层次结构清晰,值得我们去学习.nginx的源码目录与nginx的模块化以及功能的划分是紧密结合,这也使得我 ...
- 『TensorFlow』读书笔记_TFRecord学习
一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...
随机推荐
- SCU 4437 Carries(二分乱搞)题解
题意:问任意两对ai,aj相加的总进位数为多少.比如5,6,95分为(5,6)(5,95)(6,95),进位数 = 1 + 2 + 2 = 5 思路:显然暴力是会超时的.我们可以知道总进位数等于每一位 ...
- python 之 文件I/0
打开和关闭文件 open()函数 必须要open()内置函数打开一个文件,创建一个file对象,相关的方法才可以调用它进行读写. 语法 file object=open(file_name [,acc ...
- Vue的生命周期(钩子函数)
Vue一共有10个生命周期函数,我们可以利用这些函数在vue的每个阶段都进行操作数据或者改变内容. 其实在Vue的官网有一张图已经很好的诠释了生命周期,我在这里就不再多讲了,直接贴图,然后上程序代码. ...
- 【Hadoop 分布式部署 五:分布式部署之分发、基本测试及监控】
1.对 hadoop 进行格式化 到 /opt/app/hadoop-2.5.0 目录下 执行命令: bin/hdfs namenode -format 执行的效果图如下 ( 下图成功 ...
- Java基础 【自动装箱和拆箱、面试题】
JDK 1.5 (以后的版本)的新特性自动装箱和拆箱 1. 自动装箱:把基本类型转换为包装类类型 int a =10; Integer i = new Integer(a); Integer valu ...
- 浅谈IIS 和 asp.net的应用之间的关系
IIS可以理解为一个web服务器. 用于提供web相关的各种服务. IIS6.0中添加了一个新的功能, application pool. application pool的作用是将运行在同一个ser ...
- 复习ing
记不住啊,我能有什么办法,只好一遍又一遍看
- 给大家讲个故事,感受一下什么叫CF。不知道的请认真听。
我朋友是个温柔.体贴.负责.做事认真和口才流利的好男人 他在大一时喜欢上别系的女同学,像他这样的好人我以为这段恋情是手到擒来 但并没有,女方只把它当工具人,一当就当了四年 身为室友的我每天看著她为女方 ...
- 【译】第39节---EF6-数据库命令日志
原文:http://www.entityframeworktutorial.net/entityframework6/database-command-logging.aspx 本节,我们学习如何记录 ...
- HDU 1247 Hat’s Words(字典树)
http://acm.hdu.edu.cn/showproblem.php?pid=1247 题意: 给出一些单词,问哪些单词可以正好由其他的两个单词首尾相连而成. 思路: 先将所有单独插入字典树,然 ...