修改pytorch官方实例适用于自己的二分类迁移学习项目
本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:
- 根据AUC来迭代最优参数;
- 五折交叉验证;
- 输出验证集错误分类图片;
- 输出分类报告并保存AUC结果图片。
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import KFold
from torch.autograd import Variable
import torch.optim as optim
import time
import copy
import shutil
import sys
import scikitplot as skplt
import matplotlib.pyplot as plt
import pandas as pd plt.switch_backend('agg')
N_CLASSES = 2
BATCH_SIZE = 8
DATA_DIR = './data'
LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
since = time.time()
# 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
best_model_wts = copy.deepcopy(model.state_dict())
# best_acc = 0.0
# 初始auc
best_auc = 0.0
best_desc = [0, 0, None]
best_img_name = None
plt_auc = [None, None] for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('- ' * 50) for phase in ['train', 'val']:
if phase == 'train':
# 训练的时候进行学习率规划,其定义在下面给出
scheduler.step()
model.train(True)
else:
model.train(False)
phase_pred = np.array([])
phase_label = np.array([])
img_name = np.zeros((1, 2))
prob_pred = np.zeros((1, 2))
running_loss = 0.0
running_corrects = 0
# 这样迭代方便跟踪图片路径,输出错误图片名称
for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
optimizer.zero_grad() # forward
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels) # backward + 训练阶段优化
if phase == 'train':
loss.backward()
optimizer.step() if phase == 'val':
img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
prob = outputs.data.cpu().numpy()
prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
phase_label = np.append(phase_label, labels.data.cpu().numpy())
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data).float()
print()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
epoch_auc = roc_auc_score(phase_label, phase_pred)
print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
phase, epoch_loss, epoch_acc, epoch_auc))
report = classification_report(phase_label, phase_pred, target_names=class_names)
print(report) img_name = zip(img_name[1:], phase_pred)
# 当验证时遇到了更好的模型则予以保留
if phase == 'val' and epoch_auc > best_auc:
best_auc = epoch_auc
best_desc = epoch_acc, epoch_auc, report
best_img_name = img_name
# 深拷贝模型参数
best_model_wts = copy.deepcopy(model.state_dict())
plt_auc = phase_label, prob_pred[1:] print()
print(plt_auc[0].shape, plt_auc[1].shape)
csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
csv_file['true_label'] = pd.DataFrame(plt_auc[0])
csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
best_desc[2])
report_file.write(reports)
print(reports)
print('List the wrong judgement img ...')
count = 0
for i in best_img_name:
actual_label = int(i[0][1])
pred_label = i[1]
if actual_label != pred_label:
tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
f'pred: {LABEL_DICT[pred_label]}'
print(tmp_word)
label_file.write(tmp_word + '\n')
count += 1
print(f'This fold has {count} wrong records ...') # 载入最优模型参数
model.load_state_dict(best_model_wts)
return model def plot_img():
for i, data in enumerate(dataloaders['train']):
inputs, classes = data
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
def move_file(data, file_path, dir_path, root_path):
label_0 = 'class_2'
label_1 = 'class_1'
print(f'start copy the {file_path} file ...')
os.chdir(dir_path)
if os.path.exists(file_path):
print(f'Find exist {file_path} file, the file will be dropped.')
shutil.rmtree(os.path.join(root_path, dir_path, file_path))
print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
tmp_path = os.path.join(os.getcwd(), file_path)
tmp_pre_path = os.getcwd()
for d in data:
pre_path = os.path.join(tmp_pre_path, d)
os.chdir(tmp_path)
if d[:2] == label_0:
if not os.path.exists(label_0):
os.mkdir(label_0)
cur_path = os.path.join(tmp_path, label_0, d)
shutil.copyfile(pre_path, cur_path)
if d[:2] == label_1:
if not os.path.exists(label_1):
os.mkdir(label_1)
cur_path = os.path.join(tmp_path, label_1, d)
shutil.copyfile(pre_path, cur_path)
print('finish this work ...') if __name__ == "__main__":
if not os.path.exists('roc_img'):
os.mkdir('roc_img')
if not os.path.exists('prob_result'):
os.mkdir('prob_result')
if not os.path.exists('report'):
os.mkdir('report')
if not os.path.exists('error_record'):
os.mkdir('error_record')
if not os.path.exists('model'):
os.mkdir('model')
label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
origin_path = '/home/project/'
dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
print(f'The {m} fold for copy file and training ...')
move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
os.chdir(origin_path)
move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
os.chdir(origin_path)
data_transforms = {
'train': transforms.Compose([
# 裁剪到224,224
transforms.RandomResizedCrop(224),
# 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
transforms.ToTensor(),
# 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
} image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
shuffle=True, num_workers=8, pin_memory=False)
for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
size = len(class_names)
print('label mapping: ')
print(image_datasets['train'].class_to_idx)
use_gpu = torch.cuda.is_available()
model_ft = None
if sys.argv[1] == 'resnet':
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
) # 这边可以自行把inception模型加进去
if sys.argv[1] == 'inception':
raise Exception("not provide inception model ...")
# model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
model_ft = models.densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
)
# use_gpu = False if use_gpu:
model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 每7个epoch衰减0.1倍
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
print('Start save the model ...')
torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
print(f'The mission of the fold {m} finished.')
print('# '*50)
report_file.close()
label_file.close()
修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章
- Unity-2017.3官方实例教程Space-Shooter(二)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...
- Unity-2017.2官方实例教程Roll-a-ball(二)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- Unity-2017.3官方实例教程Space-Shooter(一)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...
- Unity-2017.2官方实例教程Roll-a-ball(一)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...
- NLP(二十二)利用ALBERT实现文本二分类
在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...
- 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。
这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展
教程来源于:Unity官方实例教程 Space Shooter(一)-(五) http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...
随机推荐
- 从零开始的全栈工程师——html篇1.4
背景与边框 一.背景(backgound) 1.背景颜色:background-color:red;(简写:background:color;) 备注:ie9以下给body设置background-c ...
- 如何才能快速入门python3?
一些朋友自学python过程中,发现书也能看懂,书上的玩具代码也能看懂,但为啥自己不能做习题,不能写代码解决问题,自己不能动手写代码? 原因是初学者没有学会计算思维.解决问题的方法.编程思路. 编程思 ...
- chosen下拉框插件的使用
效果如下 第一步: 第二步: 根据HTML5规范, 通常在引入CSS和JS时不需要指明 type,因为 text/css 和 text/javascript 分别是他们的默认值. <link r ...
- Android FlycoDialog 简单实用的自定义Android弹窗对话框之Dialog篇
效果图镇楼 FlycoDialog是一款非常棒的弹窗对话框处理框架,今天在这里主要讲一下他的自定义弹出对话框的功能,这里以第二幅效果图为例,图片已经放在博客最下方,X号自己随便找一个东西代替吧. ...
- iPython与notebook的基本用法
1 Ipython 安装 pip install ipython 2 Notebooke 基本用法 启动ipython使用ipython 启动notebook 使用 ipython notebook ...
- SpringMvc-helloword
说明:在此只说明helloword的简单实现,通过helloword例子先了解springMvc是这样工作的,然后在一步步的研究原理 配置web.xml 1.配置servlet servlet-cla ...
- react-webpack-express
这是一个整合react express 实现前后台交互,并且采用webpack进行打包和解析文件.其实react官方有一个脚手架create react app,也可以看那个,但是这个脚手架webpa ...
- IOS Xib使用
- 2019年5月训练记录(更新ing)
前言 \(ZJOI\)正式结束了. 但期中考试只考了年级\(216\),退役既视感... 于是就被抓回去补文化课了. 下半个学期可能要以文化课为主了吧! 但周三.周日应该还是会正常参加训练的,但其他时 ...
- LA 2038 最少点覆盖
题目链接:https://vjudge.net/problem/UVALive-2038 题意:我看了原题,lrj的书上题意写错了,应该是最少点覆盖,当然可以用最大匹配去做,由于是树形的: 可以树形D ...