相对于自适应神经网络、感知器,softmax巧妙低使用简单的方法来实现多分类问题。

  • 功能上,完成从N维向量到M维向量的映射
  • 输出的结果范围是[0, 1],对于一个sample的结果所有输出总和等于1
  • 输出结果,可以隐含地表达该类别的概率

softmax的损失函数是采用了多分类问题中常见的交叉熵,注意经常有2个表达的形式

这两个版本在求导过程有点不同,但是结果都是一样的,同时损失表达的意思也是相同的,因为在第一种表达形式中,当y不是正确分类时,y_right等于0,当y是正确分类时,y_right等于1。

下面基于mnist数据做了一个多分类的实验,整体能达到85%的精度。

'''
softmax classifier for mnist created on 2019.9.28
author: vince
'''
import math
import logging
import numpy
import random
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
from sklearn.metrics import accuracy_score def loss_max_right_class_prob(predictions, y):
return -predictions[numpy.argmax(y)]; def loss_cross_entropy(predictions, y):
return -numpy.dot(y, numpy.log(predictions)); '''
Softmax classifier
linear classifier
'''
class Softmax: def __init__(self, iter_num = 100000, batch_size = 1):
self.__iter_num = iter_num;
self.__batch_size = batch_size; def train(self, train_X, train_Y):
X = numpy.c_[train_X, numpy.ones(train_X.shape[0])];
Y = numpy.copy(train_Y); self.L = []; #initialize parameters
self.__weight = numpy.random.rand(X.shape[1], 10) * 2 - 1.0;
self.__step_len = 1e-3; logging.info("weight:%s" % (self.__weight)); for iter_index in range(self.__iter_num):
if iter_index % 1000 == 0:
logging.info("-----iter:%s-----" % (iter_index));
if iter_index % 100 == 0:
l = 0;
for i in range(0, len(X), 100):
predictions = self.forward_pass(X[i]);
#l += loss_max_right_class_prob(predictions, Y[i]);
l += loss_cross_entropy(predictions, Y[i]);
l /= len(X);
self.L.append(l); sample_index = random.randint(0, len(X) - 1);
logging.debug("-----select sample %s-----" % (sample_index)); z = numpy.dot(X[sample_index], self.__weight);
z = z - numpy.max(z);
predictions = numpy.exp(z) / numpy.sum(numpy.exp(z));
dw = self.__step_len * X[sample_index].reshape(-1, 1).dot((predictions - Y[sample_index]).reshape(1, -1));
# dw = self.__step_len * X[sample_index].reshape(-1, 1).dot(predictions.reshape(1, -1));
# dw[range(X.shape[1]), numpy.argmax(Y[sample_index])] -= X[sample_index] * self.__step_len; self.__weight -= dw; logging.debug("weight:%s" % (self.__weight));
logging.debug("loss:%s" % (l));
logging.info("weight:%s" % (self.__weight));
logging.info("L:%s" % (self.L)); def forward_pass(self, x):
net = numpy.dot(x, self.__weight);
net = net - numpy.max(net);
net = numpy.exp(net) / numpy.sum(numpy.exp(net));
return net; def predict(self, x):
x = numpy.append(x, 1.0);
return self.forward_pass(x); def main():
logging.basicConfig(level = logging.INFO,
format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt = '%a, %d %b %Y %H:%M:%S'); logging.info("trainning begin."); mnist = read_data_sets('../data/MNIST',one_hot=True) # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签 #load data
train_X = mnist.train.images #训练集样本
validation_X = mnist.validation.images #验证集样本
test_X = mnist.test.images #测试集样本
#labels
train_Y = mnist.train.labels #训练集标签
validation_Y = mnist.validation.labels #验证集标签
test_Y = mnist.test.labels #测试集标签 classifier = Softmax();
classifier.train(train_X, train_Y); logging.info("trainning end. predict begin."); test_predict = numpy.array([]);
test_right = numpy.array([]);
for i in range(len(test_X)):
predict_label = numpy.argmax(classifier.predict(test_X[i]));
test_predict = numpy.append(test_predict, predict_label);
right_label = numpy.argmax(test_Y[i]);
test_right = numpy.append(test_right, right_label); logging.info("right:%s, predict:%s" % (test_right, test_predict));
score = accuracy_score(test_right, test_predict);
logging.info("The accruacy score is: %s "% (str(score))); plt.plot(classifier.L)
plt.show(); if __name__ == "__main__":
main();

损失函数收敛情况

Sun, 29 Sep 2019 18:08:08 softmax.py[line:104] INFO trainning end. predict begin.
Sun, 29 Sep 2019 18:08:08 softmax.py[line:114] INFO right:[7. 2. 1. ... 4. 5. 6.], predict:[7. 2. 1. ... 4. 8. 6.]
Sun, 29 Sep 2019 18:08:08 softmax.py[line:116] INFO The accruacy score is: 0.8486

softmax及python实现的更多相关文章

  1. 机器学习-softmax回归 python实现

    ---恢复内容开始--- Softmax Regression 可以看做是 LR 算法在多分类上的推广,即类标签 y 的取值大于或者等于 2. 假设数据样本集为:$\left \{ \left ( X ...

  2. softmax函数python实现

    import numpy as np def softmax(x): """ 对输入x的每一行计算softmax. 该函数对于输入是向量(将向量视为单独的行)或者矩阵(M ...

  3. TensorFlow(2)Softmax Regression

    Softmax Regression Chapter Basics generate random Tensors Three usual activation function in Neural ...

  4. logistic regression model

    logistic regression model LR softmax classification Fly logistic regression model loss fuction softm ...

  5. [C2W3] Improving Deep Neural Networks : Hyperparameter tuning, Batch Normalization and Programming Frameworks

    第三周:Hyperparameter tuning, Batch Normalization and Programming Frameworks 调试处理(Tuning process) 目前为止, ...

  6. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  7. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  8. 如何用Python计算Softmax?

    Softmax函数,或称归一化指数函数,它能将一个含任意实数的K维向量z"压缩"到另一个K维实向量\(\sigma{(z)}\)中,使得每一个元素的范围都在(0,1)之间,并且所有 ...

  9. 使用python计算softmax函数

    softmax计算公式:                        Softmax是机器学习中一个非常重要的工具,他可以兼容 logistics 算法.可以独立作为机器学习的模型进行建模训练.还可 ...

随机推荐

  1. 牛客网剑指offer第34题——找到第一个只出现一次的字符

    题目如下: 在一个字符串(0<=字符串长度<=10000,全部由字母组成)中找到第一个只出现一次的字符,并返回它的位置, 如果没有则返回 -1(需要区分大小写). 先上代码: class ...

  2. python打包py为exe程序:PyInstaller

    打包库:PyInstaller python程序编写过程中的脚本文件为py格式的文件,当我们想将编写好的程序移植到其他机器上给其他人使用时,如果目标机器没有安装python环境,py文件将无法运行,而 ...

  3. Python 【基础常识概念】

    深浅拷贝 浅copy与deepcopy 浅copy: 不管多么复杂的数据结构,浅拷贝都只会copy一层 deepcopy : 深拷贝会完全复制原变量相关的所有数据,在内存中生成一套完全一样的内容,我们 ...

  4. openwrt 外挂usb 网卡 RTL8188CU 及添加 RT5572 kernel支持

    RT5572 原来叫 Ralink雷凌 现在被 MTK 收购了,淘宝上买的很便宜50块邮,2.4 5G 双频.在 win10 上插了试试,果然是支持 5G.这上面写着 飞荣 是什么牌子,有知道的和我说 ...

  5. openwrt 上的 upnp wifi 音频推送 gmediarender

    首先是必须启用的模块 Libraries ---> <*> libupnp Sound ---> <*> alsa-utils<*> madplay-a ...

  6. Jenkins构建项目帮助文档

    Jenkins构建项目帮助文档 主要步骤 一.配置jdk 1.1.   下载jdk,安装到自己电脑磁盘的Java目录下(比如:D:\Java\jdk). 1.2.   Jdk环境变量的配置: 1. 鼠 ...

  7. 网络编程---socket模块

    内容中代码都是先写  server端, 再写 client端 1 TCP和UDP对比 TCP(Transmission Control Protocol)可靠的.面向连接的协议(eg:打电话).传输效 ...

  8. FastDFS源码学习(一)FastDFS介绍及源码编译安装

    FastDFS是淘宝的余庆主导开发的一个分布式文件系统,采用C语言开发,性能较优.在淘宝网.京东商城.支付宝和某些网盘等系统均有使用,使用场景十分广泛. 下图来源:https://blog.csdn. ...

  9. Java第一节课考试

    1 package kaoshi; import java.util.Scanner; public class ScoreInformation { Scanner input=new Scanne ...

  10. Druid连接池和springJDbc框架-Java(新手)

    Druid连接池: Druid 由阿里提供 安装步骤: 导包 durid1.0.9 jar包 定义配置文件 properties文件 名字任意位置也任意 加载文件 获得数据库连接池对象 通过Durid ...