本项目参考:

https://www.bilibili.com/video/av31500120?t=4657

训练代码

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Train import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense
from keras.optimizers import Adam
from keras.models import load_model nb_class = 10
nb_epoch = 4
batchsize = 128 '''
1st,准备参数
X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗?
设置时间函数测试一下两者对比。
小技巧:X_train /= 255.0 就可不用转换成浮点了???
'''
# Preparing your data mnist. MAC /.keras/datasets linux home ./keras/datasets
(X_train, Y_train), (X_test, Y_test) = mnist.load_data() # setup data shape
# (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道
X_train = X_train.reshape(-1, 28, 28, 1) # tensorflow-channel last,while theano-channel first
X_test = X_test.reshape(-1, 28, 28, 1) X_train = X_train/255.000
X_test = X_test/255.000 # One-hot 6 --> [0,0,0,0,0,1,0,0,0]
Y_train = np_utils.to_categorical(Y_train, nb_class)
Y_test = np_utils.to_categorical(Y_test, nb_class) '''
2nd,设置模型
''' # setup model
model = Sequential() # 1st convolution layer # 滤波器要在28x28的图上横着走32次
model.add(Convolution2D(
filters=32, # 此处把filters写成了filter,找了半天。囧
kernel_size=[5, 5], # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖
padding='same', # padding也是一个窗口模式
input_shape=(28, 28, 1) # 定义输入的数据,必须是元组
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 2nd Conv2D layer
model.add(Convolution2D(
filters=64,
kernel_size=(5, 5),
padding='same',
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 讨论,卷积层数和最终结果关系。 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
model.add(Flatten()) # 把卷积层里面的全部转换层一维数组
model.add(Dense(1024)) # Dense is output
model.add(Activation('relu')) # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
# 把卷积层里面的全部转换层一维数组
model.add(Dense(256)) # Dense is output
model.add(Activation('tanh')) # 2nd Fully connected Dense
model.add(Dense(10))
model.add(Activation('softmax')) '''
3rd 定义参数
'''
# Define Optimizer and setup Param
adam = Adam(lr=0.0001) # Adam实例化 # compile model
model.compile(
optimizer=adam, # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam
loss='categorical_crossentropy',
metrics=['accuracy'],
) # Run network
model.fit(x=X_train, # 更多参数可以查看fit函数,alt+鼠标左键单击fit
y=Y_train,
epochs=nb_epoch,
batch_size=batchsize, # p=parameter, batch_size; v=var, batch size
verbose=1, # 显示模式
validation_data=(X_test, Y_test)
)
model.save('model_name.h5')
# evaluation = model.evaluate(X_test, Y_test) 现在用model.fit(validation_data)
# print(evaluation) 效果一样

测试代码:

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Predict import numpy as np
from keras.models import load_model # ??
import matplotlib.pyplot as plt
import matplotlib.image as processimage # load trained model
model = load_model('model_name.h5') # 已经训练好了的模型,在根目录下,默认为model_name.h5 # 写一个来预测的类
class MainPredictImg(object): def __init__(self):
pass def pred(self, filename):
pred_img = processimage.imread(filename)
pred_img = np.array(pred_img)
pred_img = pred_img.reshape(-1, 28, 28, 1)
prediction = model.predict(pred_img)
final_prediction = [result.argmax() for result in prediction][0]
a = 0
for i in prediction[0]:
print(a)
print('Percent:{:.30%}'.format(i))
a = a+1
return final_prediction def main():
predict = MainPredictImg()
res = predict.pred('4.png')
print("your number is:-->", res) if __name__ == '__main__':
main()

keras02 - hello convolution neural network 搭建第一个卷积神经网络的更多相关文章

  1. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  2. Convolution Neural Network (CNN) 原理与实现

    本文结合Deep learning的一个应用,Convolution Neural Network 进行一些基本应用,参考Lecun的Document 0.1进行部分拓展,与结果展示(in pytho ...

  3. Deeplearning - Overview of Convolution Neural Network

    Finally pass all the Deeplearning.ai courses in March! I highly recommend it! If you already know th ...

  4. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  5. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  6. 深度学习:卷积神经网络(convolution neural network)

    (一)卷积神经网络 卷积神经网络最早是由Lecun在1998年提出的. 卷积神经网络通畅使用的三个基本概念为: 1.局部视觉域: 2.权值共享: 3.池化操作. 在卷积神经网络中,局部接受域表明输入图 ...

  7. Recurrent Neural Network系列1--RNN(循环神经网络)概述

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  8. 【面向代码】学习 Deep Learning(三)Convolution Neural Network(CNN)

    ========================================================================================== 最近一直在看Dee ...

  9. TensorFlow从入门到理解(三):你的第一个卷积神经网络(CNN)

    运行代码: from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutoria ...

随机推荐

  1. 第5章 简单的C程序设计——循环结构程序设计

    5.1 为什么需要循环控制 前面介绍了程序中常用到的顺序结构和选择结构,但是只有这两种结构是不够的,还需要用到循环结构(或称重复结构).因为在程序所处理的问题中常常遇到需要重复处理的问题. 循环结构和 ...

  2. 『练手』001 Laura.SqlForever架构基础(Laura.XtraFramework 的变迁)

    001 Laura.SqlForever架构的基础(Laura.XtraFramework 的变迁之路) Laura.XtraFramework 到底是 做什么的? Laura.XtraFramewo ...

  3. 玩转Spring Cloud之熔断降级(Hystrix)与监控

    本文内容导航目录: 前言:解释熔断降级一.搭建服务消费者项目,并集成 Hystrix环境 1.1.在POM XML中添加Hystrix依赖(spring-cloud-starter-netflix-h ...

  4. 操作MongoDB数据库知识点

    一.命令行操作mongo: 1.开启数据库 mongo 如果启动mongo报以下错误: 运行brew services start mongodb 2.创建数据库并进入实例 use test 3.查看 ...

  5. WinForm客户端限速下载(C#限速下载)

    最近由于工作需要,需要开发一个能把服务器上的文件批量下载下来本地保存,关键是要实现限速下载,如果全速下载会影响服务器上的带宽流量.本来我最开始的想法是在服务器端开发一个可以从源头就限速下载的Api端口 ...

  6. Rekit

    本文转自:http://rekit.js.org/docs/get-started.html Get started The easiest way to try out Rekit is creat ...

  7. WinForm DataGridView实时更新表格数据

    前言 一个特殊的项目没有用第三方控件库,但用到了DataGridView,由于是客户端产生的数据,所以原始数据源就是一个集合. 根据需要会向集合中添加数据项,或是修改某些数据项的值,但DataGrid ...

  8. nginx系列6:nginx的进程结构

    nginx的进程结构 如下图: 通过ps –ef | grep nginx可以看到共有三个进程,一个master进程,两个worker进程. nginx是多进程结构,多进程结构设计是为了保证nginx ...

  9. vue项目中获取cdn域名插件

    import axios from 'axios' let CdnPath = {} CdnPath.install = function (Vue, options) { Vue.prototype ...

  10. 禁止微信内的H5页面上下拖动

    客户需求:禁止微信内的H5页面上下拖动: 解决方案: 网上的答案几乎都是阻止默认事件,即: document.body.addEventListener('touchmove' , function( ...