2.1、Softmax Regression模型
Softmax Regression模型
由于Logistics Regression算法复杂度低,容易实现等特点,在工业中的到广泛的使用,但是Logistics Regression算法主要用于处理二分类问题,若需要处理的是多分类问题,如手写字的识别,即识别{0,1,2,3,4,5,6,7,8,9}中的数字,此时需要使用能够处理多分类问题的算法。
Softmax Regression算法是Logistics Regression算法在多分类问题上的推广,主要用于处理多分类问题,其中,任意两个类之间是线性可分的。
多分类问题,他的类标签y的取值个数大于2,如手写字识别,即识别{0,1,2,3,4,5,6,7,8,9}中的数字,手写字如图所示:
1、Softmax Regression算法模型用于解决多分类问题
Softmax Regression算法是Logistics Regression算法的推广,即类标签y的取值大于或等于2.假设有m个训练样本{(X(1),y(1)),(X(2),y(2)),........(X(m),y(m))},对于Softmax Regression算法,其输入特征为X(i)€Rn+1,类标签记为:y(i)€{0,1,.......,k}。假设函数为每一个样本估计其所属的类别的概率P(y=j |X),具体假设函数为:
为了方便起见,我们同样使用符号θ来表示全部的模型参数。在实现softmax回归时,你通常会发现,将θ用一个k×(n+1)的矩阵来表示会十分便利,该矩阵是将θ1,θ2,...,θk按行罗列起来得到的,如下所示:
则对于每一个样本估计其所属的类别的概率为:
2、Softmax Regression算法的代价函数
类似于logistic Regression 算法,在Softmax Regression算法的损失函数中引入指示函数I(.),其具体形式为;
那么对于Softmax Regression算法的损失函数为;
其中,I{y(i)=j}表示的是当y(i)属于第j类时,I{y(i)=j}=1,否则I{y(i)=j}=0.
可以看出softmax是logistic的一个泛化版。logistic是k=2情况下的softmax回归。
3、Softmax Regression算法的求解
对于上述问题,可以使用梯度下降法对其进行求解,首先对其进行求梯度:
最终结果为:
注意,此处的θj表示的是一个向量,通过梯度下降法的公式更新:
现在,来使用Python实现上述Softmax Regression的更新过程
def gradientAscent(feature_data,label_data,k,maxCycle,alpha):
'''利用梯度下降法训练Softmax模型
:param feature_data: 特征
:param label_data: 标签
:param k: 类别个数
:param maxCycle: 最大迭代次数
:param alpha: 学习率
:return weights: 权重
'''
m,n = np.shape(feature_data)
weights = np.mat(np.ones((n,k)))#初始化权重
i = 0
while i<=maxCycle:
err = np.exp(feature_data*weights)
if i % 100 == 0:
print("\t--------iter:",i,\
",cost:",cost(err,label_data))
rowsum = -err.sum(axis=1)
rowsum = rowsum.repeat(k,axis = 1)
err = err/rowsum
for x in range(m):
err[x,label_data[x,0]]+=1
weights = weights+(alpha/m)*feature_data.T*err
i+=1
return weights
cost函数:
def cost(err,label_data):
'''
:param err: exp的值
:param label_data: 标签的值
:return: 损失函数的值
'''
m = np.shape(err)[0]
sum_cost = 0.0
for i in range(m):
if err[i,label_data[i,0]]/np.sum(err[i,:])>0:
sum_cost -=np.log(err[i,label_data[i,0]]/np.sum(err[i,:]))
else:
sum_cost -= 0
return sum_cost / m
4、Softmax Regression与Logistics Regression的关系
有一点需要注意的是,按上述方法用softmax求得的参数并不是唯一的,因为,对每一个参数来说,若都减去一个相同的值,依然是上述的代价函数的值。证明如下:
这表明了softmax回归中的参数是“冗余”的。更正式一点来说,我们的softmax模型被过度参数化了,这意味着对于任何我们用来与数据相拟合的估计值,都会存在多组参数集,它们能够生成完全相同的估值函数hθ将输入x映射到预测值。因此使J(θ)最小化的解不是唯一的。而Hessian矩阵是奇异的/不可逆的,这会直接导致Softmax的牛顿法实现版本出现数值计算的问题。
为了解决这个问题,加入一个权重衰减项到代价函数中:
有了这个权重衰减项以后(对于任意的λ>0),代价函数就变成了严格的凸函数而且hession矩阵就不会不可逆了。
此时的偏导数:
对于上面的Softmax Regression过度参数化问题可以参考博客:https://www.cnblogs.com/bzjia-blog/p/3366780.html
2.1、Softmax Regression模型的更多相关文章
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
- TensorFlow实战之Softmax Regression识别手写数字
关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...
- R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)
本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...
- TensorFlow实现Softmax Regression识别手写数字
本章已机器学习领域的Hello World任务----MNIST手写识别做为TensorFlow的开始.MNIST是一个非常简单的机器视觉数据集,是由几万张28像素*28像素的手写数字组成,这些图片只 ...
- Softmax回归(Softmax Regression)
转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...
- UFLDL实验报告1: Softmax Regression
PS:这些是今年4月份,跟斯坦福UFLDL教程时的实验报告,当时就应该好好整理的…留到现在好凌乱了 Softmax Regression实验报告 1.Softmax Regression实验描述 So ...
- ufldl学习笔记和编程作业:Softmax Regression(softmax回报)
ufldl学习笔记与编程作业:Softmax Regression(softmax回归) ufldl出了新教程.感觉比之前的好,从基础讲起.系统清晰,又有编程实践. 在deep learning高质量 ...
- 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...
- TensorFlow(2)Softmax Regression
Softmax Regression Chapter Basics generate random Tensors Three usual activation function in Neural ...
随机推荐
- LCS(详解)
一,问题描述 给定两个字符串,求解这两个字符串的最长公共子序列(Longest Common Sequence).比如字符串1:BDCABA:字符串2:ABCBDAB 则这两个字符串的最长公共子序列长 ...
- tomcat 7 7.0.73 url 参数 大括号 {} 不支持 , 7.0.67支持
7.0.73 url有JSON.stringify一个对象,然后作为参数拼接.结果请求报400错误,但是tomcat 7.0.67版本没有问题,猜测是高级版本对url参数比较严格. 处理方法: ...
- Python将两个数组合并成一个数组,多维数组变成一维数组
1.extend方法 c1 = ["Red","Green","Blue"] c2 = ["Orange"," ...
- ora-28547:连接服务器失败,可能是 Oracle Net 管理失败
检查如下: 监听程序的配置文件 发现多了 (PROGRAM = extproc) 去掉后如下: # listener.ora Network Configuration \dbhome_1\netwo ...
- 使用DevStack安装openstack(单机环境)
DevStack是一系列可扩展的脚本,用于根据git master的最新版本快速启动完整的OpenStack环境.它以交互方式用作开发环境,并作为OpenStack项目功能测试的基础. 参考源码. 警 ...
- 【HDU5391】Zball in Tina Town
[题目大意] 一个球初始体积为1,一天天变大,第一天变大1倍,第二天变大2倍,第n天变大n倍.问当第 n-1天的时候,体积变为多少.注意答案对n取模. [题解] 根据威尔逊定理:(n-1)! mod ...
- 学习神器!本机安装虚拟机,并安装Linux系统,并部署整套web系统手册(包含自动部署应用脚本,JDK,tomcat,TortoiseSVN,Mysql,maven等软件)
1. 引言 编写目的 本文档的编写目的主要是在Linux系统上部署mis_v3员工管理系统,方便测试,并为以后开发人员进行项目部署提供参考 准备工作 软件部分 软件项 版本 备注 Mysql 5. ...
- unity3d 为什么要烘焙?烘焙作用是为了什么?
可以这样理解.你把物体模型放进了场景里之后, 引擎会计算光线,光线照到你的物体的表面形成反光和阴影. 如果不烘焙, 游戏运行的时候,这些反光和阴影都是由显卡和CPU计算出来的.你烘焙之后,这些反光和阴 ...
- mysql 基本操作 alter
查看数据库 show databases; 新建数据库 命令 create database 库名字. 选择数据库 use 2016test; 创建表:create table 表名(字段1,2, ...
- 关于ptrdiff_t
整除.opencv中的内存一般是通过malloc分配,不能保证都是都能被16整除,此时需要截断,但是剩下的内存要如何维护? CV2.0的这样维护的:在 malloc 是多申请一个指针的空间,这个指针指 ...