# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337) #for reproducibility再现性
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential#按层
from keras.layers import Dense, Activation#全连接层
import matplotlib.pyplot as plt
from keras.optimizers import RMSprop

从mnist下载手写数字图片数据集,图片为28*28,将每个像素的颜色(0到255)改为(0倒1),将标签y变为10个长度,若为1,则在1处为1,剩下的都标为0。

#dowmload the mnisst the path '~/.keras/datasets/' if it is the first time to be called
#x shape (60000 28*28),y shape(10000,)
(x_train,y_train),(x_test,y_test) = mnist.load_data()#0-9的图片数据集 #data pre-processing
x_train = x_train.reshape(x_train.shape[0],-1)/255 #normalize 到【0,1】
x_test = x_test.reshape(x_test.shape[0],-1)/255
y_train = np_utils.to_categorical(y_train, num_classes=10) #把标签变为10个长度,若为1,则在1处为1,剩下的都标为0
y_test = np_utils.to_categorical(y_test,num_classes=10)

 

搭建神经网络,Activation为激活函数。由于第一个Dense传出32.所以第二个的Dense默认传进32,不用特意设置。

#Another way to build neural net
model = Sequential([
Dense(32,input_dim=784),#传出32
Activation('relu'),
Dense(10),
Activation('softmax')
]) #Another way to define optimizer
rmsprop = RMSprop(lr=0.001,rho=0.9,epsilon=1e-08,decay=0.0) # We add metrics to get more results you want to see
model.compile( #编译
optimizer = rmsprop,
loss = 'categorical_crossentropy',
metrics=['accuracy'], #在更新时同时计算一下accuracy
)

 

训练和测试

print("Training~~~~~~~~")
#Another way to train the model
model.fit(x_train,y_train, epochs=2, batch_size=32) #训练2大批,每批32个 print("\nTesting~~~~~~~~~~")
#Evalute the model with the metrics we define earlier
loss,accuracy = model.evaluate(x_test,y_test) print('test loss:',loss)
print('test accuracy:', accuracy)

 

全代码:

# -*- coding: utf-8 -*-
import numpy as np
np.random.seed(1337) #for reproducibility再现性
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential#按层
from keras.layers import Dense, Activation#全连接层
import matplotlib.pyplot as plt
from keras.optimizers import RMSprop #dowmload the mnisst the path '~/.keras/datasets/' if it is the first time to be called
#x shape (60000 28*28),y shape(10000,)
(x_train,y_train),(x_test,y_test) = mnist.load_data()#0-9的图片数据集 #data pre-processing
x_train = x_train.reshape(x_train.shape[0],-1)/255 #normalize 到【0,1】
x_test = x_test.reshape(x_test.shape[0],-1)/255
y_train = np_utils.to_categorical(y_train, num_classes=10) #把标签变为10个长度,若为1,则在1处为1,剩下的都标为0
y_test = np_utils.to_categorical(y_test,num_classes=10) #Another way to build neural net
model = Sequential([
Dense(32,input_dim=784),#传出32
Activation('relu'),
Dense(10),
Activation('softmax')
]) #Another way to define optimizer
rmsprop = RMSprop(lr=0.001,rho=0.9,epsilon=1e-08,decay=0.0) # We add metrics to get more results you want to see
model.compile( #编译
optimizer = rmsprop,
loss = 'categorical_crossentropy',
metrics=['accuracy'], #在更新时同时计算一下accuracy
) print("Training~~~~~~~~")
#Another way to train the model
model.fit(x_train,y_train, epochs=2, batch_size=32) #训练2大批,每批32个 print("\nTesting~~~~~~~~~~")
#Evalute the model with the metrics we define earlier
loss,accuracy = model.evaluate(x_test,y_test) print('test loss:',loss)
print('test accuracy:', accuracy)

结果为:

 

用Keras搭建神经网络 简单模版(二)——Classifier分类(手写数字识别)的更多相关文章

  1. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  2. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  3. 用Keras搭建神经网络 简单模版(六)——Autoencoder 自编码

    import numpy as np np.random.seed(1337) from keras.datasets import mnist from keras.models import Mo ...

  4. 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) from keras.datasets import mnist fro ...

  5. 用Keras搭建神经网络 简单模版(一)——Regressor 回归

    首先需要下载Keras,可以看到我用的是TensorFlow 的backend 自己构建虚拟数据,x是-1到1之间的数,y为0.5*x+2,可视化出来 # -*- coding: utf-8 -*- ...

  6. 用Keras搭建神经网络 简单模版(五)——RNN LSTM Regressor 循环神经网络

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) import matplotlib.pyplot as plt from ...

  7. 吴裕雄 python 神经网络TensorFlow实现LeNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf tf.reset_default_graph() # 配置神经网络的参数 INPUT_NODE = 784 OUTPUT_NODE = 10 IMAGE ...

  8. 吴裕雄 python 神经网络——TensorFlow实现AlexNet模型处理手写数字识别MNIST数据集

    import tensorflow as tf # 输入数据 from tensorflow.examples.tutorials.mnist import input_data mnist = in ...

  9. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

随机推荐

  1. Flask初级(五)flash在模板中使用继承,模板的模板

    Project name :Flask_Plan templates:templates static:static 继续上一篇文章. 我们不希望每个页面都写一遍引入js,css,导航条……………… ...

  2. L1-007 念数字

    输入一个整数,输出每个数字对应的拼音.当整数为负数时,先输出fu字.十个数字对应的拼音如下: 0: ling 1: yi 2: er 3: san 4: si 5: wu 6: liu 7: qi 8 ...

  3. 玩转X-CTR100 l STM32F4 l HMC5983/HMC5883L三轴磁力计传感器

    我造轮子,你造车,创客一起造起来!塔克创新资讯[塔克社区 www.xtark.cn ][塔克博客 www.cnblogs.com/xtark/ ]      本文介绍X-CTR100控制器 扩展HMC ...

  4. CUDA ---- CUDA库简介

    CUDA Libraries简介 上图是CUDA 库的位置,本文简要介绍cuSPARSE.cuBLAS.cuFFT和cuRAND,之后会介绍OpenACC. cuSPARSE线性代数库,主要针对稀疏矩 ...

  5. 20165210 Java第四周学习总结

    20165210 Java第四周学习总结 教材学习内容 第五章学习总结 子类与父类: 子类: class 子类名 extends 父类名 { ... } 类的树形结构 子类的继承性: 子类和父类在同一 ...

  6. <?xml version="1.0" encoding="UTF-8" standalone="no"?>

    XML standalone 定义了外部定义的 DTD 文件的存在性. standalone element 有效值是 yes 和 no. 如下是一个例子: <?xml version=&quo ...

  7. vue 兼容360及safari的方法

    1. npm install --save-dev babel-polyfill 2.  main.js 中 import "babel-polyfill";        或者: ...

  8. OC基础:类的扩展.协议 分类: ios学习 OC 2015-06-22 19:22 34人阅读 评论(0) 收藏

    //再设计一个类的时候,有些方法需要对外公开(接口),有些仅供内部使用. 类的扩展:为类添加新的特征(属性)或者方法 对已知类: 1.直接添加 2.继承(在其子类中添加实例变量和方法) 3.使用ext ...

  9. CTC+pytorch编译配置warp-CTC

    CTC CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音和文本识别系统.CTC论文地址:http://www.cs ...

  10. 利用WebApplicationInitializer配置SpringMVC取代web.xml

    对于Spring MVC的DispatcherServlet配置方式,传统的是基于XML方式的,也就是官方说明的XML-based,如下: <servlet> <servlet-na ...