本文的keras后台为tensorflow,介绍如何利用预编译的模型进行迁移学习,以训练和识别自己的图片集。

官网 https://keras.io/applications/ 已经介绍了各个基于ImageNet的预编译模型,对于我们来说,既可以直接为我所用进行图片识别,也可在其基础上进行迁移学习,以满足自己的需求。

但在迁移学习的例子中,并不描述的十分详细,我将给出一个可运行的代码,以介绍如何进行迁移学习。

from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.optimizers import SGD
import tensorflow.keras.backend as K # 训练和测试的图片分为'bus', 'dinosaur', 'flower', 'horse', 'elephant'五类
# 其图片的下载地址为 http://pan.baidu.com/s/1nuqlTnN ,总共500张图片,其中图片以3,4,5,6,7开头进行按类区分
# 训练图片400张,测试图片100张;注意下载后,在train和test目录下分别建立上述的五类子目录,keras会按照子目录进行分类识别
NUM_CLASSES = 5
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/500pics/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/500pics/test'
# 代码最后挑出一张图片进行预测识别
PREDICT_IMG = '/home/yourname/Documents/tensorflow/images/500pics/test/elephant/502.jpg'
# FC层定义输入层的大小
FC_NUMS = 1024
# 冻结训练的层数,根据模型的不同,层数也不一样,根据调试的结果,VGG19和VGG16c层比较符合理想的测试结果,本文采用VGG19做示例
FREEZE_LAYERS = 17
# 进行训练和测试的图片大小,VGG19推荐为224×244
IMAGE_SIZE = 224 # 采用VGG19为基本模型,include_top为False,表示FC层是可自定义的,抛弃模型中的FC层;该模型会在~/.keras/models下载基本模型
base_model = VGG19(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), include_top=False, weights='imagenet') # 自定义FC层以基本模型的输入为卷积层的最后一层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(FC_NUMS, activation='relu')(x)
prediction = Dense(NUM_CLASSES, activation='softmax')(x) # 构造完新的FC层,加入custom层
model = Model(inputs=base_model.input, outputs=prediction)
# 可观察模型结构
model.summary()
# 获取模型的层数
print("layer nums:", len(model.layers)) # 除了FC层,靠近FC层的一部分卷积层可参与参数训练,
# 一般来说,模型结构已经标明一个卷积块包含的层数,
# 在这里我们选择FREEZE_LAYERS为17,表示最后一个卷积块和FC层要参与参数训练
for layer in model.layers[:FREEZE_LAYERS]:
layer.trainable = False
for layer in model.layers[FREEZE_LAYERS:]:
layer.trainable = True
for layer in model.layers:
print("layer.trainable:", layer.trainable) # 预编译模型
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) # 给出训练图片的生成器, 其中classes定义后,可让model按照这个顺序进行识别
train_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE), classes=['bus', 'dinosaur', 'flower', 'horse', 'elephant'])
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(directory=TEST_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE), classes=['bus', 'dinosaur', 'flower', 'horse', 'elephant']) # 运行模型
model.fit_generator(train_generator, epochs=5, validation_data=test_generator) # 找一张图片进行预测验证
img = load_img(path=PREDICT_IMG, target_size=(IMAGE_SIZE, IMAGE_SIZE))
# 转换成numpy数组
x = img_to_array(img)
# 转换后的数组为3维数组(224,224,3),
# 而训练的数组为4维(图片数量, 224,224, 3),所以我们可扩充下维度
x = K.expand_dims(x, axis=0)
# 需要被预处理下
x = preprocess_input(x)
# 数据预测
result = model.predict(x, steps=1)
# 最后的结果是一个含有5个数的一维数组,我们取最大值所在的索引号,即对应'bus', 'dinosaur', 'flower', 'horse', 'elephant'的顺序
print("result:", K.eval(K.argmax(result)))

需要说明的是,各个预编译模型在面临不同的数据集时,其训练效果表现不一,需要我们不断地调整各种超参数,以期找到满意的模型

深度学习应用系列(二) | 如何使用keras进行迁移学习,以训练和识别自己的图片集的更多相关文章

  1. 深度学习基础系列(十一)| Keras中图像增强技术详解

    在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难.根据之前的学习可知,数据量少带来的最直接 ...

  2. Telegram学习解析系列(二):这我怎么给后台传输数据?

    写在前面: 在iOS开发的过程中,有很多时候我们都在和数据打交道,最基本的就是数据的下载和上传了,估计很多很多的小伙伴都在用AFNetworking与后台数据打交道,可有没有想过,哪天AFNetwor ...

  3. Socket学习总结系列(二) -- CocoaAsyncSocket

    这是系列的第二篇 这是这个系列文章的第二篇,要是没有看第一篇的还是建议看看第一篇,以为这个是接着第一篇梳理的 先大概的总结一下在上篇的文章中说的些内容: 1. 整理了一下做IM我们有那些途径,以及我们 ...

  4. 小白学习Spark系列二:spark应用打包傻瓜式教程(IntelliJ+maven 和 pycharm+jar)

    在做spark项目时,我们常常面临如何在本地将其打包,上传至装有spark服务器上运行的问题.下面是我在项目中尝试的两种方案,也踩了不少坑,两者相比,方案一比较简单,本博客提供的jar包适用于spar ...

  5. 学习CNN系列二:训练过程

    卷积神经网络在本质上是一种输入到输出的映射,它能够学习大量的输入与输出之间的映射关系,而不需要任何输入和输出之间精确的数学表达式,只要用已知的模式对卷积神经网络加以训练,网络就具有输入.输出之间映射的 ...

  6. Dubbo源码学习总结系列二 dubbo-rpc远程调用模块

    dubbo本质是一个RPC框架,我们首先讨论这个骨干中的骨干,dubbo-rpc模块. 主要讨论一下几部分内容: 一.此模块在dubbo整体框架中的作用: 二.此模块需要完成的需求功能点及接口定义: ...

  7. JNI 学习笔记系列(二)

    c中没有Boolean类型的值,一般是使用1表示true,0表示false,c中也没有String类型的数据,c中的字符串要通过char数组来表示.c中没有byte类型,一般用char表示byte类型 ...

  8. Windows-universal-samples学习笔记系列二:Controls, layout, and text

    Controls, layout, and text AutoSuggestBox migration Clipboard Commanding Context menu Context menu ( ...

  9. TensorFlow学习笔记(二)-- MNIST机器学习入门程序学习

    此程序被称为TF的 Hello World,19行代码,给人感觉很简单.第一遍看的时候,不到半个小时,就把程序看完了.感觉有点囫囵吞枣的意思,没理解透彻.现在回过头来看,感觉还可以从中学到更多东西. ...

随机推荐

  1. 解决在linux安装网易云音乐无法点击图标打开

    一下内容转载自:https://blog.csdn.net/Handoking/article/details/81026651 似乎linux下无法直接打开网易云音乐的原因是图标自带的启动脚本中没有 ...

  2. 【BZOJ1043】下落的圆盘 [计算几何]

    下落的圆盘 Time Limit: 10 Sec  Memory Limit: 162 MB[Submit][Status][Discuss] Description 有n个圆盘从天而降,后面落下的可 ...

  3. bzoj 2705: [SDOI2012]Longge的问题——欧拉定理

    Description Longge的数学成绩非常好,并且他非常乐于挑战高难度的数学问题.现在问题来了:给定一个整数N,你需要求出∑gcd(i, N)(1<=i <=N). Input 一 ...

  4. 【CODEVS】1281 Xn数列

    [算法]矩阵快速幂 [题解]T*A(n-1)=A(n)矩阵如下: a 1 * x(n-1) 0 = xn 0 0 1    c        0    c   0 防止溢出可以用类似快速幂的快速乘. ...

  5. bzoj 2786 DP

    我们可以将=左右的两个数看成一个块,块内无顺序要求,把<分隔的看成两个块,那么我们设w[i][j]代表将i个元素分成j个块的方案数,那么显然w[i][j]=w[i-1][j]*j+w[i-1][ ...

  6. hdu 1690 Bus System(Dijkstra最短路)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1690 Bus System Time Limit: 2000/1000 MS (Java/Others ...

  7. 【JDK】JDK7与JDK8环境共存与切换:先安装jdk7,配置好环境变量后再安装jdk8

    1.先安装JDK7 下载jdk-7u79-windows-i586.exe,安装后配置好环境变量JAVA_HOME,CLASSPATH,PATH java -version javac 指令都正常 2 ...

  8. 《Linux内核原理与设计》第十一周作业 ShellShock攻击实验

    <Linux内核原理与设计>第十一周作业 ShellShock攻击实验 分组: 和20179215袁琳完成实验及博客攥写 实验内容:   Bash中发现了一个严重漏洞shellshock, ...

  9. github删除文件夹

    git rm -rf dirgit add .git commit -m 'remove dir'git push origin master //dir是要删除的文件夹路径

  10. Python模块学习 - IPy

    简介 在IP地址规划中,涉及到计算大量的IP地址,包括网段.网络掩码.广播地址.子网数.IP类型等,即便是专业的网络人员也要进行繁琐的计算,而IPy模块提供了专门针对IPV4地址与IPV6地址的类与工 ...