一、网络模型的保存和加载

1、网络模型保存方法1

import torch
import torchvision vgg16 = torchvision.models.vgg16(weights=False)
# 保存方法1:模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

运行上述代码会发现在其同路径下保存了神经网络模型文件:vgg16_model1.pth

加载代码:

import torch

# 方法1  -> 保存方法1,加载模型
model = torch.load("vgg16_method1.pth")
print(model)

结果:

VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)

保存了网络模型及模型的参数:

注:没有预训练的模型不是没有参数,而是参数在初始化的状态

2、网络模型保存方法2

保存的是模型参数(官方推荐):

import torch
import torchvision vgg16 = torchvision.models.vgg16(weights=False)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
import torch

model = torch.load("vgg16_method2.pth")
print(model)

结果:

OrderedDict([('features.0.weight', tensor([[[[-9.0302e-02,  6.1546e-02, -1.7735e-02],
[ 1.1606e-01, -1.7557e-02, -5.4266e-02],
[-3.0833e-02, 2.3019e-02, 2.2968e-02]], [[-3.5706e-02, -3.8619e-02, 2.7329e-02],
[ 1.0525e-02, 7.0172e-02, -4.3097e-02],
[-7.9473e-03, -2.8735e-02, -4.3932e-02]], [[ 6.6814e-02, -6.1849e-02, -9.8496e-02],
[-5.7835e-02, 3.3374e-02, 3.2937e-02],
[-4.3170e-02, -3.1252e-02, 1.1314e-01]]], [[[ 6.6068e-02, -6.5313e-02, -8.0335e-02],
[-1.5587e-02, 1.1784e-02, -8.8468e-03],
[ 7.2871e-02, 7.5150e-02, -7.2230e-02]], [[-3.7871e-02, 1.8217e-02, 1.1531e-01],
[ 5.7616e-02, -1.2748e-01, 2.3816e-02],
[-4.1781e-02, -2.1523e-02, 6.2196e-02]], [[-2.0698e-03, 8.8641e-02, 3.1991e-02],
[-8.9041e-02, -1.1210e-01, -7.8223e-04],
[-2.9659e-02, -1.5199e-01, 3.9977e-06]]],       ......

两种模型保存的大小不一样:

从上述输出结果中得到的结果是字典类型,其中参数的值也一起输出来了,如果想要查看具体的网络结构,则需要增加下述代码:

# 方式2-> 保存方式2,加载模型结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2.pth")) # 输出完整的模型结构,与第一种方式输出的模型结构相同
print(vgg16)

pytorch学习笔记(9)--神经网络模型的保存与读取的更多相关文章

  1. matlab学习笔记4--多媒体文件的保存和读取

    一起来学matlab-matlab学习笔记4 数据导入和导出_2 多媒体文件的保存和读取 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考书籍 <matlab 程序设计与综合应用&g ...

  2. TensorFlow学习笔记(8)--网络模型的保存和读取【转】

    转自:http://blog.csdn.net/lwplwf/article/details/62419087 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等 ...

  3. 使用PyTorch简单实现卷积神经网络模型

    这里我们会用 Python 实现三个简单的卷积神经网络模型:LeNet .AlexNet .VGGNet,首先我们需要了解三大基础数据集:MNIST 数据集.Cifar 数据集和 ImageNet 数 ...

  4. barabasilab-networkScience学习笔记3-随机网络模型

    第一次接触复杂性科学是在一本叫think complexity的书上,Allen博士很好的讲述了数据结构与复杂性科学,barabasi是一个知名的复杂性网络科学家,barabasilab则是他所主导的 ...

  5. PyTorch学习笔记6--案例2:PyTorch神经网络(MNIST CNN)

    上一节中,我们使用autograd的包来定义模型并求导.本节中,我们将使用torch.nn包来构建神经网络. 一个nn.Module包含各个层和一个forward(input)方法,该方法返回outp ...

  6. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  7. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  8. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  9. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  10. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

随机推荐

  1. immutable.js学习笔记(七)----- Seq

    一.Seq 懒得意思就是"不运算,不执行" 二.运行 当console.log这个值的时候,才去观察 三.任意collection 四.Seq.keyed 五.Seq.Indexe ...

  2. 编程哲学之 C# 篇:006——什么是 .NET

    本章将用本系列第二章中提到的 类比 思维来让读者快速了解什么是.NET. 当年在网上看到一个初学者问<Java编程思想>第一章看不懂怎么办.然后我发现在很多经典的技术书中,如<C#入 ...

  3. Selenium中的option用法实例

    Selenium中的option用法实例 在上一篇文章Selenium中免登录的实现方法一option中我们用到了option,而option的用法是很多的,本文举几个例子 关于无头浏览器,也属于op ...

  4. 职场IT老手教你3步教你玩转可视化大屏设计,让领导眼前一亮!

    我是制造企业的IT中心的研发人员,平常工作就是配合业务部门出出报表,选型一些商业软件,并在内部负责实施运维.最近领导出去参观了一些数字化转型比较领先的工厂和制造企业,回来就甩给我几张图,问能不能我们也 ...

  5. 对Jim博士质疑的质疑

    ​ 我只是中科大一个本科生,不像Jim博士那样顶了博士的帽子.去年他上头条的时候评论了他的一篇文章. 看了他的一些文章,感觉他对国内科研现状以及和美西方的差距非常了解,并且做了大量的调研,站在国家的立 ...

  6. 下篇 | 使用 🤗 Transformers 进行概率时间序列预测

    在<使用 Transformers 进行概率时间序列预测>的第一部分里,我们为大家介绍了传统时间序列预测和基于 Transformers 的方法,也一步步准备好了训练所需的数据集并定义了环 ...

  7. RabbitMQ基础和解疑

    一.基础概念 1. Producer:生产者,就是投递消息的一方 消息一般可以包含2个部分:消息体和标签(Label).消息体也可以称之为payload,在实际应用中,消息体一般是一个带有业务逻辑结构 ...

  8. [EULAR文摘] 利用蛋白组学技术开发一项蛋白评分用于预测TNFi疗效

    利用蛋白组学技术开发一项蛋白评分与临床参数联用可以增强对TNF拮抗剂对RA疗效的预测效能 Cuppen BV, et al. EULAR 2015. Present ID: OP0130. 背景: 对 ...

  9. 地理探测器简介(R语言)

    地理探测器 1. 地理探测器原理 空间分异性是地理现象的基本特点之一.地理探测器是探测和利用空间分异性的工具.地理探测器包括4个探测器. 分异及因子探测:探测Y的空间分异性:以及探测某因子X多大程度上 ...

  10. Postgresql模板数据库之template1 和 template0

    一.简介 template1和template0是PostgreSQL的模板数据库.所谓模板数据库就是创建新database时,PostgreSQL会基于模板数据库制作一份副本,其中会包含所有的数据库 ...