一个深度学习项目包括了: 模型设计、损失函数设计、梯度更新方法、模型保存和加载和模型训练,其中损失函数就像一把衡量模型学习效果的尺子,训练模型的过程就是优化损失函数的过程,模型拿到数据之后有一个非常重要的环节: 将模型自己的判断结果和数据真实的情况做比较,如果偏差或者差异特别大,那么模型就要去纠正自己的判断,用某种方式去减少这种偏差,然后反复这个过程,知道最后模型能够对数据进行正确的判断

损失函数和代价函数介绍

例如在二维空间中,任意一个点对应的真实函数为F(x),通过模型的学习拟合出来的函数为f(x),F(x)和f(x)之间就存在着一个误差,定义为L(x),于是有:

\[L(x)=(F(x)-f(x))^2
\]

L(x)提供了一个评价你和函数表现效果"好坏"的度量指标,这个指标函数称作损失函数,根据公式可知,损失函数越小,拟合函数对于真实情况的拟合效果就越好,但损失函数的种类有很多中,L(x)其中一个

如果将数据从刚才的任意一个点,扩大到所有的点,那么这些点实际上就是一个训练集合,将集合所有的点对应的拟合误差做平均:

\[\frac{1}{N}\sum(F(x)-f(x))^2
\]

这个函数叫作代价函数,就是在训练样本集合上,所有的样本的拟合误差的平均值,也称经验风险

常见损失函数

损失函数的种类是无穷多的,因为损失函数用来度量模型拟合效果和真实值之间的差距,而度量方式要根据问题的特点或者需要优化的方面具体定制,下面列举一些常用的

0-1损失函数

如果模型判断的结果只有两种: 是或非,那么这是一个最为简单的评估方式,如果预测对了损失函数的值为0,因为没有误差,如果错了,损失函数值就为1,这就是最简单的0-1损失函数

\[L(F(x),f(x))=\begin{cases}
0 & ifF(x) \neq f(x) \\
1 & ifF(x) = f(x)
\end{cases}
\]

其中F(x)是输入数据的真实类别,f(x)是模型预测的类别,但是0-1损失函数在模型训练中很少用到,因为其导数值为0

平方损失函数

上述列举的L(x)就属于平方损失函数,是可求导损失函数中最简单的一种,它直接度量了模型拟合结果和真实结果之间的距离

均方差损失函数和平均绝对误差损失函数

均方误差是回归问题损失函数中最常用的一个,是预测值与目标值之间差值的平方和:

\[MSE=\frac{\sum_{i=1}^{n}(s_i-y_i^p)^2}{n}
\]

其中s为目标值的向量表示,y为预测值的向量表示

平均绝对误差损失函数是另一种常用于回归问题的损失函数,其目标是度量真实值和预测值差异的绝对值之和,定义如下:

\[MAE=\frac{\sum_{i=1}^{n}|y_i-y_i^p|}{n}
\]

交叉熵损失函数

熵表示了一个系统的混乱程度或无序程度,如果一个系统越混乱,那么熵就越大

公式:

\[H=-\sum_{i=1}^{n}p(x_i)log(q(x_i))
\]

p(x)表示真实概率分布,q(x)表示预测概率分布,该函数就是交叉熵损失函数,这个公式同时衡量了真实概率分布和预测概率分布两方面,所以这个函数实际上就是通过衡量并不断去尝试缩小两个概率分布的误差,使预测的概率分布尽可能达到真实概率分布

softmax损失函数

在某些场景下,一些数值大小范围分布非常广,而为了方便计算,或者梯度更好的更新,将输入的数值映射为0-1之间的实数,并且归一化后能够保证几个数的和为1,公式化表示:

\[S_j=\frac{e^{a_j}}{\sum_{k=1}^{T}e^{a_k}}
\]

Pytorch实践模型训练(损失函数)的更多相关文章

  1. pytorch seq2seq模型训练测试

    num_sequence.py """ 数字序列化方法 """ class NumSequence: """ ...

  2. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  3. 【新人赛】阿里云恶意程序检测 -- 实践记录10.13 - Google Colab连接 / 数据简单查看 / 模型训练

    1. 比赛介绍 比赛地址:阿里云恶意程序检测新人赛 这个比赛和已结束的第三届阿里云安全算法挑战赛赛题类似,是一个开放的长期赛. 2. 前期准备 因为训练数据量比较大,本地CPU跑不起来,所以决定用Go ...

  4. pytorch 中模型的保存与加载,增量训练

     让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...

  5. 轻量化模型训练加速的思考(Pytorch实现)

    0. 引子 在训练轻量化模型时,经常发生的情况就是,明明 GPU 很闲,可速度就是上不去,用了多张卡并行也没有太大改善. 如果什么优化都不做,仅仅是使用nn.DataParallel这个模块,那么实测 ...

  6. 【机器学习PAI实践十】深度学习Caffe框架实现图像分类的模型训练

    背景 我们在之前的文章中介绍过如何通过PAI内置的TensorFlow框架实验基于Cifar10的图像分类,文章链接:https://yq.aliyun.com/articles/72841.使用Te ...

  7. Pytorch线性规划模型 学习笔记(一)

    Pytorch线性规划模型 学习笔记(一) Pytorch视频学习资料参考:<PyTorch深度学习实践>完结合集 Pytorch搭建神经网络的四大部分 1. 准备数据 Prepare d ...

  8. 谷歌大规模机器学习:模型训练、特征工程和算法选择 (32PPT下载)

    本文转自:http://mp.weixin.qq.com/s/Xe3g2OSkE3BpIC2wdt5J-A 谷歌大规模机器学习:模型训练.特征工程和算法选择 (32PPT下载) 2017-01-26  ...

  9. PyTorch的十七个损失函数

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击: tensor-yu/PyTorch_Tutorial​github.com 版权声明:本文为博主原创文章,转载请附上 ...

  10. [炼丹术]使用Pytorch搭建模型的步骤及教程

    使用Pytorch搭建模型的步骤及教程 我们知道,模型有一个特定的生命周期,了解这个为数据集建模和理解 PyTorch API 提供了指导方向.我们可以根据生命周期的每一个步骤进行设计和优化,同时更加 ...

随机推荐

  1. vscode 一些扩展的推荐(前端)

    - `Auto Rename Tag`:成对修改 HTML 标签名 - `Bracket Pair Colorizer`:括号匹配高亮 - `Color Highlight`:显示颜色代码的颜色 -  ...

  2. windows-git-tagslist

    windows平台使用 Git-bash + vim + Taglist + ctags + cscope 安装Git for win版 安装ctags for win版,目录添加到环境变量 下载 T ...

  3. HttpRunner4.x版本调试测试用例时报错 run testcase failed error="abort running due to failfast setting: variable XXX not found" 解决方法

    httprunner脚本调试报错 未知变量名称未定义问题 解决了,由于请求的requestBody证件照片链接包含$关键字,需要使用$$转义.   执行脚本报错截图 接口requestBody参数截图 ...

  4. Python机器学习/LogisticRegression(逻辑回归模型)(附源码)

    LogisticRegression(逻辑回归) 逻辑回归虽然名称上带回归,但实际上它属于监督学习中的分类算法. 1.算法基础 LogisticRegression基本架构源自于Adline算法,只是 ...

  5. php上传文件时出现 caution: request is not finished yet

    其中的一个原因:是wamp64下的tmp文件夹中的临时文件太多,把这个文件夹的临时文件清理后就可以了.

  6. 【NAS使用心得】使用Synology Photos管理照片

    整理方式 1.本地没有整理或只按年份整理的:时间线模式下直接上传,让软件自己按照片创建时间生成文件夹:有按年份生成相册需求的,可以用"选择照片以创建相册"功能,找到年份文件夹,全选 ...

  7. idea安装阿里规范审查插件

    Install from repositories Settings >> Plugins >> Browse repositories... Search plugin by ...

  8. JavaScript之jQuery要点记录

    一 属性和属性节点 1.什么是属性? 对象身上保存的变量就是属性 2.如何操作属性? 对象.属性名称 = 值; 对象.属性名称; 对象["属性名称"] = 值; 对象[" ...

  9. Action: Consider the following: If you want an embedded database (H2, HSQL or Derby), please put it on the classpath.

    错误原因 在pom中引入了mybatis-spring-boot-starter ,Spring boot默认会加载org.springframework.boot.autoconfigure.jdb ...

  10. MySql.Data 链接MySql数据库 查询语句中带有中文的奇怪问题

    首先Nuget管理器安装MySql.Data 1.ado.net 直接链接 public static void Test() { MySqlConnection myconn = null; MyS ...