[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载预测模型问题

上一篇实践学习中,遇到了在多/单个GPU、GPU与CPU的不同环境下训练保存、加载使用使用模型的问题,如果保存、加载的上述三类环境不同,加载时会出错。就去研究了一下,做了实验,得出以下结论:

多/单GPU训练保存模型参数、CPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.module.state_dict(), PATH) #加载
net = Net()
net.load_state_dict(torch.load(PATH))

多GPU训练模型、单GPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.state_dict(), PATH) #加载
net = Net()
net = nn.DataParallel(net) #保存多GPU的,在加载时需要把网络也转成DataParallel的
net.to(device) #放到GPU上
net.load_state_dict(torch.load(PATH)) # 然后测试数据也需要放到GPU上
images, labels = images.to(device), labels.to(device)

多GPU训练保存模型参数、多GPU加载使用模型

#保存
PATH = 'cifar_net.pth'
torch.save(net.state_dict(), PATH) #加载
net = Net()
net = nn.DataParallel(net) #保存多GPU的,在加载时需要把网络也转成DataParallel的
net.to(device) #放到GPU上
net.load_state_dict(torch.load(PATH)) # 然后测试数据也需要放到GPU上
images, labels = images.to(device), labels.to(device)

可以看到,单GPU和多GPU加载数据的方法其实是一样的,经运行验证,只要按上述代码写,有多个GPU就调用多个,只有一个就调用一个。

另外,保存、加载网络模型有三种不同的做法

1.保存整个网络模型

2.只保存模型参数(我们用的就是这种)

3.自定义保存

详细方法,请参考:https://blog.csdn.net/Code_Mart/article/details/88254444

[深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题的更多相关文章

  1. VSTO学习笔记(三) 开发Office 2010 64位COM加载项

    原文:VSTO学习笔记(三) 开发Office 2010 64位COM加载项 一.加载项简介 Office提供了多种用于扩展Office应用程序功能的模式,常见的有: 1.Office 自动化程序(A ...

  2. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  3. 使用Pytorch在多GPU下保存和加载训练模型参数遇到的问题

    最近使用Pytorch在学习一个深度学习项目,在模型保存和加载过程中遇到了问题,最终通过在网卡查找资料得已解决,故以此记之,以备忘却. 首先,是在使用多GPU进行模型训练的过程中,在保存模型参数时,应 ...

  4. 深度学习“引擎”之争:GPU加速还是专属神经网络芯片?

    深度学习“引擎”之争:GPU加速还是专属神经网络芯片? 深度学习(Deep Learning)在这两年风靡全球,大数据和高性能计算平台的推动作用功不可没,可谓深度学习的“燃料”和“引擎”,GPU则是引 ...

  5. [深度学习] Pytorch学习(一)—— torch tensor

    [深度学习] Pytorch学习(一)-- torch tensor 学习笔记 . 记录 分享 . 学习的代码环境:python3.6 torch1.3 vscode+jupyter扩展 #%% im ...

  6. pytorch在CPU和GPU上加载模型

    pytorch允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上.CPU->CPU,GPU->GPU torch.load('gen_500000.pkl ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. 炫!一组单元素实现的 CSS 加载进度提示效果

    之前的文章个大家分享过各种类型的加载效果(Loading Effects),这里再给大家奉献一组基于单个元素实现的 CSS 加载动画集合.这些加载效果都是基于一个 DIV 元素实现的,十分强悍. 温馨 ...

  9. Android 学习笔记之Volley(八)实现网络图片的数据加载

    PS:最后一篇关于Volley框架的博客... 学习内容: 1.使用ImageRequest.java实现网络图片加载 2.使用ImageLoader.java实现网络图片加载 3.使用NetWork ...

随机推荐

  1. commons-fileload图片文件上传工具 , servlet文件图片上传案列

    本案列是java  maven工程小项目,提供个大家学习! 1.在pom.xml文件中导入依赖: <!--文件上传依赖--><dependency> <groupId&g ...

  2. Redis 分布式锁(一)

    前言 本文力争以最简单的语言,以博主自己对分布式锁的理解,按照自己的语言来描述分布式锁的概念.作用.原理.实现.如有错误,还请各位大佬海涵,恳请指正.分布式锁分两篇来讲解,本篇讲解客户端,下一篇讲解r ...

  3. Ethical Hacking - GAINING ACCESS(10)

    CLIENT SIDE ATTACKS Use if server-side attacks fail. If IP is probably useless. Require user interac ...

  4. Python Ethical Hacking - VULNERABILITY SCANNER(3)

    Polish the Python code using sending requests in a session Class Scanner. #!/usr/bin/env python impo ...

  5. T4 分配时间 题解

    问题描述 小王参加的考试是几门科目的试卷放在一起考,一共给 t 分钟来做.他现在已经知道每 门科目花的时间和得到的分数的关系,还有写名字要的时间(他写自己的名字很慢)请帮他 算一下他最高能得几分.总分 ...

  6. three.js 将图片马赛克化

    这篇郭先生来说说BufferGeometry,类型化数组和粒子系统的使用,并且让图片有马赛克效果(同理可以让不清晰的图片清晰化),如图所示.在线案例点击博客原文 1. 解析图片 解析图片和上一篇一样 ...

  7. 学习jvm(一)--java内存区域

    前言 通过学习深入理解java虚拟机的教程,以及自己在网上的查询的资料,做一个对jvm学习过程中的小总结. 本文章内容首先讲解java的内存分布区域,之后讲内存的分配原则以及内存的监控工具.再下来会着 ...

  8. 17 个 Python 特别实用的操作技巧,记得收藏!

    Python 是一门非常优美的语言,其简洁易用令人不得不感概人生苦短.在本文中,作者 Gautham Santhosh 带我们回顾了 17 个非常有用的 Python 技巧,例如查找.分割和合并列表等 ...

  9. jmeter压力测试报错:java.net.BindException: Address already in use: connect || java.net.SocketException: Socket closed

    windows提供给TCP/IP链接的端口为 1024-5000,并且要四分钟来循环回收它们,就导致我们在短时间内跑大量的请求时将端口占满了,导致如上报错. 解决办法(在jmeter所在服务器操作): ...

  10. django-rest-framework-源码解析002-序列化/请求模块/响应模块/异常处理模块/渲染模块/十大接口

    简介 当我们使用django-rest-framework框架时, 项目必定是前后端分离的, 那么前后端进行数据交互时, 常见的数据类型就是xml和json(现在主流的是json), 这里就需要我们d ...