提示:建议先看day36-38的内容

TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。

TensorFlow 最初由Google大脑小组(隶属于Google机器智能研究机构)的研究员和工程师们开发出来,用于机器学习和深度神经网络方面的研究,但这个系统的通用性使其也可广泛用于其他计算领域。

1、安装库tensorflow

有些教程会推荐安装nightly,它适用于在一个全新的环境下进行TensorFlow的安装,默认会把需要依赖的库也一起装上。我使用的是anaconda,本文我们安装的是纯净版的tensorflow,非常简单,只需打开Prompt:

pip install tensorflow

安装成功

导入成功

#导入keras
from tensorflow import keras
#导入tensorflow
import tensorflow as tf

注:有些教程中导入Keras用的是import tensorflow.keras as keras会提示No module named 'tensorflow.keras'

2、导入mnist数据

在上篇文章中我们已经提到过 MNIST 了,用有趣的方式解释梯度下降算法

它是一个收录了许多 28 x 28 像素手写数字图片(以灰度值矩阵存储)及其对应的数字的数据集,可以把它理解成下图这个样子:

由于众所周知的原因,Keras自带minist数据集下载会报错,无法下载。博客园崔小秋同学给出了很好的解决方法:

1、找到本地keras目录下的mnist.py文件,通常在这个目录下。

2、下载mnist.npz文件到本地,下载链接如下。

https://pan.baidu.com/s/1C3c2Vn-_616GqeEn7hQQ2Q

3、修改mnist.py文件为以下内容,并保存

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from ..utils.data_utils import get_file
import numpy as np def load_data(path='mnist.npz'):
    """Loads the MNIST dataset. # Arguments
        path: path where to cache the dataset locally
            (relative to ~/.keras/datasets).     # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    path = 'E:/Data/Mnist/mnist.npz' #此处的path为你刚刚存放mnist.py的目录。注意斜杠
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

看一下数据

mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
print(x_train[0].shape)

(28, 28)

import matplotlib.pyplot as plt
plt.imshow(x_train[0],cmap=plt.cm.binary)
plt.show()

print(y_train[0])

5

对数据进行归一化处理

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)

再看一下,图像的像素值被限定在了 [0,1]

plt.imshow(x_train[0],cmap=plt.cm.binary)
plt.show()

3 构建与训练模型我们使用 Keras 的 Sequential 模型(顺序模型),顺序模型是多个网络层的线性堆叠。本文旨在介绍TensorFlow 及Keras用法,不再展开,有兴趣的同学们学习其具体用法,可以参考Keras文档:

https://keras.io/zh/getting-started/sequential-model-guide/

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3)

我们构建出的模型大概是这个样子的,区别是我们的隐藏层有128个单元

在训练的过程中,我们会发现损失值(loss)在降低,而准确度(accuracy)在提高,最后达到了一个令人满意的程度。

Epoch 1/3

60000/60000  - 8s 127us/step - loss: 0.2677 - acc: 0.9211Epoch 2/3

60000/60000  - 8s 130us/step - loss: 0.1106 - acc: 0.9655Epoch 3/3

60000/60000  - 8s 136us/step - loss: 0.0751 - acc: 0.9764

4 测试模型

val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)

10000/10000  - 0s 45us/step0.0916121033909265

0.9713

损失和准确度看起来还凑合,尝试识别训练集

predictions = model.predict(x_test)
print(predictions)

用 argmax 解析一下(就是找出最大数对应的索引,即为识别出的数字)

import numpy as np
print(np.argmax(predictions[0]))

7

plt.imshow(x_test[0],cmap=plt.cm.binary)
plt.show()

OK,模型可以识别数字了。

5、保存模型

主要用于模型的存储和恢复。

model.save('epic_num_reader.model')
# 加载保存的模型
new_model = tf.keras.models.load_model('epic_num_reader.model')
# 测试保存的模型
predictions = new_model.predict(x_test)print(np.argmax(predictions[0]))

看到这里的都是真爱,另推荐一个Keras教程

Colab超火的Keras/TPU深度学习免费实战,有点Python基础就能看懂的快速课程

参考:

https://www.cnblogs.com/shinny/p/9283372.html

https://www.cnblogs.com/wj-1314/p/9579490.html

https://github.com/MLEveryday/100-Days-Of-ML-Code/blob/master/Code/Day 39.ipynb

100天搞定机器学习|day39 Tensorflow Keras手写数字识别的更多相关文章

  1. 100天搞定机器学习|day40-42 Tensorflow Keras识别猫狗

    100天搞定机器学习|1-38天 100天搞定机器学习|day39 Tensorflow Keras手写数字识别 前文我们用keras的Sequential 模型实现mnist手写数字识别,准确率0. ...

  2. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  3. 【机器学习】李宏毅机器学习-Keras-Demo-神经网络手写数字识别与调参

    参考: 原视频:李宏毅机器学习-Keras-Demo 调参博文1:深度学习入门实践_十行搭建手写数字识别神经网络 调参博文2:手写数字识别---demo(有小错误) 代码链接: 编程环境: 操作系统: ...

  4. OpenCV+TensorFlow图片手写数字识别(附源码)

    初次接触TensorFlow,而手写数字训练识别是其最基本的入门教程,网上关于训练的教程很多,但是模型的测试大多都是官方提供的一些素材,能不能自己随便写一串数字让机器识别出来呢?纸上得来终觉浅,带着这 ...

  5. python-积卷神经网络全面理解-tensorflow实现手写数字识别

    首先,关于神经网络,其实是一个结合很多知识点的一个算法,关于cnn(积卷神经网络)大家需要了解: 下面给出我之前总结的这两个知识点(基于吴恩达的机器学习) 代价函数: 代价函数 代价函数(Cost F ...

  6. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

  7. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  8. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  9. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

随机推荐

  1. HBase学习笔记一

    HBase简介 HBase概念 HBase的原型是谷歌的Bigtable论文 HBase是一个高可靠性.高性能.面向列.可伸缩的分布式存储系统,利用HBase技术可在廉价PC上搭建起大规模结构化存储集 ...

  2. Python 爬虫:煎蛋网妹子图

    使用 Headless Chrome 替代了 PhatomJS. 图片保存到指定文件夹中. import requests from bs4 import BeautifulSoup from sel ...

  3. Dapper学习笔记

    听说有个轻量化的orm Dapper,我就去了解下.试着对Sql Server和Mysql进行增删改查,体验不错.它不如EF臃肿,也比一般的封装灵活,比如我们封装了一个映射类.利用反射,在Execut ...

  4. Linux中的保护机制

    Linux中的保护机制 在编写漏洞利用代码的时候,需要特别注意目标进程是否开启了NX.PIE等机制,例如存在NX的话就不能直接执行栈上的数据,存在PIE 的话各个系统调用的地址就是随机化的. 一:ca ...

  5. UVA663 Sorting Slides(烦人的幻灯片)

    UVA663 Sorting Slides(烦人的幻灯片) 第一次做到这么玄学的题,在<信息学奥赛一本通>拓扑排序一章找到这个习题(却发现标程都是错的),结果用二分图匹配做了出来 蒟蒻感觉 ...

  6. 《VR入门系列教程》之11---基本几何-材质-光照

    网格.多边形.顶点     绘制3D图形有许多方法,用的最多的是用网格绘制.一个网格由一个或多个多边形组成,这些多边形的顶点都是三维空间中的点,它们具有x.y.z三个坐标值.网格中通常采用三角形和四边 ...

  7. 实现简单的 IOC 和 AOP

    1 简单的 IOC 1.1 简单的 IOC 容器实现的步骤 加载 xml 配置文件,遍历其中的标签 获取标签中的 id 和 class 属性,加载 class 属性对应的类,并创建 bean 遍历标签 ...

  8. 如何挑选node docker镜像

    如何挑选node docker镜像 在使用Jenkins构建前端项目的时候遇到一点问题: node的版本问题. 由于可能编译的项目历史不同,所依赖的node版本也各有千秋,直接把所有项目都升级到最新的 ...

  9. Spring Boot 面试的十个问题

    用下面这些常见的面试问题为下一次 Spring Boot 面试做准备. 在本文中,我们将讨论 Spring boot 中最常见的10个面试问题.现在,在就业市场上,这些问题有点棘手,而且趋势日益严重. ...

  10. web前端开发-博客目录

    web前端开发是一个新的领域,知识连接范围广,处于设计与后端数据交互的桥梁,并且现在很多web前端相关语言标准,框架库都在高速发展.在学习过程中也常常处于烦躁与迷茫,有时候一直在想如何能够使自己更加系 ...