MNIST 数据集
包含60 000 张训练图像和10 000 张测试图像,由美国国家标准与技术研究院(National Institute of Standards and Technology,即MNIST 中
的NIST)在20 世纪80 年代收集得到。
 
类和标签
在机器学习中,分类问题中的某个类别叫作类(class)。数据点叫作样本(sample)。某
个样本对应的类叫作标签(label)。
 
MNIST 数据集预先加载在Keras 库中,其中包括4 个Numpy 数组。
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
train_images 和train_labels 组成了训练集(training set),模型将从这些数据中进行
学习。然后在测试集(test set,即test_images 和test_labels)上对模型进行测试。
 
图像被编码为Numpy 数组,而标签是数字数组,取值范围为0~9。图像和标签一一对应。
 
我们来看一下训练数据:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
 
 
测试数据:
>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
 
 
神经网络架构
 
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
 
本例中的网络包含2 个Dense 层,它们是密集连接(也叫全连接)的神经层。第二层(也
是最后一层)是一个10 路softmax 层,它将返回一个由10 个概率值(总和为1)组成的数组。
每个概率值表示当前数字图像属于10 个数字类别中某一个的概率。
 
要想训练网络,我们还需要选择编译(compile)步骤的三个参数。
损失函数(loss function):网络如何衡量在训练数据上的性能,即网络如何朝着正确的
方向前进。
优化器(optimizer):基于训练数据和损失函数来更新网络的机制。
在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所
占的比例。
 
编译步骤
network.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
 
在开始训练之前,我们将对数据进行预处理,将其变换为网络要求的形状,并缩放到所
有值都在[0, 1] 区间。比如,之前训练图像保存在一个uint8 类型的数组中,其形状为
(60000, 28, 28),取值区间为[0, 255]。我们需要将其变换为一个float32 数组,其形
状为(60000, 28 * 28),取值范围为0~1。
 
准备图像数据
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
 
 
 
准备标签
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
 
 
 
开始训练网络
>>> network.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [=============================] - 9s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=======================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692
 
 
 
检查模型在测试集上的性能
>>> test_loss, test_acc = network.evaluate(test_images, test_labels)
>>> print('test_acc:', test_acc)
test_acc: 0.9785
 
 
 
 
 
 
 
 
 
 
 
 
 

Python深度学习读书笔记-2.初识神经网络的更多相关文章

  1. Python深度学习读书笔记-3.神经网络的数据表示

    标量(0D 张量) 仅包含一个数字的张量叫作标量(scalar,也叫标量张量.零维张量.0D 张量).在Numpy 中,一个float32 或float64 的数字就是一个标量张量(或标量数组).你可 ...

  2. Python深度学习读书笔记-4.神经网络入门

    神经网络剖析   训练神经网络主要围绕以下四个方面: 层,多个层组合成网络(或模型) 输入数据和相应的目标 损失函数,即用于学习的反馈信号 优化器,决定学习过程如何进行   如图 3-1 所示:多个层 ...

  3. Python深度学习读书笔记-1.什么是深度学习

    人工智能 什么是人工智能.机器学习与深度学习(见图1-1)?这三者之间有什么关系?

  4. Python深度学习读书笔记-5.Keras 简介

    Keras 重要特性 相同的代码可以在 CPU 或 GPU 上无缝切换运行. 具有用户友好的 API,便于快速开发深度学习模型的原型. 内置支持卷积网络(用于计算机视觉).循环网络(用于序列处理)以及 ...

  5. Python深度学习读书笔记-6.二分类问题

    电影评论分类:二分类问题   加载 IMDB 数据集 from keras.datasets import imdb (train_data, train_labels), (test_data, t ...

  6. 深度学习读书笔记之RBM(限制波尔兹曼机)

    深度学习读书笔记之RBM 声明: 1)看到其他博客如@zouxy09都有个声明,老衲也抄袭一下这个东西 2)该博文是整理自网上很大牛和机器学习专家所无私奉献的资料的.具体引用的资料请看参考文献.具体的 ...

  7. [1天搞懂深度学习] 读书笔记 lecture I:Introduction of deep learning

    - 通常机器学习,目的是,找到一个函数,针对任何输入:语音,图片,文字,都能够自动输出正确的结果. - 而我们可以弄一个函数集合,这个集合针对同一个猫的图片的输入,可能有多种输出,比如猫,狗,猴子等, ...

  8. 深度学习课程笔记(一)CNN 卷积神经网络

    深度学习课程笔记(一)CNN 解析篇 相关资料来自:http://speech.ee.ntu.edu.tw/~tlkagk/courses_ML17_2.html 首先提到 Why CNN for I ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. 将div生成图片并下载下来

    //文件需要引入html2canvas.js.jquery.js function downLoadImg(){ var element = $(".orgchart"); // ...

  2. Java 计算两点间的全部路径(一)

    算法要求: 在一个无向连通图中求出两个给定点之间的所有路径: 在所得路径上不能含有环路或重复的点: 算法思想描述: 整理节点间的关系,为每个节点建立一个集合,该集合中保存所有与该节点直接相连的节点(不 ...

  3. vue 列表渲染 v-for

    1.数组列表       v-for 块中,我们拥有对父作用域属性的完全访问权限.v-for 还支持一个可选的第二个参数为当前项的索引 1.1 普通渲染       v-for="item ...

  4. 9- 基于6U VPX的 XC7VX690T+C6678的双FMC接口雷达通信处理板 C6678板卡

    基于6U VPX的 XC7VX690T+C6678的双FMC接口雷达通信处理板   一.板卡概述 高性能VPX信号处理板基于标准6U VPX架构,提供两个标准FMC插槽,适用于电子对抗或雷达信号等领域 ...

  5. SpringCloud系列(一):Eureka 服务注册与服务发现

    上一篇,我们介绍了服务注册中心,光有服务注册中心没有用,我们得发服务注册上去,得从它那边获取服务.下面我们注册一个服务到服务注册中心上去. 我们创建一个 hello-service 的 spring ...

  6. 开发规范总结-java代码

    java8新特性: 开发的时候适当用一些新特性的语法,可以使代码更简洁.譬如List根据某个属性转map.stream.函数式编程.lambda表达式 有一种场景:两个list一个转map 两个lis ...

  7. mysql时间函数操作

    Mysql时间转换函数 https://blog.csdn.net/w_qqqqq/article/details/88863269 mysql时间日期函数 https://www.cnblogs.c ...

  8. git shell 右键启动注册表

    Windows Registry Editor Version 5.00 [HKEY_CLASSES_ROOT\Directory\Background\shell\Git Bash Here] ): ...

  9. 微信小程序-饮食日志_开发日志

    针对假期作业为父母或者身边的人做一款“小软件”这个课题,由于对 android 开发不熟悉 ,所以决定做一款微信小程序. 项目名称:饮食管理日志 目的:身边的人群对摄入食物热量及消耗不清楚,对健康需求 ...

  10. springboot项目作为其他项目子项目

    <?xml version="1.0"?> <project xsi:schemaLocation="http://maven.apache.org/P ...