下面要说的基本都是《动手学深度学习》这本花书上的内容,图也采用的书上的

首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望)

模型选择

  验证数据集(validation data set),又叫验证集(validation set),指用于模型选择的在train settest set之外预留的一小部分数据集

  若训练数据不够时,预留验证集也是一种luxury。常采用的方法为K折交叉验证。原理为:把train set分割成k个不重合的子数据集(SubDataset),然后做k次模型训练和验证。每次训练中,用一个SubDataset作为validation set,其余k-1个SubDataset作为train set。最后对k次训练误差和验证误差求平均(mean)

欠拟合:模型无法得到较低的训练误差,即训练误差降低不了

过拟合:模型训练误差远小于在测试集上的误差

解决欠拟合和过拟合的方法有二:其一,针对数据选择合适的复杂度模型(模型复杂度过高,易出现过拟合;否则易出现欠拟合)。其二,训练数据集大小(train set过少,则容易过拟合。没有否则)

torch.pow():求tensor的幂次(pow是power(有幂次的意思)的缩写),比如求tensor a的平方,则torch.pow(a,2)

torch.cat((A,B),dim):cat是concatenate(拼接,连接在一起)的缩写,参考博客 https://www.cnblogs.com/JeasonIsCoding/p/10162356.html  解释的很好,感谢博主。我多加一句:连接tensor A和B,就是扩增dim维,比如两个矩阵,dim=1,则扩增列,即横着拼接

torch.utils.data.TensorDataset(x,y):大概意思是整合x和y,使其对应。即x的每一行对应y的每一行。

torch.utils.data.DatasetLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=2):dataset(通过TensorDataset整合的);batch_size(批量大小);shuffle(是否打乱);num_workers(线程数)

权重衰减(weight decay)

  权重衰减又叫L2范数正则化,即在原损失函数基础上添加L2范数惩罚项。
  范数公式$ ||x|| _{p}= (\sum_{i=1}^{n}|x_{i}|^{p})^{1/p} $   L2范数为:$ ||x||_{2} = (\sum_{i=1}^{n}|x_{i}|^{2})^{1/2} $

  带L2范数惩罚项的新损失函数为:$ \iota (w_{1},w_{2},b) + \frac{\lambda }{2n}||x||^2 $      torch.norm(input, p=)求范数

丢弃法(dropout)

  隐藏单元采用一定的概率进行丢弃。使用丢弃法重新计算新的隐藏单元公式为

  $ h_{i}^{'} = \frac{\xi _{i}}{1-p}h_{i} $

  其中$ h_{i}$ 为隐藏单元$ h_{i} = \O (x_{1}w_{1i} + x_{2}w_{2i} + x_{3}w_{3i} + x_{4}w_{4i} + b_{i}) $,随机变量$\xi_{i}$取值为0(概率为p)和1(概率为1-p)

def dropout(X, drop_prob):
X = X.float()
assert 0 <= drop_prob <= 1 #drop_prob的值必须在0-1之间,和数据库中的断言一个意思
#这种情况下把全部元素丢弃
if keep_prob == 0: #keep_prob=0等价于1-p=0,这是$\xi_{i}$值为1的概率为0
return torch.zeros_like(X)
mask = (torch.rand(X.shape) < keep_prob).float() #torch.rand()均匀分布,小于号<判别,若真,返回1,否则返回0
return mask * X / keep_prob # 重新计算新的隐藏单元的公式实现

model.train():启用BatchNormalization和Dropout

model.eval():禁用BatchNormalization和Dropout

正向传播和反向传播

  在深度学习模型训练时,正向传播和反向传播之间相互依赖。下面1和2看不懂的可先看《动手学深度学习》3.14.1和3.14.2

  1.正向传播的计算可能依赖模型参数的当前值,而这些模型参数是在反向传播梯度计算后通过优化算法迭代的。

    如正则化项$ s = ({\lambda }/{2})(\left \| W^{(1)} \right \|_{F}^{2} + \left \| W^{(2)} \right \|_{F}^{2}) $依赖模型参数$W^{(1)}$和$W^{(2)}$的当前值,而这些当前值是优化算法最近一次根据反向传播算出梯度后迭代得到的。

  2.反向传播的梯度计算可能依赖于各变量的当前值,而这些变量的当前值是通过正向传播计算得到的。

    如参数梯度$ \frac{\partial J}{\partial W^{(2))}} = (\frac{\partial J}{\partial o}h^{T} + \lambda W^{(2)}) $的计算需要依赖隐藏层变量的当前值h。这个当前值是通过从输入层到输出层的正向传播计算并存储得到的。

小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播的更多相关文章

  1. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  2. 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...

  3. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  4. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  5. 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())

    先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...

  6. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  7. 用交叉验证改善模型的预测表现-着重k重交叉验证

    机器学习技术在应用之前使用“训练+检验”的模式(通常被称作”交叉验证“). 预测模型为何无法保持稳定? 让我们通过以下几幅图来理解这个问题: 此处我们试图找到尺寸(size)和价格(price)的关系 ...

  8. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  9. 模型选择---KFold,StratifiedKFold k折交叉切分

    StratifiedKFold用法类似Kfold,但是他是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同. 例子: import numpy as np from sklearn.m ...

随机推荐

  1. C# Stream篇(五) -- MemoryStream

    MemoryStream 目录: 1 简单介绍一下MemoryStream 2 MemoryStream和FileStream的区别 3 通过部分源码深入了解下MemoryStream 4 分析Mem ...

  2. Day3-T4

    原题目 Describe:有点恶心的DP+最短路 code: #include<bits/stdc++.h> using namespace std; long long A,B,C,z, ...

  3. 线性数据结构案例1 —— 单向链表中获取倒数k个节点

    一.介绍  先遍历整个链表获取链表长度length,然后通过 (length-index) 方式得到我们想要节点在链表中的位置. 二.代码 public Node findLastIndexNode( ...

  4. hibernate.QueryException: Legacy-style query parameters (`?`) are no longer supported

    传统样式查询参数(`?`)不再支持:使用JPA样式的序号参数(例如,`?1’) hibernate4.1之后已经对HQL查询参数中的占位符做了改进: 更改代码:

  5. python3 sort list

    1. 对元素指定的某一部分进行排序,关键字排序 s = ['release.10.txt','release.1.txt','release.2.txt','release.14.txt','rele ...

  6. IDEA--安装

    1:下载IDEA 官网:http://www.jetbrains.com/idea/download/#section=windows(选择下载.zip) 2:解压 3:破解: 1)在C:\Windo ...

  7. 冒泡排序_python

    def popdata(ls): for i in range(len(ls)): for j in range(i+1,len(ls)): if ls[i]>ls[j]: # tmp=ls[i ...

  8. CTF-域渗透--HTTP服务--命令注入1

    开门见山 1. 扫描靶机ip,发现PCS 192.168.31.210 2. 用nmap扫描开放服务和服务版本 3. 再扫描全部信息 4. 探测http服务的目录信息 5. 再用dirb扫描 6. 查 ...

  9. 19.3.8 HTML+css 课程

    form 归属于 form 通过id产生联系 ​<form id = "testform" method = "get" action = "s ...

  10. windows和ubuntu双系统设置开机默认系统

    1.记住grub界面中windows的位置 我的界面如下:windows在第3行 2.选择进入ubuntu系统 3.打开终端,输入如下命令 sudo vim /etc/default/grub 4.看 ...