该案例主要目的是为了熟悉Keras基本用法,以及了解DNN基本流程。

示例代码:

import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.datasets import mnist
from keras.layers import Dense
from keras.utils.np_utils import to_categorical #加载数据,训练60000条,测试10000条,X_train.shape=(60000,28,28)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
#特征扁平化,缩放,标签独热
X_train_flat = X_train.reshape(60000, 28*28)
X_test_flat = X_test.reshape(10000, 28*28)
X_train_norm = X_train_flat / 255
X_test_norm = X_test_flat / 255
y_train_onehot = to_categorical(y_train, 10) #shape为(60000,10)
y_test_onehot = to_categorical(y_test, 10) #shape为(10000,10)
#构建模型
model = Sequential()
model.add(Dense(100, activation='relu', input_shape=(28*28,)))
model.add(Dense(50, activation='relu'))
model.add(Dense(10, activation='softmax'))
#模型配置和训练
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_train_norm, y_train_onehot, epochs=5, batch_size=32, verbose=1)
print("训练完毕!")

训练结果为:

继续在测试集上评估模型。

#测试集上评估表现
score = model.evaluate(X_test_norm, y_test_onehot)
print("在测试集上评估完毕!")
print("在测试集上表现:Loss={:.4f}, Accuracy={:.4f}".format(score[0], score[1]))
#在测试集上预测
y_pred_class = model.predict_classes(X_test_norm) #shape=(10000,)
print("预测完毕!")
#查看预测效果,随机查看多张图片
idx = 22 #随机设置
count = 0
fig1 = plt.figure(figsize = (10,7))
for i in range(3):
for j in range(5):
count += 1
ax = plt.subplot(3,5,count)
plt.imshow(X_test[idx+count])
ax.set_title("predict:{} label:{}".format(y_pred_class[idx+count],
y_test[idx+count]))
fig1.savefig('images/look.jpg')

运行结果为:


为了了解模型预测错误原因,可查看预测错误的图片。

#找出错误所在
X_test_err = X_test[y_test!=y_pred_class] #(num_errors, 28, 28)
y_test_err = y_test[y_test!=y_pred_class] #(num_errors,)
y_pred_class_err = y_pred_class[y_test!=y_pred_class]
#连续查看多张错误图片
idx = -1
count = 0
fig2 = plt.figure(figsize = (10,7))
for i in range(3):
for j in range(5):
count += 1
ax = plt.subplot(3,5,count)
plt.imshow(X_test_err[idx+count])
ax.set_title("predict:{} label:{}".format(y_pred_class_err[idx+count],
y_test_err[idx+count]))
fig2.savefig('images/errors.jpg')

运行结果为:

深度学习练手项目——DNN识别手写数字的更多相关文章

  1. 深度学习面试题12:LeNet(手写数字识别)

    目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...

  2. 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%

    源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...

  3. keras框架下的深度学习(一)手写体识别

    这个系列文章主要记录使用keras框架来搭建深度学习模型的学习过程,其中有一些自己的想法和体会,主要学习的书籍是:Deep Learning with Python,使用的IDE是pycharm. 在 ...

  4. 分别基于TensorFlow、PyTorch、Keras的深度学习动手练习项目

    ×下面资源个人全都跑了一遍,不会出现仅是字符而无法运行的状况,运行环境: Geoffrey Hinton在多次访谈中讲到深度学习研究人员不要仅仅只停留在理论上,要多编程.个人在学习中也体会到单单的看理 ...

  5. Tensorflow 实战Google深度学习框架 第五章 5.2.1Minister数字识别 源代码

    import os import tab import tensorflow as tf print "tensorflow 5.2 " from tensorflow.examp ...

  6. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  7. 【机器学习】李宏毅机器学习-Keras-Demo-神经网络手写数字识别与调参

    参考: 原视频:李宏毅机器学习-Keras-Demo 调参博文1:深度学习入门实践_十行搭建手写数字识别神经网络 调参博文2:手写数字识别---demo(有小错误) 代码链接: 编程环境: 操作系统: ...

  8. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  9. [深度学习] 使用Darknet YOLO 模型破解中文验证码点击识别

    内容 背景 准备 实践 结果 总结 引用 背景 老规矩,先上代码吧 代码所在: https://github.com/BruceDone/darknet_demo 最近在做深度学习相关的项目的时候,了 ...

随机推荐

  1. Win7 ODBC驱动 Excel (转)

    “控制面板-管理工具-数据源(ODBC)”,打开“ODBC数据源管理器”窗口,然后“添加”,打开“创建新数据源”的窗口,最后选择Microsoft Access Driver(*.mdb)选项,往后等 ...

  2. 常用生物信息 ID 及转换方法

    众多不同的数据库所采用的对 Gene 和 Protein 编号的 ID 也是不同的, 所以在使用不同数据库数据的时候需要进行 ID 转换. 常用数据库 ID ID 示例 ID 来源 ENSG00000 ...

  3. vue 项目的运行与 打包

    1.vue init webpack 2.npm install axios 3.npm run dev  运行项目 4.npm run build 打包项目 会生成一个dist 文件夹,我们只需要把 ...

  4. Session过期,如何跳出iframe框架页的问题

    跳出框架页,实际上是更改父页面地址.那么更改父页面地址很简单即: window.parent.location='/Login/loginindex'; 这里说session过期,那么浏览器端的任何请 ...

  5. boost number handling

    Boost.Integer defines specialized for integers. 1. types for integers with number of bits #include & ...

  6. 【Tomcat】Tomcat系统架构

    一.Tomcat顶层架构 先上一张Tomcat的顶层结构图(图A),如下: Tomcat中最顶层的容器是Server,代表着整个服务器,从上图中可以看出,一个Server可以包含至少一个Service ...

  7. Linux基础优化(二)

    Linux基础优化(二) 一操作系统字符优化 避免出现中文乱码,UTF-8支持中文GBK-Xx支持中文 (一)查看默认编码 [root@centos7 ~]# echo $LANG en_US.UTF ...

  8. linux系统下tomcat应用开机自启动 配置

    linux系统下tomcat应用开机自启动 配置 相对简单的方式是将tomcat添加为系统服务第一步  复制文件将 $Tomcat_Home/bin目录下的 catalina.sh脚本文件复制到目录/ ...

  9. python中常犯错误之字符串列表下标问题

    下标用得是中括号[] 不是小括号() 1,python中的小括号( ):代表tuple元组数据类型,元组是一种不可变序列.创建方法很简单,大多时候都是用小括号括起来的. 2.python中的中括号[ ...

  10. 登录成功后如何利用cookie保持登录状态

    Cookie是一种服务器发送给浏览器的一组数据,用于浏览器跟踪用户,并访问服务器时保持登录状态等功能. 通常用户登录的时候,服务器根据用户名和密码在服务器数据库中校验该用户是否正确,校验正确后则可以根 ...