修改pytorch官方实例适用于自己的二分类迁移学习项目
本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:
- 根据AUC来迭代最优参数;
- 五折交叉验证;
- 输出验证集错误分类图片;
- 输出分类报告并保存AUC结果图片。
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import KFold
from torch.autograd import Variable
import torch.optim as optim
import time
import copy
import shutil
import sys
import scikitplot as skplt
import matplotlib.pyplot as plt
import pandas as pd plt.switch_backend('agg')
N_CLASSES = 2
BATCH_SIZE = 8
DATA_DIR = './data'
LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
since = time.time()
# 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
best_model_wts = copy.deepcopy(model.state_dict())
# best_acc = 0.0
# 初始auc
best_auc = 0.0
best_desc = [0, 0, None]
best_img_name = None
plt_auc = [None, None] for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('- ' * 50) for phase in ['train', 'val']:
if phase == 'train':
# 训练的时候进行学习率规划,其定义在下面给出
scheduler.step()
model.train(True)
else:
model.train(False)
phase_pred = np.array([])
phase_label = np.array([])
img_name = np.zeros((1, 2))
prob_pred = np.zeros((1, 2))
running_loss = 0.0
running_corrects = 0
# 这样迭代方便跟踪图片路径,输出错误图片名称
for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
optimizer.zero_grad() # forward
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels) # backward + 训练阶段优化
if phase == 'train':
loss.backward()
optimizer.step() if phase == 'val':
img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
prob = outputs.data.cpu().numpy()
prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
phase_label = np.append(phase_label, labels.data.cpu().numpy())
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data).float()
print()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
epoch_auc = roc_auc_score(phase_label, phase_pred)
print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
phase, epoch_loss, epoch_acc, epoch_auc))
report = classification_report(phase_label, phase_pred, target_names=class_names)
print(report) img_name = zip(img_name[1:], phase_pred)
# 当验证时遇到了更好的模型则予以保留
if phase == 'val' and epoch_auc > best_auc:
best_auc = epoch_auc
best_desc = epoch_acc, epoch_auc, report
best_img_name = img_name
# 深拷贝模型参数
best_model_wts = copy.deepcopy(model.state_dict())
plt_auc = phase_label, prob_pred[1:] print()
print(plt_auc[0].shape, plt_auc[1].shape)
csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
csv_file['true_label'] = pd.DataFrame(plt_auc[0])
csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
best_desc[2])
report_file.write(reports)
print(reports)
print('List the wrong judgement img ...')
count = 0
for i in best_img_name:
actual_label = int(i[0][1])
pred_label = i[1]
if actual_label != pred_label:
tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
f'pred: {LABEL_DICT[pred_label]}'
print(tmp_word)
label_file.write(tmp_word + '\n')
count += 1
print(f'This fold has {count} wrong records ...') # 载入最优模型参数
model.load_state_dict(best_model_wts)
return model def plot_img():
for i, data in enumerate(dataloaders['train']):
inputs, classes = data
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
def move_file(data, file_path, dir_path, root_path):
label_0 = 'class_2'
label_1 = 'class_1'
print(f'start copy the {file_path} file ...')
os.chdir(dir_path)
if os.path.exists(file_path):
print(f'Find exist {file_path} file, the file will be dropped.')
shutil.rmtree(os.path.join(root_path, dir_path, file_path))
print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
tmp_path = os.path.join(os.getcwd(), file_path)
tmp_pre_path = os.getcwd()
for d in data:
pre_path = os.path.join(tmp_pre_path, d)
os.chdir(tmp_path)
if d[:2] == label_0:
if not os.path.exists(label_0):
os.mkdir(label_0)
cur_path = os.path.join(tmp_path, label_0, d)
shutil.copyfile(pre_path, cur_path)
if d[:2] == label_1:
if not os.path.exists(label_1):
os.mkdir(label_1)
cur_path = os.path.join(tmp_path, label_1, d)
shutil.copyfile(pre_path, cur_path)
print('finish this work ...') if __name__ == "__main__":
if not os.path.exists('roc_img'):
os.mkdir('roc_img')
if not os.path.exists('prob_result'):
os.mkdir('prob_result')
if not os.path.exists('report'):
os.mkdir('report')
if not os.path.exists('error_record'):
os.mkdir('error_record')
if not os.path.exists('model'):
os.mkdir('model')
label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
origin_path = '/home/project/'
dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
print(f'The {m} fold for copy file and training ...')
move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
os.chdir(origin_path)
move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
os.chdir(origin_path)
data_transforms = {
'train': transforms.Compose([
# 裁剪到224,224
transforms.RandomResizedCrop(224),
# 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
transforms.ToTensor(),
# 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
} image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
shuffle=True, num_workers=8, pin_memory=False)
for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
size = len(class_names)
print('label mapping: ')
print(image_datasets['train'].class_to_idx)
use_gpu = torch.cuda.is_available()
model_ft = None
if sys.argv[1] == 'resnet':
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
) # 这边可以自行把inception模型加进去
if sys.argv[1] == 'inception':
raise Exception("not provide inception model ...")
# model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
model_ft = models.densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
)
# use_gpu = False if use_gpu:
model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 每7个epoch衰减0.1倍
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
print('Start save the model ...')
torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
print(f'The mission of the fold {m} finished.')
print('# '*50)
report_file.close()
label_file.close()
修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章
- Unity-2017.3官方实例教程Space-Shooter(二)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...
- Unity-2017.2官方实例教程Roll-a-ball(二)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- Unity-2017.3官方实例教程Space-Shooter(一)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...
- Unity-2017.2官方实例教程Roll-a-ball(一)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...
- NLP(二十二)利用ALBERT实现文本二分类
在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...
- 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。
这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展
教程来源于:Unity官方实例教程 Space Shooter(一)-(五) http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...
随机推荐
- python数据类型(数字\字符串\列表)
一.基本数据类型——数字 1.布尔型 bool型只有两个值:True和False 之所以将bool值归类为数字,是因为我们也习惯用1表示True,0表示False. (1)布尔值是False的各种情况 ...
- 《CSS实现单行、多行文本溢出显示省略号》
如果实现单行文本的溢出显示省略号同学们应该都知道用text-overflow:ellipsis属性来,当然还需要加宽度width属来兼容部分浏览. 实现方式: overflow: hidden; te ...
- (EXPDP) Fails With Errors ORA-39079 ORA-25306 On One Node In RAC Environment
分类: Oracle DataPump export on one certain RAC instance fails with errors: ORA-39006: internal errorO ...
- OC 类 的声明
Student.h // @interface代表声明一个类 // : 代表继承 @interface Student : NSObject { // 成员变量要定义在下面的大括号中{} int ag ...
- POJ-3579 Median---二分第k大(二分套二分)
题目链接: https://cn.vjudge.net/problem/POJ-3579 题目大意: 求的是一列数所有相互之间差值的序列的最中间的值是多少. 解题思路: 可以用二分套二分的方法求解第m ...
- Android(java)学习笔记23:finally关键字的作用
1. finally 关键字的作用 package cn.itcast_07; import java.text.ParseException; import java.text.SimpleDate ...
- BZOJ2730:[HNOI2012]矿场搭建(双连通分量)
Description 煤矿工地可以看成是由隧道连接挖煤点组成的无向图.为安全起见,希望在工地发生事故时所有挖煤点的工人都能有一条出路逃到救援出口处.于是矿主决定在某些挖煤点设立救援出口,使得无论哪一 ...
- hihocoder 后缀自动机四·重复旋律7
题目 在\(DAG\)上跑一个\(dp\)就好了 设\(ans_i\)表示到了\(SAM\)的\(i\)位置上所有的子串形成的数的和,之后我们顺便记录一个方案数\(d_i\) 之后我们直接转移就好了 ...
- 课堂笔记-------字符串类型string------练习
字符串类型 一.string //打出s.时就会出现一堆的方框,要找不带箭头的(不带箭头的是我们现在可以用的到的),不要找带箭头的(带箭头的是扩展,现在还用不到) //不带箭头的都是对s的操作(动作和 ...
- C# .Net Framework4.5中配置和使用managedCUDA及常见问题解决办法
主要参考英文帖子.我就不翻译了哈.很容易懂的. 先说明我的运行平台: 1.IDE:Visual Studio 2012 C# .Net Framework4.5,使用默认安装路径: 2.显卡类型:NV ...