相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist created on 2019.9.28
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score def loss_max_right_class_prob(predictions, y):
return -predictions[numpy.argmax(y)]; def loss_cross_entropy(predictions, y):
return -numpy.dot(y, numpy.log(predictions)); '''
Softmax classifier
linear classifier
'''
class Softmax: def __init__(self, iter_num = 100000, batch_size = 1):
self.__iter_num = iter_num;
self.__batch_size = batch_size; def train(self, train_X, train_Y):
X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
Y = numpy.copy(train_Y); self.L = []; #initialize parameters
self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
self.__step_len = 1e-3; logging.info("weight:%s" % (self.__weight)); for iter_index in range(self.__iter_num):
if iter_index % 1000 == 0:
logging.info("-----iter:%s-----" % (iter_index));
if iter_index % 100 == 0:
l = 0;
for i in range(0, len(X), 100):
predictions = self.forward_pass(X[i]);
#l += loss_max_right_class_prob(predictions, Y[i]);
l += loss_cross_entropy(predictions, Y[i]);
l /= len(X);
self.L.append(l); sample_index = random.randint(0, len(X) - 1);
logging.debug("-----select sample %s-----" % (sample_index)); z = numpy.dot(X[sample_index], self.__weight);
z = z - numpy.max(z);
predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
# dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1));
# dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len; self.__weight -= dw; logging.debug("weight:%s" % (self.__weight));
logging.debug("loss:%s" % (l));
logging.info("weight:%s" % (self.__weight));
logging.info("L:%s" % (self.L)); def forward_pass(self, x):
net = numpy.dot(x, self.__weight);
net = net - numpy.max(net);
net = numpy.exp(net) / numpy.sum(numpy.exp(net));
return net; def predict(self, x):
x = numpy.append(x, 1.0);
return self.forward_pass(x); def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S'); logging.info("trainning begin."); mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 #load data
train_X = mnist.train.images #训练集样本
validation_X = mnist.validation.images #验证集样本
test_X = mnist.test.images #测试集样本
#labels
train_Y = mnist.train.labels #训练集标签
validation_Y = mnist.validation.labels #验证集标签
test_Y = mnist.test.labels #测试集标签 classifier = Softmax();
classifier.train(train_X, train_Y); logging.info("trainning end. predict begin."); test_predict = numpy.array([]);
test_right = numpy.array([]);
for i in range(len(test_X)):
predict_label = numpy.argmax(classifier.predict(test_X[i]));
test_predict = numpy.append(test_predict, predict_label);
right_label = numpy.argmax(test_Y[i]);
test_right = numpy.append(test_right, right_label); logging.info("right:%s, predict:%s" % (test_right, test_predict));
score = accuracy_score(test_right, test_predict);
logging.info("The accruacy score is: %s "% (str(score))); plt.plot(classifier.L)
plt.show(); if __name__ == "__main__":
main();

损失函数收敛情况

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

softmax及python实现的更多相关文章

  1. 机器学习-softmax回归 python实现

    ---恢复内容开始--- Softmax Regression 可以看做是 LR 算法在多分类上的推广,即类标签 y 的取值大于或者等于 2. 假设数据样本集为:$\left \{ \left ( X ...

  2. softmax函数python实现

    import numpy as np def softmax(x): """ 对输入x的每一行计算softmax. 该函数对于输入是向量(将向量视为单独的行)或者矩阵(M ...

  3. TensorFlow(2)Softmax Regression

    Softmax Regression Chapter Basics generate random Tensors Three usual activation function in Neural ...

  4. logistic regression model

    logistic regression model LR softmax classification Fly logistic regression model loss fuction softm ...

  5. [C2W3] Improving Deep Neural Networks : Hyperparameter tuning, Batch Normalization and Programming Frameworks

    第三周:Hyperparameter tuning, Batch Normalization and Programming Frameworks 调试处理(Tuning process) 目前为止, ...

  6. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  7. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  8. 如何用Python计算Softmax?

    Softmax函数,或称归一化指数函数,它能将一个含任意实数的K维向量z"压缩"到另一个K维实向量\(\sigma{(z)}\)中,使得每一个元素的范围都在(0,1)之间,并且所有 ...

  9. 使用python计算softmax函数

    softmax计算公式:                        Softmax是机器学习中一个非常重要的工具,他可以兼容 logistics 算法.可以独立作为机器学习的模型进行建模训练.还可 ...

随机推荐

  1. 初识Arduino

    Arduino是一款便捷灵活.方便上手的开源电子原型平台.包含硬件(各种型号的Arduino板)和软件(Arduino IDE).由一个欧洲开发团队于2005年冬季开发.其成员包括Massimo Ba ...

  2. 后渗透阶段之基于MSF的路由转发

    目录 反弹MSF类型的Shell 添加内网路由 MSF的跳板功能是MSF框架中自带的一个路由转发功能,其实现过程就是MSF框架在已经获取的Meterpreter Shell的基础上添加一条去往“内网” ...

  3. SpringBoot整合Swagger2案例,以及报错:java.lang.NumberFormatException: For input string: ""原因和解决办法

    原文链接:https://blog.csdn.net/weixin_43724369/article/details/89341949 SpringBoot整合Swagger2案例 先说SpringB ...

  4. C与ARM汇编结合实现mini2440串口uart简单程序

    最近学完了ARM的一些基础知识,开始在mini2440上开发一些简单的程序,串口发送程序是一开始涉及多个寄存器的例子,稍有繁多的步骤应该是开发过程中要慢慢适应的境况 下面的程序的目的是实现mini24 ...

  5. 前端每日实战:48# 视频演示如何用纯 CSS 创作一盘传统蚊香

    效果预览 按下右侧的"点击预览"按钮可以在当前页面预览,点击链接可以全屏预览. https://codepen.io/comehope/pen/BVpvMz 可交互视频教程 此视频 ...

  6. C#爬取微博文字、图片、视频(不使用Cookie)

    前两天在网上偶然看到一个大佬OmegaXYZ写的文章,Python爬取微博文字与图片(不使用Cookie) 于是就心血来潮,顺手撸一个C#版本的. 其实原理也很简单,现在网上大多数版本都需要Cooki ...

  7. 编码的来源于格式简介ANSI、GBK、GB2312、UTF-8、GB18030和 UNICODE

    编码一直是让新手头疼的问题,特别是 GBK.GB2312.UTF-8 这三个比较常见的网页编码的区别,更是让许多新手晕头转向,怎么解释也解释不清楚.但是编码又是那么重要,特别在网页这一块.如果你打出来 ...

  8. 内网渗透之权限维持 - MSF与cs联动

    年初六 六六六 MSF和cs联动 msf连接cs 1.在队伍服务器上启动cs服务端 ./teamserver 团队服务器ip 连接密码 2.cs客户端连接攻击机 填团队服务器ip和密码,名字随便 ms ...

  9. 关于PHP命名空间的讨论

    什么是命名空间? 根据php.net官方翻译文档描述,命名空间是这样定义的: 什么是命名空间?从广义上来说,命名空间是一种封装事物的方法. 在PHP中,命名空间用来解决在编写类库或应用程序时创建可重用 ...

  10. 原生ajax动态添加数据

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...