这一节,介绍TensorFlow中的一个封装好的高级库,里面有前面讲过的很多函数的高级封装,使用这个高级库来开发程序将会提高效率。

我们改写第十三节的程序,卷积函数我们使用tf.contrib.layers.conv2d(),池化函数使用tf.contrib.layers.max_pool2d()和tf.contrib.layers.avg_pool2d(),全连接函数使用tf.contrib.layers.fully_connected()。

一 tf.contrib.layers中的具体函数介绍

1.tf.contrib.layers.conv2d()函数的定义如下:

def convolution(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
data_format=None,
rate=1,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):

常用的参数说明如下:

  • inputs:形状为[batch_size, height, width, channels]的输入。
  • num_outputs:代表输出几个channel。这里不需要再指定输入的channel了,因为函数会自动根据inpus的shpe去判断。
  • kernel_size:卷积核大小,不需要带上batch和channel,只需要输入尺寸即可。[5,5]就代表5x5的卷积核,如果长和宽都一样,也可以只写一个数5.
  • stride:步长,默认是长宽都相等的步长。卷积时,一般都用1,所以默认值也是1.如果长和宽都不相等,也可以用一个数组[1,2]。
  • padding:填充方式,'SAME'或者'VALID'。
  • activation_fn:激活函数。默认是ReLU。也可以设置为None
  • weights_initializer:权重的初始化,默认为initializers.xavier_initializer()函数。
  • weights_regularizer:权重正则化项,可以加入正则函数。biases_initializer:偏置的初始化,默认为init_ops.zeros_initializer()函数。
  • biases_regularizer:偏置正则化项,可以加入正则函数。
  • trainable:是否可训练,如作为训练节点,必须设置为True,默认即可。如果我们是微调网络,有时候需要冻结某一层的参数,则设置为False。

2.tf.contrib.layers.max_pool2d()函数的定义如下:

def max_pool2d(inputs,
kernel_size,
stride=2,
padding='VALID',
data_format=DATA_FORMAT_NHWC,
outputs_collections=None,
scope=None):

参数说明如下:

  • inputs: A 4-D tensor of shape `[batch_size, height, width, channels]` if`data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`.
  • kernel_size: A list of length 2: [kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same.
  • stride: A list of length 2: [stride_height, stride_width].Can be an int if both strides are the same. Note that presently both strides must have the same value.
  • padding: The padding method, either 'VALID' or 'SAME'.
  • data_format: A string. `NHWC` (default) and `NCHW` are supported.
  • outputs_collections: The collections to which the outputs are added.
  • scope: Optional scope for name_scope.

3.tf.contrib.layers.avg_pool2d()函数定义

def avg_pool2d(inputs,
kernel_size,
stride=2,
padding='VALID',
data_format=DATA_FORMAT_NHWC,
outputs_collections=None,
scope=None):

参数说明如下:

  • inputs: A 4-D tensor of shape `[batch_size, height, width, channels]` if`data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`.
  • kernel_size: A list of length 2: [kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same.
  • stride: A list of length 2: [stride_height, stride_width].Can be an int if both strides are the same. Note that presently both strides must have the same value.
  • padding: The padding method, either 'VALID' or 'SAME'.
  • data_format: A string. `NHWC` (default) and `NCHW` are supported.
  • outputs_collections: The collections to which the outputs are added.
  • scope: Optional scope for name_scope.

4.tf.contrib.layers.fully_connected()函数的定义如下:

def fully_connected(inputs,
num_outputs,
activation_fn=nn.relu,
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
scope=None):

参数说明如下:

  • inputs: A tensor of at least rank 2 and static value for the last dimension; i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
  • num_outputs: Integer or long, the number of output units in the layer.
  • activation_fn: Activation function. The default value is a ReLU function.Explicitly set it to None to skip it and maintain a linear activation.
  • normalizer_fn: Normalization function to use instead of `biases`. If `normalizer_fn` is provided then `biases_initializer` and
  • `biases_regularizer` are ignored and `biases` are not created nor added.default set to None for no normalizer function
  • normalizer_params: Normalization function parameters.
  • weights_initializer: An initializer for the weights.
  • weights_regularizer: Optional regularizer for the weights.
  • biases_initializer: An initializer for the biases. If None skip biases.
  • biases_regularizer: Optional regularizer for the biases.
  • reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given.
  • variables_collections: Optional list of collections for all the variables or a dictionary containing a different list of collections per variable.
  • outputs_collections: Collection to add the outputs.
  • trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).如果我们是微调网络,有时候需要冻结某一层的参数,则设置为False。
  • scope: Optional scope for variable_scope.

二 改写cifar10分类

代码如下:

# -*- coding: utf-8 -*-
"""
Created on Thu May 3 12:29:16 2018 @author: zy
""" '''
建立一个带有全连接层的卷积神经网络 并对CIFAR-10数据集进行分类
1.使用2个卷积层的同卷积操作,滤波器大小为5x5,每个卷积层后面都会跟一个步长为2x2的池化层,滤波器大小为2x2
2.对输出的64个feature map进行全局平均池化,得到64个特征
3.加入一个全连接层,使用softmax激活函数,得到分类
''' import cifar10_input
import tensorflow as tf
import numpy as np def print_op_shape(t):
'''
输出一个操作op节点的形状
'''
print(t.op.name,'',t.get_shape().as_list()) '''
一 引入数据集
'''
batch_size = 128
learning_rate = 1e-4
training_step = 15000
display_step = 200
#数据集目录
data_dir = './cifar10_data/cifar-10-batches-bin'
print('begin')
#获取训练集数据
images_train,labels_train = cifar10_input.inputs(eval_data=False,data_dir = data_dir,batch_size=batch_size)
print('begin data') '''
二 定义网络结构
''' #定义占位符
input_x = tf.placeholder(dtype=tf.float32,shape=[None,24,24,3]) #图像大小24x24x
input_y = tf.placeholder(dtype=tf.float32,shape=[None,10]) #0-9类别 x_image = tf.reshape(input_x,[batch_size,24,24,3]) #1.卷积层 ->池化层 h_conv1 = tf.contrib.layers.conv2d(inputs=x_image,num_outputs=64,kernel_size=5,stride=1,padding='SAME', activation_fn=tf.nn.relu) #输出为[-1,24,24,64]
print_op_shape(h_conv1)
h_pool1 = tf.contrib.layers.max_pool2d(inputs=h_conv1,kernel_size=2,stride=2,padding='SAME') #输出为[-1,12,12,64]
print_op_shape(h_pool1) #2.卷积层 ->池化层 h_conv2 =tf.contrib.layers.conv2d(inputs=h_pool1,num_outputs=64,kernel_size=[5,5],stride=[1,1],padding='SAME', activation_fn=tf.nn.relu) #输出为[-1,12,12,64]
print_op_shape(h_conv2)
h_pool2 = tf.contrib.layers.max_pool2d(inputs=h_conv2,kernel_size=[2,2],stride=[2,2],padding='SAME') #输出为[-1,6,6,64]
print_op_shape(h_pool2) #3全连接层 nt_hpool2 = tf.contrib.layers.avg_pool2d(inputs=h_pool2,kernel_size=6,stride=6,padding='SAME') #输出为[-1,1,1,64]
print_op_shape(nt_hpool2)
nt_hpool2_flat = tf.reshape(nt_hpool2,[-1,64])
y_conv = tf.contrib.layers.fully_connected(inputs=nt_hpool2_flat,num_outputs=10,activation_fn=tf.nn.softmax)
print_op_shape(y_conv) '''
三 定义求解器
''' #softmax交叉熵代价函数
cost = tf.reduce_mean(-tf.reduce_sum(input_y * tf.log(y_conv),axis=1)) #求解器
train = tf.train.AdamOptimizer(learning_rate).minimize(cost) #返回一个准确度的数据
correct_prediction = tf.equal(tf.arg_max(y_conv,1),tf.arg_max(input_y,1))
#准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.float32)) '''
四 开始训练
'''
sess = tf.Session();
sess.run(tf.global_variables_initializer())
# 启动计算图中所有的队列线程 调用tf.train.start_queue_runners来将文件名填充到队列,否则read操作会被阻塞到文件名队列中有值为止。
tf.train.start_queue_runners(sess=sess) for step in range(training_step):
#获取batch_size大小数据集
image_batch,label_batch = sess.run([images_train,labels_train]) #one hot编码
label_b = np.eye(10,dtype=np.float32)[label_batch] #开始训练
train.run(feed_dict={input_x:image_batch,input_y:label_b},session=sess) if step % display_step == 0:
train_accuracy = accuracy.eval(feed_dict={input_x:image_batch,input_y:label_b},session=sess)
print('Step {0} tranining accuracy {1}'.format(step,train_accuracy))

第十六节,使用函数封装库tf.contrib.layers的更多相关文章

  1. 第三百三十六节,web爬虫讲解2—urllib库中使用xpath表达式—BeautifulSoup基础

    第三百三十六节,web爬虫讲解2—urllib库中使用xpath表达式—BeautifulSoup基础 在urllib中,我们一样可以使用xpath表达式进行信息提取,此时,你需要首先安装lxml模块 ...

  2. centos shell脚本编程2 if 判断 case判断 shell脚本中的循环 for while shell中的函数 break continue test 命令 第三十六节课

    centos  shell脚本编程2 if 判断  case判断   shell脚本中的循环  for   while   shell中的函数  break  continue  test 命令   ...

  3. ASP.NET MVC深入浅出系列(持续更新) ORM系列之Entity FrameWork详解(持续更新) 第十六节:语法总结(3)(C#6.0和C#7.0新语法) 第三节:深度剖析各类数据结构(Array、List、Queue、Stack)及线程安全问题和yeild关键字 各种通讯连接方式 设计模式篇 第十二节: 总结Quartz.Net几种部署模式(IIS、Exe、服务部署【借

    ASP.NET MVC深入浅出系列(持续更新)   一. ASP.NET体系 从事.Net开发以来,最先接触的Web开发框架是Asp.Net WebForm,该框架高度封装,为了隐藏Http的无状态模 ...

  4. 第一百二十六节,JavaScript,XPath操作xml节点

    第一百二十六节,JavaScript,XPath操作xml节点 学习要点: 1.IE中的XPath 2.W3C中的XPath 3.XPath跨浏览器兼容 XPath是一种节点查找手段,对比之前使用标准 ...

  5. 第四百一十六节,Tensorflow简介与安装

    第四百一十六节,Tensorflow简介与安装 TensorFlow是什么 Tensorflow是一个Google开发的第二代机器学习系统,克服了第一代系统DistBelief仅能开发神经网络算法.难 ...

  6. 第三百四十六节,Python分布式爬虫打造搜索引擎Scrapy精讲—Requests请求和Response响应介绍

    第三百四十六节,Python分布式爬虫打造搜索引擎Scrapy精讲—Requests请求和Response响应介绍 Requests请求 Requests请求就是我们在爬虫文件写的Requests() ...

  7. 第三百二十六节,web爬虫,scrapy模块,解决重复ur——自动递归url

    第三百二十六节,web爬虫,scrapy模块,解决重复url——自动递归url 一般抓取过的url不重复抓取,那么就需要记录url,判断当前URL如果在记录里说明已经抓取过了,如果不存在说明没抓取过 ...

  8. 大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)

                                                    第二十六节决策树系列之Cart回归树及其参数(5) 上一节我们讲了不同的决策树对应的计算纯度的计算方法, ...

  9. m_Orchestrate learning system---二十六、动态给封装好的控件添加属性

    m_Orchestrate learning system---二十六.动态给封装好的控件添加属性 一.总结 一句话总结:比如我现在封装好了ueditor控件,我外部调用这个控件,因为要写数据到数据库 ...

随机推荐

  1. python学习笔记(6)--条件分支语句

    if xxxx: coding if xxxx: coding else: coding if xxxx: coding elif xxx: coding …… else: coding 或者一种简洁 ...

  2. ubuntu 有些软件中不能输入中文

    如果Ubuntu设定的是英文语言,在各种软件例如wps等中很有可能就不能输入中文.这种情况,我们的解决方案是,把中文输入法加到软件的启动文件中,如何加呢?把下面内容加进去就可以解决: export X ...

  3. 使用poi将Excel文件转换为data数据

    pom <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http:// ...

  4. 微信小程序——demo合集及简单的文档解读【五】

    官方Demo https://github.com/wechat-miniprogram/miniprogram-demo 其他Demo https://www.cnblogs.com/ytkah/p ...

  5. Codeforces963C Frequency of String 【字符串】【AC自动机】

    题目大意: 给一个串s和很多模式串,对每个模式串求s的一个最短的子串使得这个子串中包含至少k个该模式串. 题目分析: 均摊分析,有sqrt(n)种长度不同的模式串,所以有关的串只有msqrt(n)种. ...

  6. Codeforces518 D. Ilya and Escalator

    传送门:>Here< 题意:有n个人排队做电梯,每个人必须等前面的人全部上了以后才能上.对于每秒钟,有p的概率选择上电梯,(1-p)的概率选择不上电梯.现在问t秒期望多少人上电梯 解题思路 ...

  7. 【XSY2753】Lcm 分治 FWT FFT 容斥

    题目描述 给你\(n,k\),要你选一些互不相同的正整数,满足这些数的\(lcm\)为\(n\),且这些数的和为\(k\)的倍数. 求选择的方案数.对\(232792561\)取模. \(n\leq ...

  8. 【BZOJ3999】【TJOI2015】旅游 树剖

    题目大意 给你一棵树,有\(n\)个点.有\(q\)个操作,每次要你从\(x\)到\(y\)的路径上选两个点,使得距离\(x\)比较远的点的点权\(-\)距离\(x\)比较近的点的点权最大,然后把这条 ...

  9. bzoj 3626 : [LNOI2014]LCA (树链剖分+线段树)

    Description 给出一个n个节点的有根树(编号为0到n-1,根节点为0).一个点的深度定义为这个节点到根的距离+1.设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先.有q ...

  10. MT【250】距离0-7

    是否存在一个正方体,它的8个顶点到某一个平面的距离恰好为$0,1,2,3,4,5,6,7$ ?若存在指出正方体与相应的平面的位置关系.不存在说明理由. 分析:设平面$\alpha$的单位法向量为$\o ...