theano中的dimshuffle
theano中的dimshuffle函数用于对张量的维度进行操作,可以增加维度,也可以交换维度,删除维度。
注意的是只有shared才能调用dimshuffle()
'x'表示增加一维,从0d scalar到1d vector
(0, 1)表示一个与原先相同的2D向量
(1, 0)表示将2D向量的两维交换
(‘x’, 0) 表示将一个1d vector变为一个1xN矩阵
(0, ‘x’)将一个1d vector变为一个Nx1矩阵
(2, 0, 1) -> AxBxC to CxAxB (2表示第三维也就是C,0表示第一维A,1表示第二维B)
(0, ‘x’, 1) -> AxB to Ax1xB 表示A,B顺序不变在中间增加一维
(1, ‘x’, 0) -> AxB to Bx1xA 同理自己理解一下
(1,) -> 删除维度0,(1xA to A)
写了个小程序来验证猜想
from __future__ import print_function
import theano
import numpy as np
def print_hline(file):
print('------------------------------------------',file=file,end='\r\n')
write_file=open('G:\data\dimshuffle_output.txt','wb')
v = theano.shared(np.arange(3))
# v.shape is a symbol expression, need theano.function or eval to compile it
print_hline(write_file)
v_disp = v.dimshuffle(0)
print('v.dimshuffle(0):',v_disp.eval(),file=write_file,end='\r\n')
print('v.dimshuffle(0).shape:',v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
v_disp = v.dimshuffle('x', 0)
print("v.dimshuffle('x',0):",v_disp.eval(),file=write_file,end='\r\n')
print("v.dimshuffle('x',0).shape:",v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
v_disp = v.dimshuffle(0,'x')
print("v.dimshuffle(0,'x'):",v_disp.eval(),file=write_file,end='\r\n')
print("v.dimshuffle(0,'x').shape:",v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
v_disp = v.dimshuffle(0,'x','x')
print("v.dimshuffle(0,'x','x'):",v_disp.eval(),file=write_file,end='\r\n')
print("v.dimshuffle(0,'x','x').shape:",v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
v_disp = v.dimshuffle('x',0,'x')
print("v.dimshuffle('x',0,'x'):",v_disp.eval(),file=write_file,end='\r\n')
print("v.dimshuffle('x',0,'x').shape:",v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
v_disp = v.dimshuffle('x','x',0)
print("v.dimshuffle('x','x',0):",v_disp.eval(),file=write_file,end='\r\n')
print("v.dimshuffle('x','x',0).shape:",v_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
m = theano.shared(np.arange(6).reshape(2,3))
print("m:",m.eval(),file=write_file,end='\r\n')
print("m.shape:",m.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
m_disp = m.dimshuffle(0,'x',1)
print("m.dimshuffle(0,'x',1):",m_disp.eval(),file=write_file,end='\r\n')
print("m.dimshuffle(0,'x',1).shape:",m_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
m_disp = m.dimshuffle('x',0,1)
print("m.dimshuffle('x',0,1):",m_disp.eval(),file=write_file,end='\r\n')
print("m.dimshuffle('x',0,1).shape:",m_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
m_disp = m.dimshuffle(0,1,'x')
print("m.dimshuffle(0,1,'x'):",m_disp.eval(),file=write_file,end='\r\n')
print("m.dimshuffle(0,1,'x').shape:",m_disp.shape.eval(),file=write_file,end='\r\n')
print_hline(write_file)
# amount to transpose
m_disp = m.dimshuffle(1,'x',0)
print("m.dimshuffle(1,'x',0):",m_disp.eval(),file=write_file,end='\r\n')
print("m.dimshuffle(1,'x',0).shape:",m_disp.shape.eval(),file=write_file,end='\r\n')
write_file.close()
首先定义了一个[0 1 2]的1D vector:v,v.dimshuffle(0)中的0表示第一维:3,也只有一维,所以不变。因为是1D的,所以shape只有(3,)
v.dimshuffle(0): [0 1 2]
v.dimshuffle(0).shape: [3]
v.dimshuffle('x',0)表示在第一维前加入一维,只要记住加了'x'就加了一维,所以大小变成了1x3
v.dimshuffle('x',0): [[0 1 2]]
v.dimshuffle('x',0).shape: [1 3]
剩下的同理可理解
v.dimshuffle(0,'x'): [[0]
[1]
[2]]
v.dimshuffle(0,'x').shape: [3 1]
v.dimshuffle(0,'x','x'): [[[0]]
[[1]]
[[2]]]
v.dimshuffle(0,'x','x').shape: [3 1 1]
v.dimshuffle('x',0,'x'): [[[0]
[1]
[2]]]
v.dimshuffle('x',0,'x').shape: [1 3 1]
v.dimshuffle('x','x',0): [[[0 1 2]]]
v.dimshuffle('x','x',0).shape: [1 1 3]
第二个例子,m是一个2x3矩阵
m: [[0 1 2]
[3 4 5]]
m.shape: [2 3]
先确定0,'x',1的维数,0对应第一维(2),1表示第二维(3),'x'表示新加入的维度(1)
所以结果维度是2x1x3
加括号的顺序按照从左到右(外->内)的顺序
1.先加最内层3,3表示括号内有3个数,因此是[0 1 2]和[3 4 5]
2.再加中间层1,1表示括号内只有一个匹配的"[]",因此是[[0 1 2]],[[3 4 5]]
3.最后加最外层2,2表示括号内有两个匹配的"[]"(只算最外层的匹配),于是最后结果是
[[[0 1 2]]
[[3 4 5]]]
m.dimshuffle(0,'x',1): [[[0 1 2]]
[[3 4 5]]]
m.dimshuffle(0,'x',1).shape: [2 1 3]
剩下的同理可以理解
m.dimshuffle('x',0,1): [[[0 1 2]
[3 4 5]]]
m.dimshuffle('x',0,1).shape: [1 2 3]
m.dimshuffle(0,1,'x'): [[[0]
[1]
[2]]
[[3]
[4]
[5]]]
m.dimshuffle(0,1,'x').shape: [2 3 1]
m.dimshuffle(1,'x',0): [[[0 3]]
[[1 4]]
[[2 5]]]
m.dimshuffle(1,'x',0).shape: [3 1 2]
theano中的dimshuffle的更多相关文章
- Theano入门笔记1:Theano中的Graph Structure
译自:http://deeplearning.net/software/theano/extending/graphstructures.html#graphstructures 理解Theano计算 ...
- theano中的scan用法
scan函数是theano中的循环函数,相当于for loop.在读别人的代码时第一次看到,有点迷糊,不知道输入.输出怎么定义,网上也很少有example,大多数都是相互转载同一篇.所以,还是要看官方 ...
- Theano中的导数
计算梯度 现在让我们使用Theano来完成一个稍微复杂的任务:创建一个函数,该函数计算相对于其参数x的某个表达式y的导数.为此,我们将使用宏T.grad.例如,我们可以计算相对于的梯度 import ...
- theano中对图像进行convolution 运算
(1) 定义计算过程中需要的symbolic expression """ 定义相关的symbolic experssion """ # c ...
- theano中的concolutional_mlp.py学习
(1) evaluate _lenet5中的导入数据部分 # 导入数据集,该函数定义在logistic_sgd中,返回的是一个list datasets = load_data(dataset) # ...
- Theano2.1.21-基础知识之theano中多核的支持
来自:http://deeplearning.net/software/theano/tutorial/multi_cores.html Multi cores support in Theano 一 ...
- theano中的logisticregression代码学习
1 class LogisticRegression (object): 2 def __int__(self,...): 3 4 #定义一些与逻辑回归相关的各种函数 5 6 def method1( ...
- theano中tensor的构造方法
import theano.tensor as T x = T.scalar('myvar') myvar = 256 print type(x),x,myvar 运行结果: <class 't ...
- Theano入门笔记2:scan函数等
1.Theano中的scan函数 目前先弱弱的认为:相当于symbolic的for循环吧,或者说计算图上的for循环,也可以用来替代repeat-until. 与scan相比,scan_checkpo ...
随机推荐
- HTTP 2.0的那些事
转自:http://www.admin10000.com/document/9310.html 在我们所处的互联网世界中,HTTP协议算得上是使用最广泛的网络协议.最近http2.0的诞生使得它再次互 ...
- 常用的sql数据库语句
1.说明:复制表(只复制结构,源表名:a 新表名:b) (Access可用)法一:select * into b from a where 1 <>1法二:select top 0 * i ...
- MySQL的索引类型和左前缀索引
1.索引类型: 1.1B-tree索引: 注:名叫btree索引,大的方面看,都用的是平衡树,但具体的实现上,各引擎稍有不同,比如,严格的说,NDB引擎,使用的是T-tree,但是在MyISAM,In ...
- percona-toolkit工具包的安装和使用
1.安装与Perl相关的模块 PT工具是使用Perl语言编写和执行的,所以需要系统中有Perl环境 # yum install -y perl perl-devel perl-Time-HiRes p ...
- mac 查看无线wifi的密码
finder->应用程序->实用工具->钥匙串访问->右上角输入wifi名查找->显示密码(需要管理员账号)
- 重走java---Step 1
开发环境 1.使用java开发,首先要完成java运行环境的安装配置,JVM可以说是java最大的优点之一,就是它实现了java一次编译多次运行,关于JVM以后再详谈.安装配置JDK,完成java开发 ...
- 遭遇flash播放mp3诡异问题
在部分ie10+flash player 播放mp3,播放第二句话时,中断无法正常播放,(客户的机器上),自己公司的机器测试了几个,都没发现这个问题.其它浏览器(chrome,firefox)也没发现 ...
- LAMP环境
LAMP = Linux + Apache + MySQL + PHP [1] [2] [3] [4] [1]Linux是一套免费使用和自由传播的类Unix操作系统, ...
- python之路-Day8
抽象接口 class Alert(object): '''报警基类''' def send(self): raise NotImplementedError class MailAlert(Alert ...
- U-Mail邮件系统六项特色服务铸就金口碑
评价一款邮件系统优劣的标准或许有很多,左右你是否选择某个平台的需求或许有不同,但是U-Mail小编必须提醒你:服务水准不可等闲视之!之所以如此, 这是因为:现代社会垃圾邮件猖獗,病毒层出不穷令人防不胜 ...