在深度学习的学习过程中,可能会用到一些已经训练好的模型,比如Alex Net,google Net,VGG,Resnet等,那我们怎样对这些训练好的模型进行fine-tune来提高准确率呢?

参考文章:https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

使用已经训练好的VGG16模型来帮助我们进行这个分类任务,因为要分类的是猫,狗这类物体,而VGG net是在ImageNet上训练的,而imageNet实际上已经包含了这2种物体(猫,狗)了。

方法

首先载入VGG-16的权重

接下来在初始化好的VGG网络上添加我们预训练好的模型

最后将最后一个卷积块的层数冻结,然后以很低的学习率开始训练(我们只选择最后一个卷积块进行训练,因为训练样本很少,而VGG模型层数很多,全部训练肯定不能训练好,会过拟合)。其次fine-tune是由于在一个已经训练好的模型上进行的,故权值更新应该是一个小范围的,以免破坏预训练好的特征。

首先构造VGG16模型

# build the VGG16 network
model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(3, img_width, img_height))) model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, 3, 3, activation='relu', name='conv1_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, 3, 3, activation='relu', name='conv2_2'))
model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, 3, 3, activation='relu', name='conv3_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv4_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2))) model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_1'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_2'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))

加载VGG16训练好的权重(我们只要全连接以前的权重):

# load the weights of the VGG16 networks
# (trained on ImageNet, won the ILSVRC competition in 2014)
# note: when there is a complete match between your model definition
# and your weight savefile, you can simply call model.load_weights(filename)
assert os.path.exists(weights_path), 'Model weights not found (see "weights_path" variable in script).'
f = h5py.File(weights_path)
for k in range(f.attrs['nb_layers']):
if k >= len(model.layers):
# we don't look at the last (fully-connected) layers in the savefile
break
g = f['layer_{}'.format(k)]
weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
model.layers[k].set_weights(weights)
f.close()
print('Model loaded.')

然后再VGG16结构基础上添加一个简单的分类器及预训练好的模型:

# build a classifier model to put on top of the convolutional model
top_model = Sequential()
top_model.add(Flatten(input_shape=model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(1, activation='sigmoid')) # note that it is necessary to start with a fully-trained
# classifier, including the top classifier,
# in order to successfully do fine-tuning
top_model.load_weights(top_model_weights_path) # add the model on top of the convolutional base
model.add(top_model)

把随后一个卷积块前的权重设置为不训练:

# set the first 25 layers (up to the last conv block)
# to non-trainable (weights will not be updated)
for layer in model.layers[:25]:
layer.trainable = False # compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
model.compile(loss='binary_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])

这样一个很简单的fine-tune在50个epoch后就可以达到一个大概0.94的accuracy

Keras-在预训练好网络模型上进行fine-tune的更多相关文章

  1. 学习AI之NLP后对预训练语言模型——心得体会总结

    一.学习NLP背景介绍:      从2019年4月份开始跟着华为云ModelArts实战营同学们一起进行了6期关于图像深度学习的学习,初步了解了关于图像标注.图像分类.物体检测,图像都目标物体检测等 ...

  2. 知识增强的预训练语言模型系列之ERNIE:如何为预训练语言模型注入知识

    NLP论文解读 |杨健 论文标题: ERNIE:Enhanced Language Representation with Informative Entities 收录会议:ACL 论文链接: ht ...

  3. 在Keras模型中one-hot编码,Embedding层,使用预训练的词向量/处理图片

    最近看了吴恩达老师的深度学习课程,又看了python深度学习这本书,对深度学习有了大概的了解,但是在实战的时候, 还是会有一些细枝末节没有完全弄懂,这篇文章就用来总结一下用keras实现深度学习算法的 ...

  4. VGG16等keras预训练权重文件的下载及本地存放

    VGG16等keras预训练权重文件的下载: https://github.com/fchollet/deep-learning-models/releases/ .h5文件本地存放目录: Linux ...

  5. AI:拿来主义——预训练网络(二)

    上一篇文章我们聊的是使用预训练网络中的一种方法,特征提取,今天我们讨论另外一种方法,微调模型,这也是迁移学习的一种方法. 微调模型 为什么需要微调模型?我们猜测和之前的实验,我们有这样的共识,数据量越 ...

  6. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  7. pytorch预训练

    Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet.densenet.inception.resnet. ...

  8. 第二十四节,TensorFlow下slim库函数的使用以及使用VGG网络进行预训练、迁移学习(附代码)

    在介绍这一节之前,需要你对slim模型库有一些基本了解,具体可以参考第二十二节,TensorFlow中的图片分类模型库slim的使用.数据集处理,这一节我们会详细介绍slim模型库下面的一些函数的使用 ...

  9. BERT总结:最先进的NLP预训练技术

    BERT(Bidirectional Encoder Representations from Transformers)是谷歌AI研究人员最近发表的一篇论文:BERT: Pre-training o ...

随机推荐

  1. springcloud(十):服务网关zuul初级篇

    前面的文章我们介绍了,Eureka用于服务的注册于发现,Feign支持服务的调用以及均衡负载,Hystrix处理服务的熔断防止故障扩散,Spring Cloud Config服务集群配置中心,似乎一个 ...

  2. 搜集的一些酷炫的金属色 ,RGB值 和大家分享一下

    开发iOS程序过程中会使用到RGB,要注意每个RGB值都要除以 255.0 ,注意: ' .0 ' 不能省!! 一下是本人搜集的一些酷炫金属色的RGB值:   黄金 242,192,86 石墨 87, ...

  3. storm深入研究

    著作权归作者所有.商业转载请联系作者获得授权,非商业转载请注明出处.作者:He Ransom链接:http://www.zhihu.com/question/23441639/answer/28075 ...

  4. linux服务器检测CPU使用率、负载以及java占用CPU使用率的shell脚本

    #!/bin/bash CPU=`top -b -n 1|grep Cpu|awk '{print $2}'|cut -f 1 -d "."`LOAD=`top -b -n 1|g ...

  5. php错误:You don't have permission to access / on this server.

    以前php环境崩溃了,重新装了个,打开第一个文件就出现: You don't have permission to access / on this server.错误,让我情以何堪啊,居然说我此台服 ...

  6. brew faq:call ISHELL_GetJulianDate always return 1980 1 6

    假设你当时系统的时间为20130804000000,那么如果你将系统的时间改为20140104000000,那么ISHELL_GetJulianDate  将返回20140104000000. 但如果 ...

  7. Linux下安装配置SVN

    1.检查系统上是否安装了SVN rpm -qa subversion 没有安装,则使用以下命令安装 yum -y install  subversion 2.配置svn并启动svn服务 (1) 指定s ...

  8. JavaScript作用域原理——预编译

    JavaScript是一种脚本语言, 它的执行过程, 是一种翻译执行的过程.并且JavaScript是有预编译过程的,在执行每一段脚本代码之前, 都会首先处理var关键字和function定义式(函数 ...

  9. 用Broadcast Receiver刷新数据(二)

    采用消息发布/订阅的一个很大的优点就是代码的简洁性,并且能够有效地降低消息发布者和订阅者之间的耦合度.举个例子,比如有两个界面,ActivityA和ActivityB,从ActivityA界面跳转到A ...

  10. CSS3 属性组参考资料

    CSS 属性组: 动画 背景 边框和轮廓 盒(框) 颜色 内容分页媒体 定位 可伸缩框 字体 生成内容 网格 超链接 行框 列表 外边距 Marquee 多列 内边距 分页媒体 定位 打印 Ruby ...