迁移学习的两个主要场景

  1. 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可
  2. 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练

下面修改预训练好的resnet18网络在私人数据集上进行训练来分类蚂蚁和蜜蜂

数据集下载

这里使用的数据集包含ants和bees训练图片各约120张,验证图片各75张。由于数据样本非常少,如果从0初始化一个网络进行训练很难有令人满意的结果,这时候迁移学习就派上了用场。数据集下载地址,下载后解压到项目目录

导入相关包

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
import torchvision.transforms as transforms
import time
import os
import copy device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

加载数据

PyTorch提供了 torchvision.datasets.ImageFolder 方法来加载私人数据集:

# 训练数据集需要扩充和归一化
# 验证数据集仅需要归一化
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
} data_dir = 'hymenoptera_data' image_datasets = {
x: torchvision.datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']
} dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
for x in ['train', 'val']
} dataset_sizes = {
x: len(image_datasets[x])
for x in ['train', 'val']
} class_names = image_datasets['train'].classes

定义一个通用的训练函数,得到最优参数

# 训练模型函数,参数scheduler是一个 torch.optim.lr_scheduler 学习速率调整类对象
def train_model(model, criterion, optimizer, scheduler, num_epochs=2):
since = time.time() best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0 for epoch in range(num_epochs):
print('-' * 20)
print('Epoch {}/{}'.format(epoch+1, num_epochs)) # 每个epoch都有一个训练和验证阶段
for phase in ['train', 'val']:
if phase == 'train':
model.train() # 训练模式
else:
model.eval() # 验证模式 running_loss = 0.0
running_corrects = 0 for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device) # 训练阶段开启梯度跟踪
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels) # 仅在训练阶段进行后向+优化
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() # 统计
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) # 记录最好的状态
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict()) print('-' * 20)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60))
print('Best val Acc: {:4f}'.format(best_acc)) # 返回最佳参数的模型
model.load_state_dict(best_model_wts)
return model

场景一:微调CNN

这里我们使用resnet18作为我们的初始网络,在自己的数据集上继续训练预训练好的模型,所不同的是,我们修改原网络最后的全连接层输出维度为2,因为我们只需要预测是蚂蚁还是蜜蜂,原网络输出维度是1000,预测了1000个类别:

net = torchvision.models.resnet18(pretrained=True)     # 加载resnet网络结构和预训练参数
num_ftrs = net.fc.in_features # 提取fc层的输入参数
net.fc = nn.Linear(num_ftrs, 2) # 修改输出维度为2 net = net.to(device) # 使用分类交叉熵 Cross-Entropy 作损失函数,动量SGD做优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 每5个epochs衰减一次学习率 new_lr = old_lr * gamma ^ (epoch/step_size)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # 训练模型
net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=10)

场景二:CNN作为固定特征提取器

这里我们通过设置 requires_grad == False 冻结除最后一层之外的所有网络,这样在反向传播的时候他们的梯度就不会被计算,参数也不会更新:

net = torchvision.models.resnet18(pretrained=True)
# 通过设置requires_grad = False来冻结参数,这样在反向传播的时候他们的梯度就不会被计算
for param in net.parameters():
param.requires_grad = False # 新连接层参数默认requires_grad=True
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2) net = net.to(device) criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.fc.parameters(), lr=0.001, momentum=0.9)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) net = train_model(net, criterion, optimizer, lr_scheduler, num_epochs=20)

PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类的更多相关文章

  1. NASNet学习笔记——   核心一:延续NAS论文的核心机制使得能够自动产生网络结构;    核心二:采用resnet和Inception重复使用block结构思想;    核心三:利用迁移学习将生成的网络迁移到大数据集上提出一个new search space。

    from:https://blog.csdn.net/xjz18298268521/article/details/79079008 NASNet总结 论文:<Learning Transfer ...

  2. Pytorch迁移学习实现驾驶场景分类

    Pytorch迁移学习实现驾驶场景分类 源代码:https://github.com/Dalaska/scene_clf 1.安装 pytorch 直接用官网上的方法能装上但下载很慢.通过换源安装发现 ...

  3. Pytorch迁移学习

    环境: Pytorch1.1,Python3.6,win10/ubuntu18,GPU 正文 Pytorch构建ResNet18模型并训练,进行真实图片分类: 利用预训练的ResNet18模型进行Fi ...

  4. pytorch 迁移学习[摘自官网]

    迁移学习包含两种:微调和特征提取器. 微调:对整个网络进行训练,更新所有参数 特征提取器:只对最后的输出层训练,其他层的权重保持不变 当然,二者的共性就是需要加载训练好的权重,比如在ImageNet上 ...

  5. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  6. TensorFlow从1到2(九)迁移学习

    迁移学习基本概念 迁移学习是这两年比较火的一个话题,主要原因是在当前的机器学习中,样本数据的获取是成本最高的一块.而迁移学习可以有效的把原有的学习经验(对于模型就是模型本身及其训练好的权重值)带入到新 ...

  7. DLNg[结构化ML项目]第二周迁移学习+多任务学习

    1.迁移学习 比如要训练一个放射科图片识别系统,但是图片非常少,那么可以先在有大量其他图片的训练集上进行训练,比如猫狗植物等的图片,这样训练好模型之后就可以转移到放射科图片上,模型已经从其他图片中学习 ...

  8. tensorflow实现迁移学习

    此例程出自<TensorFlow实战Google深度学习框架>6.5.2小节 卷积神经网络迁移学习. 数据集来自http://download.tensorflow.org/example ...

  9. Python 迁移学习实用指南 | iBooker·ApacheCN

    原文:Hands-On Transfer Learning with Python 协议:CC BY-NC-SA 4.0 自豪地采用谷歌翻译 不要担心自己的形象,只关心如何实现目标.--<原则& ...

随机推荐

  1. luogu P3180 [HAOI2016]地图 仙人掌 线段树合并 圆方树

    LINK:地图 考虑如果是一棵树怎么做 权值可以离散 那么可以直接利用dsu on tree+树状数组解决. 当然 也可以使用莫队 不过前缀和比较难以维护 外面套个树状数组又带了个log 套分块然后就 ...

  2. springboot多数据源启动报错:required a single bean, but 6 were found:

    技术群: 816227112 参考:https://stackoverflow.com/questions/43455869/could-not-autowire-there-is-more-than ...

  3. 获取判断IE版本 TypeError: Cannot read property 'msie' of undefined

    注意:以下方法只适用于IE11 以下: TypeError: Cannot read property 'msie' of undefined jquery1.9去掉了 $.browser  所以报错 ...

  4. Java对象(创建过程、内存布局、访问方法)

    (Java 普通对象.不包括数组.Class 对象等.) ​ 对象创建过程 类加载 遇到 new 指令时,获取对应的符号引用,并检查该符号引用代表的类是否已被初始化.如果没有就进行类加载. 分配内存 ...

  5. Shiro探索1. Realm

    1. Realm 是什么?汉语意思:领域,范围:王国:这个比较抽象: 简单一点就是:Realm 用来对用户进行认证和角色授权的 再简单一点,一个用户怎么判断它有没有登陆?这个用户是什么角色有哪些权限? ...

  6. 使用免费证书安装 ipa 到真机

    使用免费证书安装 ipa 密码设置 进入 AppleId 官网 登录个人账号 登录进去之后, 找到 Security, 点击 Generate Password... 锁边输入几个字符, 再点击 Cr ...

  7. Tutte 定理与 Tutte–Berge 公式

    Tutte theorem 图 \(G=(V,E)\) 有完美匹配当且仅当满足 \(\forall U\subseteq V,o(G-U)\le|U|,o(X)\) 表示 X 子图的奇连通块数. Tu ...

  8. Flink的状态编程和容错机制(四)

    一.状态编程 Flink 内置的很多算子,数据源 source,数据存储 sink 都是有状态的,流中的数据都是 buffer records,会保存一定的元素或者元数据.例如 : ProcessWi ...

  9. 实现1.双击自动关联文件类型打开 2.PC所有驱动器 3.小型资源管理器

    感谢各位这里实现:双击自动关联文件类型打开 2.PC所有驱动器 3.小型资源管理器!! 首先主页面: 2.运用DriveInfo驱动器的信息:获得整个系统磁盘驱动!!,运用frorach循环遍历到Tr ...

  10. 机器学习:支持向量机(SVM)

    SVM,称为支持向量机,曾经一度是应用最广泛的模型,它有很好的数学基础和理论基础,但是它的数学基础却比以前讲过的那些学习模型复杂很多,我一直认为它是最难推导,比神经网络的BP算法还要难懂,要想完全懂这 ...