TensorFlow conv2d原理及实践
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
官方教程说明:
给定四维的input和filter tensor,计算一个二维卷积
Args:
input: ATensor. type必须是以下几种类型之一:half,float32,float64.filter: ATensor. type和input必须相同strides: A list ofints.一维,长度4, 在input上切片采样时,每个方向上的滑窗步长,必须和format指定的维度同阶padding: Astringfrom:"SAME", "VALID". padding 算法的类型use_cudnn_on_gpu: An optionalbool. Defaults toTrue.data_format: An optionalstringfrom:"NHWC", "NCHW", 默认为"NHWC"。
指定输入输出数据格式,默认格式为"NHWC", 数据按这样的顺序存储:[batch, in_height, in_width, in_channels]
也可以用这种方式:"NCHW", 数据按这样的顺序存储:[batch, in_channels, in_height, in_width]name: 操作名,可选.
Returns:
A Tensor. type与input相同
Given an input tensor of shape [batch, in_height, in_width, in_channels]
and a filter / kernel tensor of shape[filter_height, filter_width, in_channels, out_channels]
conv2d实际上执行了以下操作:
- 将filter转为二维矩阵,shape为
[filter_height * filter_width * in_channels, output_channels]. - 从input tensor中提取image patches,每个patch是一个virtual tensor,shape
[batch, out_height, out_width, filter_height * filter_width * in_channels]. - 将每个filter矩阵和image patch向量相乘
具体来讲,当data_format为NHWC时:
output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
filter[di, dj, q, k]
input 中的每个patch都作用于filter,每个patch都能获得其他patch对filter的训练
需要满足strides[0] = strides[3] = 1. 大多数水平步长和垂直步长相同的情况下:strides = [1, stride, stride, 1].
下面举例来进行说明
在最基本的例子中,没有padding和stride = 1。让我们假设你的input和kernel有:

当您的内核您将收到以下输出:
,它按以下方式计算:
- 14 = 4 * 1 + 3 * 0 + 1 * 1 + 2 * 2 + 1 * 1 + 0 * 0 + 1 * 0 + 2 * 0 + 4 * 1
- 6 = 3 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 0 * 1 + 1 * 0 + 2 * 0 + 4 * 0 + 1 * 1
- 6 = 2 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 2 * 1 + 4 * 0 + 3 * 0 + 1 * 0 + 0 * 1
- 12 = 1 * 1 + 0 * 0 + 1 * 1 + 2 * 2 + 4 * 1 + 1 * 0 + 1 * 0 + 0 * 0 + 2 * 1
TF的conv2d函数批量计算卷积,并使用稍微不同的格式。对于一个输入,它是[batch, in_height, in_width, in_channels]内核的[filter_height, filter_width, in_channels, out_channels]。所以我们需要以正确的格式提供数据:
import tensorflow as tf
k = tf.constant([
[1, 0, 1],
[2, 1, 0],
[0, 0, 1]
], dtype=tf.float32, name='k')
i = tf.constant([
[4, 3, 1, 0],
[2, 1, 0, 1],
[1, 2, 4, 1],
[3, 1, 0, 2]
], dtype=tf.float32, name='i')
kernel = tf.reshape(k, [3, 3, 1, 1], name='kernel')
image = tf.reshape(i, [1, 4, 4, 1], name='image')
之后,卷积用下式计算:
res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))
# VALID means no padding
with tf.Session() as sess:
print sess.run(res)
并将相当于我们手工计算的,输出结果:
[[ 14. 6.]
[ 6. 12.]]
附上一张图:

区别SAME和VALID
VALID
input = tf.Variable(tf.random_normal([1,5,5,5])) filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
输出图形:
.....
.xxx.
.xxx.
.xxx.
.....
SAME
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
输出图形:
xxxxx
xxxxx
xxxxx
xxxxx
xxxxx
参考链接
TensorFlow conv2d原理及实践的更多相关文章
- 转:fastText原理及实践(达观数据王江)
http://www.52nlp.cn/fasttext 1条回复 本文首先会介绍一些预备知识,比如softmax.ngram等,然后简单介绍word2vec原理,之后来讲解fastText的原理,并 ...
- 使用腾讯云 GPU 学习深度学习系列之二:Tensorflow 简明原理【转】
转自:https://www.qcloud.com/community/article/598765?fromSource=gwzcw.117333.117333.117333 这是<使用腾讯云 ...
- Atitit 管理原理与实践attilax总结
Atitit 管理原理与实践attilax总结 1. 管理学分类1 2. 我要学的管理学科2 3. 管理学原理2 4. 管理心理学2 5. 现代管理理论与方法2 6. <领导科学与艺术4 7. ...
- Atitit.ide技术原理与实践attilax总结
Atitit.ide技术原理与实践attilax总结 1.1. 语法着色1 1.2. 智能提示1 1.3. 类成员outline..func list1 1.4. 类型推导(type inferenc ...
- Atitit.异步编程技术原理与实践attilax总结
Atitit.异步编程技术原理与实践attilax总结 1. 俩种实现模式 类库方式,以及语言方式,java futuretask ,c# await1 2. 事件(中断)机制1 3. Await 模 ...
- Atitit.软件兼容性原理与实践 v5 qa2.docx
Atitit.软件兼容性原理与实践 v5 qa2.docx 1. Keyword2 2. 提升兼容性的原则2 2.1. What 与how 分离2 2.2. 老人老办法,新人新办法,只新增,少修改 ...
- Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法
Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法 于是我们可以把上面的语法改写成如下形式:1 合并前缀1 语法分析有自上而下和自下而上两种分析 ...
- Atitit.gui api自动化调用技术原理与实践
Atitit.gui api自动化调用技术原理与实践 gui接口实现分类(h5,win gui, paint opengl,,swing,,.net winform,)1 Solu cate1 Sol ...
- Atitit.提升语言可读性原理与实践
Atitit.提升语言可读性原理与实践 表1-1 语言评价标准和影响它们的语言特性1 1.3.1.2 正交性2 1.3.2.2 对抽象的支持3 1.3.2.3 表达性3 .6 语言设计中的权 ...
随机推荐
- c++ thread
Either pthread_join(3) or pthread_detach() should be called for each thread,that an application crea ...
- bootstrap 架构知识点
.col-md-pull-2 向右相对定位偏移量 .col-md-push-2 向左相对定位偏移量 .pull-left 左浮动 .pull-right 右浮动 改变大小写 通过这几个类可以改 ...
- Elasticsearch重要配置
虽然Elasticsearch需要很少的配置,但是有一些设置需要手动配置,并且必须在进入生产之前进行配置. path.data and path.logs cluster.name node.nam ...
- 详解 try-with-resource
[TOC] Oracle官方文档: http://docs.oracle.com/javase/7/docs/technotes/guides/language/try-with-resources. ...
- VMware-VCSA-6.5安装过程
1.新建虚拟机 2.选择从OVF或OVA文件导入 3.给虚拟机命名,并选择OVF文件. 4.选择虚拟机的存储位置.这里没有配置共享存储宿,这里选择的宿主机的存储. 5.许可协议同意就OK了. 6.部署 ...
- jquery移出select指定option
$("#selectLine option[value!='']").remove();
- 找到你在网页中缓存起来的flash文件
通过IE浏览器工具->Internet选项->常规->设置->Internet临时文件->查看文件(找到你在网页中缓存起来的flash文件)
- mac如何进入应用程序的内部文件夹?
在程序点击右键,选择显示包内容,就可以看到了
- MySQL各模块工作配合
MySQL各模块工作配合 在了解了 MySQL 的各个模块之后,我们再看看 MySQL 各个模块间是如何相互协同工作的 .接下来,我们通过启动 MySQL,客户端连接,请求 query,得到返回结果, ...
- input 等替换元素的baseline问题
行内标签和设置为block:inline;形式的标签与input并排放置时,为何会错位?例如下面的. 因为在同一行中,所有行内元素默认 baseline 对齐.但是,input(还有textarea. ...