小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播
下面要说的基本都是《动手学深度学习》这本花书上的内容,图也采用的书上的
首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望)
模型选择
验证数据集(validation data set),又叫验证集(validation set),指用于模型选择的在train set和test set之外预留的一小部分数据集
若训练数据不够时,预留验证集也是一种luxury。常采用的方法为K折交叉验证。原理为:把train set分割成k个不重合的子数据集(SubDataset),然后做k次模型训练和验证。每次训练中,用一个SubDataset作为validation set,其余k-1个SubDataset作为train set。最后对k次训练误差和验证误差求平均(mean)
欠拟合:模型无法得到较低的训练误差,即训练误差降低不了
过拟合:模型训练误差远小于在测试集上的误差
解决欠拟合和过拟合的方法有二:其一,针对数据选择合适的复杂度模型(模型复杂度过高,易出现过拟合;否则易出现欠拟合)。其二,训练数据集大小(train set过少,则容易过拟合。没有否则)

torch.pow():求tensor的幂次(pow是power(有幂次的意思)的缩写),比如求tensor a的平方,则torch.pow(a,2)
torch.cat((A,B),dim):cat是concatenate(拼接,连接在一起)的缩写,参考博客 https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 解释的很好,感谢博主。我多加一句:连接tensor A和B,就是扩增dim维,比如两个矩阵,dim=1,则扩增列,即横着拼接
torch.utils.data.TensorDataset(x,y):大概意思是整合x和y,使其对应。即x的每一行对应y的每一行。
torch.utils.data.DatasetLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=2):dataset(通过TensorDataset整合的);batch_size(批量大小);shuffle(是否打乱);num_workers(线程数)
权重衰减(weight decay)
权重衰减又叫L2范数正则化,即在原损失函数基础上添加L2范数惩罚项。
范数公式$ ||x|| _{p}= (\sum_{i=1}^{n}|x_{i}|^{p})^{1/p} $ L2范数为:$ ||x||_{2} = (\sum_{i=1}^{n}|x_{i}|^{2})^{1/2} $
带L2范数惩罚项的新损失函数为:$ \iota (w_{1},w_{2},b) + \frac{\lambda }{2n}||x||^2 $ torch.norm(input, p=)求范数
丢弃法(dropout)
隐藏单元采用一定的概率进行丢弃。使用丢弃法重新计算新的隐藏单元公式为
$ h_{i}^{'} = \frac{\xi _{i}}{1-p}h_{i} $
其中$ h_{i}$ 为隐藏单元$ h_{i} = \O (x_{1}w_{1i} + x_{2}w_{2i} + x_{3}w_{3i} + x_{4}w_{4i} + b_{i}) $,随机变量$\xi_{i}$取值为0(概率为p)和1(概率为1-p)
def dropout(X, drop_prob):
X = X.float()
assert 0 <= drop_prob <= 1 #drop_prob的值必须在0-1之间,和数据库中的断言一个意思
#这种情况下把全部元素丢弃
if keep_prob == 0: #keep_prob=0等价于1-p=0,这是$\xi_{i}$值为1的概率为0
return torch.zeros_like(X)
mask = (torch.rand(X.shape) < keep_prob).float() #torch.rand()均匀分布,小于号<判别,若真,返回1,否则返回0
return mask * X / keep_prob # 重新计算新的隐藏单元的公式实现
model.train():启用BatchNormalization和Dropout
model.eval():禁用BatchNormalization和Dropout
正向传播和反向传播
在深度学习模型训练时,正向传播和反向传播之间相互依赖。下面1和2看不懂的可先看《动手学深度学习》3.14.1和3.14.2
1.正向传播的计算可能依赖模型参数的当前值,而这些模型参数是在反向传播的梯度计算后通过优化算法迭代的。
如正则化项$ s = ({\lambda }/{2})(\left \| W^{(1)} \right \|_{F}^{2} + \left \| W^{(2)} \right \|_{F}^{2}) $依赖模型参数$W^{(1)}$和$W^{(2)}$的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。
2.反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。
如参数梯度$ \frac{\partial J}{\partial W^{(2))}} = (\frac{\partial J}{\partial o}h^{T} + \lambda W^{(2)}) $的计算需要依赖隐藏层变量的当前值h。这个当前值是通过从输入层到输出层的正向传播计算并存储得到的。
小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播的更多相关文章
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)
本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)
我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...
- 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())
先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- 用交叉验证改善模型的预测表现-着重k重交叉验证
机器学习技术在应用之前使用“训练+检验”的模式(通常被称作”交叉验证“). 预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系 ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 模型选择---KFold,StratifiedKFold k折交叉切分
StratifiedKFold用法类似Kfold,但是他是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同. 例子: import numpy as np from sklearn.m ...
随机推荐
- Dynamic Route Matching Vue路由(1)
Dynamic Route Matching 动态的 路由 匹配 Very often we will need to map routes with the given pattern to the ...
- P 1020 月饼
转跳点:
- vue坑 - vue安装vue-cli报错coffee-script@1.12.7: CoffeeScript on NPM has moved to "coffeescript" (no hyphen)或者说不支
$ npm install -g vue-cli npm WARN deprecated coffee-script@1.12.7: CoffeeScript on NPM has moved to ...
- git使用代理
在使用git科隆一个repo的时候,因为这个repo的子模块是托管在google上的,还是因为gfw导致子模块科隆不下来 只好使用代理了,那么怎么配置git使用代理呢 代码如下 因为我用的是ss所以这 ...
- mariadb主从
实验环境: 两台centos7 master:192.168.1.6 slave:192.168.1.7 一.安装mariadb服务 [root@master ~]# yum -y install m ...
- css笔记01
CSS样式(Cascading Style Sheets) 表格布局缺陷: 嵌套太多,一旦顺序错乱页面达不到预期效果 表格布局页面不灵活,动一块整个布局全都要变 语法: 在style标签中 ...
- kettle 数据库连接失败
kettle 数据库连接失败 测试连接提示缺少驱动. 提示错误信息:Driver class 'oracle.jdbc.driver.OracleDriver' could not be found, ...
- Day 12:枚举值、枚举类
jdk1.5新特性之-----枚举 问题:某些方法所接收的数据必须是在固定范围之内的, 解决方案: 这时候我们的解决方案就是自定义一个类,然后是私有化构造函数,在自定义类中创建本类的对象对外使用. ...
- CCCC 红色警报
题意: 战争中保持各个城市间的连通性非常重要.本题要求你编写一个报警程序,当失去一个城市导致国家被分裂为多个无法连通的区域时,就发出红色警报.注意:若该国本来就不完全连通,是分裂的k个区域,而失去一个 ...
- python类(3)感悟
1.关于类属性attribute和实例(对象)特性property思考 为什么特性会出现,类属性不能完全替代它吗? 属性: python在为属性赋值时,只会搜索对象本身的__dict__,如果找不到对 ...