Fine-Tuning微调原理

如何在只有60000张图片的Fashion-MNIST训练数据集中训练模型。ImageNet,这是学术界使用最广泛的大型图像数据集,它拥有1000多万幅图像和1000多个类别的对象。然而,我们经常处理的数据集的大小通常比第一个大,但比第二个小。

假设我们想在图像中识别不同种类的椅子,然后将购买链接推给用户。一种可行的方法是先找到一百张常见的椅子,每把椅子取一千张不同角度的图像,然后在采集到的图像数据集上训练分类模型。虽然这个数据集可能比时尚MNIST大,但是示例的数量仍然不到ImageNet的十分之一。这可能导致适用于ImageNet的复杂模型在此数据集上过度拟合。同时,由于数据量有限,最终训练出的模型精度可能达不到实际要求。

为了解决上述问题,一个显而易见的解决办法就是收集更多的数据。然而,收集和标记数据会消耗大量的时间和金钱。例如,为了收集ImageNet的数据集,研究人员花费了数百万美元的研究经费。尽管近年来,数据采集成本大幅下降,但成本仍然不容忽视。

另一种解决方案是应用转移学习将从源数据集学习的知识迁移到目标数据集。例如,虽然ImageNet中的图像大多与椅子无关,但是在这个数据集上训练的模型可以提取更一般的图像特征,这些特征可以帮助识别边缘、纹理、形状和对象组成。这些相似的特征对于识别椅子同样有效。

在本节中,我们将介绍迁移学习中的一种常用技术:微调。如图13.2.1所示,微调包括以下四个步骤:

在源数据集(例如ImageNet数据集)上预训练神经网络模型,即源模型。

建立一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数,输出层除外。我们假设这些模型参数包含从源数据集学习到的知识,这些知识将同样适用于目标数据集。我们还假设源模型的输出层与源数据集的标签密切相关,因此不在目标模型中使用。

将输出大小为目标数据集类别数的输出层添加到目标模型中,并随机初始化该层的模型参数。

在目标数据集上训练目标模型,例如椅子数据集。我们将从头开始训练输出层,同时根据源模型的参数对所有剩余层的参数进行微调。

Fig. 1.  Fine tuning.

1. Hot Dog Recognition

我们将使用一个具体的例子来练习:热狗识别。我们将基于一个小的数据集,对在ImageNet数据集上训练的ResNet模型进行微调。这个小数据集包含数千张图像,其中一些包含热狗。我们将使用通过微调获得的模型来识别图像是否包含热狗。

首先,导入实验所需的软件包和模块。Gluon的model_zoo package提供了一个通用的预训练模型。如果你想获得更多的计算机视觉的预先训练模型,你可以使用GluonCV工具箱。

%matplotlib inline

from d2l import mxnet as d2l

from mxnet import gluon, init, np, npx

from mxnet.gluon import nn

import os

npx.set_np()

1.1. Obtaining the Dataset

我们使用的热狗数据集来自在线图像,包含1400个热狗的正面图片和其他食物的相同数量的负面图片。1000个各种课程的图像用于训练,其余的用于测试。

我们首先下载压缩数据集,得到两个文件夹hotdog/train和hotdog/test。这两个文件夹都有hotdog和not hotdog类别子文件夹,每个子文件夹都有相应的图像文件。

#@save

d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL+'hotdog.zip',

'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')

Downloading ../data/hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

我们创建两个ImageFolderDataset实例,分别读取训练数据集和测试数据集中的所有图像文件。

train_imgs = gluon.data.vision.ImageFolderDataset(

os.path.join(data_dir, 'train'))

test_imgs = gluon.data.vision.ImageFolderDataset(

os.path.join(data_dir, 'test'))

前8个正面示例和最后8个负面图像如下所示。如您所见,图像的大小和纵横比各不相同。

hotdogs = [train_imgs[i][0] for i in range(8)]

not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]

d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

在训练过程中,我们首先从图像中裁剪出一个大小和纵横比随机的随机区域,然后将该区域缩放到一个高度和宽度为224像素的输入。在测试过程中,我们将图像的高度和宽度缩放到256像素,然后裁剪高宽为224像素的中心区域作为输入。此外,我们规范化三个RGB(红色、绿色和蓝色)颜色通道的值。从每个值中减去信道所有值的平均值,然后将结果除以信道所有值的标准差,以产生输出。

# We specify the mean and variance of the three RGB channels to normalize the

# image channel

normalize = gluon.data.vision.transforms.Normalize(

[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = gluon.data.vision.transforms.Compose([

gluon.data.vision.transforms.RandomResizedCrop(224),

gluon.data.vision.transforms.RandomFlipLeftRight(),

gluon.data.vision.transforms.ToTensor(),

normalize])

test_augs = gluon.data.vision.transforms.Compose([

gluon.data.vision.transforms.Resize(256),

gluon.data.vision.transforms.CenterCrop(224),

gluon.data.vision.transforms.ToTensor(),

normalize])

1.2. Defining and Initializing the Model

我们使用ResNet-18作为源模型,ResNet-18是在ImageNet数据集上预先训练的。这里,我们指定pretrained=True以自动下载和加载预先训练的模型参数。第一次使用时,需要从互联网上下载模型参数。

pretrained_net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)

预先训练的源模型实例包含两个成员变量:features和output。前者包含模型的所有层,输出层除外,后者是模型的输出层。这一划分的主要目的是促进除输出层之外的所有层的模型参数的微调。源模型的成员变量输出如下所示。作为一个完全连接的层,它将ResNet最终的全局平均池层输出转换为ImageNet数据集上的1000个类输出。

pretrained_net.output

Dense(512 -> 1000, linear)

然后构建一个新的神经网络作为目标模型。它的定义方式与预先训练的源模型相同,但最终输出数量等于目标数据集中的类别数。在下面的代码中,目标模型实例finetune_net的成员变量特征中的模型参数初始化为源模型对应层的模型参数。由于特征中的模型参数是通过对ImageNet数据集的预训练得到的,所以它是足够好的。因此,我们通常只需要使用较小的学习速率来“微调”这些参数。相比之下,成员变量输出中的模型参数是随机初始化的,通常需要更大的学习速率才能从头开始学习。假设训练实例中的学习率为 η,学习率为10η,更新成员变量输出中的模型参数。

finetune_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

finetune_net.features = pretrained_net.features

finetune_net.output.initialize(init.Xavier())

# The model parameters in output will be updated using a learning rate ten

# times greater

finetune_net.output.collect_params().setattr('lr_mult', 10)

1.3. Fine Tuning the Model

我们首先定义了一个训练函数train_fine_tuning,它使用了微调,因此可以多次调用它。

def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5):

train_iter = gluon.data.DataLoader(

train_imgs.transform_first(train_augs), batch_size, shuffle=True)

test_iter = gluon.data.DataLoader(

test_imgs.transform_first(test_augs), batch_size)

ctx = d2l.try_all_gpus()

net.collect_params().reset_ctx(ctx)

net.hybridize()

loss = gluon.loss.SoftmaxCrossEntropyLoss()

trainer = gluon.Trainer(net.collect_params(), 'sgd', {

'learning_rate': learning_rate, 'wd': 0.001})

d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)

我们将训练器实例中的学习率设置为一个较小的值,如0.01,以便对预训练中获得的模型参数进行微调。基于前面的设置,我们将使用10倍以上的学习率从头开始训练目标模型的输出层参数。

train_fine_tuning(finetune_net, 0.01)

loss 0.518, train acc 0.890, test acc 0.927

634.3 examples/sec on [gpu(0), gpu(1)]

为了进行比较,我们定义了一个相同的模型,但将其所有模型参数初始化为随机值。由于整个模型需要从头开始训练,所以我们可以使用更大的学习率。

scratch_net = gluon.model_zoo.vision.resnet18_v2(classes=2)

scratch_net.initialize(init=init.Xavier())

train_fine_tuning(scratch_net, 0.1)

loss 0.371, train acc 0.839, test acc 0.784

706.5 examples/sec on [gpu(0), gpu(1)]

正如您所看到的,由于参数的初始值更好,微调后的模型往往在同一时代获得更高的精度。

2. Summary

  • Transfer learning migrates the knowledge learned from the source dataset to the target dataset. Fine tuning is a common technique for transfer learning.
  • The target model replicates all model designs and their parameters on the source model, except the output layer, and fine-tunes these parameters based on the target dataset. In contrast, the output layer of the target model needs to be trained from scratch.
  • Generally, fine tuning parameters use a smaller learning rate, while training the output layer from scratch can use a larger learning rate.

Fine-Tuning微调原理的更多相关文章

  1. L23模型微调fine tuning

    resnet185352 链接:https://pan.baidu.com/s/1EZs9XVUjUf1MzaKYbJlcSA 提取码:axd1 9.2 微调 在前面的一些章节中,我们介绍了如何在只有 ...

  2. (原)caffe中fine tuning及使用snapshot时的sh命令

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/5946041.html 参考网址: http://caffe.berkeleyvision.org/tu ...

  3. Fine Tuning

    (转载自:WikiPedia) Fine tuning is a process to take a network model that has already been trained for a ...

  4. DL开源框架Caffe | 模型微调 (finetune)的场景、问题、技巧以及解决方案

    转自:http://blog.csdn.net/u010402786/article/details/70141261 前言 什么是模型的微调?   使用别人训练好的网络模型进行训练,前提是必须和别人 ...

  5. FineTuning机制的分析

    FineTuning机制的分析 为什么用FineTuning 使用别人训练好的网络模型进行训练,前提是必须和别人用同一个网络,因为参数是根据网络而来的.当然最后一层是可以修改的,因为我们的数据可能并没 ...

  6. [转载]关于Pretrain、Fine-tuning

    [转载]关于Pretrain.Fine-tuning 这两种tricks的意思其实就是字面意思,pre-train(预训练)和fine -tuning(微调) 来源:https://blog.csdn ...

  7. 【原创】TextCNN原理详解(一)

    ​ 最近一直在研究textCNN算法,准备写一个系列,每周更新一篇,大致包括以下内容: TextCNN基本原理和优劣势 TextCNN代码详解(附Github链接) TextCNN模型实践迭代经验总结 ...

  8. (原)torch中微调某层参数

    转载请注明出处: http://www.cnblogs.com/darkknightzh/p/6221664.html 参考网址: https://github.com/torch/nn/issues ...

  9. TorchVision Faster R-CNN 微调,实战 Kaggle 小麦检测

    本文将利用 TorchVision Faster R-CNN 预训练模型,于 Kaggle: 全球小麦检测 上实践迁移学习中的一种常用技术:微调(fine tuning). 本文相关的 Kaggle ...

随机推荐

  1. hdu 2058 枚举区间和个数

    题意:       给你两个数n,m,意思是有一个序列长度n,他是1 2 3 4 ...n,然后让你输出所有连续和等于m的范围. 思路:       是个小水题,随便写几个数字就能发现规律了,我们可以 ...

  2. SSRF(服务端请求伪造)漏洞

    目录 SSRF SSRF漏洞的挖掘 SSRF漏洞利用 SSRF漏洞防御 SSRF SSRF(Server-Side Request Forgery,服务器端请求伪造)漏洞,是一种由攻击者构造请求,由服 ...

  3. PhotoShop 第一课 功能认识

    功能认识 1.基本界面 可以对各工具栏进行编辑,对工具/栏目进行勾选添加和整合并搭建自己的专属操作页面. 2.画布设置 拍照或者画画都需要一个东西来呈现这个东西叫做画布(可以通过导航栏-文件-新建画布 ...

  4. Windows PE 第四章 导入表

    第四章 导入表 导入表是PE数据组织中的一个很重要的组成部分,它是为实现代码重用而设置的.通过分析导入表数据,可以获得诸如OE文件的指令中调用了多少外来函数,以及这些外来函数都存在于哪些动态链接库里等 ...

  5. jquery里面.length和.size()有什么区别

    区别: 1.针对标签对象元素,比如数html页面有多少个段落元素<p></p>,那么此时的$("p").size()==$("p").l ...

  6. thinkphp中常用到的sql操作

    1.清空某表数据: $sql = 'truncate table table_name'; Db::execute($sql );

  7. SQL必知必会 —— 性能优化篇

    数据库调优概述 数据库中的存储结构是怎样的 在数据库中,不论读一行,还是读多行,都是将这些行所在的页进行加载.也就是说,数据库管理存储空间的基本单位是页(Page). 一个页中可以存储多个行记录(Ro ...

  8. Markdown编辑器怎么用

    Markdown编辑器怎么用 1.代码块 快速创建一个代码块 // 语法: // ```+语言名称,如```java,```c++ 2.标题 语法:#+空格+标题名字,一个#表示一级标题,两个#表示二 ...

  9. Redis6.x学习笔记(一)

    前言 最近学习Redis6.x,特做笔记以备忘,与大家共学.课程是从私塾在线下载的,他们把架构师课程都放出来了,大家可以去下载学习,不要钱的,地址是http://t.hk.uy/eac,课程很不错,值 ...

  10. python爬虫——《英雄联盟》英雄及皮肤图片

    还记得那些年一起网吧开黑通宵的日子吗?<英雄联盟>绝对是大学时期的风靡游戏,即使毕业多年的大学同学相聚,难免不怀念一番当时一起玩<英雄联盟>的日子. 今天就给大家分享一下英雄及 ...