本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
     import os
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.optim import lr_scheduler
    import torchvision
    from torchvision import datasets, models, transforms
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score, classification_report
    from sklearn.model_selection import KFold
    from torch.autograd import Variable
    import torch.optim as optim
    import time
    import copy
    import shutil
    import sys
    import scikitplot as skplt
    import matplotlib.pyplot as plt
    import pandas as pd plt.switch_backend('agg')
    N_CLASSES = 2
    BATCH_SIZE = 8
    DATA_DIR = './data'
    LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
    plt.title(title)
    plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None] for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('- ' * 50) for phase in ['train', 'val']:
    if phase == 'train':
    # 训练的时候进行学习率规划,其定义在下面给出
    scheduler.step()
    model.train(True)
    else:
    model.train(False)
    phase_pred = np.array([])
    phase_label = np.array([])
    img_name = np.zeros((1, 2))
    prob_pred = np.zeros((1, 2))
    running_loss = 0.0
    running_corrects = 0
    # 这样迭代方便跟踪图片路径,输出错误图片名称
    for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
    inputs, labels = data
    if use_gpu:
    inputs = Variable(inputs.cuda())
    labels = Variable(labels.cuda())
    else:
    inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
    optimizer.zero_grad() # forward
    outputs = model(inputs)
    _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, labels) # backward + 训练阶段优化
    if phase == 'train':
    loss.backward()
    optimizer.step() if phase == 'val':
    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
    prob = outputs.data.cpu().numpy()
    prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
    phase_label = np.append(phase_label, labels.data.cpu().numpy())
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data).float()
    print()
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]
    epoch_auc = roc_auc_score(phase_label, phase_pred)
    print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc, epoch_auc))
    report = classification_report(phase_label, phase_pred, target_names=class_names)
    print(report) img_name = zip(img_name[1:], phase_pred)
    # 当验证时遇到了更好的模型则予以保留
    if phase == 'val' and epoch_auc > best_auc:
    best_auc = epoch_auc
    best_desc = epoch_acc, epoch_auc, report
    best_img_name = img_name
    # 深拷贝模型参数
    best_model_wts = copy.deepcopy(model.state_dict())
    plt_auc = phase_label, prob_pred[1:] print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
    actual_label = int(i[0][1])
    pred_label = i[1]
    if actual_label != pred_label:
    tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    f'pred: {LABEL_DICT[pred_label]}'
    print(tmp_word)
    label_file.write(tmp_word + '\n')
    count += 1
    print(f'This fold has {count} wrong records ...') # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model def plot_img():
    for i, data in enumerate(dataloaders['train']):
    inputs, classes = data
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
    def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
    print(f'Find exist {file_path} file, the file will be dropped.')
    shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
    pre_path = os.path.join(tmp_pre_path, d)
    os.chdir(tmp_path)
    if d[:2] == label_0:
    if not os.path.exists(label_0):
    os.mkdir(label_0)
    cur_path = os.path.join(tmp_path, label_0, d)
    shutil.copyfile(pre_path, cur_path)
    if d[:2] == label_1:
    if not os.path.exists(label_1):
    os.mkdir(label_1)
    cur_path = os.path.join(tmp_path, label_1, d)
    shutil.copyfile(pre_path, cur_path)
    print('finish this work ...') if __name__ == "__main__":
    if not os.path.exists('roc_img'):
    os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
    os.mkdir('prob_result')
    if not os.path.exists('report'):
    os.mkdir('report')
    if not os.path.exists('error_record'):
    os.mkdir('error_record')
    if not os.path.exists('model'):
    os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
    report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    print(f'The {m} fold for copy file and training ...')
    move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    os.chdir(origin_path)
    move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    os.chdir(origin_path)
    data_transforms = {
    'train': transforms.Compose([
    # 裁剪到224,224
    transforms.RandomResizedCrop(224),
    # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
    transforms.ToTensor(),
    # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    } image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    data_transforms[x])
    for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    shuffle=True, num_workers=8, pin_memory=False)
    for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    size = len(class_names)
    print('label mapping: ')
    print(image_datasets['train'].class_to_idx)
    use_gpu = torch.cuda.is_available()
    model_ft = None
    if sys.argv[1] == 'resnet':
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    ) # 这边可以自行把inception模型加进去
    if sys.argv[1] == 'inception':
    raise Exception("not provide inception model ...")
    # model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
    model_ft = models.densenet121(pretrained=True)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    )
    # use_gpu = False if use_gpu:
    model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    # 每7个epoch衰减0.1倍
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    print('Start save the model ...')
    torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    print(f'The mission of the fold {m} finished.')
    print('# '*50)
    report_file.close()
    label_file.close()

修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章

  1. Unity-2017.3官方实例教程Space-Shooter(二)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...

  2. Unity-2017.2官方实例教程Roll-a-ball(二)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...

  3. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  4. Unity-2017.3官方实例教程Space-Shooter(一)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...

  5. Unity-2017.2官方实例教程Roll-a-ball(一)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...

  6. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  7. 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。

    这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展

    教程来源于:Unity官方实例教程 Space Shooter(一)-(五)       http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...

随机推荐

  1. javaSystem.out.println()输出byte[]和char[]异常的问题

    javaSystem.out.println()输出byte[]和char[]异常的问题 今天 突然有人问我他写的byte[]和char[],在用System.out.println()输出的时候所得 ...

  2. Celery-------周期任务

    在项目目录例子的基础上进行修改一下celery文件 from celery import Celery from celery.schedules import crontab celery_task ...

  3. php有经纬度计算距离

    /** *  @desc 根据两点间的经纬度计算距离 *  @param float $lat 纬度值 *  @param float $lng 经度值 */  function getDistanc ...

  4. javascript实现数据结构: 树和二叉树,二叉树的遍历和基本操作

    树型结构是一类非常重要的非线性结构.直观地,树型结构是以分支关系定义的层次结构. 树在计算机领域中也有着广泛的应用,例如在编译程序中,用树来表示源程序的语法结构:在数据库系统中,可用树来组织信息:在分 ...

  5. C语言买卖股票问题

    遇到个简单的算法题,没有当场解出来,以后可以写伪代码表达思路. 数组中保存每天的股票价值,求买入卖出的时间和最大利润,比较好的解法如下: 伪代码: begin start day = 0; end d ...

  6. HTTP Strict Transport Security

    HTTP Strict Transport Security (通常简称为HSTS) 是一个安全功能,它告诉浏览器只能通过HTTPS访问当前资源, 禁止HTTP方式. 作用 一个网站接受一个HTTP的 ...

  7. extension Kingfisher where Base: Image:泛型类型的具体化与实例化

    具体化:针对特定的类型参量进行二次定义: 实例化:实例化:

  8. on-session问题

    .D:\0kecheng\bos\bosv2.0_chapter03.无条件查询. 方法1.@JSON(serialize=false)是注解排除不需要加载的实体类上,找到它的get方法,解决no-s ...

  9. org.slf4j.impl.Log4jLoggerAdapter cannot be cast to ch.qos.logback.classic.Logger

    https://stackoverflow.com/questions/31433246/classcastexception-org-slf4j-impl-log4jloggeradapter-ca ...

  10. div可编辑框,去除粘贴文字样式😄

    上个月做了个聊天的需求(网页版的).说到聊天都想到输入框,说到输入框都会想到input,但是input标签是不支持插入图片的(包括areatext标签).查阅了一些资料就看到div标签有一个属性con ...