1. # -*- coding: utf-8 -*-
  2. """
  3. 用神经网络搭建的softmax线性分离器
  4. Softmax是用于分类过程,用来实现多分类的,简单来说,它把一些输出的神经元映射到(0-1)之间的实数,并且归一化保证和为1,从而使得多分类的概率之和也刚好为1。
  5. Softmax可以分为soft和max,max也就是最大值,假设有两个变量a,b。如果a>b,则max为a,反之为b。
  6. 那么在分类问题里面,如果只有max,输出的分类结果只有a或者b,是个非黑即白的结果。
  7. 但是在现实情况下,我们希望输出的是取到某个分类的概率 我们希望分值大的那一项被经常取到,而分值较小的那一项也有一定的概率偶尔被取到,所以我们就应用到了soft的概念,即最后的输出是每个分类被取到的概率。
  8. """
  9. from keras.datasets import mnist
  10. from keras.models import Sequential
  11. from keras.layers import Dense,Activation
  12. from keras.optimizers import SGD
  13. from keras.utils import np_utils
  14. import numpy as np
  15. ############################################
  16. import os
  17. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  18. ###########################################
  19. np.random.seed(0)
  20. #导入数据 分离成为训练集+标签 和测试集+标签
  21. (X_train,y_train),(X_test,y_test) = mnist.load_data()
  22. # 训练集的个数
  23. # 60000 行的训练集分拆为 55000 行的训练集和 5000 行的验证集。
  24. # 每个样本都是一张28 * 28像素的灰度手写数字图片。
  25. print(X_train.shape,y_train.shape)
  26. # 测试集的个数
  27. print(X_test.shape,y_test.shape)
  28. #数据变换,变为10个类别
  29. nb_classes = 10
  30. X_train_1 = X_train.reshape(60000,784)
  31. # 归一化的意思 把所有的数弄到【0,1】之间
  32. X_train_1 = X_train_1/255
  33. print(X_train_1.shape)
  34. X_train_1 = X_train_1.astype('float32')
  35. # 使用np_utils.to_categorical(y_train, 10)将原来标签是一列的[1,0,0,0,1…]的转换为一行10列的独热码。
  36. y_train_1 = np_utils.to_categorical(y_train,nb_classes)
  37.  
  38. X_test_1 = X_test.reshape(10000,784)
  39. X_test_1 = X_test_1.astype('float32')
  40. y_test_1 = np_utils.to_categorical(y_test,nb_classes)
  41. print(X_test_1.shape,y_test_1.shape)
  42. print( '-----变换后的数据结构----->')
  43. print(X_train_1.shape)
  44. print(y_train_1.shape)
  45. print('Success!')
  46.  
  47. #建立模型
  48. model = Sequential()
  49. model.add(Dense(nb_classes,input_shape=(784,)))
  50. model.add(Activation('softmax'))
  51.  
  52. # 编译模型
  53. # 梯度下降 SGD 每次更新时对每个样本进行梯度更新
  54. sgd = SGD(lr=0.01)
  55. # 交叉熵函数
  56. model.compile(loss='binary_crossentropy',
  57. optimizer=sgd,
  58. metrics = ['accuracy'])
  59. #model 概要
  60. model.summary()
  61. #训练模型
  62. # X_train_1 输入数据
  63. # y_train_1 标签,
  64. # epochs:训练的总轮数
  65. # verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
  66. # batch_size:整数,指定进行梯度下降时每个batch包含的样本数
  67. # 最终得到模型 modela
  68. model.fit(X_train_1,y_train_1,
  69. epochs = 20,
  70. verbose=1,
  71. batch_size=100
  72. )
  73. #模型的测试误差指标
  74. print(model.metrics_names)
  75. #对测试数据进行测试
  76. # model.evaluate输入数据(data)和标签(label),然后将预测结果与标签比较,得到两者误差并输出.
  77. loss,accu = model.evaluate(X_test_1,y_test_1,
  78. verbose=2,
  79. batch_size = 100)
  80. print(loss,accu)

图片:

Python+Softmax+MNIST的更多相关文章

  1. python读取mnist

    python读取mnist 其实就是python怎么读取binnary file mnist的结构如下,选取train-images TRAINING SET IMAGE FILE (train-im ...

  2. Python读取MNIST数据集

    MNIST数据集获取 MNIST数据集是入门机器学习/模式识别的最经典数据集之一.最早于1998年Yan Lecun在论文: Gradient-based learning applied to do ...

  3. python 将Mnist数据集转为jpg,并按比例/标签拆分为多个子数据集

    现有条件:Mnist数据集,下载地址:跳转 下载后的四个.gz文件解压后放到同一个文件夹下,如:/raw Step 1:将Mnist数据集转为jpg图片(代码来自这篇博客) 1 import os 2 ...

  4. caffe的python接口学习(4):mnist实例---手写数字识别

    深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...

  5. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  6. python读取,显示,保存mnist图片

    python处理二进制 python的struct模块可以将整型(或者其它类型)转化为byte数组.看下面的代码. # coding: utf-8 from struct import * # 包装成 ...

  7. mnist手写数字检测

    # -*- coding: utf-8 -*- """ Created on Tue Apr 23 06:16:04 2019 @author: 92958 " ...

  8. python调用Opencv库和dlib库

    python是一门胶水语言,可以调用C++编译好的dll库 python调用opencv-imggui.dll文件 https://www.cnblogs.com/zhangxian/articles ...

  9. C++基于文件流和armadillo读取mnist

    发现网上大把都是用python读取mnist的,用C++大都是用opencv读取的,但我不怎么用opencv,因此自己摸索了个使用文件流读取mnist的方法,armadillo仅作为储存矩阵的一种方式 ...

  10. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

随机推荐

  1. navicate的安装使用

    1 navicat概述 Navicat for MySQL 是管理和开发 MySQL 或 MariaDB 的理想解决方案. 这套全面的前端工具为数据库管理.开发和维护提供了一款直观而强大的图形界面. ...

  2. 痞子衡嵌入式:从功耗测试角度了解i.MXRTxxx系列片内SRAM分区电源控制

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是从功耗测试角度了解i.MXRTxxx系列片内SRAM分区电源控制. 我们知道配合 MCU 一起工作的存储器包含 ROM(Flash) 和 ...

  3. 【SpringBoot】 集成 Ehcache

    SpringBoot ehcache 缓存 简介 EhCache 是一个纯 Java 的进程内缓存框架,具有快速.精干等特点, 是 Hibernate 中默认CacheProvider.Ehcache ...

  4. 【活动回顾】WebRTC服务端工程实践和优化探索

    11月7日,即构和上海GDG技术社区联合举办了实时音视频技术云上技术分享专场,来自即构科技和Bilibili的资深技术专家进行了深度分享.大会吸引了众多开发人员交流.观看,并在活动过程中与分享嘉宾进行 ...

  5. ObjectInputStream_报错问题

    报错: Exception in thread "main" java.io.StreamCorruptedException: invalid stream header: CE ...

  6. 【Python】从同步到异步多核:测试桩性能优化,加速应用的开发和验证

    测试工作中常用到的测试桩mock能力 在我们的测试工作过程中,可能会遇到多个项目并行开发的时候,后端服务还没有开发完成,或者我们需要压测某个服务,这个服务测在试环境的依赖组件(如 MQ) 无法支撑我们 ...

  7. HTB靶场之Sandworm

    准备: 攻击机:虚拟机kali. 靶机:Sandworm,htb网站:https://www.hackthebox.com/,靶机地址:https://app.hackthebox.com/machi ...

  8. 2021-7-6 new tcpip

    using System; using System.Collections.Generic; using System.Linq; using System.Net; using System.Ne ...

  9. 如何使用iptables防火墙模拟远程服务超时

    前言 超时,应该是程序员很不爱处理的一种状态.当我们调用某服务.某个中间件.db时,希望对方能快速回复,正确就正常,错误就错误,而不是一直不回复.目前在后端领域来说,如java领域,调用服务时以同步阻 ...

  10. QLabel类中的常用方法&信号

    setAlignment: 按固定值方式对齐文本 Qt.AlignLeft:水平方向靠左对齐 Qt.AlignRight:水平方向靠右对齐 Qt.AlignCenter:水平方向居中对齐 Qt.Ali ...