Softmax Regression模型

  由于Logistics Regression算法复杂度低,容易实现等特点,在工业中的到广泛的使用,但是Logistics Regression算法主要用于处理二分类问题,若需要处理的是多分类问题,如手写字的识别,即识别{0,1,2,3,4,5,6,7,8,9}中的数字,此时需要使用能够处理多分类问题的算法。

  Softmax Regression算法是Logistics Regression算法在多分类问题上的推广,主要用于处理多分类问题,其中,任意两个类之间是线性可分的。

  多分类问题,他的类标签y的取值个数大于2,如手写字识别,即识别{0,1,2,3,4,5,6,7,8,9}中的数字,手写字如图所示:

1、Softmax Regression算法模型用于解决多分类问题

  Softmax Regression算法是Logistics Regression算法的推广,即类标签y的取值大于或等于2.假设有m个训练样本{(X(1),y(1)),(X(2),y(2)),........(X(m),y(m))},对于Softmax Regression算法,其输入特征为X(i)€Rn+1,类标签记为:y(i)€{0,1,.......,k}。假设函数为每一个样本估计其所属的类别的概率P(y=j |X),具体假设函数为:

为了方便起见,我们同样使用符号θ来表示全部的模型参数。在实现softmax回归时,你通常会发现,将θ用一个k×(n+1)的矩阵来表示会十分便利,该矩阵是将θ1,θ2,...,θk按行罗列起来得到的,如下所示:

则对于每一个样本估计其所属的类别的概率为:

2、Softmax Regression算法的代价函数

  类似于logistic Regression 算法,在Softmax Regression算法的损失函数中引入指示函数I(.),其具体形式为;

那么对于Softmax Regression算法的损失函数为;

其中,I{y(i)=j}表示的是当y(i)属于第j类时,I{y(i)=j}=1,否则I{y(i)=j}=0.

可以看出softmax是logistic的一个泛化版。logistic是k=2情况下的softmax回归。

3、Softmax Regression算法的求解

  对于上述问题,可以使用梯度下降法对其进行求解,首先对其进行求梯度:

最终结果为:

注意,此处的θj表示的是一个向量,通过梯度下降法的公式更新:

现在,来使用Python实现上述Softmax Regression的更新过程

def gradientAscent(feature_data,label_data,k,maxCycle,alpha):
'''利用梯度下降法训练Softmax模型
:param feature_data: 特征
:param label_data: 标签
:param k: 类别个数
:param maxCycle: 最大迭代次数
:param alpha: 学习率
:return weights: 权重
'''
m,n = np.shape(feature_data)
weights = np.mat(np.ones((n,k)))#初始化权重
i = 0
while i<=maxCycle:
err = np.exp(feature_data*weights)
if i % 100 == 0:
print("\t--------iter:",i,\
",cost:",cost(err,label_data))
rowsum = -err.sum(axis=1)
rowsum = rowsum.repeat(k,axis = 1)
err = err/rowsum
for x in range(m):
err[x,label_data[x,0]]+=1
weights = weights+(alpha/m)*feature_data.T*err
i+=1
return weights

cost函数:

 def cost(err,label_data):
'''
:param err: exp的值
:param label_data: 标签的值
:return: 损失函数的值
'''
m = np.shape(err)[0]
sum_cost = 0.0
for i in range(m):
if err[i,label_data[i,0]]/np.sum(err[i,:])>0:
sum_cost -=np.log(err[i,label_data[i,0]]/np.sum(err[i,:]))
else:
sum_cost -= 0
return sum_cost / m

4、Softmax Regression与Logistics Regression的关系

有一点需要注意的是,按上述方法用softmax求得的参数并不是唯一的,因为,对每一个参数来说,若都减去一个相同的值,依然是上述的代价函数的值。证明如下:

  

这表明了softmax回归中的参数是“冗余”的。更正式一点来说,我们的softmax模型被过度参数化了,这意味着对于任何我们用来与数据相拟合的估计值,都会存在多组参数集,它们能够生成完全相同的估值函数hθ将输入x映射到预测值。因此使J(θ)最小化的解不是唯一的。而Hessian矩阵是奇异的/不可逆的,这会直接导致Softmax的牛顿法实现版本出现数值计算的问题。

为了解决这个问题,加入一个权重衰减项到代价函数中:

有了这个权重衰减项以后(对于任意的λ>0),代价函数就变成了严格的凸函数而且hession矩阵就不会不可逆了。

此时的偏导数:

对于上面的Softmax Regression过度参数化问题可以参考博客:https://www.cnblogs.com/bzjia-blog/p/3366780.html

2.1、Softmax Regression模型的更多相关文章

  1. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  2. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  3. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

  4. TensorFlow实现Softmax Regression识别手写数字

    本章已机器学习领域的Hello World任务----MNIST手写识别做为TensorFlow的开始.MNIST是一个非常简单的机器视觉数据集,是由几万张28像素*28像素的手写数字组成,这些图片只 ...

  5. Softmax回归(Softmax Regression)

    转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...

  6. UFLDL实验报告1: Softmax Regression

    PS:这些是今年4月份,跟斯坦福UFLDL教程时的实验报告,当时就应该好好整理的…留到现在好凌乱了 Softmax Regression实验报告 1.Softmax Regression实验描述 So ...

  7. ufldl学习笔记和编程作业:Softmax Regression(softmax回报)

    ufldl学习笔记与编程作业:Softmax Regression(softmax回归) ufldl出了新教程.感觉比之前的好,从基础讲起.系统清晰,又有编程实践. 在deep learning高质量 ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. TensorFlow(2)Softmax Regression

    Softmax Regression Chapter Basics generate random Tensors Three usual activation function in Neural ...

随机推荐

  1. LCS(详解)

    一,问题描述 给定两个字符串,求解这两个字符串的最长公共子序列(Longest Common Sequence).比如字符串1:BDCABA:字符串2:ABCBDAB 则这两个字符串的最长公共子序列长 ...

  2. tomcat 7 7.0.73 url 参数 大括号 {} 不支持 , 7.0.67支持

    7.0.73 url有JSON.stringify一个对象,然后作为参数拼接.结果请求报400错误,但是tomcat 7.0.67版本没有问题,猜测是高级版本对url参数比较严格.  处理方法:   ...

  3. Python将两个数组合并成一个数组,多维数组变成一维数组

    1.extend方法 c1 = ["Red","Green","Blue"] c2 = ["Orange"," ...

  4. ora-28547:连接服务器失败,可能是 Oracle Net 管理失败

    检查如下: 监听程序的配置文件 发现多了 (PROGRAM = extproc) 去掉后如下: # listener.ora Network Configuration \dbhome_1\netwo ...

  5. 使用DevStack安装openstack(单机环境)

    DevStack是一系列可扩展的脚本,用于根据git master的最新版本快速启动完整的OpenStack环境.它以交互方式用作开发环境,并作为OpenStack项目功能测试的基础. 参考源码. 警 ...

  6. 【HDU5391】Zball in Tina Town

    [题目大意] 一个球初始体积为1,一天天变大,第一天变大1倍,第二天变大2倍,第n天变大n倍.问当第 n-1天的时候,体积变为多少.注意答案对n取模. [题解] 根据威尔逊定理:(n-1)! mod ...

  7. 学习神器!本机安装虚拟机,并安装Linux系统,并部署整套web系统手册(包含自动部署应用脚本,JDK,tomcat,TortoiseSVN,Mysql,maven等软件)

    1.   引言 编写目的 本文档的编写目的主要是在Linux系统上部署mis_v3员工管理系统,方便测试,并为以后开发人员进行项目部署提供参考 准备工作 软件部分 软件项 版本 备注 Mysql 5. ...

  8. unity3d 为什么要烘焙?烘焙作用是为了什么?

    可以这样理解.你把物体模型放进了场景里之后, 引擎会计算光线,光线照到你的物体的表面形成反光和阴影. 如果不烘焙, 游戏运行的时候,这些反光和阴影都是由显卡和CPU计算出来的.你烘焙之后,这些反光和阴 ...

  9. mysql 基本操作 alter

    查看数据库 show  databases; 新建数据库 命令 create database 库名字. 选择数据库 use  2016test; 创建表:create table 表名(字段1,2, ...

  10. 关于ptrdiff_t

    整除.opencv中的内存一般是通过malloc分配,不能保证都是都能被16整除,此时需要截断,但是剩下的内存要如何维护? CV2.0的这样维护的:在 malloc 是多申请一个指针的空间,这个指针指 ...