CNN基础三:预训练模型的微调
上一节中,我们利用了预训练的VGG网络卷积基,来简单的提取了图像的特征,并用这些特征作为输入,训练了一个小分类器。
这种方法好处在于简单粗暴,特征提取部分的卷积基不需要训练。但缺点在于,一是别人的模型是针对具体的任务训练的,里面提取到的特征不一定适合自己的任务;二是无法使用图像增强的方法进行端到端的训练。
因此,更为常用的一种方法是预训练模型修剪 + 微调,好处是可以根据自己任务需要,将预训练的网络和自定义网络进行一定的融合;此外还可以使用图像增强的方式进行端到端的训练。仍然以VGG16为例,过程为:
- 在已经训练好的基网络(base network)上添加自定义网络;
- 冻结基网络,训练自定义网络;
- 解冻部分基网络,联合训练解冻层和自定义网络。
注意在联合训练解冻层和自定义网络之前,通常要先训练自定义网络,否则,随机初始化的自定义网络权重会将大误差信号传到解冻层,破坏解冻层以前学到的表示,使得训练成本增大。
第一步:对预训练模型进行修改
##################第一步:在已经训练好的卷积基上添加自定义网络######################
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
#搭建模型
conv_base = VGG16(include_top=False, input_shape=(150,150,3)) #模型也可以看作一个层
model = Sequential()
model.add(conv_base)
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
#model.summary()
第二步:冻结卷积基,训练自定义网络
######################第二步:冻结卷积基,训练自定义网络##########################
#冻结卷积基,确保结果符合预期。或者用assert len(model.trainable_weights) == 30来验证
print("冻结之前可训练的张量个数:", len(model.trainable_weights)) #结果为30
conv_base.trainable = False
print("冻结之后可训练的张量个数:", len(model.trainable_weights)) #结果为4
#注:只有后两层Dense可以训练,每层一个权重张量和一个偏置张量,所以有4个
#利用图像生成器进行图像增强
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255) #验证、测试的图像生成器不能用图像增强
train_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\train'
validation_dir = r'D:\KaggleDatasets\MyDatasets\dogs-vs-cats-small\validation'
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=(150,150),
batch_size=20,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(validation_dir,
target_size=(150,150),
batch_size=20,
class_mode='binary')
#模型编译和训练,注意修改trainable属性之后需要重新编译,否则修改无效
from keras import optimizers
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
H = model.fit_generator(train_generator,
steps_per_epoch=2000/20,
epochs=30,
validation_data=validation_generator,
validation_steps=1000/20)
训练30个epoch之后,结果如图所示。(结果可视化代码见上一节)
第三步:解冻部分卷积基(第5个block),联合训练
通常keras的冻结和解冻操作用的是模型或层的trainable属性。需要注意三点:
- model.trainable是全局属性,layer.trainable是层的属性,单独定义层的这一属性后全局属性即失效;
- 定义这一属性后,模型需要重新编译才能生效;
- conv_base是一个模型,但它在总模型model中是作为一个层的实例,因此遍历model.layers时会把conv_base作为一个层,如果需要深入conv_base内部各层进行操作,需要遍历conv_base.layers。
为了确保trainable属性符合预期,通常会确认一下,下面一些代码可能会有用。(这段主要是便于理解,跑代码时可选择性忽略这段。)
#可视化各层序号及名称
for i, layer in enumerate(model.layers):
print(i, layer.name)
for i, layer in enumerate(conv_base.layers):
print(i, layer.name)
#由于之前操作错误,导致模型全部层都被冻结,所以这个模块先把所有层解冻
for layer in conv_base.layers: #先解冻卷积基中所有层的张量
layer.trainable = True
for layer in model.layers: #解冻model中所有层张量
layer.trainable = True
#查看各层的trainable属性
for layer in model.layers:
print(layer.name, layer.trainable)
for layer in conv_base.layers:
print(layer.name, layer.trainable)
#model.trainable = True #注意:设定单独层的trainable属性后,全局trainable属性无效
print(len(conv_base.trainable_weights)) #26
print(len(model.trainable_weights)) #30
经过第二步之后,卷积基被冻结,后两层Dense可训练。接下来正式开始第三步,解冻第5个block,联合训练解冻层和自定义网络。
######################第三步:解冻部分卷积基,联合训练##########################
#冻结VGG16中前四个block,解冻第五个block
flag = False #标记是否到达第五个block
for layer in conv_base.layers: #注意不是遍历model.layers
if layer.name == 'block5_conv1': #若到达第五个block,则标记之
flag = True
if flag == False: #若标记为False,则冻结,否则设置为可训练
layer.trainable = False
else:
layer.trainable = True
print(len(model.trainable_weights)) #应为10
#重新编译并训练。血泪教训,一定要重新编译,不然trainable属性就白忙活了!
from keras import optimizers
#注:吐血,官网文档参数learning_rate,这里竟然不认,只能用lr
model.compile(loss='binary_crossentropy',
optimizer=optimizers.Adam(lr=1e-5), metrics=['accuracy'])
H2 = model.fit_generator(train_generator,
steps_per_epoch=2000/20,
epochs=100,
validation_data=validation_generator,
validation_steps=1000/20)
经过100个epoch之后,结果如下。可以看出验证准确率被提高到94%左右。
Reference:
书籍:Python深度学习
CNN基础三:预训练模型的微调的更多相关文章
- 自然语言处理(三) 预训练模型:XLNet 和他的先辈们
预训练模型 在CV中,预训练模型如ImagNet取得很大的成功,而在NLP中之前一直没有一个可以承担此角色的模型,目前,预训练模型如雨后春笋,是当今NLP领域最热的研究领域之一. 预训练模型属于迁移学 ...
- BERT的通俗理解 预训练模型 微调
1.预训练模型 BERT是一个预训练的模型,那么什么是预训练呢?举例子进行简单的介绍 假设已有A训练集,先用A对网络进行预训练,在A任务上学会网络参数,然后保存以备后用,当来一个新 ...
- 使用BERT预训练模型+微调进行文本分类
本文记录使用BERT预训练模型,修改最顶层softmax层,微调几个epoch,进行文本分类任务. BERT源码 首先BERT源码来自谷歌官方tensorflow版:https://github.co ...
- 我的Keras使用总结(4)——Application中五款预训练模型学习及其应用
本节主要学习Keras的应用模块 Application提供的带有预训练权重的模型,这些模型可以用来进行预测,特征提取和 finetune,上一篇文章我们使用了VGG16进行特征提取和微调,下面尝试一 ...
- 预训练模型——开创NLP新纪元
预训练模型--开创NLP新纪元 论文地址 BERT相关论文列表 清华整理-预训练语言模型 awesome-bert-nlp BERT Lang Street huggingface models 论文 ...
- BERT预训练模型的演进过程!(附代码)
1. 什么是BERT BERT的全称是Bidirectional Encoder Representation from Transformers,是Google2018年提出的预训练模型,即双向Tr ...
- CNN基础框架简介
卷积神经网络简介 卷积神经网络是多层感知机的变种,由生物学家休博尔和维瑟尔在早期关于猫视觉皮层的研究发展而来.视觉皮层的细胞存在一个复杂的构造,这些细胞对视觉输入空间的子区域非常敏感,我们称之为感受野 ...
- Pytorch——BERT 预训练模型及文本分类
BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...
- 我的Keras使用总结(3)——利用bottleneck features进行微调预训练模型VGG16
Keras的预训练模型地址:https://github.com/fchollet/deep-learning-models/releases 一个稍微讲究一点的办法是,利用在大规模数据集上预训练好的 ...
随机推荐
- SpringIntegration---Redis
1.依赖 <dependency> <groupId>org.springframework.integration</groupId> <artifactI ...
- vfs的super block
super block这个数据结构,乃至super block在磁盘上的位置,是哪里的规定? 没规定,1k偏移只是ext文件系统.但是像fat,它们第0扇区后就是保留扇区,但linux一样要识别它们. ...
- CDH6.3.1安装hue 报错
x 一.查看日志server运行日志 /var/log/cloudera-scm-server/cloudera-scm-server.log 2019-12-11 17:28:34,201 INFO ...
- C# 私有字段前缀 _ 的设置(VS2019, .editorconfig)
常量和静态只读字段大写 私有字段前缀 _ #### Naming styles #### # Naming rules dotnet_naming_rule.const_should_be_all_u ...
- LDD3 第11章 内核的数据类型
考虑到可移植性的问题,现代版本的Linux内核的可移植性是非常好的. 在把x86上的代码移植到新的体系架构上时,内核开发人员遇到的若干问题都和不正确的数据类型有关.坚持使用严格的数据类型,并且使用-W ...
- 【Dart学习】--之Duration相关方法总结
一,概述 Duration表示从一个时间点到另一个时间点的时间差 如果是一个较晚的时间点和一个较早的时间点,Duration可能是负数 二,创建Duration 唯一的构造函数创建Duration对象 ...
- 【CF1210C】Kamil and Making a Stream(vector,数论,树)
题意:给定一棵n个点带点权的树,i号点的点定义f(i,j)为i到j路径上所有点的gcd,其中i是j的一个祖先,求所有f(i,j)之和mod1e9+7 2<=n<=1e5,0<=a[i ...
- BZOJ 4484: [Jsoi2015]最小表示(拓扑排序+bitset)
传送门 解题思路 \(bitset\)维护连通性,给每个点开个\(bitset\),第\(i\)位为\(1\)则表示与第\(i\)位联通.算答案时显然要枚举每条边,而枚举边的顺序需要贪心,一个点先到达 ...
- 2018-2019-2 20175307实验三《敏捷开发与XP实践》实验报告
实验三 敏捷开发与XP实践-1 1.仔细学习了http://www.cnblogs.com/rocedu/p/4795776.html,发布了一篇关于Google的Java编码的博客,具体内容就不在这 ...
- wsl中加载git之后,发现文件是修改状态
查看git status,发现所有文件都被修改. git diff文件查看,发现是行尾的问题导致的. https://github.com/Microsoft/WSL/issues/184 在wsl里 ...