不是python层面Tensor的剖析,是C层面的剖析。

看pytorch下lib库中的TH好一阵子了,TH也是torch7下面的一个重要的库。

可以在torch的github上看到相关文档。看了半天才发现pytorch借鉴了很多torch7的东西。

pytorch大量借鉴了torch7下面lua写的东西并且做了更好的设计和优化。

https://github.com/torch/torch7/tree/master/doc

pytorch中的Tensor是在TH中实现的。TH = torch

TH中先实现了一个THStorage,再在THStorage的基础上实现了THTensor。

THStorage定义如下,定义在TH/generic/THStorage.h中

 typedef struct THStorage
{
real *data;
ptrdiff_t size;
int refcount;
char flag;
THAllocator *allocator;
void *allocatorContext;
struct THStorage *view;
} THStorage;

这些成员里重点关注*data和size就可以了。

real *data中的real会在预编译的时候替换成预先设计的数据类型,比如int,float,byte等。

比如 int a[3] = {1,2,3},data是数组a的地址,对应的size是3,不是sizeof(a)。

所以*data指向的是一段连续内存。是一维的!

讲Tensor前先回顾下数组在内存中的排列方式。参看《C和指针》8.2节相关内容。

比如 int a[3][6]; 内存中的存储顺序为:

00 01 02 03 04 05 10 11 12 13 14 15 20 21 22 23 24 25

是连续存储的。存储顺序按照最右边的下标率先变化。

然后数组a是2维的,nDimension = 2。dimension从0开始算起。

size(a) = {3,6}
[3] 是 dimension 0    size[0] = 3
[6] 是 dimension 1    size[1] = 6
nDimension = 2

THTensor定义如下,定义在TH/generic/THTensor.h中

 typedef struct THTensor
{
int64_t *size; // 注意是指针
int64_t *stride; // 注意是指针
int nDimension; // Note: storage->size may be greater than the recorded size
// of a tensor
THStorage *storage;
ptrdiff_t storageOffset;
int refcount;
char flag;
} THTensor;

比如

z = torch.Tensor(2,3,4)   // 新建一个张量,size为 2,3,4

size(z) = {2,3,4}
[2] 是 dimension 0    size[0] = 2
[3] 是 dimension 1    size[1] = 3
[4] 是 dimension 2    size[2] = 4
nDimension = 3

THStorage只管理内存,是一维的。

THTensor通过size和nDimension将THStorage管理的一维内存映射成逻辑上的多维张量,

底层还是一维的。但是注意,代表某个Tensor的底层内存是一维的但是未必是连续的!

把Tensor按照数组来理解好了。

Tensor a[3][6]  裁剪(narrow函数)得到一个 Tensor b[3][4],在内存中就是

Tensor a:
Tensor b: x x x x x x

narrow函数并不会真正创建一个新的Tensor,Tensor b还是指向Tensor a的那段内存。

所以Tensor b在内存上就不是连续的了。

那么怎么体现Tensor在内存中是连续的呢?就靠THTensor结构体中的

size,stride,nDimension共同判断了。

pytorch的Tensor有个 contiguous 函数,C层面也有一个对应的函数:

int THTensor_(isContiguous)(const THTensor *self)
判断 Tensor 在内存中是否连续。定义在 TH/generic/THTensor.c 中。
 int THTensor_(isContiguous)(const THTensor *self)
{
int64_t z = ;
int d;
for(d = self->nDimension-; d >= ; d--)
{
if(self->size[d] != )
{
if(self->stride[d] == z)
z *= self->size[d]; // 如果是连续的,应该在这循环完然后跳到下面return 1
else
return ;
}
}
return ;
}

把Tensor a[3][6] 作为这个函数的参数:

size[0] = 3    size[1] = 6    nDimension = 2      z =1
d = 1   if size(1) = 6 != 1   if stride[1] == 1   z = z*size(d)=6
d = 0   if size(0) = 3 != 1   if stride[0] == 6   z = z*size(d)=6*3 = 18
因此,对于连续存储的a
stride = {6,1}
size = {3,6}

再举一个Tensor c[2][3][4]的例子,如果c是连续存储的,则:

stride = {12,4,1}
size =    { 2,3,4}  // 2所对应的stride就是 右边的数相乘(3x4), 3所对应的stride就是右边的数相乘(4)

stride(i)返回第i维的长度。stride又被翻译成步长。

比如第0维,就是[2]所在的维度,Tensor c[ i ][ j ][ k ]跟Tensor c[ i+1 ][ j ][ k ]

在连续内存上就距离12个元素的距离。

对于内存连续的stride,计算方式就是相应的size数右边的数相乘。

所以不连续呢?

对于a[3][6]

stride = {6,1} 
size =   {3,6}

对于从a中裁剪出来的b[3][4]

stride = {6,1} 
size =   {3,4}

stride和size符合不了 右边的数相乘 的计算方法,所以就不连续了。

所以一段连续的一维内存,可以根据size和stride 解释 成  逻辑上变化万千,内存上是否连续 的张量。

比如24个元素,可以解释成 4 x 6 的2维张量,也可以解释成 2 x 3 x 4 的3维张量。

THTensor中的 storageOffset 就是说要从 THStorage 的第几个元素开始 解释 了。

连续的内存能给程序并行化和最优化算法提供很大的便利。

其实写这篇博客是为了给理解 TH 中的 TH_TENSOR_APPLY2 等宏打基础。

这个宏就像是在C中实现了broadcast。

2017年12月11日01:00:22

最近意识到,用 H x W x C 和 C x H x W 哪个来装图像更好,取决于矩阵在内存中是行存储还是

列存储,这个会影响内存读取速度,进而影响算法用时。

后来意识到,这就是个cache-friendly的问题,大部分对程序性能的要求还上升不到要研究算法复杂度

这个地步,常规优化的话注意下缓存友好等问题就好了,再优化就要靠更专业团队写的库或者榨干硬件了。

看了下numpy的文档,怪不得说pytorch是numpy的gpu版本。。。

后来又看了下opencv的mat的数据结构,原来矩阵库都是一毛一样的。。。

对pytorch中Tensor的剖析的更多相关文章

  1. pytorch中tensor数据和numpy数据转换中注意的一个问题

    转载自:(pytorch中tensor数据和numpy数据转换中注意的一个问题)[https://blog.csdn.net/nihate/article/details/82791277] 在pyt ...

  2. [Pytorch]Pytorch中tensor常用语法

    原文地址:https://zhuanlan.zhihu.com/p/31494491 上次我总结了在PyTorch中建立随机数Tensor的多种方法的区别. 这次我把常用的Tensor的数学运算总结到 ...

  3. pytorch中tensor张量数据基础入门

    pytorch张量数据类型入门1.对于pytorch的深度学习框架,其基本的数据类型属于张量数据类型,即Tensor数据类型,对于python里面的int,float,int array,flaot ...

  4. pytorch中tensor的属性 类型转换 形状变换 转置 最大值

    import torch import numpy as np a = torch.tensor([[[1]]]) #只有一个数据的时候,获取其数值 print(a.item()) #tensor转化 ...

  5. pytorch中tensor张量的创建

    import torch import numpy as np print(torch.tensor([1,2,3])) print(torch.tensor(np.arange(15).reshap ...

  6. Pytorch 中 tensor的维度拼接

    torch.stack() 和 torch.cat() 都可以按照指定的维度进行拼接,但是两者也有区别,torch.satck() 是增加新的维度进行堆叠,即其维度拼接后会增加一个维度:而torch. ...

  7. pytorch 中的数据类型,tensor的创建

    pytorch中的数据类型 import torch a=torch.randn(2,3) b=a.type() print(b) #检验是否是该数据类型 print(isinstance(a,tor ...

  8. pytorch之dataloader深入剖析

    PyTorch学习笔记(6)——DataLoader源代码剖析 - dataloader本质是一个可迭代对象,使用iter()访问,不能使用next()访问: - 使用iter(dataloader) ...

  9. PyTorch官方中文文档:PyTorch中文文档

    PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...

随机推荐

  1. 页面检测网络外网连接- 网页基础模块(JavaScript)

    方法一 html 添加图片标签 加载外站图片 <img id="connect-test" style="display:none;" onload=&q ...

  2. PI接口开发之调java WS接口

    java提供的WSDL:http://XXX.XXX.XXX.XX/XXXXXXXcrm/ws/financialStatementsService?wsdl 登陆PI,下载Enterprise Se ...

  3. jvm回收器回收过程一:CMS和 G1的初认知(持续更新中)

    CMS:介绍: 1.CMS(Concurrent Mark-Sweep)是以牺牲吞吐量为代价来获得最短回收停顿时间的垃圾回收器.对于要求服务器响应速度的应用上,这种垃圾回收器非常适合. 在启动JVM参 ...

  4. FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. MetaException(me

    FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. MetaException(me ...

  5. Python-接口自动化(二)

    python基础知识(二) (二)常用控制流 1.控制语句 分支语句:起到一个分支分流的作用,类似马路上的红绿灯 循环语句:for while 可以使代码不断重复的执行 2.判断语句:关键字是if.. ...

  6. Java易错题(1)

    检查程序,是否存在问题,如果存在指出问题所在,如果不存在,说明输出结果. public class HelloB extends HelloA { public HelloB() { } { Syst ...

  7. python笔记13-文件读写

    1.打开文件 f=open('a.txt','a+',encoding='utf-8')#f代表的是文件对象,叫句柄 f.seek(0)把文件指针到最前 文件打开模式有3种: 1:w写模式,它是不能读 ...

  8. wed

    先有一个无后缀的flag 文件 第一次改成 TXT 收索FLAG 得到了一段 flag.txt f返回到第一次修改后缀 改成RAR 打开RAR 发现一个 flag.txt 的文件 打开,即得到 fla ...

  9. Tcl脚本整理照片

    我那个媳妇啊,典型的只管照不管 理,32G的卡竟然被弄满了. 费好大劲好不容易整理到电脑上,可是都是数字名字,看着都头疼,索性整理下. 首先安装tcl编译环境tcl86,度娘搞的,然后开动: proc ...

  10. ubuntu 升级 python3.5到 python3.6

    首先是在Ubuntu中安装python3.6 sudo apt-get install software-properties-common sudo add-apt-repository ppa:j ...