『TensorFlow』SSD源码学习_其八:网络训练
Fork版本项目地址:SSD
作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释。我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含优化器选择之类的。
一、滑动平均
# =================================================================== #
# Configure the moving averages.
# =================================================================== #
if FLAGS.moving_average_decay:
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
else:
moving_average_variables, variable_averages = None, None
二、学习率衰减
with tf.device(deploy_config.optimizer_device()):
learning_rate = tf_utils.configure_learning_rate(FLAGS,
dataset.num_samples,
global_step)
细节实现函数,有三种形式,一种是常数学习率,两种不同的衰减方式(默认参数:exponential):
def configure_learning_rate(flags, num_samples_per_epoch, global_step):
"""Configures the learning rate. Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
decay_steps = int(num_samples_per_epoch / flags.batch_size *
flags.num_epochs_per_decay) if flags.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(flags.learning_rate,
global_step,
decay_steps,
flags.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif flags.learning_rate_decay_type == 'fixed':
return tf.constant(flags.learning_rate, name='fixed_learning_rate')
elif flags.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(flags.learning_rate,
global_step,
decay_steps,
flags.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
三、优化器选择
optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
选择很丰富(默认参数:adam):
def configure_optimizer(flags, learning_rate):
"""Configures the optimizer used for training. Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
"""
if flags.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=flags.adadelta_rho,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=flags.adagrad_initial_accumulator_value)
elif flags.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=flags.adam_beta1,
beta2=flags.adam_beta2,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=flags.ftrl_learning_rate_power,
initial_accumulator_value=flags.ftrl_initial_accumulator_value,
l1_regularization_strength=flags.ftrl_l1,
l2_regularization_strength=flags.ftrl_l2)
elif flags.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=flags.momentum,
name='Momentum')
elif flags.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=flags.rmsprop_decay,
momentum=flags.rmsprop_momentum,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
return optimizer
四、训练
实际上中间有好一段分布式梯度计算过程,这里不多介绍,大概就是在各个clone上计算出梯度,汇总梯度,再优化各个clone网络,将优化节点提出作为train_tensor等等。
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
keep_checkpoint_every_n_hours=1.0,
write_version=2,
pad_step_number=False)
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master='',
is_chief=True,
init_fn=tf_utils.get_init_fn(FLAGS), # 看函数实现就明白了,assign变量用
summary_op=summary_op, # tf.summary.merge节点
number_of_steps=FLAGS.max_number_of_steps, # 训练step
log_every_n_steps=FLAGS.log_every_n_steps, # 输出训练信息间隔
save_summaries_secs=FLAGS.save_summaries_secs, # 每次summary时间间隔
saver=saver, # tf.train.Saver节点
save_interval_secs=FLAGS.save_interval_secs, # 每次model保存step间隔
session_config=config, # sess参数
sync_optimizer=None)
其中调用的初始化函数如下:
def get_init_fn(flags):
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step. Returns:
An init function run by the supervisor.
"""
if flags.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then ignore.
if tf.train.latest_checkpoint(flags.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% flags.train_dir)
return None exclusions = []
if flags.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in flags.checkpoint_exclude_scopes.split(',')] # TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# Change model scope if necessary.
if flags.checkpoint_model_scope is not None:
variables_to_restore = \
{var.op.name.replace(flags.model_name,
flags.checkpoint_model_scope): var
for var in variables_to_restore} if tf.gfile.IsDirectory(flags.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
else:
checkpoint_path = flags.checkpoint_path
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=flags.ignore_missing_vars)
至此,SSD项目介绍完毕,训练命令如下,不过默认训练step是无限的,不手动终止会一直训练下去,所以要关注一下训练的指标,够用了就关了吧,
DATASET_DIR=./tfrecords
TRAIN_DIR=./logs/
CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
python train_ssd_network.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=pascalvoc_2012 \
--dataset_split_name=train \
--model_name=ssd_300_vgg \
--checkpoint_path=${CHECKPOINT_PATH} \
--save_summaries_secs=60 \
--save_interval_secs=600 \
--weight_decay=0.0005 \
--optimizer=adam \
--learning_rate=0.001 \
--batch_size=32
如何使用训练好模型见集智专栏的文章最后一部分。
『TensorFlow』SSD源码学习_其八:网络训练的更多相关文章
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理
Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...
- 『TensorFlow』SSD源码学习_其四:数据介绍及TFR文件生成
Fork版本项目地址:SSD 一.数据格式介绍 数据文件夹命名为VOC2012,内部有5个子文件夹,如下, 我们的检测任务中使用JPEGImages文件夹和Annotations文件夹. JPEGIm ...
- 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构
Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...
- 『TensorFlow』SSD源码学习_其七:损失函数
Fork版本项目地址:SSD 一.损失函数介绍 SSD损失函数分为两个部分:对应搜索框的位置loss(loc)和类别置信度loss(conf).(搜索框指网络生成的网格) 详细的说明如下: i指代搜索 ...
- 『TensorFlow』SSD源码学习_其六:标签整理
Fork版本项目地址:SSD 一.输入标签生成 在数据预处理之后,图片.类别.真实框格式较为原始,不能够直接作为损失函数的输入标签(ssd向前网络只需要图像就行,这里的处理主要需要满足loss的计算) ...
- 『TensorFlow』SSD源码学习_其三:锚框生成
Fork版本项目地址:SSD 上一节中我们定义了vgg_300的网络结构,实际使用中还需要匹配SSD另一关键组件:被选取特征层的搜索网格.在项目中,vgg_300网络和网格生成都被统一进一个class ...
- nginx源码学习_源码结构
nginx的优秀除了体现在程序结构以及代码风格上,nginx的源码组织也同样简洁明了,目录结构层次结构清晰,值得我们去学习.nginx的源码目录与nginx的模块化以及功能的划分是紧密结合,这也使得我 ...
- 『TensorFlow』读书笔记_TFRecord学习
一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...
随机推荐
- 【原理、注意点】Quartz的原理和需要注意的地方
基本介绍和核心接口 1.quartz是完全基于java的可用于进行定时任务调度的开源框架,使用的时候需要引入: <dependency> <groupId>org.quartz ...
- .Net Core项目在Docker上运行,内存占用过多导致pods重启的问题
默认情况下,.NET Core应用的内存回收模式是Server模式,这种情况下,内存占用和服务器核心数量有关,一半占用量比较大. 我们的应用目前吞吐量都不大,可以采用Workstation模式,这种模 ...
- CommandLine exe参数
[Verb("OptionsEntity")] public class OptionsEntity { [Option('a', HelpText = "Plantfo ...
- Writing device drivers in Linux: A brief tutorial
“Do you pine for the nice days of Minix-1.1, when men were men and wrote their own device drivers?” ...
- php的Allowed memory size of 134217728 bytes exhausted问题
提示Allowed memory size of 134217728 bytes exhausted,出现这种错误的情况常见的有三种: 0:查询的数据量大. 1:数据量不大,但是php.ini配置的内 ...
- Java中泛型Class<T>、T与Class<?>、 Object类和Class类、 object.getClass()和Object.class
一.区别 单独的T 代表一个类型(表现形式是一个类名而已) ,而 Class<T>代表这个类型所对应的类(又可以称做类实例.类类型.字节码文件), Class<?>表示类型不确 ...
- mysql联合主键自增、主键最大长度小记
前言 一. 联合主键自增问题 今天上午闲来无事翻看了下数据库分类表的设计,看到这样一幕: 当时我好奇的是怎么cateId自增会存在重复值的问题,然后翻看了下主键是由siteId和cateId组成.所以 ...
- Java转义形如nbsp;的HTML编码
需要引用一个maven <!-- https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 --> <d ...
- 因样式冲突引起的div消失问题
工作需要,搭建一个网站的模型,简单分成三个部分,标题栏,导航栏,主界面,效果如图: 但是点击界面的任意地方,中间的div块消失了,如图所示: 调试,发现在点击界面其他地方的时候display属性有变化 ...
- rtrim
<?php $str = '14岁'; $new_str = rtrim($str, '岁'); echo $new_str; 如果右边是'岁',就过滤掉.