模型训练的三要素:数据处理、损失函数、优化算法 

   数据处理(模块torch.utils.data)

从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始

from torch.nn import init   # pytorch的init模块提供了多中参数初始化方法

init.normal_(net[0].weight, mean=0, std=0.01)    #初始化net[0].weight的期望为0,标准差为0.01的正态分布tensor
init.constant_(net[0].bias, val=0) #初始化net[0].bias,值为0的常数tensor
# 此外还封装了好多
# init.ones_(w) 初始化一个形状如w的全1分布的tensor,如w是3行5列,则初始化为3行5列的全1tensor
# init.zeros_(w) 初始化一个形状如w的全0分布的tensor
# init.eye_(w) 初始化一个形状如w的对角线为1,其余为0的tensor
# init.sparse_(w,sparsity=0.1) 初始化一个形状如w稀疏性为0.1的稀疏矩阵

 损失函数(模块torch.nn含有大量的神经网络层)

 pytorch的nn模块中定义了各种损失函数,这些损失函数可以看成一种特殊的网络层 

loss = nn.MSELoss()  # 均方误差损失函数
# torch.nn.MSELoss(reduce=True, size_average=True)
# reduce=True,返回标量形式的loss,reduce=False,返回向量形式的loss
# size_average=True,返回loss.mean(),size_average=False,返回loss.sum()
# 默认两者皆为True

 优化算法(模块torch.optim)

torch.optim模块定义了很多的优化算法,如SGD、Adam、RMSProp等

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.03)
print(optimizer) # 对不同的子网络设置不同的学习率
optimizer = optim.SGD([
          # 如果对某个参数不指定学习率,就使用最外层的默认学习率
          {'params':net.subnet1.parameters()}, # lr=0.03
          {'params':net.subnet2.parameters(),'lr':0.01}
],lr=0.03)

  设置动态学习率,不是固定一个常数

  方法1、修改optimizer.param_groups中的学习率

#调整学习率
for param_group in optimizer.param_groups:
param_group['lr'] *= 0.1 # 学习率是之前的0.1倍

  方法2、新建优化器,即构建新的optimizer。使用动量的优化器(如Adam),可能会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。

optimizer = optim.SGD([
{'param':net.subnet1.parameters()},
{'param':net.subnet2.parameters(),'lr':old_lr*0.1}],lr=0.03)

  上述代码若不理解net.subnet1.parameters(),可参考博客 https://www.cnblogs.com/hellcat/p/8496727.html   万分感谢博主

view(-1,1)   # -1是不确定几行的意思,在这就是我不确定要取几行,但是肯定是一列,故view(-1,1);

  torch.view()和numpy.reshape()效果一样,view操作的是tensor,且view后的tensor和原tensor共享内存,修改其中一个,另一个也会改变,reshape()操作的是nparray。

  线性回归  

  torch.nn.Linear(in_features,out_features,bias)

  参数解析:

    in_features:输入特征的数量(或称为特征数特征向量X的维度),即在房价预测中仅和房龄与面积有关,则in_features=2

    out_features:输出特征的数量(同in_features)

    bias:偏置,默认为True

  例子请参考 https://www.cnblogs.com/Archer-Fang/p/10645473.html  感谢博主

小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()的更多相关文章

  1. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  2. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  3. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  4. 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...

  5. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  6. 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())

    先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...

  7. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  8. Pytorch修改ResNet模型全连接层进行直接训练

    之前在用预训练的ResNet的模型进行迁移训练时,是固定除最后一层的前面层权重,然后把全连接层输出改为自己需要的数目,进行最后一层的训练,那么现在假如想要只是把 最后一层的输出改一下,不需要加载前面层 ...

  9. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

随机推荐

  1. springboot - 映射HTTP Response Status Codes 到 静态 HTML页面

    1.总览 2.代码 1).pom.xml <dependencies> <dependency> <groupId>org.springframework.boot ...

  2. css 径向渐变

    .example { width: 150px; height: 80px; background: -webkit-radial-gradient(red, green, blue); /* Saf ...

  3. 在登陆退出时候使用Vuex

    1.登陆的时候,在登陆模块请求接口,然后获取一个access_token,获取用户权限.保存到缓存里面. 2.退出的时候,请求退出接口,把缓存里面的access_token清除. 一旦要在登陆里面做一 ...

  4. POJ 2187:Beauty Contest 求给定一些点集里最远的两个点距离

    Beauty Contest Time Limit: 3000MS   Memory Limit: 65536K Total Submissions: 31414   Accepted: 9749 D ...

  5. 从零开始Windows环境下安装python+tensorflow

    从零开始Windows环境下安装python+tensorflow 2017年07月12日 02:30:47 qq_16257817 阅读数:29173 标签: windowspython机器学习te ...

  6. 3,Structured Streaming使用checkpoint进行故障恢复

    使用checkpoint进行故障恢复 如果发生故障或关机,可以恢复之前的查询的进度和状态,并从停止的地方继续执行.这是使用Checkpoint和预写日志完成的.您可以使用检查点位置配置查询,那么查询将 ...

  7. quartz详解3:quartz数据库集群-锁机制

    http://blog.itpub.NET/11627468/viewspace-1764753/ 一.quartz数据库锁 其中,QRTZ_LOCKS就是Quartz集群实现同步机制的行锁表,其表结 ...

  8. mysql第四篇:数据操作之单表查询

    单表查询 一.简单查询 -- 创建表 DROP TABLE IF EXISTS `person`; CREATE TABLE `person` ( `id` ) NOT NULL AUTO_INCRE ...

  9. Java集合--list接口

    list是一个接口,实现类:Arraylist,Vector,Linkedlist list接口(有序): 常用方法 排除Collection中具有的之外的 添加功能 void add(int ind ...

  10. git 一些操作

    1. 代码相关 克隆代码 git clone xxx.git 拉取代码 git pull 查看 修改的 状态 git status 推送代码 git push add 或者 修改代码之后 回滚到 未修 ...