Fine-Tuning微调原理
Fine-Tuning微调原理
如何在只有60000张图片的Fashion-MNIST训练数据集中训练模型。ImageNet,这是学术界使用最广泛的大型图像数据集,它拥有1000多万幅图像和1000多个类别的对象。然而,我们经常处理的数据集的大小通常比第一个大,但比第二个小。
假设我们想在图像中识别不同种类的椅子,然后将购买链接推给用户。一种可行的方法是先找到一百张常见的椅子,每把椅子取一千张不同角度的图像,然后在采集到的图像数据集上训练分类模型。虽然这个数据集可能比时尚MNIST大,但是示例的数量仍然不到ImageNet的十分之一。这可能导致适用于ImageNet的复杂模型在此数据集上过度拟合。同时,由于数据量有限,最终训练出的模型精度可能达不到实际要求。
为了解决上述问题,一个显而易见的解决办法就是收集更多的数据。然而,收集和标记数据会消耗大量的时间和金钱。例如,为了收集ImageNet的数据集,研究人员花费了数百万美元的研究经费。尽管近年来,数据采集成本大幅下降,但成本仍然不容忽视。
另一种解决方案是应用转移学习将从源数据集学习的知识迁移到目标数据集。例如,虽然ImageNet中的图像大多与椅子无关,但是在这个数据集上训练的模型可以提取更一般的图像特征,这些特征可以帮助识别边缘、纹理、形状和对象组成。这些相似的特征对于识别椅子同样有效。
在本节中,我们将介绍迁移学习中的一种常用技术:微调。如图13.2.1所示,微调包括以下四个步骤:
在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型。
建立一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数,输出层除外。我们假设这些模型参数包含从源数据集学习到的知识,这些知识将同样适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关,因此不在目标模型中使用。
将输出大小为目标数据集类别数的输出层添加到目标模型中,并随机初始化该层的模型参数。
在目标数据集上训练目标模型,例如椅子数据集。我们将从头开始训练输出层,同时根据源模型的参数对所有剩余层的参数进行微调。

Fig. 1. Fine tuning.
1. Hot Dog Recognition
我们将使用一个具体的例子来练习:热狗识别。我们将基于一个小的数据集,对在ImageNet数据集上训练的ResNet模型进行微调。这个小数据集包含数千张图像,其中一些包含热狗。我们将使用通过微调获得的模型来识别图像是否包含热狗。
首先,导入实验所需的软件包和模块。Gluon的model_zoo package提供了一个通用的预训练模型。如果你想获得更多的计算机视觉的预先训练模型,你可以使用GluonCV工具箱。
%matplotlib inline
from d2l import mxnet as d2l
from mxnet import gluon, init, np, npx
from mxnet.gluon import nn
import os
npx.set_np()
1.1. Obtaining the Dataset
我们使用的热狗数据集来自在线图像,包含1400个热狗的正面图片和其他食物的相同数量的负面图片。1000个各种课程的图像用于训练,其余的用于测试。
我们首先下载压缩数据集,得到两个文件夹hotdog/train和hotdog/test。这两个文件夹都有hotdog和not hotdog类别子文件夹,每个子文件夹都有相应的图像文件。
#@save
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL+'hotdog.zip',
'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = d2l.download_extract('hotdog')
Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...
我们创建两个ImageFolderDataset实例,分别读取训练数据集和测试数据集中的所有图像文件。
train_imgs = gluon.data.vision.ImageFolderDataset(
os.path.join(data_dir, 'train'))
test_imgs = gluon.data.vision.ImageFolderDataset(
os.path.join(data_dir, 'test'))
前8个正面示例和最后8个负面图像如下所示。如您所见,图像的大小和纵横比各不相同。
hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在训练过程中,我们首先从图像中裁剪出一个大小和纵横比随机的随机区域,然后将该区域缩放到一个高度和宽度为224像素的输入。在测试过程中,我们将图像的高度和宽度缩放到256像素,然后裁剪高宽为224像素的中心区域作为输入。此外,我们规范化三个RGB(红色、绿色和蓝色)颜色通道的值。从每个值中减去信道所有值的平均值,然后将结果除以信道所有值的标准差,以产生输出。
# We specify the mean and variance of the three RGB channels to normalize the
# image channel
normalize = gluon.data.vision.transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_augs = gluon.data.vision.transforms.Compose([
gluon.data.vision.transforms.RandomResizedCrop(224),
gluon.data.vision.transforms.RandomFlipLeftRight(),
gluon.data.vision.transforms.ToTensor(),
normalize])
test_augs = gluon.data.vision.transforms.Compose([
gluon.data.vision.transforms.Resize(256),
gluon.data.vision.transforms.CenterCrop(224),
gluon.data.vision.transforms.ToTensor(),
normalize])
1.2. Defining and Initializing the Model
我们使用ResNet-18作为源模型,ResNet-18是在ImageNet数据集上预先训练的。这里,我们指定pretrained=True以自动下载和加载预先训练的模型参数。第一次使用时,需要从互联网上下载模型参数。
pretrained_net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)
预先训练的源模型实例包含两个成员变量:features和output。前者包含模型的所有层,输出层除外,后者是模型的输出层。这一划分的主要目的是促进除输出层之外的所有层的模型参数的微调。源模型的成员变量输出如下所示。作为一个完全连接的层,它将ResNet最终的全局平均池层输出转换为ImageNet数据集上的1000个类输出。
pretrained_net.output
Dense(512 -> 1000, linear)
然后构建一个新的神经网络作为目标模型。它的定义方式与预先训练的源模型相同,但最终输出数量等于目标数据集中的类别数。在下面的代码中,目标模型实例finetune_net的成员变量特征中的模型参数初始化为源模型对应层的模型参数。由于特征中的模型参数是通过对ImageNet数据集的预训练得到的,所以它是足够好的。因此,我们通常只需要使用较小的学习速率来“微调”这些参数。相比之下,成员变量输出中的模型参数是随机初始化的,通常需要更大的学习速率才能从头开始学习。假设训练实例中的学习率为 η,学习率为10η,更新成员变量输出中的模型参数。
finetune_net = gluon.model_zoo.vision.resnet18_v2(classes=2)
finetune_net.features = pretrained_net.features
finetune_net.output.initialize(init.Xavier())
# The model parameters in output will be updated using a learning rate ten
# times greater
finetune_net.output.collect_params().setattr('lr_mult', 10)
1.3. Fine Tuning the Model
我们首先定义了一个训练函数train_fine_tuning,它使用了微调,因此可以多次调用它。
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5):
train_iter = gluon.data.DataLoader(
train_imgs.transform_first(train_augs), batch_size, shuffle=True)
test_iter = gluon.data.DataLoader(
test_imgs.transform_first(test_augs), batch_size)
ctx = d2l.try_all_gpus()
net.collect_params().reset_ctx(ctx)
net.hybridize()
loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {
'learning_rate': learning_rate, 'wd': 0.001})
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)
我们将训练器实例中的学习率设置为一个较小的值,如0.01,以便对预训练中获得的模型参数进行微调。基于前面的设置,我们将使用10倍以上的学习率从头开始训练目标模型的输出层参数。
train_fine_tuning(finetune_net, 0.01)
loss 0.518, train acc 0.890, test acc 0.927
634.3 examples/sec on [gpu(0), gpu(1)]

为了进行比较,我们定义了一个相同的模型,但将其所有模型参数初始化为随机值。由于整个模型需要从头开始训练,所以我们可以使用更大的学习率。
scratch_net = gluon.model_zoo.vision.resnet18_v2(classes=2)
scratch_net.initialize(init=init.Xavier())
train_fine_tuning(scratch_net, 0.1)
loss 0.371, train acc 0.839, test acc 0.784
706.5 examples/sec on [gpu(0), gpu(1)]

正如您所看到的,由于参数的初始值更好,微调后的模型往往在同一时代获得更高的精度。
2. Summary
- Transfer learning migrates the knowledge learned from the source dataset to the target dataset. Fine tuning is a common technique for transfer learning.
- The target model replicates all model designs and their parameters on the source model, except the output layer, and fine-tunes these parameters based on the target dataset. In contrast, the output layer of the target model needs to be trained from scratch.
- Generally, fine tuning parameters use a smaller learning rate, while training the output layer from scratch can use a larger learning rate.
Fine-Tuning微调原理的更多相关文章
- L23模型微调fine tuning
resnet185352 链接:https://pan.baidu.com/s/1EZs9XVUjUf1MzaKYbJlcSA 提取码:axd1 9.2 微调 在前面的一些章节中,我们介绍了如何在只有 ...
- (原)caffe中fine tuning及使用snapshot时的sh命令
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/5946041.html 参考网址: http://caffe.berkeleyvision.org/tu ...
- Fine Tuning
(转载自:WikiPedia) Fine tuning is a process to take a network model that has already been trained for a ...
- DL开源框架Caffe | 模型微调 (finetune)的场景、问题、技巧以及解决方案
转自:http://blog.csdn.net/u010402786/article/details/70141261 前言 什么是模型的微调? 使用别人训练好的网络模型进行训练,前提是必须和别人 ...
- FineTuning机制的分析
FineTuning机制的分析 为什么用FineTuning 使用别人训练好的网络模型进行训练,前提是必须和别人用同一个网络,因为参数是根据网络而来的.当然最后一层是可以修改的,因为我们的数据可能并没 ...
- [转载]关于Pretrain、Fine-tuning
[转载]关于Pretrain.Fine-tuning 这两种tricks的意思其实就是字面意思,pre-train(预训练)和fine -tuning(微调) 来源:https://blog.csdn ...
- 【原创】TextCNN原理详解(一)
最近一直在研究textCNN算法,准备写一个系列,每周更新一篇,大致包括以下内容: TextCNN基本原理和优劣势 TextCNN代码详解(附Github链接) TextCNN模型实践迭代经验总结 ...
- (原)torch中微调某层参数
转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...
- TorchVision Faster R-CNN 微调,实战 Kaggle 小麦检测
本文将利用 TorchVision Faster R-CNN 预训练模型,于 Kaggle: 全球小麦检测 上实践迁移学习中的一种常用技术:微调(fine tuning). 本文相关的 Kaggle ...
随机推荐
- Python socket编程(阻塞) --基于SocketServer
SocketServer模块是Python对socket常规通信的一个经过封装的模块,使用简单,基于面向对象的设计模式,但功能有限,可用于快速开发. Tips: 默认端口:6767 默认本地ip:12 ...
- Python中Numpy模块的使用
目录 NumPy ndarray对象 Numpy数据类型 Numpy数组属性 NumPy NumPy(Numerical Python) 是 Python 的一个扩展程序库,支持大量的维度数组与矩阵运 ...
- Windows核心编程 第十二章 纤程
第1 2章 纤 程 M i c r o s o f t公司给Wi n d o w s添加了一种纤程,以便能够非常容易地将现有的 U N I X服务器应用程序移植到Wi n d o w s中.U N I ...
- <JVM下篇:性能监控与调优篇>补充:使用OQL语言查询对象信息
笔记来源:尚硅谷JVM全套教程,百万播放,全网巅峰(宋红康详解java虚拟机) 同步更新:https://gitee.com/vectorx/NOTE_JVM https://codechina.cs ...
- CentOS安装Redis报错[server.o] Error 1
原因 准备安装的Redis服务版本为6.0.8, gcc的版本为4.8.5,可能是gcc版本过低到导致的 解决办法 安装低版本Redis或者安装高版本gcc
- tp5.1中返回当天、昨天、当月等的开始和结束时间戳
/** * 返回今日开始和结束的时间戳 * * @return array */function today(){ list($y, $m, $d) = explode('-', date('Y-m- ...
- solidworks中 toolbox调用出现未配置的解决方法
解决步骤:1:win7卸载安全补丁:KB3072630 WIN10,忽略.2:关闭所有Solidworks的进程3:CMD命令行进入:cd c:\program files\solidwokrs co ...
- 『政善治』Postman工具 — 8、Postman中Pre-request Script的使用
目录 1.Pre-request Script介绍 2.常用SNIPPETS(片段)说明 (1)获取变量脚本: (2)设置变量脚本: (3)清空变量脚本: (4)Send a request代码片段 ...
- MySQL5.7升级到8.0过程详解
前言: 不知不觉,MySQL8.0已经发布好多个GA小版本了.目前互联网上也有很多关于MySQL8.0的内容了,MySQL8.0版本基本已到稳定期,相信很多小伙伴已经在接触8.0了.本篇文章主要介绍从 ...
- MySQL库表设计小技巧
前言: 在我们项目开发中,数据库及表的设计可以说是非常重要,我遇到过很多库表设计比较杂乱的项目,像表名.字段名命名混乱.字段类型设计混乱等等,此类数据库后续极难维护与拓展.我一直相信只有优秀的库表设计 ...