本项目参考:

https://www.bilibili.com/video/av31500120?t=4657

训练代码

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Train import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense
from keras.optimizers import Adam
from keras.models import load_model nb_class = 10
nb_epoch = 4
batchsize = 128 '''
1st,准备参数
X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗?
设置时间函数测试一下两者对比。
小技巧:X_train /= 255.0 就可不用转换成浮点了???
'''
# Preparing your data mnist. MAC /.keras/datasets linux home ./keras/datasets
(X_train, Y_train), (X_test, Y_test) = mnist.load_data() # setup data shape
# (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道
X_train = X_train.reshape(-1, 28, 28, 1) # tensorflow-channel last,while theano-channel first
X_test = X_test.reshape(-1, 28, 28, 1) X_train = X_train/255.000
X_test = X_test/255.000 # One-hot 6 --> [0,0,0,0,0,1,0,0,0]
Y_train = np_utils.to_categorical(Y_train, nb_class)
Y_test = np_utils.to_categorical(Y_test, nb_class) '''
2nd,设置模型
''' # setup model
model = Sequential() # 1st convolution layer # 滤波器要在28x28的图上横着走32次
model.add(Convolution2D(
filters=32, # 此处把filters写成了filter,找了半天。囧
kernel_size=[5, 5], # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖
padding='same', # padding也是一个窗口模式
input_shape=(28, 28, 1) # 定义输入的数据,必须是元组
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 2nd Conv2D layer
model.add(Convolution2D(
filters=64,
kernel_size=(5, 5),
padding='same',
))
model.add(Activation('relu'))
model.add(MaxPool2D(
pool_size=(2, 2), # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
strides=(2, 2), # 相当于把图片缩小了。
padding="same",
)) # 讨论,卷积层数和最终结果关系。 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
model.add(Flatten()) # 把卷积层里面的全部转换层一维数组
model.add(Dense(1024)) # Dense is output
model.add(Activation('relu')) # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
# 把卷积层里面的全部转换层一维数组
model.add(Dense(256)) # Dense is output
model.add(Activation('tanh')) # 2nd Fully connected Dense
model.add(Dense(10))
model.add(Activation('softmax')) '''
3rd 定义参数
'''
# Define Optimizer and setup Param
adam = Adam(lr=0.0001) # Adam实例化 # compile model
model.compile(
optimizer=adam, # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam
loss='categorical_crossentropy',
metrics=['accuracy'],
) # Run network
model.fit(x=X_train, # 更多参数可以查看fit函数,alt+鼠标左键单击fit
y=Y_train,
epochs=nb_epoch,
batch_size=batchsize, # p=parameter, batch_size; v=var, batch size
verbose=1, # 显示模式
validation_data=(X_test, Y_test)
)
model.save('model_name.h5')
# evaluation = model.evaluate(X_test, Y_test) 现在用model.fit(validation_data)
# print(evaluation) 效果一样

测试代码:

 # coding: utf-8
# Learning from Mofan and Mike G
# Recreated by Paprikatree
# Convolution NN Predict import numpy as np
from keras.models import load_model # ??
import matplotlib.pyplot as plt
import matplotlib.image as processimage # load trained model
model = load_model('model_name.h5') # 已经训练好了的模型,在根目录下,默认为model_name.h5 # 写一个来预测的类
class MainPredictImg(object): def __init__(self):
pass def pred(self, filename):
pred_img = processimage.imread(filename)
pred_img = np.array(pred_img)
pred_img = pred_img.reshape(-1, 28, 28, 1)
prediction = model.predict(pred_img)
final_prediction = [result.argmax() for result in prediction][0]
a = 0
for i in prediction[0]:
print(a)
print('Percent:{:.30%}'.format(i))
a = a+1
return final_prediction def main():
predict = MainPredictImg()
res = predict.pred('4.png')
print("your number is:-->", res) if __name__ == '__main__':
main()

keras02 - hello convolution neural network 搭建第一个卷积神经网络的更多相关文章

  1. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  2. Convolution Neural Network (CNN) 原理与实现

    本文结合Deep learning的一个应用,Convolution Neural Network 进行一些基本应用,参考Lecun的Document 0.1进行部分拓展,与结果展示(in pytho ...

  3. Deeplearning - Overview of Convolution Neural Network

    Finally pass all the Deeplearning.ai courses in March! I highly recommend it! If you already know th ...

  4. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  5. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  6. 深度学习:卷积神经网络(convolution neural network)

    (一)卷积神经网络 卷积神经网络最早是由Lecun在1998年提出的. 卷积神经网络通畅使用的三个基本概念为: 1.局部视觉域: 2.权值共享: 3.池化操作. 在卷积神经网络中,局部接受域表明输入图 ...

  7. Recurrent Neural Network系列1--RNN(循环神经网络)概述

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  8. 【面向代码】学习 Deep Learning(三)Convolution Neural Network(CNN)

    ========================================================================================== 最近一直在看Dee ...

  9. TensorFlow从入门到理解(三):你的第一个卷积神经网络(CNN)

    运行代码: from __future__ import print_function import tensorflow as tf from tensorflow.examples.tutoria ...

随机推荐

  1. python将字符串类型list转换成list

    python读取了一个list是字符串形式的'[11.23,23.34]',想转换成list类型: 方式一: import ast str_list = "[11.23,23.34]&quo ...

  2. kubernetes 存储卷

    kubernetes 存储卷    数据卷用于实现容器持久化数据,Kubernetes对于数据卷重新定义,提供了丰富强大的功能.在Kubernetes系统中,当Pod重建的时候,数据卷会丢失,Kube ...

  3. 【Android Studio安装部署系列】三、Android Studio项目目录结构

    版权声明:本文为HaiyuKing原创文章,转载请注明出处! 概述 简单介绍下Android studio新建项目的目录结构. 常用项目结构类型 在Android Studio中,提供了以下几种项目结 ...

  4. 【Android Studio安装部署系列】八、Android Studio主题皮肤更换

    版权声明:本文为HaiyuKing原创文章,转载请注明出处! 概述 Android Studio具有自己的主题皮肤,但是如果想要更换自己喜欢的主题皮肤,可以参考下面的步骤. 注意,更换主题皮肤,之前的 ...

  5. PreferencesUtils【SharedPreferences操作工具类】

    版权声明:本文为HaiyuKing原创文章,转载请注明出处! 前言 可以替代ACache用来保存用户名.密码. 相较于Acache,不存在使用猎豹清理大师进行垃圾清理的时候把缓存的数据清理掉的问题. ...

  6. java锁与监视器概念 为什么wait、notify、notifyAll定义在Object中 多线程中篇(九)

    在Java中,与线程通信相关的几个方法,是定义在Object中的,大家都知道Object是Java中所有类的超类 在Java中,所有的类都是Object,借助于一个统一的形式Object,显然在有些处 ...

  7. 前端javascript如何阻止按下退格键页面回退 但 不阻止文本框使用退格键删除文本

    这段代码可以: document.onkeydown = function (e) { e.stopPropagation(); // 阻止事件冒泡传递 e.preventDefault(); // ...

  8. 【转】IIS上的反向代理

    http://blog.csdn.net/yuanguozhengjust/article/details/23576033 一直说在IIS上做反向代理,由于沉迷在nginx一行指令完事的美好情景当中 ...

  9. 学习笔记—XML

    XML XML简介 XML指可扩展标记语言(EXtensible Markup Language),是一种标记语言. XML是一种灵活的语言,标签没有被预定义,需要自行定义标签. 通常,XML被用于信 ...

  10. Ajax - Apache安装配置

    apache安装配置 1.安装wamp2.配置根路径3.默认的网站根路径是安装目录的www子目录,如果不想使用默认目录,可以自己配置.配置方式如下: --找到文件wamp\bin\apache\Apa ...