1.模型转为cuda

gpus = [0]   #使用哪几个GPU进行训练,这里选择0号GPU
cuda_gpu = torch.cuda.is_available() #判断GPU是否存在可用
net = Net(12288, 25, 16, 6)
if(cuda_gpu):
net = torch.nn.DataParallel(net, device_ids=gpus).cuda() #将模型转为cuda类型

2.数据转为cuda

(minibatchX, minibatchY) = minibatch
minibatchX = minibatchX.astype(np.float32).T
minibatchY = minibatchY.astype(np.float32).T
if(cuda_gpu):
b_x = Variable(torch.from_numpy(minibatchX).cuda()) #将数据转为cuda类型
b_y = Variable(torch.from_numpy(minibatchY).cuda())
else:
b_x = Variable(torch.from_numpy(minibatchX))
b_y = Variable(torch.from_numpy(minibatchY))

3.输出数据去cuda,转为numpy

correct_prediction = sum(torch.max(output, 1)[1].data.squeeze() == torch.max(b_y, 1)[1].data.squeeze())
if(cuda_gpu):
correct_prediction = correct_prediction.cpu().numpy() #.cpu将cuda转为tensor类型,.numpy将tensor转为numpy类型
else:
correct_prediction = correct_prediction.numpy()

linux输入nvidia-smi,可以看到调用GPU成功!

Pytorch:使用GPU训练的更多相关文章

  1. Pytorch多GPU训练

    Pytorch多GPU训练 临近放假, 服务器上的GPU好多空闲, 博主顺便研究了一下如何用多卡同时训练 原理 多卡训练的基本过程 首先把模型加载到一个主设备 把模型只读复制到多个设备 把大的batc ...

  2. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/weixin_40087578/artic ...

  3. pytorch 指定GPU训练

    # 1: torch.cuda.set_device(1) # 2: device = torch.device("cuda:1") # 3:(官方推荐)import os os. ...

  4. pytorch 多GPU训练过程中出现ap=0情况

    原因可能是pytorch 自带的BN bug:安装nvidia apex 可以解决: $ git clone https://github.com/NVIDIA/apex $ cd apex $ pi ...

  5. Pytorch中多GPU训练指北

    前言 在数据越来越多的时代,随着模型规模参数的增多,以及数据量的不断提升,使用多GPU去训练是不可避免的事情.Pytorch在0.4.0及以后的版本中已经提供了多GPU训练的方式,本文简单讲解下使用P ...

  6. PyTorch Tutorials 4 训练一个分类器

    %matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...

  7. Pytorch使用分布式训练,单机多卡

    pytorch的并行分为模型并行.数据并行 左侧模型并行:是网络太大,一张卡存不了,那么拆分,然后进行模型并行训练. 右侧数据并行:多个显卡同时采用数据训练网络的副本. 一.模型并行 二.数据并行 数 ...

  8. MinkowskiEngine多GPU训练

    MinkowskiEngine多GPU训练 目前,MinkowskiEngine通过数据并行化支持Multi-GPU训练.在数据并行化中,有一组微型批处理,这些微型批处理将被送到到网络的一组副本中. ...

  9. 使用Deeplearning4j进行GPU训练时,出错的解决方法

    一.问题 使用deeplearning4j进行GPU训练时,可能会出现java.lang.UnsatisfiedLinkError: no jnicudnn in java.library.path错 ...

  10. tensorflow使用多个gpu训练

    关于多gpu训练,tf并没有给太多的学习资料,比较官方的只有:tensorflow-models/tutorials/image/cifar10/cifar10_multi_gpu_train.py ...

随机推荐

  1. 客户端相关知识学习(八)之Android“.9.png”

    参考 Android中.9图片的含义及制作教程 .9.png Android .9.png 的介绍

  2. C++:函数先声明后实现

    贼神奇的是,直到昨天在写flex规则的时候我才知道C++中的函数要么在使用之前先定义,要么将实现放在调用之前,不允许先调用后实现.之前一年多竟然不知道这件事,汗````,当然也是可能这件事本身和我思考 ...

  3. Rsyslog服务器的安装与配置

    一.Rsyslog服务器的安装与配置 1.清空iptabels, 关闭selinux避免安装过中报错 清空iptables iptables -F service iptables save 关闭se ...

  4. Html5+Css3小试牛刀

    前因: 我开始做个收款系统,突然客户跑来要插进一个任务,据说他们老板挺在意的,一个小商场,一个首页,一个详情页,UI无自由发挥,要求,尽量好看点. 一番交谈后,确认这是一个对外的网站,最好移动端也能正 ...

  5. Java反射【三、方法的反射】

    获取一个类下的所有方法 可以获取类类型后,获取到所有方法及相关信息 Method[] ms = c.getMethods(); 获取方法列表(public) Method[] ms = c.getDe ...

  6. centos8/redhat8 无法上网,通过启动systemctl start NetworkManger搞定

    在systemd里面,可以直接使用systemctl进行管理 启动:systemctl start NetworkManger 关闭:systemctl stop NetworkManager 开机启 ...

  7. idea 下gradle创建springboot 项目

    InterlijIdea 开发环境下创建基于springBoot的项目. 环境 1.jdk1.5以上 2.interlijidea 15 以上 步骤 1.File –>new –>Proj ...

  8. maven 项目下 Maven Dependencies 下列表为空

    问题如题,如下图: 解决: 选中 Maven Dependencies ,右键 属性 如下图: 把   resolve dependencies from workspace projects   这 ...

  9. div contenteditable 代替Textarea,做成Vue属性动态绑定

    前言 一般都是用Textarea 文本来编辑,但发现可以用 div contenteditable = “true”,这个属性来搞定 <div contenteditable=true plac ...

  10. nodejs+mysql 批量更新

    没办法,只能通过循环一次次更新: var updateMysql = function(){ data = excel[0].data; for(var i=1; i<data.length; ...