PyTorch下,使用list放置模块,导致计算设备不一的报错
报错
在复现 Transformer 代码的训练阶段时,发生报错:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
解决方案
通过next(linear.parameters()).device确定 model 已经在 cuda:0 上了,同时输入 model.forward()的张量也位于 cuda:0。输入的张量没什么好推敲的,于是考虑到模型具有多层结构,遂输出每层结构的设备信息,model.encoder -> model.encoder.sublayer[0] ··· ···
测试发现,model.encoder.sublayer[0] 之后的模块的设备信息均位于 cpu,原因是构造这部分模块时,由于需要多个相同的模块,使用了 list 来存放模块:
# module: 需要深拷贝的模块
# n: 拷贝的次数
# return: 深拷贝后的模块列表
def clones(module, n: int) -> list:
return [copy.deepcopy(module) for _ in range(n)]
显然 list 不支持 GPU,需要用 PyTorch 提供的代替:
def clones(module, n: int):
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
ModuleList 把子模块存入列表,能像 Python 里普通的列表被索引,最重要的是能使内部的模块被正确注册,并对所有的 Module 方法可见。[Source]
成功解决!
相关环境
python 3.11.7 he1021f5_0
pytorch 2.1.2 py3.11_cuda12.1_cudnn8_0
PyTorch下,使用list放置模块,导致计算设备不一的报错的更多相关文章
- linux 下通过xhost进入图形界面,经常会出现报错“unable to open display”
linux 下通过xhost进入图形界面,经常会出现报错“unable to open display” linux下的操作步骤如下: [root@localhost ~]# vncserver N ...
- webRTC中回声消除(AEC)模块编译时aec_rdft.c文件报错:
webRTC中回声消除(AEC)模块编译时aec_rdft.c文件报错. 原因是: 局部变量ip跟全局变量冲突的问题,可以将局部变量重新命名一下,就可以通过编译了. aec_rdft.c修改以后文件代 ...
- Token过期导致页面多个请求报错提示多次
关于Token过期导致页面多个请求报错提示的问题 我们先在全局定义一个变量(global.js)来控制token是否过期 export default { // token无效标记 TokenInva ...
- C# UTF8的BOM导致XML序列化与反序列化报错:Data at the root level is invalid. Line 1, position 1.
最近在写一个xml序列化及反序列化实现时碰到个问题,大致类似下面的代码: class Program { static void Main1(string[] args) { var test = n ...
- CentOS下httpd下php 连接mysql 本机可以,外网报错Could not connect: Can't connect to MySQL server on '127.0.0.1' (13)2003 原因解析
php代码很简单: $server="127.0.0.1"; println("Begin"); $link = mysql_connect($server,& ...
- Django时区导致的datetime时间比较报错
我们使用python 的datetime模块比较Django数据库Datetime字段的时候,可能会出现报错: TypeError: can't compare offset-naive and of ...
- win下python脚本以unix风格换行保存将会报错为编码问题 SyntaxError: encoding problem:gbk
utf-8与gbk编码都报错 从别人的github拉下来一个python脚本. 直接运行,python报错如下: File ".\drag_files_do_event.py", ...
- 用pip下载的python模块怎么在PyCharm中引入报错
在IDE中导入下载的模块,比如:numpy模块 你会发现虽然你安装了numpy模块,在CMD中python可以import numpy,但是你在PyCharm引不进去,为什么呢?你要是有注意的话,安装 ...
- window下用notepad++编辑了脚本文件然后放在linux报错显示无法运行
首先vi :set ff 查看文件类型 接着 下载dos2unix root用户下yum -y install dos2unix 然后 dos2unix 文件.sh 转换格式 接着在正常启动即可
- 接口拿到的id和传到后台的id不一致,导致查询详情和编辑报错
碰到这个问题真是百思不得其解.接口上打印的值和数据库一致,浏览器查看response的反馈也一致.但是一在页面打印请求回来的值,就变了,变成了另一个id,但是其他数据又和数据库一致. 查了一圈也没有查 ...
随机推荐
- windows mysql安装及常用命令
安装windows版本mysql只是为本地代码调试,不建议用于生产.觉得步骤麻烦也可以直接下载集成环境(如xampp),一键安装即可用.之前本地测试都用一键安装,今天换个方法玩玩,安装步骤如下: my ...
- Linux查看文件内容与处理文件
Linux查看文件内容与处理文件 目录 Linux查看文件内容与处理文件 查看文件内容 1.查看文件类型 2.查看整个文件 3.查看部分文件 处理文件 1.创建空文件 2.过滤文件内容 3.统计文件内 ...
- [IDEA] - tomcat VM配置
-Dfile.encoding=UTF-8
- VIte+Vue3 打包在本地 双击 index.html 打开项目
npm i @vitejs/plugin-legacy --save import legacy from '@vitejs/plugin-legacy'; export default define ...
- Spring Boot对接Oracle数据库
Spring Boot对接Oracle数据库 最近学习了Oracle数据库,那么如何使用Spring Boot和MyBatis Plus对接Oracle数据库呢? 这就有了这篇随记,具体流程如下 1. ...
- SQL函数——时间函数
1.使用 NOW() . CURDATE().CURTIME() 获取当前时间 在这里我有一个问题想问问大家,你们平时都是怎么样子获取时间的呢?是不是通过手表.手机.电脑等设备了解到的,那么你们有没有 ...
- [转帖]拯救关键业务上线:DBA 的惊魂24小时
一个电话,打破深夜的宁静 9月20日晚上10点 刚完成外地一个重点项目为期2周的现场支持,从机场回家的路上,一阵急促的铃声惊醒了出租车上昏昏欲睡的我,多年的工作经验告诉我这么晚来电一定是出事了,接起电 ...
- [转帖]一次 Java 进程 OOM 的排查分析(glibc 篇)
https://juejin.cn/post/6854573220733911048 遇到了一个 glibc 导致的内存回收问题,查找原因和实验的的过程是比较有意思的,主要会涉及到下面这些: Linu ...
- [转帖]Linux-文本处理三剑客awk详解+企业真实案例(变量、正则、条件判断、循环、数组、分析日志)
https://developer.aliyun.com/article/885607?spm=a2c6h.24874632.expert-profile.313.7c46cfe9h5DxWK 简介: ...
- Linux下面sysstat的安装与简介
https://blog.51cto.com/smoke520/2160073 在Linux系统下获取sysstat-10.0.5.tar.gz的两种方式: 方式一: 下载sysstat-10.0 ...