莫烦大大keras学习Mnist识别(3)-----CNN
一、步骤:
导入模块以及读取数据
数据预处理
构建模型
编译模型
训练模型
测试
二、代码:
导入模块以及读取数据
#导包
import numpy as np
np.random.seed(1337)
# from keras.datasets import mnist
from keras.utils import np_utils # 主要采用这个模块下的to_categorical函数,将该函数转成one_hot向量
from keras.models import Sequential #keras的模型模块
from keras.layers import Dense , Activation , Convolution2D, MaxPooling2D, Flatten #keras的层模块
from keras.optimizers import Adam #keras的优化器 #读取数据,因为本地已经下载好数据在绝对路径:E:\jupyter\TensorFlow\MNIST_data下,直接采用TensorFlow来读取
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('E:\jupyter\TensorFlow\MNIST_data',one_hot = True) X_train = mnist.train.images
Y_train = mnist.train.labels
X_test = mnist.test.images
Y_test = mnist.test.labels
2、数据预处理
x原本的shape为(55000,784),55000表示样本数量,784表示一个图像样本拉成一个向量的大小,故要将其转成28*28这种长×宽的形式。(-1,1,28,28)中的-1是之后batch_size的大小,即一次取batch大小的样本来训练,1,28,28表示高为1,长为28,宽为28。
#数据预处理
X_train = X_train.reshape(-1,1,28,28)
X_test = X_test.reshape(-1,1,28,28)
y_train = np_utils.to_categorical(Y_train,num_classes = 10)#to_categorical将标签转化成ont-hot
y_test = np_utils.to_categorical(Y_test,num_classes = 10)
3、构建模型
2个卷积层【包括卷积+激活relu+最大池化】+2个全连接层
#模型构建
model = Sequential() #建立一个序列模型 #在这个模型首层添加一个卷积层,一个卷积过滤器大小为5*5,32个过滤器,采用的padding模式是same,即通过补0使输入输出大小一下。首层要加一个输入大小(1,28,28)
model.add(Convolution2D(
nb_filter = 32,
nb_row = 5,
nb_col = 5,
border_mode = 'same',
input_shape = (1,28,28)
)) #接着加一个激活层
model.add(Activation('relu')) #接着加一个最大池化层,pool大小为(2,2),strides步长长移动2,宽移动2。padding采用same模式
model.add(MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same',
)) #卷积层2
model.add(Convolution2D(64,5,5,border_mode = 'same')) #激活层2
model.add(Activation('relu')) #池化层2
model.add(MaxPooling2D(pool_size = (2,2),border_mode = 'same')) #进行全连接之前将矩阵展开成一个长向量
model.add(Flatten()) #全连接层1,大小有1024个参数
model.add(Dense(1024)) #激活层
model.add(Activation('relu')) #全连接层2,大小为10
model.add(Dense(10)) #输出层加一个softmax处理
model.add(Activation('softmax'))
4、编译模型:model.compile
采用model.compile来编译,函数内参数说明优化器optimizer、损失函数loss、评价标准metrics。
#编译模型
adam = Adam(lr = 1e-4)
model.compile(optimizer=adam,
loss = 'categorical_crossentropy',
metrics = ['accuracy'])
5、训练模型:model.fit
类似sklearn中的形式
model.fit(X_train,Y_train,nb_epoch = 20,batch_size = 32)
6、测试:model.evaluate
输出测试的损失和准确度
loss , acc = model.evaluate(X_test,y_test)
莫烦大大keras学习Mnist识别(3)-----CNN的更多相关文章
- 莫烦大大keras学习Mnist识别(4)-----RNN
一.步骤: 导入包以及读取数据 设置参数 数据预处理 构建模型 编译模型 训练以及测试模型 二.代码: 1.导入包以及读取数据 #导入包 import numpy as np np.random.se ...
- 莫烦大大keras的Mnist手写识别(5)----自编码
一.步骤: 导入包和读取数据 数据预处理 编码层和解码层的建立 + 构建模型 编译模型 训练模型 测试模型[只用编码层来画图] 二.代码: 1.导入包和读取数据 #导入相关的包 import nump ...
- 莫烦大大TensorFlow学习笔记(9)----可视化
一.Matplotlib[结果可视化] #import os #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf i ...
- 莫烦python教程学习笔记——总结篇
一.机器学习算法分类: 监督学习:提供数据和数据分类标签.--分类.回归 非监督学习:只提供数据,不提供标签. 半监督学习 强化学习:尝试各种手段,自己去适应环境和规则.总结经验利用反馈,不断提高算法 ...
- 莫烦大大TensorFlow学习笔记(8)----优化器
一.TensorFlow中的优化器 tf.train.GradientDescentOptimizer:梯度下降算法 tf.train.AdadeltaOptimizer tf.train.Adagr ...
- 莫烦python教程学习笔记——保存模型、加载模型的两种方法
# View more python tutorials on my Youtube and Youku channel!!! # Youtube video tutorial: https://ww ...
- 莫烦python教程学习笔记——validation_curve用于调参
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——learn_curve曲线用于过拟合问题
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
- 莫烦python教程学习笔记——利用交叉验证计算模型得分、选择模型参数
# View more python learning tutorial on my Youtube and Youku channel!!! # Youtube video tutorial: ht ...
随机推荐
- TCP 连接状态
TCP/IP的设计者如此设计,主要原因有两个: 防止上一次连接中的包迷路后重新出现,影响新的连接(经过2MSL时间后,上一次连接中所有重复的包都会消失). 为了可靠地关闭TCP连接.主动关闭方发送的最 ...
- install pip 回顾
在install pip的时候遇到如下问题 1. yum install 想安装一个package 总是提示没有package 可以安装. 但是后来可以了 2. make 和 configure 到底 ...
- Openfire:解决乱码问题
当部署openfire后,创建用户和发送离线消息时会出现中文字符乱码的问题.要解决这个问题需要同时配置openfire和mysql两端. 首先openfire端,在安装页面中指定odbc连接串中需要带 ...
- hdu1316(大数的斐波那契数)
题目信息:求两个大数之间的斐波那契数的个数(C++/JAVA) pid=1316">http://acm.hdu.edu.cn/showproblem.php? pid=1316 这里 ...
- Oracle学习(一):基本操作和基本查询语句
文中以"--"开头的语句为凝视,即为绿色部分 1.知识点:能够对比以下的录屏进行阅读 SQL> --录屏工具spool,開始录制,并指定保存路径为c:\基本查询.txt SQ ...
- 64位oracle数据库用32位plsql developer无法连接问题(无法载入oci.dll)
在64位操作系统下安装oracle数据库,新下载了64位数据库(假设是32位数据库安装在64位的操作系统上,无论是client还是server端.都不要去选择C:\Program Files (x86 ...
- 杂项-Company:ShineYoo
ylbtech-杂项-Company:ShineYoo 1. 网站返回顶部 1. 2. 3. 4. 2. 网站测试返回顶部 1. 2. 3.家服宝返回顶部 0.首页 http://www.jiafb. ...
- VS2010中文注释带红色下划线的解决方法
环境:Visual Studio 2010 问题:代码中出现中文后会带下划线,很多时候感觉很不舒服.找了很久的原因没找到,后来无意中在VisualAssist X里找到了解决办法. 1.安装完Visu ...
- (Go)10.流程控制示例
package main import ( "math/rand" "fmt" ) func main() { //var n int n := rand.In ...
- PCB LDI文件 自动化输出(改造)实现思路
由于工厂采用Liunxs系统输出LDI文件,由于我们数据库是用的Windows Server,编程语言是.net 无法与Liunxs系统进行有效对接, 所以造成才会造成LDI 资料输效率极低,人员工作 ...