本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
     import os
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.optim import lr_scheduler
    import torchvision
    from torchvision import datasets, models, transforms
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score, classification_report
    from sklearn.model_selection import KFold
    from torch.autograd import Variable
    import torch.optim as optim
    import time
    import copy
    import shutil
    import sys
    import scikitplot as skplt
    import matplotlib.pyplot as plt
    import pandas as pd plt.switch_backend('agg')
    N_CLASSES = 2
    BATCH_SIZE = 8
    DATA_DIR = './data'
    LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
    plt.title(title)
    plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None] for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('- ' * 50) for phase in ['train', 'val']:
    if phase == 'train':
    # 训练的时候进行学习率规划,其定义在下面给出
    scheduler.step()
    model.train(True)
    else:
    model.train(False)
    phase_pred = np.array([])
    phase_label = np.array([])
    img_name = np.zeros((1, 2))
    prob_pred = np.zeros((1, 2))
    running_loss = 0.0
    running_corrects = 0
    # 这样迭代方便跟踪图片路径,输出错误图片名称
    for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
    inputs, labels = data
    if use_gpu:
    inputs = Variable(inputs.cuda())
    labels = Variable(labels.cuda())
    else:
    inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
    optimizer.zero_grad() # forward
    outputs = model(inputs)
    _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, labels) # backward + 训练阶段优化
    if phase == 'train':
    loss.backward()
    optimizer.step() if phase == 'val':
    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
    prob = outputs.data.cpu().numpy()
    prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
    phase_label = np.append(phase_label, labels.data.cpu().numpy())
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data).float()
    print()
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]
    epoch_auc = roc_auc_score(phase_label, phase_pred)
    print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc, epoch_auc))
    report = classification_report(phase_label, phase_pred, target_names=class_names)
    print(report) img_name = zip(img_name[1:], phase_pred)
    # 当验证时遇到了更好的模型则予以保留
    if phase == 'val' and epoch_auc > best_auc:
    best_auc = epoch_auc
    best_desc = epoch_acc, epoch_auc, report
    best_img_name = img_name
    # 深拷贝模型参数
    best_model_wts = copy.deepcopy(model.state_dict())
    plt_auc = phase_label, prob_pred[1:] print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
    actual_label = int(i[0][1])
    pred_label = i[1]
    if actual_label != pred_label:
    tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    f'pred: {LABEL_DICT[pred_label]}'
    print(tmp_word)
    label_file.write(tmp_word + '\n')
    count += 1
    print(f'This fold has {count} wrong records ...') # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model def plot_img():
    for i, data in enumerate(dataloaders['train']):
    inputs, classes = data
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
    def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
    print(f'Find exist {file_path} file, the file will be dropped.')
    shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
    pre_path = os.path.join(tmp_pre_path, d)
    os.chdir(tmp_path)
    if d[:2] == label_0:
    if not os.path.exists(label_0):
    os.mkdir(label_0)
    cur_path = os.path.join(tmp_path, label_0, d)
    shutil.copyfile(pre_path, cur_path)
    if d[:2] == label_1:
    if not os.path.exists(label_1):
    os.mkdir(label_1)
    cur_path = os.path.join(tmp_path, label_1, d)
    shutil.copyfile(pre_path, cur_path)
    print('finish this work ...') if __name__ == "__main__":
    if not os.path.exists('roc_img'):
    os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
    os.mkdir('prob_result')
    if not os.path.exists('report'):
    os.mkdir('report')
    if not os.path.exists('error_record'):
    os.mkdir('error_record')
    if not os.path.exists('model'):
    os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
    report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    print(f'The {m} fold for copy file and training ...')
    move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    os.chdir(origin_path)
    move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    os.chdir(origin_path)
    data_transforms = {
    'train': transforms.Compose([
    # 裁剪到224,224
    transforms.RandomResizedCrop(224),
    # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
    transforms.ToTensor(),
    # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    } image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    data_transforms[x])
    for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    shuffle=True, num_workers=8, pin_memory=False)
    for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    size = len(class_names)
    print('label mapping: ')
    print(image_datasets['train'].class_to_idx)
    use_gpu = torch.cuda.is_available()
    model_ft = None
    if sys.argv[1] == 'resnet':
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    ) # 这边可以自行把inception模型加进去
    if sys.argv[1] == 'inception':
    raise Exception("not provide inception model ...")
    # model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
    model_ft = models.densenet121(pretrained=True)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    )
    # use_gpu = False if use_gpu:
    model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    # 每7个epoch衰减0.1倍
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    print('Start save the model ...')
    torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    print(f'The mission of the fold {m} finished.')
    print('# '*50)
    report_file.close()
    label_file.close()

修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章

  1. Unity-2017.3官方实例教程Space-Shooter(二)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...

  2. Unity-2017.2官方实例教程Roll-a-ball(二)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...

  3. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  4. Unity-2017.3官方实例教程Space-Shooter(一)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...

  5. Unity-2017.2官方实例教程Roll-a-ball(一)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...

  6. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  7. 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。

    这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展

    教程来源于:Unity官方实例教程 Space Shooter(一)-(五)       http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...

随机推荐

  1. random模块/string模块

    一.random模块 random模块可以很容易生成随机数和随机字符串. random.randint(1, 100) # 1-100之间取一个随机数 random.randrange(1, 100) ...

  2. JS多级树结构写法

    效果: 一.布局: <div class="three_tree"> <div class="tree_title_cut"> < ...

  3. sql产生流水号

    一个产生流水号(年月日+5位流水号)的存储过程 现在客户有一个需求,要产生一个流水号,如090611+000001(年月日+五位流水号),此流水号在数据库表中是主键,且为varchar类 型.如果在当 ...

  4. 跨平台移动开发_PhoneGap 再次点击返回键切换到桌面效果

    PhoneGap 再次点击返回键切换到桌面效果 相关代码 <!DOCTYPE html> <html> <head> <title> PhoneGap ...

  5. java:反射机制

    Java反射机制及IoC原理:https://www.cnblogs.com/Eason-S/p/5851078.html Java中反射机制详解:https://www.cnblogs.com/wh ...

  6. centos开启IPV6配置方法

    目前国内大部分服务器和PC不支持IPV6地址的,但是服务器上本身是可以正常开启IPV6服务,有部分程序在服务器上运行的时候,需要服务器能监听一个ipv6地址才行,因此本文档指导如何在centos服务器 ...

  7. ZT onActivityResult在android中的用法

    onActivityResult在android中的用法 举例说我想要做的一个事情是,在一个主界面(主Activity)上能连接往许多不同子功能模块(子Activity上去),当子模块的事情做完之后就 ...

  8. SAPGUI里实现自定义的语法检查

    需求:在SAPGUI里点击这个语法检查的小图标或者直接按快捷键Ctrl+F2可以执行ABAP标准的语法检查. 如果需要实现SAPGUI里自定义的语法检查,比如,某团队强制要求应用程序类的每个方法的实现 ...

  9. 『看球笔记』20131230切尔西vs利物浦,赛后复盘聊聊球

    2013-12-30 第十九轮英超联赛 切尔西vs利物浦         看着这张板凳合照… 有木有一种心里哇凉哇凉的感觉.   赛后whoscored的平均位置图 左边车子,右边我军     阿格回 ...

  10. nutz 结合QueryResult,Record 自定义分页查询,不构建pojo 整合

    public QueryResult getHistoryIncome(int d, int curPage) throws Exception { /**sql**/ Sql sql = Sqls. ...