前言

最近尝试看TensorFlow中Slim模块的代码,看的比较郁闷,所以试着写点小的代码,动手验证相关的操作,以增加直观性。

卷积函数

slim模块的conv2d函数,是二维卷积接口,顺着源代码可以看到最终调的TensorFlow接口是convolution,这个地方就进入C++层面了,暂时不涉及。先来看看这个convolution函数,官方定义是这样的:

tf.nn.convolution(
input,
filter,
padding,
strides=None,
dilation_rate=None,
name=None,
data_format=None
)

其中在默认情况下,也就是data_format=None的时候,input的要求格式是[batch_size] + input_spatial_shape + [in_channels],  也就是要求第一维是batch,最后一维是channel,中间是真正的卷积维度。所以这个接口不仅只支持2维卷积,猜测2维卷积tf.nn.conv2d是对此接口的封装。[batch, height, weight, channel]就是conv2d的input参数格式,batch就是样本数,或者更狭隘一点,图片数量,height是图片高,weight是图片的宽,Slim的分类网络都是height=weight的,以实现方阵运算,所有slim模块中的原始图片都需要经过预处理过程,这里不展开。

filter参数是卷积核的定义,spatial_filter_shape + [in_channels, out_channels],对于2维卷积同样是4维参数[weight, height, channel, out_channel]。

明明是2维卷积,输入都是4维,已经有点抽象了,所以进入下一个阶段,写段代码,验证一下吧。

实践一下

这个例子先定义一个3X3的图片,再定义一个2X2的卷积核,代码如下:

import tensorflow as tf

input = tf.constant(
[
[
[
[100., 100., 100.],
[100., 100., 100.],
[100., 100., 100.]
],
[
[100., 100., 100.],
[100., 100., 100.],
[100., 100., 100.]
],
[
[100., 100., 100.],
[100., 100., 100.],
[100., 100., 100.],
]
]
]
); filter = tf.constant(
[
[
[
[0.5],
[0.5],
[0.5]
],
[
[0.5],
[0.5],
[0.5]
]
],
[
[
[0.5],
[0.5],
[0.5]
],
[
[0.5],
[0.5],
[0.5]
]
],
]
); result = tf.nn.convolution(input, filter, padding='VALID'); with tf.Session() as sess:
print sess.run(result)

从上述代码可以看到,input的shape是[1, 3, 3, 3],filter的shape是[2, 2, 3, 1 ],卷积的过程在方阵[3, 3] 和 核[2, 2]上展开,并且由于有三个通道,每个通道分别卷积后求和。

代码的执行结果:

[

  [

    [

      [600.]
      [600.]

    ]

    [

      [600.]

      [600.]

    ]

  ]

]

由于我们填的padding参数是VALID,所以最后的结果矩阵面积会缩小,满足(3-2)+1,即 (iw - kw) + 1。

以上例子,我们可以将它称为单张图片的二维3通道卷积,所以计算过程应该是每个通道进行卷积后最后三个通道的数值累加。

如果是从单个通道看,input就是:

[

  [100., 100., 100,]

  [100., 100., 100,]

  [100., 100., 100,]

]

卷积核:

[

  [0.5, 0.5]

  [0.5, 0.5]

]

那么单层卷积结果:

[

  [200., 200.]

  [200., 200.]

]

将三层结果叠加就是程序输出结果。

增加输出通道

slim.conv2d函数的第二参数就是输出通道的数量,就是对应convolution接口filter的第4维,我们把程序改一下,增加一个输出通道:

filter = tf.constant(
[
[
[
[0.5, 0.1],
[0.5, 0.1],
[0.5, 0.1]
],
[
[0.5, 0.1],
[0.5, 0.1],
[0.5, 0.1]
]
],
[
[
[0.5, 0.1],
[0.5, 0.1],
[0.5, 0.1]
],
[
[0.5, 0.1],
[0.5, 0.1],
[0.5, 0.1]
]
],
]
);

最后的输出结果:

[

  [

    [

      [600. 120.]
      [600. 120.]

    ]
    [

      [600. 120.]
      [600. 120.]

    ]

  ]

]

其中 120 = 3 * (100 * 0.1 + 100 * 0.1 + 100 * 0.1 + 100 * 0.1)

从结果可以看到,输出结果满足 [batch_size] + output_spatial_shape + [out_channels]的格式。

padding=SAME更常用

上面的例子中使用了padding=VALID,是指不填充的情况下进行的有效卷积结果矩阵面积会收缩。而我们在阅卷几个经典网络时,都是使用padding=SAME的方式,这种方式下,结果输出矩阵形状不变,这样就便于对不同分支结果进行连接等操作。

将第一个例子中的padding改为SAME,输出结果为:

[

  [

    [

      [600.]
      [600.]
      [300.]

    ]
    [

      [600.]
      [600.]
      [300.]

    ]

    [

      [300.]
      [300.]
      [150.]

    ]

  ]

]

在SAME模式下,为了保证输出结果输入输入形状一致,实时上在原矩阵的的右侧和底部扩展了行、列 0

暂时性结束

作为新手,一旦碰到多维就蒙了,所有以上的实践,都是只是为了增加理解。

												

TensorFlow中的卷积函数的更多相关文章

  1. TensorFlow 中的卷积网络

    TensorFlow 中的卷积网络 是时候看一下 TensorFlow 中的卷积神经网络的例子了. 网络的结构跟经典的 CNNs 结构一样,是卷积层,最大池化层和全链接层的混合. 这里你看到的代码与你 ...

  2. 【tensorflow基础】tensorflow中 tf.reduce_mean函数

    参考 1. tensorflow中 tf.reduce_mean函数: 完

  3. tensorflow中的卷积和池化层(一)

    在官方tutorial的帮助下,我们已经使用了最简单的CNN用于Mnist的问题,而其实在这个过程中,主要的问题在于如何设置CNN网络,这和Caffe等框架的原理是一样的,但是tf的设置似乎更加简洁. ...

  4. Tensorflow中的run()函数

    1 run()函数存在的意义 run()函数可以让代码变得更加简洁,在搭建神经网络(一)中,经历了数据集准备.前向传播过程设计.损失函数及反向传播过程设计等三个过程,形成计算网络,再通过会话tf.Se ...

  5. 【转载】 tf.Print() (------------ tensorflow中的print函数)

    原文地址: https://blog.csdn.net/weixin_36670529/article/details/100191674 ------------------------------ ...

  6. tensorflow中 tf.reduce_mean函数

    tf.reduce_mean 函数用于计算张量tensor沿着指定的数轴(tensor的某一维度)上的的平均值,主要用作降维或者计算tensor(图像)的平均值. reduce_mean(input_ ...

  7. 对于tensorflow中的gradient_override_map函数的理解

    # #############添加############## def binarize(self, x): """ Clip and binarize tensor u ...

  8. 卷积运算的本质,以tensorflow中VALID卷积方式为例。

    卷积运算在数学上是做矩阵点积,这样可以调整每个像素上的BGR值或HSV值来形成不同的特征.从代码上看,每次卷积核扫描完一个通道是做了一次四重循环.下面以VALID卷积方式为例进行解释. 下面是pyth ...

  9. Tensorflow中的transpose函数解析

    transpose函数作用是对矩阵进行转换操作 相信说完上面这一句,大家和我一样都是懵逼状态,完全不知道是怎么回事,那么接下来和我一起探讨吧 1.二维数组 x = [[1,3,5],  [2,4,6] ...

随机推荐

  1. 使用阿里云的maven仓库

    在maven的settings.xml文件里的mirrors节点,添加如下子节点: <mirror> <id>nexus-aliyun</id> <mirro ...

  2. 【node.js】Express 框架

    Express 是一个简洁而灵活的 node.js Web应用框架, 提供了一系列强大特性帮助你创建各种 Web 应用,和丰富的 HTTP 工具. 使用 Express 可以快速地搭建一个完整功能的网 ...

  3. c++ const static

    const作用: 1.定义常量,可以保护被修饰的东西,防止意外的修改,增强程序的健壮性. const int Max = 100; void f(const int i) { i=10;//error ...

  4. Python之Web2py框架使用

    本文主要是对Web2py框架的介绍和安装使用. 一. 介绍 全栈式Web框架:Web2py是 Google 在 web.py 基础上二次开发而来的,兼容 Google App Engine .是一个为 ...

  5. PyQt 5 的学习引言

    Python 是我学习的第二门编程语言,第一门编程语言是C. 曾经用C和C++的一个库(easyx库)写过图形界面应用, 感受就是难受又难看, 现在想学一下 PyQt 5 这个python的库, 用博 ...

  6. 如何方便的结果ajax使用html5的新type类型

    今天需要做手机端的输入表单自动生成器,突然就想到了手机端对input的输入类型支持还不错,于是翻遍了资料,有了下面的使用方法,闲话少说,上正文: html5现在可以用的新input type类型一共有 ...

  7. 2017-2018-1 20155226 《信息安全系统设计基础》课下实践——实现mypwd

    2017-2018-1 20155226 <信息安全系统设计基础>课下实践--实现mypwd 1 学习pwd命令 输入pwd命令 发现他是给出当前文件夹的绝对路径. 于是 man 1 pw ...

  8. tomcat软连接的使用

    软连接说白了就是一个映射.可以映射文件,也可以映射目录.linux和windows都可以做软连接,加入现在把文件A.txt做软连接到B.txt: linux命令如下: ln -s A.txt B.tx ...

  9. 2-[Mysql]- 初识sql语句

    1.统一字符编码  强调:配置文件中的注释可以有中文,但是配置项中不能出现中文 mysql> \s # 查看字符编码 # 1.在mysql的解压目录下,新建my.ini,然后配置 #mysql5 ...

  10. Codeforces 908 D.New Year and Arbitrary Arrangement (概率&期望DP)

    题目链接:New Year and Arbitrary Arrangement 题意: 有一个ab字符串,初始为空. 用Pa/(Pa+Pb)的概率在末尾添加字母a,有 Pb/(Pa+Pb)的概率在末尾 ...