从极大似然估计的角度理解深度学习中loss函数
从极大似然估计的角度理解深度学习中loss函数
为了理解这一概念,首先回顾下最大似然估计的概念:
最大似然估计常用于利用已知的样本结果,反推最有可能导致这一结果产生的参数值,往往模型结果已经确定,用于反推模型中的参数.即在参数空间中选择最有可能导致样本结果发生的参数.因为结果已知,则某一参数使得结果产生的概率最大,则该参数为最优参数.
似然函数:\[ l(\theta) = p(x_1,x_2,...,x_N|\theta) = \prod_{i=1}^{N}{p(x_i|\theta)}\]
为了便于分析和计算,常使用对数似然函数:\[ H(\theta) = ln[l(\theta)]\]
1. logistics regression中常用的loss function:
在logistic regression中常定义的loss function为:\[ l(w) = -(ylog\hat y+(1-y)log(1-\hat y))\]
为什么选择这个函数作为loss function? 一个原因是相比于误差平方和函数的非凸性,交叉熵函数是凸的,因此可以通过梯度下降法求得全局最优点,详细原理请参考凸优化相关理论.
此处重点介绍另一个原因,即从最大似然估计得的角度来理解loss function的选择,Andrew Ng 也是从这个角度进行解释的.对于logsitic regression问题,我们实际上做出了如下假设,即训练样本(x,y)服从以下分布:
\[ P(x,y|\theta) = \begin{cases}\sigma(z),&y=1 \\ 1-\sigma(z),&y=0\end{cases}\]
其中,\(z = w^Tx+b\),意思是,在参数\(\theta\)下,训练样本(x,y)出现的概率为\(P(x,y|\theta)\).
上面的概率分布函数也可以写为整体的形式:
\[p(x,y|\theta) = \sigma(z)^y(1-\sigma(z))^{1-y}\]
对于极大似然估计而言,我们的目的就是在参数空间中,寻找使得\(p(x,y|\theta)\)取得最大的w和b,因为因为训练样本(x,y)已经经过采样得到了,所以使得他们出现概率最大(越接近1)的参数就是最优的参数.
- 对于单个样本\((x_i,y_i)\),其对应的对数似然函数为\(ln[p(x_i,y_i|\theta)]= y_iln(\sigma(z_i))+(1-y_i)ln(1-\sigma(z_i))\)(即在参数\(\theta(w,b)\)下,\((x_i,y_i)\)出现的概率),其中,\(\sigma(z_i)=w^Tx_i+b\).
因为cost function 一般向小的方向优化,所以在似然函数前加上负号,就变为loss function - 对于整个样本集来说,对应的似然函数为\[ln(\prod_{i=1}^{N} p(x_i,y_i|\theta)) = \sum_{i=1}^N{y_iln(\sigma(z_i))+(1-y_i)ln(1-\sigma(z_i))}\]
2. softmax regression中常用的loss function:
softmax regression中常使用如下loss函数:
\[ l(w) = -\sum_{i=1}^{C}y_ilog\hat y_i\]
此处,C指的是样本y的维度(分类的数目),\(y_i\)指的是样本标签第i个分量,\(\hat y_i\)同义.
接下来,同样从最大似然估计的角度进行理解.对于softmax regression,我们实际上也做出了假设,即训练样本(x,y)服从以下分布:\[P(x,y|\theta) = \hat y_l = \sum_{i=1}^{C}y_i\hat y_i\],其中l是样本标签y中唯一为1的序号
- 对于单个训练样本,其对数似然函数为\(ln[p(x_i,y_i|\theta)] = ln(\sum_{i=1}^{C}y_i\hat y_i)\),可以进一步写为\(ln[p(x_i,y_i|\theta)] = \sum_{i=1}^{C}y_iln(\hat y_i)\),因为y中只有唯一的一个维度等于1,其余全为0,通过简单的推理就可以得到化简后的结果.取负号后,得到单样本的loss函数.
- 对于整个训练样本集而言,其对数似然函数为\[ln(\prod_{i=1}^{N} p(x_i,y_i|\theta)) =\sum_{j=1}^{m}\sum_{i=1}^{C}y_i^{(j)}ln(\hat y_i^{(j)})\]
其中,\(y_i^{(j)}\)指的是训练样本集中第j个训练样本标签的第i个维度的值,\(\hat y_i^{(j)}\)同理.取负号求平均后,得到整个训练样本集的coss函数.
从极大似然估计的角度理解深度学习中loss函数的更多相关文章
- 【论文笔记】如何理解深度学习中的End to End
End to end:指的是输入原始数据,输出的是最后结果,应用在特征学习融入算法,无需单独处理. end-to-end(端对端)的方法,一端输入我的原始数据,一端输出我想得到的结果.只关心输入和输出 ...
- 如何理解深度学习中的Transposed Convolution?
知乎上的讨论:https://www.zhihu.com/question/43609045?sort=created 不过看的云里雾里,越看越糊涂. 直到看到了这个:http://deeplearn ...
- 从两个角度理解为什么 JS 中没有函数重载
函数重载是指在同一作用域内,可以有一组具有相同函数名,不同参数列表(参数个数.类型.顺序)的函数,这组函数被称为重载函数.重载函数通常用来声明一组功能相似的函数,这样做减少了函数名的数量,避免了名字空 ...
- 深度学习中loss总结
一.分类损失 1.交叉熵损失函数 公式: 交叉熵的原理 交叉熵刻画的是实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近.假设概率分布p为期望输出,概率分布q为实际输 ...
- 【转载】深度学习中softmax交叉熵损失函数的理解
深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning 版权声明:本文为博主原 ...
- 如何正确理解深度学习(Deep Learning)的概念
现在深度学习在机器学习领域是一个很热的概念,不过经过各种媒体的转载播报,这个概念也逐渐变得有些神话的感觉:例如,人们可能认为,深度学习是一种能够模拟出人脑的神经结构的机器学习方式,从而能够让计算机具有 ...
- 深度学习中交叉熵和KL散度和最大似然估计之间的关系
机器学习的面试题中经常会被问到交叉熵(cross entropy)和最大似然估计(MLE)或者KL散度有什么关系,查了一些资料发现优化这3个东西其实是等价的. 熵和交叉熵 提到交叉熵就需要了解下信息论 ...
- 深度学习中的batch_size,iterations,epochs等概念的理解
在自己完成的几个有关深度学习的Demo中,几乎都出现了batch_size,iterations,epochs这些字眼,刚开始我也没在意,觉得Demo能运行就OK了,但随着学习的深入,我就觉得不弄懂这 ...
- 利用Theano理解深度学习——Multilayer Perceptron
一.多层感知机MLP 1.MLP概述 对于含有单个隐含层的多层感知机(single-hidden-layer Multi-Layer Perceptron, MLP),可以将其看成是一个特殊的Logi ...
随机推荐
- ridis 集群配置
./redis-cli -h 192.168.106.128 -p 6379 redis 1.ping 2.set str1 abc get str1 3. mkdir ../redis-c ...
- VueJS 开发常见问题集锦
由于公司的前端开始转向 VueJS,最近开始使用这个框架进行开发,遇到一些问题记录下来,以备后用. 主要写一些 官方手册 上没有写,但是实际开发中会遇到的问题,需要一定知识基础. 涉及技术栈 CLI: ...
- css处理图片下方留白问题
引用图片的时候,图片和下方内容会有一点小空白,大概如下图紫色横条: 不是说有margin还是padding,是因为ing是行级元素,浏览器就会默认留白了,这时候处理方法很简单,给img加上样式disp ...
- (转)webpack从零开始第6课:在Vue开发中使用webpack
vue官方已经写好一个vue-webpack模板vue_cli,原本自己写一个,发现官方写得已经够好了,自己写显得有点多余,但为了让大家熟悉webpack,决定还是一步一步从0开始写,但源文件就直接拷 ...
- 3) 十分钟学会android--建立第一个APP,建立简单的用户界面
在本小节里,我们将学习如何用 XML 创建一个带有文本输入框和按钮的界面.下一节课将学会使 APP 对按钮做出响应——按钮被按下时,文本框里的内容被发送到另外一个 Activity. Android ...
- [ Tools ] [ MobaXterm ] [ SSH ] [ Linux ] export and import saved session
How to export MobaXterm sessions to another computer? https://superuser.com/questions/858973/how-to- ...
- spring之interceptor篇
springmvc中要写一个拦截器非常的简单,有两种方式:要么实现HandlerInterceptor接口或者继承实现了该接口的类,如spring已经为我们写好的一个HandlerIntercepto ...
- [Intermediate Algorithm] - Sum All Primes
题目 求小于等于给定数值的质数之和. 只有 1 和它本身两个约数的数叫质数.例如,2 是质数,因为它只能被 1 和 2 整除.1 不是质数,因为它只能被自身整除. 给定的数不一定是质数. 提示 For ...
- .apply和.call用法和区别
apply:方法能劫持另外一个对象的方法,继承另外一个对象的属性. Function.apply(obj,args)方法能接收两个参数obj:这个对象将代替Function类里this对象args:这 ...
- PHP 数组 & 字符串处理
1:数组分割为字符串 implode 2:字符串分割为数组 explode() 3:替换字符串 eg: $a = "Hello world" str_replace(“H”,“ ...