TensorFlow conv2d原理及实践
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
官方教程说明:
给定四维的input和filter tensor,计算一个二维卷积
Args:
input: ATensor. type必须是以下几种类型之一:half,float32,float64.filter: ATensor. type和input必须相同strides: A list ofints.一维,长度4, 在input上切片采样时,每个方向上的滑窗步长,必须和format指定的维度同阶padding: Astringfrom:"SAME", "VALID". padding 算法的类型use_cudnn_on_gpu: An optionalbool. Defaults toTrue.data_format: An optionalstringfrom:"NHWC", "NCHW", 默认为"NHWC"。
指定输入输出数据格式,默认格式为"NHWC", 数据按这样的顺序存储:[batch, in_height, in_width, in_channels]
也可以用这种方式:"NCHW", 数据按这样的顺序存储:[batch, in_channels, in_height, in_width]name: 操作名,可选.
Returns:
A Tensor. type与input相同
Given an input tensor of shape [batch, in_height, in_width, in_channels]
and a filter / kernel tensor of shape[filter_height, filter_width, in_channels, out_channels]
conv2d实际上执行了以下操作:
- 将filter转为二维矩阵,shape为
[filter_height * filter_width * in_channels, output_channels]. - 从input tensor中提取image patches,每个patch是一个virtual tensor,shape
[batch, out_height, out_width, filter_height * filter_width * in_channels]. - 将每个filter矩阵和image patch向量相乘
具体来讲,当data_format为NHWC时:
output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
filter[di, dj, q, k]
input 中的每个patch都作用于filter,每个patch都能获得其他patch对filter的训练
需要满足strides[0] = strides[3] = 1. 大多数水平步长和垂直步长相同的情况下:strides = [1, stride, stride, 1].
下面举例来进行说明
在最基本的例子中,没有padding和stride = 1。让我们假设你的input和kernel有:

当您的内核您将收到以下输出:
,它按以下方式计算:
- 14 = 4 * 1 + 3 * 0 + 1 * 1 + 2 * 2 + 1 * 1 + 0 * 0 + 1 * 0 + 2 * 0 + 4 * 1
- 6 = 3 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 0 * 1 + 1 * 0 + 2 * 0 + 4 * 0 + 1 * 1
- 6 = 2 * 1 + 1 * 0 + 0 * 1 + 1 * 2 + 2 * 1 + 4 * 0 + 3 * 0 + 1 * 0 + 0 * 1
- 12 = 1 * 1 + 0 * 0 + 1 * 1 + 2 * 2 + 4 * 1 + 1 * 0 + 1 * 0 + 0 * 0 + 2 * 1
TF的conv2d函数批量计算卷积,并使用稍微不同的格式。对于一个输入,它是[batch, in_height, in_width, in_channels]内核的[filter_height, filter_width, in_channels, out_channels]。所以我们需要以正确的格式提供数据:
import tensorflow as tf
k = tf.constant([
[1, 0, 1],
[2, 1, 0],
[0, 0, 1]
], dtype=tf.float32, name='k')
i = tf.constant([
[4, 3, 1, 0],
[2, 1, 0, 1],
[1, 2, 4, 1],
[3, 1, 0, 2]
], dtype=tf.float32, name='i')
kernel = tf.reshape(k, [3, 3, 1, 1], name='kernel')
image = tf.reshape(i, [1, 4, 4, 1], name='image')
之后,卷积用下式计算:
res = tf.squeeze(tf.nn.conv2d(image, kernel, [1, 1, 1, 1], "VALID"))
# VALID means no padding
with tf.Session() as sess:
print sess.run(res)
并将相当于我们手工计算的,输出结果:
[[ 14. 6.]
[ 6. 12.]]
附上一张图:

区别SAME和VALID
VALID
input = tf.Variable(tf.random_normal([1,5,5,5])) filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
输出图形:
.....
.xxx.
.xxx.
.xxx.
.....
SAME
input = tf.Variable(tf.random_normal([1,5,5,5]))
filter = tf.Variable(tf.random_normal([3,3,5,1])) op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
输出图形:
xxxxx
xxxxx
xxxxx
xxxxx
xxxxx
参考链接
TensorFlow conv2d原理及实践的更多相关文章
- 转:fastText原理及实践(达观数据王江)
http://www.52nlp.cn/fasttext 1条回复 本文首先会介绍一些预备知识,比如softmax.ngram等,然后简单介绍word2vec原理,之后来讲解fastText的原理,并 ...
- 使用腾讯云 GPU 学习深度学习系列之二:Tensorflow 简明原理【转】
转自:https://www.qcloud.com/community/article/598765?fromSource=gwzcw.117333.117333.117333 这是<使用腾讯云 ...
- Atitit 管理原理与实践attilax总结
Atitit 管理原理与实践attilax总结 1. 管理学分类1 2. 我要学的管理学科2 3. 管理学原理2 4. 管理心理学2 5. 现代管理理论与方法2 6. <领导科学与艺术4 7. ...
- Atitit.ide技术原理与实践attilax总结
Atitit.ide技术原理与实践attilax总结 1.1. 语法着色1 1.2. 智能提示1 1.3. 类成员outline..func list1 1.4. 类型推导(type inferenc ...
- Atitit.异步编程技术原理与实践attilax总结
Atitit.异步编程技术原理与实践attilax总结 1. 俩种实现模式 类库方式,以及语言方式,java futuretask ,c# await1 2. 事件(中断)机制1 3. Await 模 ...
- Atitit.软件兼容性原理与实践 v5 qa2.docx
Atitit.软件兼容性原理与实践 v5 qa2.docx 1. Keyword2 2. 提升兼容性的原则2 2.1. What 与how 分离2 2.2. 老人老办法,新人新办法,只新增,少修改 ...
- Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法
Atitit 表达式原理 语法分析 原理与实践 解析java的dsl 递归下降是现阶段主流的语法分析方法 于是我们可以把上面的语法改写成如下形式:1 合并前缀1 语法分析有自上而下和自下而上两种分析 ...
- Atitit.gui api自动化调用技术原理与实践
Atitit.gui api自动化调用技术原理与实践 gui接口实现分类(h5,win gui, paint opengl,,swing,,.net winform,)1 Solu cate1 Sol ...
- Atitit.提升语言可读性原理与实践
Atitit.提升语言可读性原理与实践 表1-1 语言评价标准和影响它们的语言特性1 1.3.1.2 正交性2 1.3.2.2 对抽象的支持3 1.3.2.3 表达性3 .6 语言设计中的权 ...
随机推荐
- 使用java实现发送邮件的功能
首先要在maven添加javamail支持 <dependency> <groupId>javax.activation</groupId> <artifac ...
- 小白审计JACKSON反序列化漏洞
1. JACKSON漏洞解析 poc代码:main.java import com.fasterxml.jackson.databind.ObjectMapper; import com.sun.or ...
- js判断是否是ie浏览器且给出ie版本
之前懒得写判断ie版本js,因为网上关于这方面的代码太多了,所以从网上拷贝了一个,放到项目上才发现由于时效性的问题,代码不生效.就自己写一个吧. 怎么去看浏览器的内核等信息 ---- js的全局对象w ...
- 学习总结------用JDBC连接MySQL
1.下载MySQL的JDBC驱动 地址:https://dev.mysql.com/downloads/connector/ 为了方便,直接就选择合适自己的压缩包 跳过登录,选择直接下载 下载完成后, ...
- Ubuntu下录音机程序的使用
在Ubuntu中使用系统自带的录音机程序可以录制电脑的音频输出(比如,电脑正在播放视频的声音),或录制外部环境音频输入(比如,自己说话的声音) 1.录制电脑音频输出 在“硬件”选项中,将”选中设备的设 ...
- eclipse下启动tomcat项目,访问tomcat默认端口显示404错误
解决:打开eclipse的server视图,双击你配置的那个tomcat,打开编辑窗口,查看server locations,看看是否选择了第一个选项(默认是第一个选项),即use workspace ...
- WCF入门, 到创建一个简单的WCF应用程序
什么是WCF? WCF, 英文全称(windows Communication Foundation) , 即为windows通讯平台. windows想到这里大家都知道了 , WCF也正是由微软公 ...
- Laravel踩坑笔记——illuminate/html被抛弃
起因 在使用如下代码的时候发生报错 {!! Form::open() !!} 错误信息 [Symfony\Component\Debug\Exception\FatalErrorException] ...
- 摘记:Web应用系统测试内容
表示层: 内容测试,包括整体审美.字体.色彩.拼写.内容准确性和默认值 Web站点结构,包括无效的链接或图形 用户环境,包括Web浏览器版本和操作系统配置(每一个浏览器都有不同的脚本引擎或虚拟机在客户 ...
- 移动端APP CSS Reset及注意事项CSS重置
@charset "utf-8"; * { -webkit-box-sizing: border-box; box-sizing: border-box; } //禁止文本缩放 h ...