『TensorFlow』SSD源码学习_其八:网络训练
Fork版本项目地址:SSD
作者使用了分布式训练的写法,这使得训练部分代码异常臃肿,我给出了部分注释。我对于多机分布式并不很熟,而且不是重点,所以不过多介绍,简单的给出一点训练中作者的优化手段,包含优化器选择之类的。
一、滑动平均
# =================================================================== #
# Configure the moving averages.
# =================================================================== #
if FLAGS.moving_average_decay:
moving_average_variables = slim.get_model_variables()
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, global_step)
else:
moving_average_variables, variable_averages = None, None
二、学习率衰减
with tf.device(deploy_config.optimizer_device()):
learning_rate = tf_utils.configure_learning_rate(FLAGS,
dataset.num_samples,
global_step)
细节实现函数,有三种形式,一种是常数学习率,两种不同的衰减方式(默认参数:exponential):
def configure_learning_rate(flags, num_samples_per_epoch, global_step):
"""Configures the learning rate. Args:
num_samples_per_epoch: The number of samples in each epoch of training.
global_step: The global_step tensor.
Returns:
A `Tensor` representing the learning rate.
"""
decay_steps = int(num_samples_per_epoch / flags.batch_size *
flags.num_epochs_per_decay) if flags.learning_rate_decay_type == 'exponential':
return tf.train.exponential_decay(flags.learning_rate,
global_step,
decay_steps,
flags.learning_rate_decay_factor,
staircase=True,
name='exponential_decay_learning_rate')
elif flags.learning_rate_decay_type == 'fixed':
return tf.constant(flags.learning_rate, name='fixed_learning_rate')
elif flags.learning_rate_decay_type == 'polynomial':
return tf.train.polynomial_decay(flags.learning_rate,
global_step,
decay_steps,
flags.end_learning_rate,
power=1.0,
cycle=False,
name='polynomial_decay_learning_rate')
三、优化器选择
optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
选择很丰富(默认参数:adam):
def configure_optimizer(flags, learning_rate):
"""Configures the optimizer used for training. Args:
learning_rate: A scalar or `Tensor` learning rate.
Returns:
An instance of an optimizer.
"""
if flags.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(
learning_rate,
rho=flags.adadelta_rho,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(
learning_rate,
initial_accumulator_value=flags.adagrad_initial_accumulator_value)
elif flags.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(
learning_rate,
beta1=flags.adam_beta1,
beta2=flags.adam_beta2,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'ftrl':
optimizer = tf.train.FtrlOptimizer(
learning_rate,
learning_rate_power=flags.ftrl_learning_rate_power,
initial_accumulator_value=flags.ftrl_initial_accumulator_value,
l1_regularization_strength=flags.ftrl_l1,
l2_regularization_strength=flags.ftrl_l2)
elif flags.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
learning_rate,
momentum=flags.momentum,
name='Momentum')
elif flags.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
learning_rate,
decay=flags.rmsprop_decay,
momentum=flags.rmsprop_momentum,
epsilon=flags.opt_epsilon)
elif flags.optimizer == 'sgd':
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
else:
raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
return optimizer
四、训练
实际上中间有好一段分布式梯度计算过程,这里不多介绍,大概就是在各个clone上计算出梯度,汇总梯度,再优化各个clone网络,将优化节点提出作为train_tensor等等。
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(log_device_placement=False,
gpu_options=gpu_options)
saver = tf.train.Saver(max_to_keep=5,
keep_checkpoint_every_n_hours=1.0,
write_version=2,
pad_step_number=False)
slim.learning.train(
train_tensor,
logdir=FLAGS.train_dir,
master='',
is_chief=True,
init_fn=tf_utils.get_init_fn(FLAGS), # 看函数实现就明白了,assign变量用
summary_op=summary_op, # tf.summary.merge节点
number_of_steps=FLAGS.max_number_of_steps, # 训练step
log_every_n_steps=FLAGS.log_every_n_steps, # 输出训练信息间隔
save_summaries_secs=FLAGS.save_summaries_secs, # 每次summary时间间隔
saver=saver, # tf.train.Saver节点
save_interval_secs=FLAGS.save_interval_secs, # 每次model保存step间隔
session_config=config, # sess参数
sync_optimizer=None)
其中调用的初始化函数如下:
def get_init_fn(flags):
"""Returns a function run by the chief worker to warm-start the training.
Note that the init_fn is only run when initializing the model during the very
first global step. Returns:
An init function run by the supervisor.
"""
if flags.checkpoint_path is None:
return None
# Warn the user if a checkpoint exists in the train_dir. Then ignore.
if tf.train.latest_checkpoint(flags.train_dir):
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% flags.train_dir)
return None exclusions = []
if flags.checkpoint_exclude_scopes:
exclusions = [scope.strip()
for scope in flags.checkpoint_exclude_scopes.split(',')] # TODO(sguada) variables.filter_variables()
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
# Change model scope if necessary.
if flags.checkpoint_model_scope is not None:
variables_to_restore = \
{var.op.name.replace(flags.model_name,
flags.checkpoint_model_scope): var
for var in variables_to_restore} if tf.gfile.IsDirectory(flags.checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
else:
checkpoint_path = flags.checkpoint_path
tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) return slim.assign_from_checkpoint_fn(
checkpoint_path,
variables_to_restore,
ignore_missing_vars=flags.ignore_missing_vars)
至此,SSD项目介绍完毕,训练命令如下,不过默认训练step是无限的,不手动终止会一直训练下去,所以要关注一下训练的指标,够用了就关了吧,
DATASET_DIR=./tfrecords
TRAIN_DIR=./logs/
CHECKPOINT_PATH=./checkpoints/ssd_300_vgg.ckpt
python train_ssd_network.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=pascalvoc_2012 \
--dataset_split_name=train \
--model_name=ssd_300_vgg \
--checkpoint_path=${CHECKPOINT_PATH} \
--save_summaries_secs=60 \
--save_interval_secs=600 \
--weight_decay=0.0005 \
--optimizer=adam \
--learning_rate=0.001 \
--batch_size=32
如何使用训练好模型见集智专栏的文章最后一部分。
『TensorFlow』SSD源码学习_其八:网络训练的更多相关文章
- 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍
一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...
- 『TensorFlow』SSD源码学习_其五:TFR数据读取&数据预处理
Fork版本项目地址:SSD 一.TFR数据读取 创建slim.dataset.Dataset对象 在train_ssd_network.py获取数据操作如下,首先需要slim.dataset.Dat ...
- 『TensorFlow』SSD源码学习_其四:数据介绍及TFR文件生成
Fork版本项目地址:SSD 一.数据格式介绍 数据文件夹命名为VOC2012,内部有5个子文件夹,如下, 我们的检测任务中使用JPEGImages文件夹和Annotations文件夹. JPEGIm ...
- 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构
Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...
- 『TensorFlow』SSD源码学习_其七:损失函数
Fork版本项目地址:SSD 一.损失函数介绍 SSD损失函数分为两个部分:对应搜索框的位置loss(loc)和类别置信度loss(conf).(搜索框指网络生成的网格) 详细的说明如下: i指代搜索 ...
- 『TensorFlow』SSD源码学习_其六:标签整理
Fork版本项目地址:SSD 一.输入标签生成 在数据预处理之后,图片.类别.真实框格式较为原始,不能够直接作为损失函数的输入标签(ssd向前网络只需要图像就行,这里的处理主要需要满足loss的计算) ...
- 『TensorFlow』SSD源码学习_其三:锚框生成
Fork版本项目地址:SSD 上一节中我们定义了vgg_300的网络结构,实际使用中还需要匹配SSD另一关键组件:被选取特征层的搜索网格.在项目中,vgg_300网络和网格生成都被统一进一个class ...
- nginx源码学习_源码结构
nginx的优秀除了体现在程序结构以及代码风格上,nginx的源码组织也同样简洁明了,目录结构层次结构清晰,值得我们去学习.nginx的源码目录与nginx的模块化以及功能的划分是紧密结合,这也使得我 ...
- 『TensorFlow』读书笔记_TFRecord学习
一.程序介绍 1.包导入 # Author : Hellcat # Time : 17-12-29 import os import numpy as np np.set_printoptions(t ...
随机推荐
- 3. Elements of a Test Plan
https://jmeter.apache.org/usermanual/test_plan.html This section describes the different parts of a ...
- 不能安装64位office提示已安装32位的
安装64位office办公软件的时候提示已经安装32位的office办公软件所以无法继续安装,但实际上之前安装的32位的office办公软件已经卸载了.问题现象截图如下: 从问题描述中,我们其实已经能 ...
- Entity Framework Core
Entity Framework是一种支持 .NET 开发人员使用 .NET 对象处理数据库的对象关系映射程序 (O/RM). 它不要求提供开发人员通常需要编写的大部分数据访问代码. Entity F ...
- Redis架构设计
高可用Redis服务架构分析与搭建 各种web开发业务中最为常用的key-value数据库了 应用: 在业务中用其存储用户登陆态(Session存储),加速一些热数据的查询(相比较mysql而言,速度 ...
- JavaScript——语法与数据类型
严格模式 ECMA5引入了严格模式的概念.严格模式是为JavaScript定义了一种不同的解析与执行模型.在严格模式下,ECMA3中的一些不确定的行为将得到处理,而且对某些不安全的操作也会抛出错误.要 ...
- HDU 5119 Happy Matt Friends(递推)
http://acm.hdu.edu.cn/showproblem.php?pid=5119 题意:给出n个数和一个上限m,求从这n个数里取任意个数做异或运算,最后的结果不小于m有多少种取法. 思路: ...
- 无法启动此程序,因为计算机中丢失api-ms-win-crt-runtime-|1-1-0.dll
今天想把自己电脑上的python2换成python3时,安装完python3后,命令行启动时需要出现了上述错误,在网上查了资料后应该是库文件遭到了破坏,于是我下了一个东西安装后就解决了,如果出现了此问 ...
- python学习 day10打卡 函数的进阶
本节主要内容: 1.函数参数--动态参数 2.名称空间,局部名称空间,全局名称空间,作用域,加载顺序. 3.函数的嵌套 4.gloabal,nonlocal关键字 一.函数参数--动态传参 形参的第三 ...
- BZOJ 3878 【AHOI2014】 奇怪的计算器
题目链接:奇怪的计算器 如果没有溢出的话,所有的标记都可以在线段树上直接维护,所以一棵线段树就解决问题了. 现在有了溢出,怎么办呢? 发现就算溢出了,各个元素的相对大小关系也是不变的.所以,如果一开始 ...
- EditPlus查找替换
换行符\n,记得选择正则表达式 1]正则表达式应用——替换指定内容到行尾解决:① 在替换对话框,查找内容里输入“abc.*”② 同时勾选“正则表达式”复选框,然后点击“全部替换”按钮其中,符号的含义如 ...