在上一节中,我们介绍了如何使用Pytorch来搭建一个经典的分类神经网络。一般情况下,搭建完模型后训练不会一次就能达到比较好的效果,这样,就需要不断的调整和优化模型的各个部分。从而引出了本文的主旨:如何优化模型。

在本节中,我们将介绍从数据集到模型各个部分的调整,从而可以有一个完整的解决思路。

1、数据集部分

1.1 数据集划分

一般情况下,我们会把数据集分成三个部分:训练集,验证集和测试集。依据数据集的大小,如果数据集比较大,数万或数十万个,可以将数据集采用7:2:1或8:1:1的比例来划分。而如果数据集比较小,只有几百条,就不能简单的使用这个方法了。这时,需要使用K折验证法(具体方法可自行百度)。

当然,还有一些需要考虑的问题:数据表征,时间敏感性和数据冗余。在数据表征中,随机打乱(shuffle)是一个不错的选择;时间敏感性主要是针对回归问题象预测股票,不同的月份对回归结果有一个不同的贡献;数据冗余指的是,在数据集中,存在着一些相同的数据会对训练和测试结果产生影响,所以,需要事先过滤掉。

1.2数据预处理

数据向量化:数据源形式各异,需要提前把它转换成框架可以识别的形式,Pytorch统一使用向量(Vector)来表示数据。

正则化:数据的范围大小不一,如果直接使用,训练的收敛会很慢,甚至会出现异常。所以,需要统一数据的范围大小,也就是去除纲量,使用【0,1】区间来统一度量。

缺失数据的处理:如果没有对缺失数据进行处理,训练过程中会直接导致数据的权重分配异常,进而直接影响训练效果。

特征工程:对数据集的特征进行有效提取,是保证模型正常训练的前提。

1.3过拟合与欠拟合

过拟合:训练效果好而验证效果不好。

欠拟合:训练效果不好。

欠拟合的处理相对容易些,针对欠拟合,我们一般采用加大训练周期,降低训练损失,提高训练精度。

过拟合策略:

1、获取更多数据

2、减小网络规模

原始模型:
class Architecture1(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(Architecture1, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
self.relu = nn.ReLU()
self.fc3 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
减小规模后的模型:
class Architecture2(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(Architecture2, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out

3、使用权重正则化

正则化分为1阶正则化和2阶正则化

     1阶正则化是将权重协相关系数的相差绝对值加入权重。

2阶正则化是将权重协相关系数的相差平方和加入权重。示例如下:

model = Architecture1(10,20,2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

4、使用DROPOUT

在隐藏层中去除某些节点,以达到防止过拟合的问题。

dropout的比率为0.2:

dropout的比率为0.5

nn.dropout(x, training=True)

1.4问题定义与数据集获取

首先需要明确两个事情:问题的类别与数据的输入,确定是分类问题还是回归问题。

不同类别的问题有着不同的处理方法,对数据集的获取也是必须面对的一大难题。

1.5模型评估

对于分类问题,一般采用精度,ROC,AUC等方法来进行评估。

而对于排名问题,一般采用mAp。

2、模型部分

2.1 搭建完基础模型后,为了使该模型能够正常工作,我们需要做以下三部分工作:

1、选择网络输出的最后一层

不同的任务,输出最后一层也不尽相同。一般的回归问题只要输出一个标量就可以;向量回归问题则需要输出相同层的向量;对于BBOX问题,则需要输出四个值;对于

二分类,我们需要使用Sigmoid,对于多分类则使用softmax。

2、选择损失函数

对于分类问题,一般采用交叉熵损失;而对于回归问题,则采用均方差。

3、优化器

如何选择一个优化器及配置相关参数是一件非常有艺术性的事。有时需要通过实验来得到。很多时候:Adam和RMSProp是个不错的选择。

Problem type                   Activation function            Loss function
Binary classification Sigmoid activation nn.CrossEntropyLoss()
Multi-class classification Softmax activation nn.CrossEntropyLoss()
Multi-label classification Sigmoid activation nn.CrossEntropyLoss()
Regression None MSE
Vector regression None MSE

2.2 提高模型规模

对于一个已搭建好的模型,如何提高模型的推理能力。可以从这三方面来提高:

1、增加更多的层

2、加入更多的权重系数

3、提高训练周期

2.3 加入泛化策略

1、加入dropout

2、使用不同的架构,不同的参数,不同的网络层数,权重。

3、使用L1或L2正则化

4、尝试不同的学习率

5、增加更多的数据或特征

2.4学习率的设置

学习率对于模型来说,是一个非常重要的超参数。它的设置很多时候直接决定着模型训练效果的好坏。所以,如何设置该参数就变得非常重要。有大量的研究就是针对于该参数进行的。

在Pytorch中,有一系列的方法:

1、stepLR:

scheduler = StepLR(optimizer, step_size=30, gamma=0.1)   #step_size:多少个周期后学习率发生改变  gamma:学习率如何你改变
for epoch in range(100):
scheduler.step()
train(...)
validate(...)

2、MultiStepLR

scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
#milestones:多少个周期后学习率发生改变 gamma:学习率如何你改变
for epoch in range(100): scheduler.step() train(...) validate(...)

3、ExponentialLR

4、ReduceLROnPlateau

optimizer = torch.optim.SGD(model.parameters(), lr=0.1,
momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):
train(...)
val_loss = validate(...)
# Note that step should be called after validate()
scheduler.step(val_loss)

上一篇:

如何入门Pytorch之二:如何搭建实用神经网络

下一篇:

待更新。。。

如何入门Pytorch之三:如何优化神经网络的更多相关文章

  1. 如何入门Pytorch之四:搭建神经网络训练MNIST

    上一节我们学习了Pytorch优化网络的基本方法,本节我们将以MNIST数据集为例,通过搭建一个完整的神经网络,来加深对Pytorch的理解. 一.数据集 MNIST是一个非常经典的数据集,下载链接: ...

  2. 如何入门Pytorch之二:如何搭建实用神经网络

    上一节中,我们介绍了Pytorch的基本知识,如数据格式,梯度,损失等内容. 在本节中,我们将介绍如何使用Pytorch来搭建一个经典的分类神经网络. 搭建一个神经网络并训练,大致有这么四个部分: 1 ...

  3. 60 分钟极速入门 PyTorch

    2017 年初,Facebook 在机器学习和科学计算工具 Torch 的基础上,针对 Python 语言发布了一个全新的机器学习工具包 PyTorch. 因其在灵活性.易用性.速度方面的优秀表现,经 ...

  4. 如何入门Pytorch之一:Pytorch基本知识介绍

    前言 PyTorch和Tensorflow是目前最为火热的两大深度学习框架,Tensorflow主要用户群在于工业界,而PyTorch主要用户分布在学术界.目前视觉三大顶会的论文大多都是基于PyTor ...

  5. 新手如何入门pytorch?

    我最近的文章中,专门为想学Pytorch的新手推荐了一些学习资源,包括教程.视频.项目.论文和书籍.希望能对你有帮助:一.PyTorch学习教程.手册 (1)PyTorch英文版官方手册:https: ...

  6. 【OpenCV入门教程之三】 图像的载入,显示和输出 一站式完全解析(转)

    本系列文章由@浅墨_毛星云 出品,转载请注明出处. 文章链接:http://blog.csdn.net/poem_qianmo/article/details/20537737 作者:毛星云(浅墨)  ...

  7. Asp.Net MVC4.0 官方教程 入门指南之三--添加一个视图

    Asp.Net MVC4.0 官方教程 入门指南之三--添加一个视图 在本节中,您需要修改HelloWorldController类,从而使用视图模板文件,干净优雅的封装生成返回到客户端浏览器HTML ...

  8. PyTorch-Adam优化算法原理,公式,应用

    概念:Adam 是一种可以替代传统随机梯度下降过程的一阶优化算法,它能基于训练数据迭代地更新神经网络权重.Adam 最开始是由 OpenAI 的 Diederik Kingma 和多伦多大学的 Jim ...

  9. 深度学习之入门Pytorch(1)------基础

    目录: Pytorch数据类型:Tensor与Storage 创建张量 tensor与numpy数组之间的转换 索引.连接.切片等 Tensor操作[add,数学运算,转置等] GPU加速 自动求导: ...

随机推荐

  1. python中Requests的重试机制

    requests原生支持 import requests from requests.adapters import HTTPAdapter s = requests.Session() # 重试次数 ...

  2. Tomcat 部署方式

    显示 部署 1.添加context元素方式(server.xml) <Host appBase="webapps" autoDeploy="true" n ...

  3. 如何使用 python 接入虹软 ArcFace SDK

    公司需要在项目中使用人脸识别SDK,并且对信息安全的要求非常高,在详细了解市场上几个主流人脸识别SDK后,综合来看虹软的Arcface SDK比较符合我们的需求,它提供了免费版本,并且可以在离线环境下 ...

  4. 【DSP开发】DSP COFF 与 ELF文件

    本文介绍了C6000最新的v7.2或者之后的编译器如何支持ELF(EABI)和COFF-ABI格式,首先由ARM引入嵌入式(Embedded) EABI的介绍,之后比较了COFF-ABI和EABI的区 ...

  5. 冲刺Noip2017模拟赛3 解题报告——五十岚芒果酱

    题1  素数 [问题描述] 给定一个正整数N,询问1到N中有多少个素数. [输入格式]primenum.in 一个正整数N. [输出格式]primenum.out 一个数Ans,表示1到N中有多少个素 ...

  6. vscode 安装一些快捷配置

    Visual Studio Code 最好的功能.插件和设置   小编推荐:掘金是一个高质量的技术社区,从 ECMAScript 6 到 Vue.js,性能优化到开源类库,让你不错过前端开发的每一个技 ...

  7. 三分钟搞定Python中的装饰器

    python的装饰器是python的特色高级功能之一,言简意赅得说,其作用是在不改变其原有函数和类的定义的基础上,给他们增添新的功能. 装饰器存在的意义是什么呢?我们知道,在python中函数可以调用 ...

  8. ~ android与ios的区别

    Oracle与Mysql的区别 项目类别 android ios 应用上 可以使用常用的android模拟器,来模拟各种android设备 只能直接使用iphone或ipad进行测试 开发语言 基于L ...

  9. 【背包问题】PACKING

    题目描述 It was bound to happen.  Modernisation has reached the North Pole.  Faced with escalating costs ...

  10. MySQL 聚合函数(三)MySQL对GROUP BY的处理

    原文来自MySQL 5.7 官方手册:12.20.3 MySQL Handling of GROUP BY SQL-92和更早版本不允许SELECT列表,HAVING条件或ORDER BY列表引用未在 ...