TensorFlow conv2d原理及实践
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
官方教程说明:
给定四维的input和filter tensor,计算一个二维卷积
Args:
input: ATensor. type必须是以下几种类型之一:half,float32,float64.filter: ATensor. type和input必须相同strides: A list ofints.一维,长度4, 在input上切片采样时,每个方向上的滑窗步长,必须和format指定的维度同阶padding: Astringfrom:"SAME", "VALID". padding 算法的类型use_cudnn_on_gpu: An optionalbool. Defaults toTrue.data_format: An optionalstringfrom:"NHWC", "NCHW", 默认为"NHWC"。
指定输入输出数据格式,默认格式为"NHWC", 数据按这样的顺序存储:[batch, in_height, in_width, in_channels]
也可以用这种方式:"NCHW", 数据按这样的顺序存储:[batch, in_channels, in_height, in_width]name: 操作名,可选.
Returns:
A Tensor. type与input相同
Given an input tensor of shape [batch, in_height, in_width, in_channels]
and a filter / kernel tensor of shape[filter_height, filter_width, in_channels, out_channels]
conv2d实际上执行了以下操作:
- 将filter转为二维矩阵,shape为
[filter_height * filter_width * in_channels, output_channels]. - 从input tensor中提取image patches,每个patch是一个virtual tensor,shape
[batch, out_height, out_width, filter_height * filter_width * in_channels]. - 将每个filter矩阵和image patch向量相乘
具体来讲,当data_format为NHWC时:
output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
filter[di, dj, q, k]
input 中的每个patch都作用于filter,每个patch都能获得其他patch对filter的训练
需要满足strides[0] = strides[3] = 1. 大多数水平步长和垂直步长相同的情况下:strides = [1, stride, stride, 1].
下面举例来进行说明
在最基本的例子中,没有padding和stride = 1。让我们假设你的input和kernel有:

当您的内核您将收到以下输出:
,它按以下方式计算:
- 14 = 4 * 1 + 3 * 0 + 1 * 1 + 2 * 2 + 1 * 1 + 0 * 0 + 1 * 0 + 2 * 0 + 4 * 1
- 6 = 3 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 0 * 1 + 1 * 0 + 2 * 0 + 4 * 0 + 1 * 1
- 6 = 2 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 2 * 1 + 4 * 0 + 3 * 0 + 1 * 0 + 0 * 1
- 12 = 1 * 1 + 0 * 0 + 1 * 1 + 2 * 2 + 4 * 1 + 1 * 0 + 1 * 0 + 0 * 0 + 2 * 1
TF的conv2d函数批量计算卷积,并使用稍微不同的格式。对于一个输入,它是[batch, in_height, in_width, in_channels]内核的[filter_height, filter_width, in_channels, out_channels]。所以我们需要以正确的格式提供数据:
import tensorflow as tf
k = tf.constant([
[1, 0, 1],
[2, 1, 0],
[0, 0, 1]
], dtype=tf.float32, name='k')
i = tf.constant([
[4, 3, 1, 0],
[2, 1, 0, 1],
[1, 2, 4, 1],
[3, 1, 0, 2]
], dtype=tf.float32, name='i')
kernel = tf.reshape(k, [3, 3, 1, 1], name='kernel')
image = tf.reshape(i, [1, 4, 4, 1], name='image')
之后,卷积用下式计算:
res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))
# VALID means no padding
with tf.Session() as sess:
print sess.run(res)
并将相当于我们手工计算的,输出结果:
[[ 14. 6.]
[ 6. 12.]]
附上一张图:

区别SAME和VALID
VALID
input = tf.Variable(tf.random_normal([1,5,5,5])) filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
输出图形:
.....
.xxx.
.xxx.
.xxx.
.....
SAME
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
输出图形:
xxxxx
xxxxx
xxxxx
xxxxx
xxxxx
参考链接
TensorFlow conv2d原理及实践的更多相关文章
- 转:fastText原理及实践(达观数据王江)
http://www.52nlp.cn/fasttext 1条回复 本文首先会介绍一些预备知识,比如softmax.ngram等,然后简单介绍word2vec原理,之后来讲解fastText的原理,并 ...
- 使用腾讯云 GPU 学习深度学习系列之二:Tensorflow 简明原理【转】
转自:https://www.qcloud.com/community/article/598765?fromSource=gwzcw.117333.117333.117333 这是<使用腾讯云 ...
- Atitit 管理原理与实践attilax总结
Atitit 管理原理与实践attilax总结 1. 管理学分类1 2. 我要学的管理学科2 3. 管理学原理2 4. 管理心理学2 5. 现代管理理论与方法2 6. <领导科学与艺术4 7. ...
- Atitit.ide技术原理与实践attilax总结
Atitit.ide技术原理与实践attilax总结 1.1. 语法着色1 1.2. 智能提示1 1.3. 类成员outline..func list1 1.4. 类型推导(type inferenc ...
- Atitit.异步编程技术原理与实践attilax总结
Atitit.异步编程技术原理与实践attilax总结 1. 俩种实现模式 类库方式,以及语言方式,java futuretask ,c# await1 2. 事件(中断)机制1 3. Await 模 ...
- Atitit.软件兼容性原理与实践 v5 qa2.docx
Atitit.软件兼容性原理与实践 v5 qa2.docx 1. Keyword2 2. 提升兼容性的原则2 2.1. What 与how 分离2 2.2. 老人老办法,新人新办法,只新增,少修改 ...
- Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法
Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法 于是我们可以把上面的语法改写成如下形式:1 合并前缀1 语法分析有自上而下和自下而上两种分析 ...
- Atitit.gui api自动化调用技术原理与实践
Atitit.gui api自动化调用技术原理与实践 gui接口实现分类(h5,win gui, paint opengl,,swing,,.net winform,)1 Solu cate1 Sol ...
- Atitit.提升语言可读性原理与实践
Atitit.提升语言可读性原理与实践 表1-1 语言评价标准和影响它们的语言特性1 1.3.1.2 正交性2 1.3.2.2 对抽象的支持3 1.3.2.3 表达性3 .6 语言设计中的权 ...
随机推荐
- Android 图片加载框架Glide4.0源码完全解析(一)
写在之前 上一篇博文写的是Picasso基本使用和源码完全解析,Picasso的源码阅读起来还是很顺畅的,然后就想到Glide框架,网上大家也都推荐使用这个框架用来加载图片,正好我目前的写作目标也是分 ...
- Bash中的特殊变量和位置参量
位置参量:向脚本或函数传递的参数,可以被set命令设置.重置和清空. 1.$$ 当前Shell的PID 2.$- 当前Shell的选项,如果是交互式shell,应该包含字符i,例如$ echo $-h ...
- GA代码中的细节
GA-BLX交叉-Gaussion变异 中的代码细节: 我写了一个GA的代码,在2005测试函数上一直不能得到与实验室其他同学类似的数量级的结果.现在参考其他同学的代码,发现至少有如下问题: 1.在交 ...
- Qt开发陷阱一QSTACKWIDGET
原始日期:2015-10-14 00:55 1.使用QStackWidget控件的setCurrentIndex方法时,要注意参数0对应着ui上StackWidget的page1,而不是page0,没 ...
- C#继承的执行顺序
自己对多态中构造函数.函数重载执行顺序和过程一直有些不理解,经过测试,对其中的运行顺序有了一定的了解,希望对初学者有些帮助. eg1: public class A { public A() { Co ...
- Vulkan Tutorial 22 Index buffer
操作系统:Windows8.1 显卡:Nivida GTX965M 开发工具:Visual Studio 2017 Introduction 在实际产品的运行环境中3D模型的数据往往共享多个三角形之间 ...
- my97自定义事件
onFocus="WdatePicker({onpicked:function(){alert(0);}})"
- 通过java反射得到javabean的属性名称和值参考
通过java反射得到javabean的属性名称和值 Field fields[]=cHis.getClass().getDeclaredFields();//cHis 是实体类名称 String[] ...
- 一些css书写的小技巧
一.css顺序 首先声明,浏览器读取css的方式是从上到下的.我们一般书写css只要元素具备这些属性就会达到我们预期的效果,但是这会给以后的维护和浏览器的渲染效率带来一定的影响,那么该怎么书写css的 ...
- mysql in 和 not in 语句用法
1.mysql in语句 select * from tb_name where id in (10,12,15,16);2.mysql not in 语句 select * from tb_name ...