【PyTorch】state_dict详解
这篇博客来自csdn,完全用于学习。
Introduce
在pytorch中,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数,state_dict作为python的字典对象将每一层的参数映射成tensor张量,需要注意的是torch.nn.Module模块中的state_dict只包含卷积层和全连接层的参数,当网络中存在batchnorm时,例如vgg网络结构,torch.nn.Module模块中的state_dict也会存放batchnorm's running_mean,关于batchnorm详解可见https://blog.csdn.net/wzy_zju/article/details/81262453
torch.optim模块中的Optimizer优化器对象也存在一个state_dict对象,此处的state_dict字典对象包含state和param_groups的字典对象,而param_groups key对应的value也是一个由学习率,动量等参数组成的一个字典对象。
因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。
Sample
通过一个简单的案例来输出state_dict字典对象中存放的变量
#encoding:utf-8 import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as mp
import matplotlib.pyplot as plt
import torch.nn.functional as F #define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(16*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10) def forward(self,x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=x.view(-1,16*5*5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x def main():
# Initialize model
model = TheModelClass() #Initialize optimizer
optimizer=optim.SGD(model.parameters(),lr=0.001,momentum=0.9) #print model's state_dict
print('Model.state_dict:')
for param_tensor in model.state_dict():
#打印 key value字典
print(param_tensor,'\t',model.state_dict()[param_tensor].size()) #print optimizer's state_dict
print('Optimizer,s state_dict:')
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name]) if __name__=='__main__':
main()
output:
Model.state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer,s state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
【PyTorch】state_dict详解的更多相关文章
- Pytorch框架详解之一
Pytorch基础操作 numpy基础操作 定义数组(一维与多维) 寻找最大值 维度上升与维度下降 数组计算 矩阵reshape 矩阵维度转换 代码实现 import numpy as np a = ...
- 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)
首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...
- 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)
本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...
- Pytorch Sampler详解
关于为什么要用Sampler可以阅读一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系. 本文我们会从源代码的角度了解Sampler. Sampler 首先需要 ...
- 目标检测之Faster-RCNN的pytorch代码详解(模型准备篇)
十月一的假期转眼就结束了,这个假期带女朋友到处玩了玩,虽然经济仿佛要陷入危机,不过没关系,要是吃不上饭就看书,吃精神粮食也不错,哈哈!开个玩笑,是要收收心好好干活了,继续写Faster-RCNN的代码 ...
- Pytorch autograd,backward详解
平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考.以下笔记基于Pytorch1.0 ...
- pytorch之nn.Conv1d详解
转自:https://blog.csdn.net/sunny_xsc1994/article/details/82969867,感谢分享 pytorch之nn.Conv1d详解
- [转载]Pytorch详解NLLLoss和CrossEntropyLoss
[转载]Pytorch详解NLLLoss和CrossEntropyLoss 来源:https://blog.csdn.net/qq_22210253/article/details/85229988 ...
- 【小白学PyTorch】11 MobileNet详解及PyTorch实现
文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...
- 【小白学PyTorch】21 Keras的API详解(上)卷积、激活、初始化、正则
[新闻]:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测.医学图像.时间序列等多个目标为技术学习的分群和水群唠嗑答疑解惑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会.微信:cyx6450 ...
随机推荐
- ios证书免费分享
首先,ios证书能不能分享给别人使用,能否用别人的证书打包呢? 这个问题的答案在技术上是肯定可以的,但是我要解释一下,技术上可以,不代表真的就可以这样做,为什么呢? 首先,假如用别人的苹果开发者账号的 ...
- ELK多租户方案
一.前言 日志分析是目前重要的系统调试和问题排查的重要手段之一,而目前分布式系统由于实例和机器众多,所以构建一套统一日志系统是非常必要的:ELK提供了一整套解决方案,并且都是开源软件,之间互相配合使用 ...
- Linux下搭建Elasticsearch7.6.2集群
使用VMvare创建虚拟机 我的创建的三台分别是: 192.168.115.129 node-1 192.168.115.130 node-2 192.168.115.131 node-3 注意:克隆 ...
- Nuxt.js必读:轻松掌握运行时配置与 useRuntimeConfig
title: Nuxt.js必读:轻松掌握运行时配置与 useRuntimeConfig date: 2024/7/29 updated: 2024/7/29 author: cmdragon exc ...
- 老旧 Linux 系统搭建现代 C++ 开发环境 —— 基于 neovim
问题背景 公司配发的电脑是 macOS,日常开发需要访问 Linux 虚拟机,出于安全方面的考虑,只能通过跳板机登录.这阻止了大多数远程图形界面的使用,让写代码的工作变得复杂起来,市面上非常好用的 V ...
- 关于spring boot中mapper注入到service时IDEA报错的解决办法
虽然这个错误不影响正常运行但是作为强迫症患者看着实属难受,经过在论坛查看资料学习到以下两种解决方法,可以供大家参考以下,如有什么错误的地方还希望各位大佬指定一下. 1.在注解@Autowired后增加 ...
- 【Mybatis】02 快速入门Part2 补完CRUD
这是我们的UserMapper.xml文件 <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE ...
- 【Windows】XP系统安装TIM/QQ 数字签名过期问题
需要手动安装数字签名 右键安装包 -> 属性 但是我的TIM没有用,对QQ是有效的 参考自视频: https://www.bilibili.com/video/av413122971/
- csv或excel文件通过plsql导入到oracle数据库中
1.背景 实际开发中经常遇到将数据直接导入到数据库中,操作如下 2.操作 第一步: 第二步:选择要导入的csv文件 第三步:选择数据库表字段与csv的列对应,然后点击导入,完成 完美!
- 微服务全链路跟踪:jaeger坑之NoSuchMethodError: io.jaegertracing.agent.thrift.Agent$Client.sendBaseOneway
在jaeger使用过程中遇到了一个奇怪的问题,本来jaeger运行的好好的,jaeger配置与依赖都没动,就上了一个版本,结果jaeger就没上报监控数据了,由于生产上没打印info日志,后面在本地试 ...