小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是《动手学深度学习》这本花书上的内容,图也采用的书上的
首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望)
模型选择
验证数据集(validation data set),又叫验证集(validation set),指用于模型选择的在train set和test set之外预留的一小部分数据集
若训练数据不够时,预留验证集也是一种luxury。常采用的方法为K折交叉验证。原理为:把train set分割成k个不重合的子数据集(SubDataset),然后做k次模型训练和验证。每次训练中,用一个SubDataset作为validation set,其余k-1个SubDataset作为train set。最后对k次训练误差和验证误差求平均(mean)
欠拟合:模型无法得到较低的训练误差,即训练误差降低不了
过拟合:模型训练误差远小于在测试集上的误差
解决欠拟合和过拟合的方法有二:其一,针对数据选择合适的复杂度模型(模型复杂度过高,易出现过拟合;否则易出现欠拟合)。其二,训练数据集大小(train set过少,则容易过拟合。没有否则)
torch.pow():求tensor的幂次(pow是power(有幂次的意思)的缩写),比如求tensor a的平方,则torch.pow(a,2)
torch.cat((A,B),dim):cat是concatenate(拼接,连接在一起)的缩写,参考博客 https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 解释的很好,感谢博主。我多加一句:连接tensor A和B,就是扩增dim维,比如两个矩阵,dim=1,则扩增列,即横着拼接
torch.utils.data.TensorDataset(x,y):大概意思是整合x和y,使其对应。即x的每一行对应y的每一行。
torch.utils.data.DatasetLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=2):dataset(通过TensorDataset整合的);batch_size(批量大小);shuffle(是否打乱);num_workers(线程数)
权重衰减(weight decay)
权重衰减又叫L2范数正则化,即在原损失函数基础上添加L2范数惩罚项。
范数公式$ ||x|| _{p}= (\sum_{i=1}^{n}|x_{i}|^{p})^{1/p} $ L2范数为:$ ||x||_{2} = (\sum_{i=1}^{n}|x_{i}|^{2})^{1/2} $
带L2范数惩罚项的新损失函数为:$ \iota (w_{1},w_{2},b) + \frac{\lambda }{2n}||x||^2 $ torch.norm(input, p=)求范数
丢弃法(dropout)
隐藏单元采用一定的概率进行丢弃。使用丢弃法重新计算新的隐藏单元公式为
$ h_{i}^{'} = \frac{\xi _{i}}{1-p}h_{i} $
其中$ h_{i}$ 为隐藏单元$ h_{i} = \O (x_{1}w_{1i} + x_{2}w_{2i} + x_{3}w_{3i} + x_{4}w_{4i} + b_{i}) $,随机变量$\xi_{i}$取值为0(概率为p)和1(概率为1-p)
def dropout(X, drop_prob):
X = X.float()
assert 0 <= drop_prob <= 1 #drop_prob的值必须在0-1之间,和数据库中的断言一个意思
#这种情况下把全部元素丢弃
if keep_prob == 0: #keep_prob=0等价于1-p=0,这是$\xi_{i}$值为1的概率为0
return torch.zeros_like(X)
mask = (torch.rand(X.shape) < keep_prob).float() #torch.rand()均匀分布,小于号<判别,若真,返回1,否则返回0
return mask * X / keep_prob # 重新计算新的隐藏单元的公式实现
model.train():启用BatchNormalization和Dropout
model.eval():禁用BatchNormalization和Dropout
正向传播和反向传播
在深度学习模型训练时,正向传播和反向传播之间相互依赖。下面1和2看不懂的可先看《动手学深度学习》3.14.1和3.14.2
1.正向传播的计算可能依赖模型参数的当前值,而这些模型参数是在反向传播的梯度计算后通过优化算法迭代的。
如正则化项$ s = ({\lambda }/{2})(\left \| W^{(1)} \right \|_{F}^{2} + \left \| W^{(2)} \right \|_{F}^{2}) $依赖模型参数$W^{(1)}$和$W^{(2)}$的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。
2.反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。
如参数梯度$ \frac{\partial J}{\partial W^{(2))}} = (\frac{\partial J}{\partial o}h^{T} + \lambda W^{(2)}) $的计算需要依赖隐藏层变量的当前值h。这个当前值是通过从输入层到输出层的正向传播计算并存储得到的。
小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播的更多相关文章
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())
先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 用交叉验证改善模型的预测表现-着重k重交叉验证
机器学习技术在应用之前使用“训练+检验”的模式(通常被称作”交叉验证“). 预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 模型选择---KFold,StratifiedKFold k折交叉切分
StratifiedKFold用法类似Kfold,但是他是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同. 例子: import numpy as np from sklearn.m ...
随机推荐
- PGSQL基本操作语句
; --更新数据 ,,) ; --插入数据 ORDER BY app_name,flag asc/desc ; --查询数据并且排序 offset ; --查询起点0开始查询,返回5条数据 ORDER ...
- 【剑指Offer】面试题24. 反转链表
题目 定义一个函数,输入一个链表的头节点,反转该链表并输出反转后链表的头节点. 示例: 输入: 1->2->3->4->5->NULL 输出: 5->4->3 ...
- python --- 对于需要关联的接口处理方法
1.unittest对于需要关联的请求,怎么处理(如购物接口,需要先登录) a)把登录请求写到测试用例类的setUP函数中,这样每次调用测试用例,都会先执行setUP函数 b)全局变量的形式声明. c ...
- 【Android】家庭记账本手机版开发报告六
一.说在前面 昨天 1.创建登入和注册界面:2.向数据库添加一张用户表 今天 用图标显示账单情况 问题 1.使用第三方库 hellochart,时添加依赖构建失败 2.在 chertFragmen ...
- dns、网关、IP地址,主要是配置resolv.conf\network\ifcfg-eth0
Ubuntu sudo vi /etc/network/interfac 添加 dns-nameservers 192.168.1.254dns-search stonebean.com cent ...
- HDU-1114 完全背包+恰好装满问题
B - Piggy-Bank Time Limit:1000MS Memory Limit:32768KB 64bit IO Format:%I64d & %I64u Subm ...
- ORACLE增删改查以及case when的基本用法
1.创建table create table test01( id int not null primary key, name ) not null, gender ) not null, age ...
- centos socket通信时 connect refused 主要是防火墙问题
centos socket通信时 connect refused 主要是防火墙问题,可以关闭防火墙,或者开放程序中的端口
- sqli-labs注入lesson3-4闯关秘籍
·lesson 3 与第一二关不同的是,这一关是基于错误的get单引号变形字符型注入 要使用 ') 进行闭合 (ps:博主自己理解为字符型注入,是不过是需要加括号进行闭合,适用于博主自己的方便记忆的 ...
- delphi 单例模式
unit Singleton; (* 单例模式适用于辅助类, 一般伴随于单元的生命周期 *) interface uses SysUtils; type TSingleton = class publ ...