使用笔记:TF辅助工具--tensorflow slim(TF-Slim)
如果抛开Keras,TensorLayer,tfLearn,tensroflow 能否写出简介的代码? 可以!slim这个模块是在16年新推出的,其主要目的是来做所谓的“代码瘦身”
一.简介
slim被放在tensorflow.contrib这个库下面,导入的方法如下:
import tensorflow.contrib.slim as slim
众所周知 tensorflow.contrib这个库,tensorflow官方对它的描述是:此目录中的任何代码未经官方支持,可能会随时更改或删除。每个目录下都有指定的所有者。它旨在包含额外功能和贡献,最终会合并到核心TensorFlow中,但其接口可能仍然会发生变化,或者需要进行一些测试,看是否可以获得更广泛的接受。所以slim依然不属于原生tensorflow。
slim是一个使构建,训练,评估神经网络变得简单的库。它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更紧凑,更具备可读性。另外slim提供了很多计算机视觉方面的著名模型(VGG, AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。
slim的子模块及功能介绍:
arg_scope: provides a new scope named arg_scope that allows a user to define default arguments for specific operations within that scope.
除了基本的namescope,variabelscope外,又加了argscope,它是用来控制每一层的默认超参数的。(后面会详细说)
data: contains TF-slim's dataset definition, data providers, parallel_reader, and decoding utilities.
貌似slim里面还有一套自己的数据定义,这个跳过,我们用的不多。
evaluation: contains routines for evaluating models.
评估模型的一些方法,用的也不多
layers: contains high level layers for building models using tensorflow.
这个比较重要,slim的核心和精髓,一些复杂层的定义
learning: contains routines for training models.
一些训练规则
losses: contains commonly used loss functions.
一些loss
metrics: contains popular evaluation metrics.
评估模型的度量标准
nets: contains popular network definitions such as VGG and AlexNet models.
包含一些经典网络,VGG等,用的也比较多
queues: provides a context manager for easily and safely starting and closing QueueRunners.
文本队列管理,比较有用。
regularizers: contains weight regularizers.
包含一些正则规则
variables: provides convenience wrappers for variable creation and manipulation.
slim管理变量的机制
二.slim定义模型
slim中定义一个变量的示例:
# Model Variables
weights = slim.model_variable('weights', shape=[10, 10, 3 , 3], initializer=tf.truncated_normal_initializer(stddev=0.1), regularizer=slim.l2_regularizer(0.05), device='/CPU:0')model_variables = slim.get_model_variables()# Regular variablesmy_var = slim.variable('my_var', shape=[20, 1], initializer=tf.zeros_initializer())regular_variables_and_model_variables = slim.get_variables()slim中实现一个层:
首先让我们看看tensorflow怎么实现一个层,例如卷积层:
input = ...
with tf.name_scope('conv1_1') as scope:kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32, stddev=1e-1), name='weights'conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32), trainable=True, name='biases')bias = tf.nn.bias_add(conv, biases)conv1 = tf.nn.relu(bias, name=scope)input = ...net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')net = ...
net = slim.conv2d(net, 256, [3, 3], scope='conv3_1')net = slim.conv2d(net, 256, [3, 3], scope='conv3_2')net = slim.conv2d(net, 256, [3, 3], scope='conv3_3')net = slim.max_pool2d(net, [2, 2], scope='pool2')net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool2')假设定义三层FC:
# Verbose way:
x = slim.fully_connected(x, 32, scope='fc/fc_1')x = slim.fully_connected(x, 64, scope='fc/fc_2')x = slim.fully_connected(x, 128, scope='fc/fc_3')32, 64, 128], scope='fc')# 普通方法:
x = slim.conv2d(x, 32, [3, 3], scope='core/core_1')x = slim.conv2d(x, 32, [1, 1], scope='core/core_2')x = slim.conv2d(x, 64, [3, 3], scope='core/core_3')x = slim.conv2d(x, 64, [1, 1], scope='core/core_4')# 简便方法:slim.stack(x, slim.conv2d, [(32, [3, 3]), (32, [1, 1]), (64, [3, 3]), (64, [1, 1])], scope='core')slim中的argscope:
如果你的网络有大量相同的参数,如下:
net = slim.conv2d(inputs, 64, [11, 11], 4, padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01), weights_regularizer=slim.l2_regularizer(0.0005), scope='conv1')net = slim.conv2d(net, 128, [11, 11], padding='VALID', weights_initializer=tf.truncated_normal_initializer(stddev=0.01), weights_regularizer=slim.l2_regularizer(0.0005), scope='conv2')net = slim.conv2d(net, 256, [11, 11], padding='SAME', weights_initializer=tf.truncated_normal_initializer(stddev=0.01), weights_regularizer=slim.l2_regularizer(0.0005), scope='conv3')with slim.arg_scope([slim.conv2d], padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01) weights_regularizer=slim.l2_regularizer(0.0005)):net = slim.conv2d(inputs, 64, [11, 11], scope='conv1')net = slim.conv2d(net, 128, [11, 11], padding='VALID', scope='conv2')net = slim.conv2d(net, 256, [11, 11], scope='conv3')with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu, weights_initializer=tf.truncated_normal_initializer(stddev=0.01), weights_regularizer=slim.l2_regularizer(0.0005)): with slim.arg_scope([slim.conv2d], stride=1, padding='SAME'): net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1') net = slim.conv2d(net, 256, [5, 5], weights_initializer=tf.truncated_normal_initializer(stddev=0.03), scope='conv2') net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc')VGG:
def vgg16(inputs):
with slim.arg_scope([slim.conv2d, slim.fully_connected], activation_fn=tf.nn.relu, weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0005)): net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') net = slim.max_pool2d(net, [2, 2], scope='pool1') net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') net = slim.max_pool2d(net, [2, 2], scope='pool2') net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') net = slim.max_pool2d(net, [2, 2], scope='pool3') net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') net = slim.max_pool2d(net, [2, 2], scope='pool4') net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') net = slim.max_pool2d(net, [2, 2], scope='pool5') net = slim.fully_connected(net, 4096, scope='fc6') net = slim.dropout(net, 0.5, scope='dropout6') net = slim.fully_connected(net, 4096, scope='fc7') net = slim.dropout(net, 0.5, scope='dropout7') net = slim.fully_connected(net, 1000, activation_fn=None, scope='fc8') return net三.训练模型
import tensorflow as tf
vgg = tf.contrib.slim.nets.vgg# Load the images and labels.images, labels = ...# Create the model.predictions, _ = vgg.vgg_16(images)# Define the loss functions and get the total loss.loss = slim.losses.softmax_cross_entropy(predictions, labels)
关于loss,要说一下定义自己的loss的方法,以及注意不要忘记加入到slim中让slim看到你的loss。
还有正则项也是需要手动添加进loss当中的,不然最后计算的时候就不优化正则目标了。
# Load the images and labels.
images, scene_labels, depth_labels, pose_labels = ...# Create the model.scene_predictions, depth_predictions, pose_predictions = CreateMultiTaskModel(images)# Define the loss functions and get the total loss.classification_loss = slim.losses.softmax_cross_entropy(scene_predictions, scene_labels)sum_of_squares_loss = slim.losses.sum_of_squares(depth_predictions, depth_labels)pose_loss = MyCustomLossFunction(pose_predictions, pose_labels)slim.losses.add_loss(pose_loss) # Letting TF-Slim know about the additional loss.# The following two ways to compute the total loss are equivalent:regularization_loss = tf.add_n(slim.losses.get_regularization_losses())total_loss1 = classification_loss + sum_of_squares_loss + pose_loss + regularization_loss# (Regularization Loss is included in the total loss by default).total_loss2 = slim.losses.get_total_loss()四.读取保存模型变量
通过以下功能我们可以载入模型的部分变量:
# Create some variables.
v1 = slim.variable(name="v1", ...)v2 = slim.variable(name="nested/v2", ...)...# Get list of variables to restore (which contains only 'v2').variables_to_restore = slim.get_variables_by_name("v2")# Create the saver which will be used to restore the variables.restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt") print("Model restored.")假设我们定义的网络变量是conv1/weights,而从VGG加载的变量名为vgg16/conv1/weights,正常load肯定会报错(找不到变量名),但是可以这样:
def name_in_checkpoint(var):
return 'vgg16/' + var.op.namevariables_to_restore = slim.get_model_variables()variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}restorer = tf.train.Saver(variables_to_restore)with tf.Session() as sess: # Restore variables from disk. restorer.restore(sess, "/tmp/model.ckpt")通过这种方式我们可以加载不同变量名的变量
使用笔记:TF辅助工具--tensorflow slim(TF-Slim)的更多相关文章
- 【TensorFlow】tf.reset_default_graph()函数
转载 https://blog.csdn.net/duanlianvip/article/details/98626111 tf.reset_default_graph函数用于清除默认图形堆栈并重置 ...
- tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()
在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下: 函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计 ...
- TensorFlow学习笔记2-性能分析工具
TensorFlow学习笔记2-性能分析工具 性能分析工具 在spyder中运行以下代码: import tensorflow as tf from tensorflow.python.client ...
- 自然语言处理NLP学习笔记二:NLP实战-开源工具tensorflow与jiagu使用
前言: NLP工具有人推荐使用spacy,有人推荐使用tensorflow. tensorflow:中文译作:张量(超过3维的叫张量)详细资料参考:http://www.tensorfly.cn/ J ...
- tf.nn.embedding_lookup TensorFlow embedding_lookup 函数最简单实例
tf.nn.embedding_lookup TensorFlow embedding_lookup 函数最简单实例 #!/usr/bin/env python # -*- coding: utf-8 ...
- 【TensorFlow】tf.nn.softmax_cross_entropy_with_logits的用法
在计算loss的时候,最常见的一句话就是 tf.nn.softmax_cross_entropy_with_logits ,那么它到底是怎么做的呢? 首先明确一点,loss是代价值,也就是我们要最小化 ...
- 【TensorFlow】tf.nn.max_pool实现池化操作
max pooling是CNN当中的最大值池化操作,其实用法和卷积很类似 有些地方可以从卷积去参考[TensorFlow]tf.nn.conv2d是怎样实现卷积的? tf.nn.max_pool(va ...
- TensorFlow学习---tf.nn.dropout防止过拟合
一. Dropout原理简述: tf.nn.dropout是TensorFlow里面为了防止或减轻过拟合而使用的函数,它一般用在全连接层. Dropout就是在不同的训练过程中随机扔掉一部分神经元.也 ...
- TensorFlow:tf.nn.max_pool实现池化操作
tf.nn.max_pool(value, ksize, strides, padding, name=None) 参数是四个,和卷积很类似: 第一个参数value:需要池化的输入,一般池化层接在卷积 ...
随机推荐
- 【面试】IP数据报格式分析
(除选项外的报头区总共20个字节) 1)版本:IPV4/IPV6 2)头长度:报头区长度,用于计算数据区的开始位置,比如头长度为6,代表报头区长度为6*4个字节,头长度的单位为4字节,所以报头区长度不 ...
- ERP解析外围系统json数据格式
外围系统调用ERP的WebService接口,将数据以json格式传到ERP,ERP解析json 1.创建java source jsp,提供java方法解析json数据 create or repl ...
- 自定义 Word 默认的 Normal.dotm 模板、更改 Word 默认字体、更改 Word 默认样式(16)
1. 引言 以Office 2016为例. 有没有遇见这样的问题: 每次新建一个 Word 空白文档打开后字体默认是等线,段落默认是单倍行距,默认标题也不是自己想要的样式,等一系列问题.每次打开都要调 ...
- 笨办法学python 习题14 优化过 遇到问题的请看
print "\t what's you name?"user_name = raw_input('>') from sys import argvscript, = arg ...
- Mybatis @One注解使用
@One注解:一对一关联查询
- Js学习01--基础知识
一. JavaScript有三种书写格式 1.行内式 <button onclick = 'alert('nice day!');'>Nice Day</button> 2. ...
- Python字符串格式化方式之format
format方式是在Python3引入了一个新的字符串格式化的方法,并且随后支持了Python2.7.这个新的字符串格式化方法摆脱了%操作符并且使得字符串格式化的语法更规范了.现在时候通过调用字符串对 ...
- pytorch learning rate decay
关于learning rate decay的问题,pytorch 0.2以上的版本已经提供了torch.optim.lr_scheduler的一些函数来解决这个问题. 我在迭代的时候使用的是下面的方法 ...
- Alfred神器使用手册【转】
我曾经耗费巨大的精力,试图在计算机的使用效率上找到一条优化的捷径,一直以来都收效甚微.直到遇上 alfred,它强大的工作流机制,彻底解决了输入输出的痛点,极大的减少了程序之间的切换成本和重复按键成本 ...
- Linux中etc目录详解大全总汇详解
/etc etc不是什么缩写,是and so on的意思 来源于 法语的 et cetera 翻译成中文就是 等等 的意思. 至于为什么在/etc下面存放配置文件, 按照原始的UNIX的说法(Linu ...