此处纯粹作为个人学习使用,原文连接:https://www.jianshu.com/p/dc24e54aec81

这篇文章是借鉴很多博文的,作为一个关于slim库的总结

导入slim模块

import tensorflow.contrib.slim as slim

定义slim的变量

#Model Variables
weights = slim.model_variable('weights', shape = [10, 10, 3, 3],
initializer = tf.truncated_normal_initializer(stddev=0.1)
regularizer = slim.l2_regularizer(0.05),
device='/CPU:0')
model_variables = slim.get_model_variables() #获取变量吗? # Regular variables
my_var = slim.variable('my_var", shape=[20, 1],
initializer = tf.zeros_initializer())
regular_variables_and_model_variables = slim.get_variables()

# 这里的model_variable是作为模型参数保存的,variable是局部变量,不会保存。

Slim中实现一个层

input = ...
net = slim.conv2d(input, 128, [3,3], scope='conv1_1') # 代码重用
net = slim.repeat(net, 3, slim.conv2d, 256, [3,3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool2') # 处理不同参数情况
x = slim.fully_connected(x, 32, scope='fc/fc_1')
x = slim.fully_connected(x, 64, scope ='fc/fc_2')
x = slim.fuly_connected(x, 128, scope = 'fc/fc_3')
# or
slim.stack(x, slim.fully_connected, [32, 64, 128], scope='fc') # 普通方法
x = slim.conv2d(x, 32, [3, 3], scope='core/core_1')
x = slim.conv2d(x, 32, [1, 1], scope='core/core_2')
x = slim.conv2d(x, 64, [3, 3], scope='core/core_3')
x = slim.conv2d(x, 64, [1, 1], scope='core/core_4') # 简便方法:
slim.stack(x, slim.conv2d, [(32, [3,3]), (32, [1,1]), (64, [3,3]), (64, [1,1]), scopre='core')

定义相同参数的简化

with slim.arg_scope([slim.conv2d],  padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
net = slim.conv2d(inputs, 64, [11, 11], scope='conv1')
net = slim.conv2d(net, [11,11], padding=' VALID', scope='conv2')
net = slim.conv2d(net, 256, [11, 11], scope='conv3') # arg_scope的嵌套
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.rely,
weights_initializer=tf.truncated_normal_initialier(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'):
net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
net = slim.conv2d(net, 256, [5, 5],
weights_initializer=tf.truncated_normal_initializer(stddev=0.03),
scope='conv2')
net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')

训练模型

loss = slim.losses.softmax_cross_entropy(predictions, labels)
# 自定义loss模型
# define the loss functions and get the total loss.
classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)
sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)
pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)
slim.losses.add_loss(pose_loss) # Letting TF-Slim know about the additional loss. # The following two ways to compute the total loss are equivalent:
regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
total_loss1 = classification_loss + sum_of_squares_loss + poses_loss + regularization_loss

# slim读取保存模型的方法

# Create some variables.
v1 = slim.variable(name='v1', ...)
v2 = slim.variable(name=''nested/v2', ...)
... # Get list of variables to restore (which contains only 'v2')
variables_to_restore = slim.get_variables_by_name("v2") # Create the saver which will be used to restore the varialbes.
restorer = tf.train.Saver(variables_to_restore) with tf.Session() as sess:
# Restore variables from disk.
restores.restore(sess, "/tmp/model.ckpt")
print("Model restored.") # 为模型添加变量前缀
# 假设我们定义的网络变量是conv1/weights, 而从VGG记载的变量名为#vgg16/conv1/weights, 正常load肯定会报错
def name_in_checkpoint(var):
return 'vgg16/' + var.op.name variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore) with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "/tmp/model.ckpt")

训练模型

在该例中, slim.learning.train根据train_op计算损失、应用梯度step. logdir指定checkpoints和event文件的存储路径。我们可以限制梯度step到任何数值。这里我们采用1000步。最后, save_summaries_secs=300表示每5分钟计算一次summaries, save_interval_secs=600表示每10分钟保存一次模型的checkpoint

g = tf.Graph()

# Create the model and specify the losses...
... total_loss = slim.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate) # create_train_op ensures that each time we ask for the loss, the update_ops
# are run and the gradients being computed are applied too.
train_op = slim.learning.create_train_op(total_loss, optimizer)
logdir = ... # Where checkpoints are stored. slim.learning.train(
train_op,
logdir,
number_of_steps=1000,
save_summaries_secs=300,
save_interval_secs=600)

Fine-Tuning a Model on a different task

假设我们有一个已经预训练好的VGG16的模型。这个模型是在拥有1000分类的ImageNet数据集上进行训练的。但是,现在我们想把它应用只具有20个分类的Pascal VOC数据集上。为了能这样做,我们可以通过利用除最后一些全连接层的其它预训练模型来初始化新模型的达到目的:

# Load the Pascal VOC data
image, label = MyPascalVocDataLoader(...)
images, labels = tf.train.batch([image, label], batch_size = 32) # Create the model
predictions = vgg.vgg_16(images)
train_op = slim.learning.create_train_op(...) # Specify where the Model, trained on ImageNet, was saved.
model_path = '/path/to/pre_trained_on_imagenet.checkpoint'
metric_ops.py
# Specify where the new model will live:
log_dir = from_checkpoint_'/path/to/my_pascal_model_dir/' # Restore only the convolutional layers:
variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore) # Start training.
slim.learning.train(train_op, log_dir, init_fn=init_fn)

evaluation loop

import tensorflow as tf

slim = tf.contrib.slim

# Load the data
images, labels = load_data(...) # Define the network
predictions = MyModel(images) # Choose the metrics to compute:
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'accuracy': slim.metrics.accuracy(predictions, labels),
'precision': slim.metrics.precision(predictions, labels),
'recall': slim.metrics.recall(mean_relative_errors, 0.3),
}) # Create the summary ops such that they also print out to std output:
summary_ops = []
for metric_name, metric_value in names_to_values.iteritems():
op = tf.summary.scalar(metric_name, metric_value)
op = tf.Print(op, [metric_value], metric_name)
summary_ops.append(op) num_examples = 10000
batch_size = 32
num_batches = math.ceil(num_examples / float(batch_size)) # Setup the global step.
slim.get_or_create_global_step() output_dir = ... # Where the summaries are stored.
eval_interval_secs = ... # How often to run the evaluation.
slim.evaluation.evaluation_loop(
'local',
checkpoint_dir,
log_dir,
num_evals=num_batches,
eval_op=names_to_updates.values(),
summary_op=tf.summary.merge(summary_ops),
eval_interval_secs=eval_interval_secs)

tensorflow slim代码使用的更多相关文章

  1. 使用笔记:TF辅助工具--tensorflow slim(TF-Slim)

    如果抛开Keras,TensorLayer,tfLearn,tensroflow 能否写出简介的代码? 可以!slim这个模块是在16年新推出的,其主要目的是来做所谓的“代码瘦身” 一.简介 slim ...

  2. 解决TensorFlow最新代码编译错误问题

    老是有个习惯,看到开源代码更新了,总是想更新到最新版,如果置之不理的话,就感觉自己懒惰了或有的不负责任了,这个也可能是一种形式的强迫症吧: 前几天晚上git pull TensorFlow,完事后也没 ...

  3. tensorflow没有代码提示的问题

    在tensorflow包下的__init__.py文件中定义了一个contrib变量表示tensorflow.contrib包下的内容,但是tensorflow.contrib这个包是懒加载的,也就是 ...

  4. google tensorflow bert代码分析

    参考网上博客阅读了bert的代码,记个笔记.代码是 bert_modeling.py 参考的博客地址: https://blog.csdn.net/weixin_39470744/article/de ...

  5. tensorflow训练代码

    from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf mnist = input_dat ...

  6. tensorflow TensorArray 代码例子

    import tensorflow as tf import numpy as np B=3 D=4 T=5 tf.reset_default_graph() xs=tf.placeholder(sh ...

  7. Tensorflow模型代码调试问题

    背景: 不知道大家有没有这样的烦恼:在使用Tensorflow搭建好模型调试的过程中,经常会碰到一些问题,当时花了不少时间把这个问题解决了,一段时间后,又出现了同样的问题,却怎么也不记得之前是怎么解决 ...

  8. TensorFlow Slim 的常用操作

    https://blog.csdn.net/mzpmzk/article/details/81706379

  9. tensorflow slim训练以及到安卓部署教程

    https://blog.csdn.net/chenyuping333/article/details/81537551 https://blog.csdn.net/u012328159/articl ...

随机推荐

  1. Linux系统iptables与firewalld防火墙

    iptables iptables服务用于处理或过滤流量的策略条目(规则),多条规则可以组成一个规则链,而规则链则依据数据包处理位置的不同进行分类. 在进行路由选择前处理数据包(PREROUTING) ...

  2. tomcat在centos7能启动不显示

    首先查看启动日志,日志显示成功启动,java路径也对,没有问题. 日志目录路径为$(tomcat)/logs/catalina.log 查看命令为:tail -300f catalina.log 然后 ...

  3. LinAlgError: Last 2 dimensions of the array must be square

    python 矩阵计算时出现错误 此时如果矩阵不是方阵,就会出现如下错误: 这里值得注意的是:当我们这样使用的时候,程序运行又会很正常: 貌似我们求得了非方阵的逆. 下面我们来验证一下: 仔细一看,这 ...

  4. CSS定位的属性值

    关于CSS定位都是老生常谈的问题了,不过有一个问题,最新的属性值在某些网站上并没有被更新到教程上 下面我记录一下 position现在有五个属性值 1.static:静态定位,没有特殊的定位规则,遵循 ...

  5. 前端深入之js篇丨Array数组操作从入门到成神Up Up Up,持续更新中

    写在前面 随着前端深入的不断学习,发现数组这个数据结构在前端中有着相当大的存在感,由于我初学前端的时候并没有系统性的学习数组,所以我将通过这篇文章同你一起学习数组,希望我们能一起进步,学会熟练操作数组 ...

  6. 【spring-boot 源码解析】spring-boot 依赖管理梳理图

    在文章 [spring-boot 源码解析]spring-boot 依赖管理 中,我梳理了 spring-boot-build.spring-boot-parent.spring-boot-depen ...

  7. MySql入门知识(一)

    概述 MySQL是一个真正多用户,多线程结构化查询语言数据库服务器.MySQL是一个关系型数据库管理系统,由瑞典MySQL AB公司开发,目前属于Oracle公司.MySQL的SQL语言是用于访问数据 ...

  8. .Net Core MVC理解新管道处理模型、中间件

    .Net Core中间件官网:https://docs.microsoft.com/zh-cn/aspnet/core/fundamentals/middleware/?view=aspnetcore ...

  9. PHP strstr 字符串函数

    定义和用法 strstr - 查找字符串的首次出现 版本支持 PHP4 PHP5 PHP7 支持 支持 支持 V5.3.0 新增可选的 before_needle 参数. V4.3.0 strstr( ...

  10. ES6-Set的增加、查找、删除、遍历、查看长度、数组去重

    set 是es6新出的一种数据结构,里边放的是数组. 作用:去重(set里边的数组不能重复) MDN:Set 对象允许你存储任何类型的唯一值,无论是原始值或者是对象引用. 总结: 1.成员唯一.无序且 ...