MNIST 数据集
包含60 000 张训练图像和10 000 张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即MNIST 中
的NIST)在20 世纪80 年代收集得到。
 
类和标签
在机器学习中,分类问题中的某个类别叫作类(class)。数据点叫作样本(sample)。某
个样本对应的类叫作标签(label)。
 
MNIST 数据集预先加载在Keras 库中,其中包括4 个Numpy 数组。
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
train_images 和train_labels 组成了训练集(training set),模型将从这些数据中进行
学习。然后在测试集(test set,即test_images 和test_labels)上对模型进行测试。
 
图像被编码为Numpy 数组,而标签是数字数组,取值范围为0~9。图像和标签一一对应。
 
我们来看一下训练数据:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
 
 
测试数据:
>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
 
 
神经网络架构
 
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
 
本例中的网络包含2 个Dense 层,它们是密集连接(也叫全连接)的神经层。第二层(也
是最后一层)是一个10 路softmax 层,它将返回一个由10 个概率值(总和为1)组成的数组。
每个概率值表示当前数字图像属于10 个数字类别中某一个的概率。
 
要想训练网络,我们还需要选择编译(compile)步骤的三个参数。
损失函数(loss function):网络如何衡量在训练数据上的性能,即网络如何朝着正确的
方向前进。
优化器(optimizer):基于训练数据和损失函数来更新网络的机制。
在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所
占的比例。
 
编译步骤
network.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
 
在开始训练之前,我们将对数据进行预处理,将其变换为网络要求的形状,并缩放到所
有值都在[0, 1] 区间。比如,之前训练图像保存在一个uint8 类型的数组中,其形状为
(60000, 28, 28),取值区间为[0, 255]。我们需要将其变换为一个float32 数组,其形
状为(60000, 28 * 28),取值范围为0~1。
 
准备图像数据
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
 
 
 
准备标签
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
 
 
 
开始训练网络
>>> network.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [=============================] - 9s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=======================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692
 
 
 
检查模型在测试集上的性能
>>> test_loss, test_acc = network.evaluate(test_images, test_labels)
>>> print('test_acc:', test_acc)
test_acc: 0.9785
 
 
 
 
 
 
 
 
 
 
 
 
 

Python深度学习读书笔记-2.初识神经网络的更多相关文章

  1. Python深度学习读书笔记-3.神经网络的数据表示

    标量(0D 张量) 仅包含一个数字的张量叫作标量(scalar,也叫标量张量.零维张量.0D 张量).在Numpy 中,一个float32 或float64 的数字就是一个标量张量(或标量数组).你可 ...

  2. Python深度学习读书笔记-4.神经网络入门

    神经网络剖析   训练神经网络主要围绕以下四个方面: 层,多个层组合成网络(或模型) 输入数据和相应的目标 损失函数,即用于学习的反馈信号 优化器,决定学习过程如何进行   如图 3-1 所示:多个层 ...

  3. Python深度学习读书笔记-1.什么是深度学习

    人工智能 什么是人工智能.机器学习与深度学习(见图1-1)?这三者之间有什么关系?

  4. Python深度学习读书笔记-5.Keras 简介

    Keras 重要特性 相同的代码可以在 CPU 或 GPU 上无缝切换运行. 具有用户友好的 API,便于快速开发深度学习模型的原型. 内置支持卷积网络(用于计算机视觉).循环网络(用于序列处理)以及 ...

  5. Python深度学习读书笔记-6.二分类问题

    电影评论分类:二分类问题   加载 IMDB 数据集 from keras.datasets import imdb (train_data, train_labels), (test_data, t ...

  6. 深度学习读书笔记之RBM(限制波尔兹曼机)

    深度学习读书笔记之RBM 声明: 1)看到其他博客如@zouxy09都有个声明,老衲也抄袭一下这个东西 2)该博文是整理自网上很大牛和机器学习专家所无私奉献的资料的.具体引用的资料请看参考文献.具体的 ...

  7. [1天搞懂深度学习] 读书笔记 lecture I:Introduction of deep learning

    - 通常机器学习,目的是,找到一个函数,针对任何输入:语音,图片,文字,都能够自动输出正确的结果. - 而我们可以弄一个函数集合,这个集合针对同一个猫的图片的输入,可能有多种输出,比如猫,狗,猴子等, ...

  8. 深度学习课程笔记(一)CNN 卷积神经网络

    深度学习课程笔记(一)CNN 解析篇 相关资料来自:http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17_2.html 首先提到 Why CNN for I ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. luogu P5342 [TJOI2019]甲苯先生的线段树

    传送门 你个好好的省选怎么可以出CF原题啊,你们这个题害人不浅啊,这样子出题像极了cxk,说到cxk,我又想起了他是NBA形象大使,跟我是西游文化大使一样一样的,今年下半年... 别说了,jinsai ...

  2. 利用sql报错帮助进行sql注入

    我们可以利用sql报错帮助进行sql注入,这里以sql server 为例: sql查询时,若用group by子句时,该子句中的字段必须跟select 条件中的字段(非聚合函数)完全匹配,如果是se ...

  3. 四、Vue CLI-异步请求(axios)

    一.模块的安装 npm install axios --save #--save可以不用写 如图: 二.配置main.js import axios from 'axios' Vue.prototyp ...

  4. linux下利用tcpdump抓包工具排查nginx获取客户端真实IP实例

    一.nginx后端负载服务器的API在获取客户端IP时始终只能获取nginx的代理服务器IP,排查nginx配置如下 upstream sms-resp { server ; server ; } s ...

  5. 删除表A的记录时,Oracle 报错:“ORA-02292:违反完整约束条件(XXX.FKXXX)- 已找到子记录

    1.找到以”FKXXX“为外键的表A的子表,直接运行select a.constraint_name, a.table_name, b.constraint_name from user_constr ...

  6. token和sign

    前言 在app开放接口api的设计中,避免不了的就是安全性问题,因为大多数接口涉及到用户的个人信息以及一些敏感的数据,所以对这些接口需要进行身份的认证,那么这就需要用户提供一些信息,比如用户名密码等, ...

  7. react-native启动时红屏报错:Unable to load script.Make sure you're either running a metro server or that ....

    一.报错信息内容 我是在Android Studio中运行启动react-native项目时报的这个错误 1.报错提示:Unable to load script.Make sure you're e ...

  8. 给oracle命令的参数赋值

    ''' <summary>    '''   给oracle命令的参数赋值    ''' </summary>    ''' <param name="cmd& ...

  9. UIWebView半透明设置

    在项目中有时候需要弹出活动弹框,由于原生的样式会固定,所以考虑h5显示,这就需要webView的背景色半透明,如图: 让 UIWebView 背景透明需要以下设置

webView.backgroun ...

  10. python gitlab 学习笔记

    gitlab创建个人访问令牌(personal access token) https://blog.csdn.net/NGU2028070003/article/details/86634474 P ...