相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist created on 2019.9.28
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score def loss_max_right_class_prob(predictions, y):
return -predictions[numpy.argmax(y)]; def loss_cross_entropy(predictions, y):
return -numpy.dot(y, numpy.log(predictions)); '''
Softmax classifier
linear classifier
'''
class Softmax: def __init__(self, iter_num = 100000, batch_size = 1):
self.__iter_num = iter_num;
self.__batch_size = batch_size; def train(self, train_X, train_Y):
X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
Y = numpy.copy(train_Y); self.L = []; #initialize parameters
self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
self.__step_len = 1e-3; logging.info("weight:%s" % (self.__weight)); for iter_index in range(self.__iter_num):
if iter_index % 1000 == 0:
logging.info("-----iter:%s-----" % (iter_index));
if iter_index % 100 == 0:
l = 0;
for i in range(0, len(X), 100):
predictions = self.forward_pass(X[i]);
#l += loss_max_right_class_prob(predictions, Y[i]);
l += loss_cross_entropy(predictions, Y[i]);
l /= len(X);
self.L.append(l); sample_index = random.randint(0, len(X) - 1);
logging.debug("-----select sample %s-----" % (sample_index)); z = numpy.dot(X[sample_index], self.__weight);
z = z - numpy.max(z);
predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
# dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1));
# dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len; self.__weight -= dw; logging.debug("weight:%s" % (self.__weight));
logging.debug("loss:%s" % (l));
logging.info("weight:%s" % (self.__weight));
logging.info("L:%s" % (self.L)); def forward_pass(self, x):
net = numpy.dot(x, self.__weight);
net = net - numpy.max(net);
net = numpy.exp(net) / numpy.sum(numpy.exp(net));
return net; def predict(self, x):
x = numpy.append(x, 1.0);
return self.forward_pass(x); def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S'); logging.info("trainning begin."); mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 #load data
train_X = mnist.train.images #训练集样本
validation_X = mnist.validation.images #验证集样本
test_X = mnist.test.images #测试集样本
#labels
train_Y = mnist.train.labels #训练集标签
validation_Y = mnist.validation.labels #验证集标签
test_Y = mnist.test.labels #测试集标签 classifier = Softmax();
classifier.train(train_X, train_Y); logging.info("trainning end. predict begin."); test_predict = numpy.array([]);
test_right = numpy.array([]);
for i in range(len(test_X)):
predict_label = numpy.argmax(classifier.predict(test_X[i]));
test_predict = numpy.append(test_predict, predict_label);
right_label = numpy.argmax(test_Y[i]);
test_right = numpy.append(test_right, right_label); logging.info("right:%s, predict:%s" % (test_right, test_predict));
score = accuracy_score(test_right, test_predict);
logging.info("The accruacy score is: %s "% (str(score))); plt.plot(classifier.L)
plt.show(); if __name__ == "__main__":
main();

损失函数收敛情况

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

softmax及python实现的更多相关文章

  1. 机器学习-softmax回归 python实现

    ---恢复内容开始--- Softmax Regression 可以看做是 LR 算法在多分类上的推广,即类标签 y 的取值大于或者等于 2. 假设数据样本集为:$\left \{ \left ( X ...

  2. softmax函数python实现

    import numpy as np def softmax(x): """ 对输入x的每一行计算softmax. 该函数对于输入是向量(将向量视为单独的行)或者矩阵(M ...

  3. TensorFlow(2)Softmax Regression

    Softmax Regression Chapter Basics generate random Tensors Three usual activation function in Neural ...

  4. logistic regression model

    logistic regression model LR softmax classification Fly logistic regression model loss fuction softm ...

  5. [C2W3] Improving Deep Neural Networks : Hyperparameter tuning, Batch Normalization and Programming Frameworks

    第三周:Hyperparameter tuning, Batch Normalization and Programming Frameworks 调试处理(Tuning process) 目前为止, ...

  6. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  7. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  8. 如何用Python计算Softmax?

    Softmax函数,或称归一化指数函数,它能将一个含任意实数的K维向量z"压缩"到另一个K维实向量\(\sigma{(z)}\)中,使得每一个元素的范围都在(0,1)之间,并且所有 ...

  9. 使用python计算softmax函数

    softmax计算公式:                        Softmax是机器学习中一个非常重要的工具,他可以兼容 logistics 算法.可以独立作为机器学习的模型进行建模训练.还可 ...

随机推荐

  1. Samtec 5G探索之路

    序言:时代在发展,2020年5G作为元年.5G全程第五代移动通信技术(英语:5th generation mobile networks或5th generation wireless systems ...

  2. 那是我夕阳下的奔跑,电商网站PC端详情页图片放大效果实现

    在详情页浏览时商品大图还是不能完全看清楚商品的细节,该特效实现鼠标悬停在商品大图上时,在商品大图右侧出现放大镜效果并根据鼠标的位置来改变右侧大图的显示内容,放大镜中的内容和鼠标悬停位置的内容相同.该特 ...

  3. Vue2.0组件的继承与扩展

    如果有需要源代码,请猛戳源代码 希望文章给大家些许帮助和启发,麻烦大家在GitHub上面点个赞!!!十分感谢 前言 本文将介绍vue2.0中的组件的继承与扩展,主要分享slot.mixins/exte ...

  4. 一个轻量级的基于 .NET Core 的 ORM 框架 HSQL

    HSQL 是一种轻量级的基于 .NET Core 的数据库对象关系映射「ORM」框架 HSQL 是一种可以使用非常简单且高效的方式进行数据库操作的一种框架,通过简单的语法,使数据库操作不再成为难事.目 ...

  5. 在windows上极简安装GPU版AI框架(Tensorflow、Pytorch)

    在windows上极简安装GPU版AI框架 如果我们想在windows系统上安装GPU版本的AI框架,比如GPU版本的tesnorflow,通常我们会看到类似下面的安装教程 官方版本 安装CUDA 安 ...

  6. 第六章、Vue项目预热

    6-5. (1). (2).引入reset.css及border.css (3).解决手机点击延迟300ms的问题 a.安装 b.引入fastclick 6-2.项目的整体架构 src--->整 ...

  7. js中所有函数的参数(按值和按引用)都是按值传递的,怎么理解?

    我觉着我可能对这块有点误解,所以单独开个博说下自己的理解,当然是研究后的正解了. 1,参数传递是基本类型,看个例子: function addTen(num){ num += 10; return n ...

  8. SUCTF checkin

    复现的时候看了源码...... 发现文件上传时会对文件内容以及后缀进行严格的检测 同时还有exif_imagetype   这个就用图片马就行绕过,绕过文件后缀试一下传图片马解析为php 但是常规解析 ...

  9. SpringBoot图文教程15—项目异常怎么办?「跳转404错误页面」「全局异常捕获」

    有天上飞的概念,就要有落地的实现 概念十遍不如代码一遍,朋友,希望你把文中所有的代码案例都敲一遍 先赞后看,养成习惯 SpringBoot 图文教程系列文章目录 SpringBoot图文教程1-Spr ...

  10. 1,Linux(CentOS)中的基本配置

    1,hostname(主机名) 查看主机名:hostname 临时修改主机名:hostname hadoop1 永久修改主机名:vi etc/sysconfig/network :  [NETWORK ...