一、前述

VGG16是由16层神经网络构成的经典模型,包括多层卷积,多层全连接层,一般我们改写的时候卷积层基本不动,全连接层从后面几层依次向前改写,因为先改参数较小的。

二、具体

1、因为本文中代码需要依赖OpenCV,所以第一步先安装OpenCV

因为VGG要求输入244*244,而数据集是28*28的,所以需要通过OpenCV在代码里去改变。

2、把模型下载后离线放入用户的管理目录下面,这样训练的时候就不需要从网上再下载了

3、我们保留的是除了全连接的所有层。

4、选择数据生成器,在真正使用的时候才会生成数据,加载到内存,前面yield只是做了一个标记

 代码:

# 使用迁移学习的思想,以VGG16作为模板搭建模型,训练识别手写字体
# 引入VGG16模块
from keras.applications.vgg16 import VGG16 # 其次加载其他模块
from keras.layers import Input
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Dropout
from keras.models import Model
from keras.optimizers import SGD # 加载字体库作为训练样本
from keras.datasets import mnist # 加载OpenCV(在命令行中窗口中输入pip install opencv-python),这里为了后期对图像的处理,
# 大家使用pip install C:\Users\28542\Downloads\opencv_python-3.4.1+contrib-cp35-cp35m-win_amd64.whl
# 比如尺寸变化和Channel变化。这些变化是为了使图像满足VGG16所需要的输入格式
import cv2
import h5py as h5py
import numpy as np # 建立一个模型,其类型是Keras的Model类对象,我们构建的模型会将VGG16顶层(全连接层)去掉,只保留其余的网络
# 结构。这里用include_top = False表明我们迁移除顶层以外的其余网络结构到自己的模型中
# VGG模型对于输入图像数据要求高宽至少为48个像素点,由于硬件配置限制,我们选用48个像素点而不是原来
# VGG16所采用的224个像素点。即使这样仍然需要24GB以上的内存,或者使用数据生成器
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(48, 48, 3))#输入进来的数据是48*48 3通道
#选择imagnet,会选择当年大赛的初始参数
#include_top=False 去掉最后3层的全连接层看源码可知
for layer in model_vgg.layers:
layer.trainable = False#别去调整之前的卷积层的参数
model = Flatten(name='flatten')(model_vgg.output)#去掉全连接层,前面都是卷积层
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation='softmax')(model)#model就是最后的y
model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')
#把model_vgg.input X传进来
#把model Y传进来 就可以训练模型了 # 打印模型结构,包括所需要的参数
model_vgg_mnist.summary() #以下是原版的模型结构 224*224
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
for layer in model_vgg.layers:
layer.trainable = False#别去调整之前的卷积层的参数
model = Flatten()(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation='softmax', name='prediction')(model)
model_vgg_mnist_pretrain = Model(model_vgg.input, model, name='vgg16_pretrain') model_vgg_mnist_pretrain.summary() # 新的模型不需要训练原有卷积结构里面的1471万个参数,但是注意参数还是来自于最后输出层前的两个
# 全连接层,一共有1.2亿个参数需要训练
sgd = SGD(lr=0.05, decay=1e-5)#lr 学习率 decay 梯度的逐渐减小 每迭代一次梯度就下降 0.05*(1-(10的-5))这样来变
#随着越来越下降 学习率越来越小 步子越小
model_vgg_mnist.compile(loss='categorical_crossentropy',
optimizer=sgd, metrics=['accuracy']) # 因为VGG16对网络输入层需要接受3通道的数据的要求,我们用OpenCV把图像从32*32变成224*224,把黑白图像转成RGB图像
# 并把训练数据转化成张量形式,供keras输入
(X_train, y_train), (X_test, y_test) = mnist.load_data("../test_data_home")
X_train, y_train = X_train[:1000], y_train[:1000]#训练集1000条
X_test, y_test = X_test[:100], y_test[:100]#测试集100条
X_train = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
for i in X_train]#变成彩色的
#np.concatenate拼接到一起把
X_train = np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32') X_test = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB)
for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32') print(X_train.shape)
print(X_test.shape) X_train = X_train / 255
X_test = X_test / 255 def tran_y(y):
y_ohe = np.zeros(10)
y_ohe[y] = 1
return y_ohe y_train_ohe = np.array([tran_y(y_train[i]) for i in range(len(y_train))])
y_test_ohe = np.array([tran_y(y_test[i]) for i in range(len(y_test))]) model_vgg_mnist.fit(X_train, y_train_ohe, validation_data=(X_test, y_test_ohe),
epochs=100, batch_size=50)

 结果:

 自定义的网络层:

【Keras篇】---利用keras改写VGG16经典模型在手写数字识别体中的应用的更多相关文章

  1. 利用神经网络算法的C#手写数字识别(二)

    利用神经网络算法的C#手写数字识别(二)   本篇主要内容: 让项目编译通过,并能打开图片进行识别.   1. 从上一篇<利用神经网络算法的C#手写数字识别>中的源码地址下载源码与资源, ...

  2. 利用神经网络算法的C#手写数字识别(一)

    利用神经网络算法的C#手写数字识别 转发来自云加社区,用于学习机器学习与神经网络 欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwri ...

  3. 利用c++编写bp神经网络实现手写数字识别详解

    利用c++编写bp神经网络实现手写数字识别 写在前面 从大一入学开始,本菜菜就一直想学习一下神经网络算法,但由于时间和资源所限,一直未展开比较透彻的学习.大二下人工智能课的修习,给了我一个学习的契机. ...

  4. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  5. NN:利用深度学习之神经网络实现手写数字识别(数据集50000张图片)—Jason niu

    import mnist_loader import network training_data, validation_data, test_data = mnist_loader.load_dat ...

  6. 手写数字识别——利用keras高层API快速搭建并优化网络模型

    在<手写数字识别——手动搭建全连接层>一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配.梯度计算.准确度的统计等问题,但 ...

  7. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  8. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  9. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

随机推荐

  1. CSS选择器详细总结

    一.基本选择器 序号 选择器 含义 1. * 通用元素选择器,匹配任何元素 2. E 标签选择器,匹配所有使用E标签的元素 3. .info class选择器,匹配所有class属性中包含info的元 ...

  2. Windows下python3和python2同时安装python2.exe、python3.exe和pip2、pip3设置

    1.添加python2到系统环境变量 打开,控制面板\系统和安全\系统,选择高级系统设置,环境变量,选择Path,点击编辑,新建,分别添加D:\Python\python27和D:\Python\py ...

  3. java 匿名对象,内部类,修饰符,代码块

    匿名对象是在建对象时只有创建对象的语句方法而没有把对象的地址赋值给变量,匿名对象只能调用一次方法,想再调用时需要再创建一个新的匿名对象 创建普通对象:Person p =new Person(); 创 ...

  4. BZOJ_4517_[Sdoi2016]排列计数_组合数学

    BZOJ_4517_[Sdoi2016]排列计数_组合数学 Description 求有多少种长度为 n 的序列 A,满足以下条件: 1 ~ n 这 n 个数在序列中各出现了一次 若第 i 个数 A[ ...

  5. 视频转字符动画-Python-60行代码

    更新:2018-5-21 注意: 最后一步播放字符动画使用了只支持类 unix 系统的模块 curses, 因此在windows上是播放不了的... 解决方法: 1. 最近好像有一个移植 https: ...

  6. 实时监听input输入框value的变化:

    HTML5 标准事件 oninput 和 IE 专属事件 onpropertychange 事件实时监听输入框value的变化 oninput 事件在用户输入时触发. 该事件在 <input&g ...

  7. 理解ASP.NET Core 依赖注入

    目录: 一.什么是依赖注入 1.1.什么是依赖? 1.2. 什么是注入? 1.3.依赖注入解决的问题 二.服务的生命周期(.Net Core DI) 三.替换默认服务容器 3.1.为什么替换默认服务容 ...

  8. 一大波开发者福利来了,一份微软官方Github上发布的开源项目清单等你签收

    目录 微软Github开源项目入口 微软开源项目受欢迎程度排名 Visual Studio Code TypeScript RxJS .NET Core 基础类库 CNTK Microsoft cal ...

  9. ajax异步请求302分析

    1.前言 遇到这样一种情况,打开网页两个窗口a,b(都是已经登录授权的),在a页面中退出登录,然后在b页面执行增删改查,这个时候因为授权原因,b页面后端的请求肯定出现异常(对这个异常的处理,进行内部跳 ...

  10. 关于ES5的indexof()和ES7的includes()的区别

    早es5的时候就有了查找数组中是否包含某个值的API  indexOf(); 使用方法很简单,比如有个数组是: var arr=[2,3,4,"php"] 如果我们想知道数组中有没 ...