参考:林大贵.TensorFlow+Keras深度学习人工智能实践应用[M].北京:清华大学出版社,2018.

首先在命令行中写入 activate tensorflow和jupyter notebook,运行如下代码。当然,事先准备好MNIST数据集。

 # coding: utf-8

 # In[4]:

 from keras.datasets import mnist
from keras.utils import np_utils
import numpy as np
np.random.seed(10) # In[5]: (x_train,y_train),(x_test,y_test)=mnist.load_data() # In[6]: x_train4d = x_train.reshape(x_train.shape[0],28,28,1).astype('float32')
x_test4d = x_test.reshape(x_test.shape[0],28,28,1).astype('float32') # In[7]: x_train4d_normalize = x_train4d/255
x_test4d_normalize = x_test4d/255 # In[8]: y_train_oneHot = np_utils.to_categorical(y_train)
y_test_oneHot = np_utils.to_categorical(y_test) # In[9]: from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D # In[10]: model = Sequential() # In[11]: model.add(Conv2D(filters = 16,
kernel_size = (5,5),
padding = 'same',
input_shape = (28,28,1),
activation = ('relu')
)) # In[12]: model.add(MaxPooling2D(pool_size=(2,2))) # In[13]: model.add(Conv2D(filters = 36,
kernel_size = (5,5),
padding = 'same',
activation = 'relu')) # In[14]: model.add(MaxPooling2D(pool_size=(2,2))) # In[15]: model.add(Dropout(0,255)) # In[16]: model.add(Flatten()) # In[17]: model.add(Dense(128,activation = 'relu')) # In[18]: model.add(Dropout(0.5)) # In[19]: model.add(Dense(10,activation = 'sigmoid')) # In[20]: print(model.summary()) # In[21]: model.compile(loss='categorical_crossentropy',
optimizer = 'adam',
metrics = ['accuracy']) # In[22]: train_history = model.fit(x = x_train4d_normalize,
y = y_train_oneHot,
validation_split = 0.2,
epochs = 10,
batch_size = 300,
verbose = 2) # In[23]: import matplotlib.pyplot as plt
def show_train_history(train_history,train,validation):
plt.plot(train_history.history[train])
plt.plot(train_history.history[validation])
plt.title('Train_History')
plt.ylabel(train)
plt.xlabel('Epoch')
plt.legend(['train','validation'], loc = 'upper left')
plt.show() # In[24]: show_train_history(train_history,'acc','val_acc') # In[25]: show_train_history(train_history,'loss','val_loss') # In[26]: scores = model.evaluate(x_test4d_normalize,y_test_oneHot)
scores[1] # In[27]: def plot_image_labels_prediction(images,labels,prediction,idx,num=10):
fig = plt.gcf()
fig.set_size_inches(12,24)
if num>50 : num = 50
for i in range(0,num):
ax = plt.subplot(10,5,1+i)
ax.imshow(images[idx],cmap='binary')
title = "lable="+str(labels[idx])
if len(prediction)>0:
title+=",predict="+str(prediction[idx])
ax.set_title(title,fontsize=10)
ax.set_xticks([]);ax.set_yticks([])
idx+=1
plt.show() # In[28]: prediction = model.predict_classes(x_test4d_normalize) # In[29]: plot_image_labels_prediction(x_test,
y_test,
prediction,
0,
50) # In[30]: import pandas as pd # In[31]: pd.crosstab(y_test,
prediction,
rownames=['label'],
colnames=['predict']) # In[ ]:

卷积神经网络简介:https://www.cnblogs.com/bai2018/p/10413889.html

产生如下神经网络:

其中加入的dropout层用于避免过度拟合。在每次迭代时,随机舍弃一部分的训练样本。

训练效果较好。

林大贵.TensorFlow+Keras深度学习人工智能实践应用[M].北京:清华大学出版社,2018.

keras框架的CNN手写数字识别MNIST的更多相关文章

  1. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  2. TensorFlow 之 手写数字识别MNIST

    官方文档: MNIST For ML Beginners - https://www.tensorflow.org/get_started/mnist/beginners Deep MNIST for ...

  3. Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别 #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层 #基于Keras 2.1.1 Tensorflow ...

  4. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  5. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  6. kaggle 实战 (2): CNN 手写数字识别

    文章目录 Tensorflow 官方示例 CNN 提交结果 Tensorflow 官方示例 import tensorflow as tf mnist = tf.keras.datasets.mnis ...

  7. pytorch CNN 手写数字识别

    一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...

  8. keras基于卷积网络手写数字识别

    import time import keras from keras.utils import np_utils start = time.time() (x_train, y_train), (x ...

  9. Tensorflow手写数字识别---MNIST

    MNIST数据集:包含数字0-9的灰度图, 图片size为28x28.训练样本:55000,测试样本:10000,验证集:5000

随机推荐

  1. 四 sys模块

    1 sys.argv 命令行参数List,第一个元素是程序本身路径 2 sys.exit(n) 退出程序,正常退出时exit(0) 3 sys.version 获取Python解释程序的版本信息 4 ...

  2. Math.random控制随机数范围

    let minNum= parseInt(Math.random()*7) + 1 let maxNum= parseInt(Math.random()*83) + 1 生成7~83的随机整数

  3. as3.0去除空格

    var str:String="是 我们 呀CuPlay er.com网站" function trim(string:String):String { return string ...

  4. 第二章 向量(e)起泡排序

  5. perl-我的第一个程序

    1.问题描述: 总共90位长度的位流数据,其中只有5位的数据为1,其余位全部为0.统计好多组5位的简化数据(每一位之间空格隔开,每一组一行),将其扩展到90位. #!D:/EDA/Perl/bin $ ...

  6. 解决 MySQL 比如我要拉取一个消息表中用户id为1的前10条最新数据

    我们都知道,各种主流的社交应用或者阅读应用,基本都有列表类视图,并且都有滑到底部加载更多这一功能, 对应后端就是分页拉取数据.好处不言而喻,一般来说,这些数据项都是按时间倒序排列的,用户只关心最新的动 ...

  7. SY-SUBRC

    一般是对read table和select语句使用. loop at g_it_data where level < <wa_data>-level and seq < < ...

  8. jira与svn的调研

    centos7.3 + jira7.8.3 + svn 1.7.14 一.环境搭建 1.centos7.3环境搭建:(1)下载centos7.3的.iso文件 http://mirrors.aliyu ...

  9. Mac下IntelliJ的Git、GitHub配置及使用

    1.git简介 Git是目前流行的分布式版本管理系统.它拥有两套版本库,本地库和远程库,在不进行合并和删除之类的操作时这两套版本库互不影响.也因此其近乎所有的操作都是本地执行,所以在断网的情况下任然可 ...

  10. ApplicationContextAware的使用

    一.这个接口有什么用? 当一个类实现了这个接口(ApplicationContextAware)之后,这个类就可以方便获得ApplicationContext中的所有bean.换句话说,就是这个类可以 ...