在Pytorch上使用稀疏矩阵

最近在写一个NLP的小项目,用到了Pytorch做神经网络模型。但是众所周知NLP的一个特点就是特征矩阵是稀疏矩阵,当时处理稀疏矩阵用的是scipy.sparse,现在要把它放到Pytorch中,还是费了一点周折的

首先,如何把python的二维数组(这里以trainData为例)转换为稀疏矩阵呢?这一步很简单,只需要

from scipy.sparse import coo_matrix,然后使用coo_matrix(trainData)就好了

其实 scipy.sparse下面有三种稀疏矩阵,这篇文章有一个大概的介绍:

scipy.sparse.coo_matrix是三元组,不能按行也不能按列切片

to_csr 是按行压缩的稀疏矩阵,按行切片比较快,可以按列切片

to_csc 是按列压缩的稀疏矩阵,按列切片比较快,可以按行切片

这篇文章介绍了稀疏矩阵的COO和CSR存储方式:https://blog.csdn.net/u012101561/article/details/90348288

这里我们使用coo_matrix就好,是因为我们等会要重新创建torch上的稀疏矩阵,这里只要参数就好了。

如何将scipy上的稀疏矩阵转换为torch上的:

values =X_train.data
indices = np.vstack((X_train.row, X_train.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = X_train.shape
X_train=torch.sparse.FloatTensor(i, v, torch.Size(shape))

上面这部分可以写成个函数

要恢复为完整的二维tensor,直接调用X_train的to_dense()方法就好了,返回值就是普通的tensor

但是,遇到了新的问题,Torch上的稀疏矩阵怎么作为神经网络模型的输入呢?我在网上查了半天也没看到,只有一个keras的教程:https://www.jianshu.com/p/a7dadd842f78 。个人觉得在torch上应该也是有办法的,遇到了这个问题的同学可以在Github上查找一些torch做NLP的项目,因为我不是做这个方向的,所以没有深究。

什么,你问我是怎么解决的?我把项目放到内存比个人电脑大得多的服务器上运行了2333

在Pytorch上使用稀疏矩阵的更多相关文章

  1. [转载]PyTorch上的contiguous

    [转载]PyTorch上的contiguous 来源:https://zhuanlan.zhihu.com/p/64551412 这篇文章写的非常好,我这里就不复制粘贴了,有兴趣的同学可以去看原文,我 ...

  2. 将TVM集成到PyTorch上

    将TVM集成到PyTorch上 随着TVM不断展示出对深度学习执行效率的改进,很明显PyTorch将从直接利用编译器堆栈中受益.PyTorch的主要宗旨是提供无缝且强大的集成,而这不会妨碍用户.为此, ...

  3. matlab——sparse函数和full函数(稀疏矩阵和非稀疏矩阵转换)

    函数功能:生成稀疏矩阵 使用方法 :S = sparse(A) 将矩阵A转化为稀疏矩阵形式,即矩阵A中任何0元素被去除,非零元素及其下标组成矩阵S.如果A本身是稀疏的,sparse(S)返回S. S ...

  4. Highway Networks Pytorch

    导读 本文讨论了深层神经网络训练困难的原因以及如何使用Highway Networks去解决深层神经网络训练的困难,并且在pytorch上实现了Highway Networks. 一 .Highway ...

  5. 基于pytorch实现HighWay Networks之Highway Networks详解

    (一)简述---承接上文---基于pytorch实现HighWay Networks之Train Deep Networks 上文已经介绍过Highway Netwotrks提出的目的就是解决深层神经 ...

  6. 【转载】 Caffe BN+Scale层和Pytorch BN层的对比

    原文地址: https://blog.csdn.net/elysion122/article/details/79628587 ------------------------------------ ...

  7. pytorch使用tensorboardX进行网络可视化

    我们知道,对于pytorch上的搭建动态图的代码的可读性非常高,实际上对于一些比较简单的网络,比如alexnet,vgg阅读起来就能够脑补它们的网络结构,但是对于比较复杂的网络,如unet,直接从代码 ...

  8. 库、教程、论文实现,这是一份超全的PyTorch资源列表(Github 2.2K星)

    项目地址:https://github.com/bharathgs/Awesome-pytorch-list 列表结构: NLP 与语音处理 计算机视觉 概率/生成库 其他库 教程与示例 论文实现 P ...

  9. 【PyTorch深度学习】学习笔记之PyTorch与深度学习

    第1章 PyTorch与深度学习 深度学习的应用 接近人类水平的图像分类 接近人类水平的语音识别 机器翻译 自动驾驶汽车 Siri.Google语音和Alexa在最近几年更加准确 日本农民的黄瓜智能分 ...

随机推荐

  1. H5本地存储详解

    H5之前存储数据一般是通过 cookie ,但是 cookie 存的数据容量比较少.H5 中扩充了文件存储能力,可存储多达 5MB 的数据.现在就实际开发经验来对本地存储 ( Storage ) 的使 ...

  2. concurrency parallel 并发 并行 parallelism

    在传统的多道程序环境下,要使作业运行,必须为它创建一个或几个进程,并为之分配必要的资源.当进程运行结束时,立即撤销该进程,以便能及时回收该进程所占用的各类资源.进程控制的主要功能是为作业创建进程,撤销 ...

  3. OOM异常的发生原因

    一,jvm内存区域 1,程序计数器 一块很小的内存空间,作用是当前线程所执行的字节码的行号指示器. 2,java栈 与程序计数器一样,java栈(虚拟机栈)也是线程私有的,其生命周期与线程相同.通常存 ...

  4. C++ STL copy copy_backward

    #include <iostream>#include <algorithm>#include <vector>#include <functional> ...

  5. JAVA WEB开放中的编码问题

    1.getParamter获取GET方式传来的中文参数乱码 场景:A B 两端都为JAVA 所有编码都为UTF-8.GET得到的参数是乱码 原因,getParamter会将中文参数先URLDECODE ...

  6. windows7-tomcat配置

    1.下载 2.解压缩 3.配置环境变量 (1)计算机属性--高级系统配置--高级--环境变量--系统变量--新建 (2)CATALINA_HOME 如:C:\apache-tomcat-7.0.73 ...

  7. JUC AQS ReentrantLock源码分析

    警告⚠️:本文耗时很长,先做好心理准备,建议PC端浏览器浏览效果更佳. Java的内置锁一直都是备受争议的,在JDK1.6之前,synchronized这个重量级锁其性能一直都是较为低下,虽然在1.6 ...

  8. 【ARM-Linux开发】打包解包命令

    tar命令 解包:tar zxvf FileName.tar 打包:tar czvf FileName.tar DirName gz命令 解压1:gunzip FileName.gz 解压2:gzip ...

  9. 在Linux上显示某个进程的线程的几种方式

    方法一:PS 在ps命令中,"-T"选项可以开启线程查看.下面的命令列出了由进程号为的进程创建的所有线程. 1.$ ps -T -p 方法二: Top top命令可以实时显示各个线 ...

  10. 使用kubeadm进行单master(single master)和高可用(HA)kubernetes集群部署

    kubeadm部署k8s 使用kubeadm进行k8s的部署主要分为以下几个步骤: 环境预装: 主要安装docker.kubeadm等相关工具. 集群部署: 集群部署分为single master(单 ...