MNIST 数据集
包含60 000 张训练图像和10 000 张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即MNIST 中
的NIST)在20 世纪80 年代收集得到。
 
类和标签
在机器学习中,分类问题中的某个类别叫作类(class)。数据点叫作样本(sample)。某
个样本对应的类叫作标签(label)。
 
MNIST 数据集预先加载在Keras 库中,其中包括4 个Numpy 数组。
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
train_images 和train_labels 组成了训练集(training set),模型将从这些数据中进行
学习。然后在测试集(test set,即test_images 和test_labels)上对模型进行测试。
 
图像被编码为Numpy 数组,而标签是数字数组,取值范围为0~9。图像和标签一一对应。
 
我们来看一下训练数据:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
 
 
测试数据:
>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
 
 
神经网络架构
 
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
 
本例中的网络包含2 个Dense 层,它们是密集连接(也叫全连接)的神经层。第二层(也
是最后一层)是一个10 路softmax 层,它将返回一个由10 个概率值(总和为1)组成的数组。
每个概率值表示当前数字图像属于10 个数字类别中某一个的概率。
 
要想训练网络,我们还需要选择编译(compile)步骤的三个参数。
损失函数(loss function):网络如何衡量在训练数据上的性能,即网络如何朝着正确的
方向前进。
优化器(optimizer):基于训练数据和损失函数来更新网络的机制。
在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所
占的比例。
 
编译步骤
network.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
 
在开始训练之前,我们将对数据进行预处理,将其变换为网络要求的形状,并缩放到所
有值都在[0, 1] 区间。比如,之前训练图像保存在一个uint8 类型的数组中,其形状为
(60000, 28, 28),取值区间为[0, 255]。我们需要将其变换为一个float32 数组,其形
状为(60000, 28 * 28),取值范围为0~1。
 
准备图像数据
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
 
 
 
准备标签
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
 
 
 
开始训练网络
>>> network.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [=============================] - 9s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=======================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692
 
 
 
检查模型在测试集上的性能
>>> test_loss, test_acc = network.evaluate(test_images, test_labels)
>>> print('test_acc:', test_acc)
test_acc: 0.9785
 
 
 
 
 
 
 
 
 
 
 
 
 

Python深度学习读书笔记-2.初识神经网络的更多相关文章

  1. Python深度学习读书笔记-3.神经网络的数据表示

    标量(0D 张量) 仅包含一个数字的张量叫作标量(scalar,也叫标量张量.零维张量.0D 张量).在Numpy 中,一个float32 或float64 的数字就是一个标量张量(或标量数组).你可 ...

  2. Python深度学习读书笔记-4.神经网络入门

    神经网络剖析   训练神经网络主要围绕以下四个方面: 层,多个层组合成网络(或模型) 输入数据和相应的目标 损失函数,即用于学习的反馈信号 优化器,决定学习过程如何进行   如图 3-1 所示:多个层 ...

  3. Python深度学习读书笔记-1.什么是深度学习

    人工智能 什么是人工智能.机器学习与深度学习(见图1-1)?这三者之间有什么关系?

  4. Python深度学习读书笔记-5.Keras 简介

    Keras 重要特性 相同的代码可以在 CPU 或 GPU 上无缝切换运行. 具有用户友好的 API,便于快速开发深度学习模型的原型. 内置支持卷积网络(用于计算机视觉).循环网络(用于序列处理)以及 ...

  5. Python深度学习读书笔记-6.二分类问题

    电影评论分类:二分类问题   加载 IMDB 数据集 from keras.datasets import imdb (train_data, train_labels), (test_data, t ...

  6. 深度学习读书笔记之RBM(限制波尔兹曼机)

    深度学习读书笔记之RBM 声明: 1)看到其他博客如@zouxy09都有个声明,老衲也抄袭一下这个东西 2)该博文是整理自网上很大牛和机器学习专家所无私奉献的资料的.具体引用的资料请看参考文献.具体的 ...

  7. [1天搞懂深度学习] 读书笔记 lecture I:Introduction of deep learning

    - 通常机器学习,目的是,找到一个函数,针对任何输入:语音,图片,文字,都能够自动输出正确的结果. - 而我们可以弄一个函数集合,这个集合针对同一个猫的图片的输入,可能有多种输出,比如猫,狗,猴子等, ...

  8. 深度学习课程笔记(一)CNN 卷积神经网络

    深度学习课程笔记(一)CNN 解析篇 相关资料来自:http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17_2.html 首先提到 Why CNN for I ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. vim最常用命令

    vi/vim常用命令汇总 vi/vim概述 vi/vim是Linux和Unix下的一款非常强大的编辑器,vim是vi的增强 版,命令更加多种和复杂,但是最常用的也就是那几个. vi有三种模式 命令行模 ...

  2. 什么是RESTful API、WSGI、pecan

    RESTful API REST的全称是Representational State Transfer(表征状态转移), 是Roy Fielding在他的博士论文Architectural Style ...

  3. 解决 "Could not autowire. No beans of 'SationMapper' type found" 的问题

    网上查找的方法,附上原文链接:https://blog.csdn.net/Coder_Knight/article/details/83999139 方法1:在mapper文件上加@Repositor ...

  4. C语言_扫雷代码

    本文详细讲述了基于C语言实现的扫雷游戏代码,代码中备有比较详细的注释,便于读者阅读和理解.希望对学习游戏开发的朋友能有一点借鉴价值. 完整的实例代码如下: ? 1 2 3 4 5 6 7 8 9 10 ...

  5. 从分析攻击方式来谈如何防御DDoS攻击

    DDoS攻击的定义: DDoS攻击全称——分布式拒绝服务攻击,是网络攻击中非常常见的攻击方式.在进行攻击的时候,这种方式可以对不同地点的大量计算机进行攻击,进行攻击的时候主要是对攻击的目标发送超过其处 ...

  6. linux的安全--Selinux,tcp_wrappers,iptables使用

    一.linux安全 安全主要是端口与服务的对应配置 1.1 linux安全主要通过下面三个进行加固 Selinux----主要是对内核的访问权限加以控制 tcp_wrappers---一定程度上限制某 ...

  7. js实现倒计时(分:秒)

    上代码: //倒计时start 需要传入的参数为秒数,此方法倒计时结束后会自动刷新页面 function resetTime(timetamp){ var timer=null; var t=time ...

  8. Jmeter之cookie的处理方式,token处理

    cookie是什么 由于http是无状态的协议,一旦客户端和服务器的数据交换完毕,就会断开连接,再次请求,会重新连接,这就说明服务器单从网络连接上是没有办法知道用户身份的.怎么办呢?那就给每次新的用户 ...

  9. zencart前台小语种后台英文 导入批量表 前后台不显示产品的问题

    admin\includes\init_includes\init_languages.php 前台小语种后台英文导致批量表导入后,前后台不显示产品的问题将红色部分修改成前台语言对应的值,前台语言对应 ...

  10. 树上独立集数量 树型DP

    题目描述: 对于一棵树,独立集是指两两互不相邻的节点构成的集合.例如,图1有5个不同的独立集(1个双点集合.3个单点集合.1个空集),图2有14个不同的独立集,图3有5536个不同的独立集.  输入: ...