本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
     import os
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.optim import lr_scheduler
    import torchvision
    from torchvision import datasets, models, transforms
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score, classification_report
    from sklearn.model_selection import KFold
    from torch.autograd import Variable
    import torch.optim as optim
    import time
    import copy
    import shutil
    import sys
    import scikitplot as skplt
    import matplotlib.pyplot as plt
    import pandas as pd plt.switch_backend('agg')
    N_CLASSES = 2
    BATCH_SIZE = 8
    DATA_DIR = './data'
    LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
    plt.title(title)
    plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None] for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('- ' * 50) for phase in ['train', 'val']:
    if phase == 'train':
    # 训练的时候进行学习率规划,其定义在下面给出
    scheduler.step()
    model.train(True)
    else:
    model.train(False)
    phase_pred = np.array([])
    phase_label = np.array([])
    img_name = np.zeros((1, 2))
    prob_pred = np.zeros((1, 2))
    running_loss = 0.0
    running_corrects = 0
    # 这样迭代方便跟踪图片路径,输出错误图片名称
    for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
    inputs, labels = data
    if use_gpu:
    inputs = Variable(inputs.cuda())
    labels = Variable(labels.cuda())
    else:
    inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
    optimizer.zero_grad() # forward
    outputs = model(inputs)
    _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, labels) # backward + 训练阶段优化
    if phase == 'train':
    loss.backward()
    optimizer.step() if phase == 'val':
    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
    prob = outputs.data.cpu().numpy()
    prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
    phase_label = np.append(phase_label, labels.data.cpu().numpy())
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data).float()
    print()
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]
    epoch_auc = roc_auc_score(phase_label, phase_pred)
    print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc, epoch_auc))
    report = classification_report(phase_label, phase_pred, target_names=class_names)
    print(report) img_name = zip(img_name[1:], phase_pred)
    # 当验证时遇到了更好的模型则予以保留
    if phase == 'val' and epoch_auc > best_auc:
    best_auc = epoch_auc
    best_desc = epoch_acc, epoch_auc, report
    best_img_name = img_name
    # 深拷贝模型参数
    best_model_wts = copy.deepcopy(model.state_dict())
    plt_auc = phase_label, prob_pred[1:] print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
    actual_label = int(i[0][1])
    pred_label = i[1]
    if actual_label != pred_label:
    tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    f'pred: {LABEL_DICT[pred_label]}'
    print(tmp_word)
    label_file.write(tmp_word + '\n')
    count += 1
    print(f'This fold has {count} wrong records ...') # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model def plot_img():
    for i, data in enumerate(dataloaders['train']):
    inputs, classes = data
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
    def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
    print(f'Find exist {file_path} file, the file will be dropped.')
    shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
    pre_path = os.path.join(tmp_pre_path, d)
    os.chdir(tmp_path)
    if d[:2] == label_0:
    if not os.path.exists(label_0):
    os.mkdir(label_0)
    cur_path = os.path.join(tmp_path, label_0, d)
    shutil.copyfile(pre_path, cur_path)
    if d[:2] == label_1:
    if not os.path.exists(label_1):
    os.mkdir(label_1)
    cur_path = os.path.join(tmp_path, label_1, d)
    shutil.copyfile(pre_path, cur_path)
    print('finish this work ...') if __name__ == "__main__":
    if not os.path.exists('roc_img'):
    os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
    os.mkdir('prob_result')
    if not os.path.exists('report'):
    os.mkdir('report')
    if not os.path.exists('error_record'):
    os.mkdir('error_record')
    if not os.path.exists('model'):
    os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
    report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    print(f'The {m} fold for copy file and training ...')
    move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    os.chdir(origin_path)
    move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    os.chdir(origin_path)
    data_transforms = {
    'train': transforms.Compose([
    # 裁剪到224,224
    transforms.RandomResizedCrop(224),
    # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
    transforms.ToTensor(),
    # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    } image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    data_transforms[x])
    for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    shuffle=True, num_workers=8, pin_memory=False)
    for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    size = len(class_names)
    print('label mapping: ')
    print(image_datasets['train'].class_to_idx)
    use_gpu = torch.cuda.is_available()
    model_ft = None
    if sys.argv[1] == 'resnet':
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    ) # 这边可以自行把inception模型加进去
    if sys.argv[1] == 'inception':
    raise Exception("not provide inception model ...")
    # model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
    model_ft = models.densenet121(pretrained=True)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    )
    # use_gpu = False if use_gpu:
    model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    # 每7个epoch衰减0.1倍
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    print('Start save the model ...')
    torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    print(f'The mission of the fold {m} finished.')
    print('# '*50)
    report_file.close()
    label_file.close()

修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章

  1. Unity-2017.3官方实例教程Space-Shooter(二)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...

  2. Unity-2017.2官方实例教程Roll-a-ball(二)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...

  3. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  4. Unity-2017.3官方实例教程Space-Shooter(一)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...

  5. Unity-2017.2官方实例教程Roll-a-ball(一)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...

  6. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  7. 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。

    这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展

    教程来源于:Unity官方实例教程 Space Shooter(一)-(五)       http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...

随机推荐

  1. python作业-网络编程

    1.什么是C/S架构? 答:C指的是client(客户端软件),S指的是Server(服务端软件) 2.互联网协议是什么?分别介绍五层协议中每一层的功能? 答:互联网的核心就是由一堆协议组成.如果把计 ...

  2. 函数进阶3 —— 生成器、yield from

    今天我们在进一步了解一下,生成器. ①: def func(): print('这是函数func') return '函数func' func() 结果是 这是函数func ②: def func1( ...

  3. String变量的两种创建方式

    在java中,有两种创建String类型变量的方式: String str01="abc";//第一种方式 String str02=new String("abc&qu ...

  4. Node.js 操作Mongodb

    Node.js 操作Mongodb1.简介官网英文文档  https://docs.mongodb.com/manual/  这里几乎什么都有了MongoDB is open-source docum ...

  5. Android 自定义Dialog中加EditText弹不出键盘跟Dialog遮挡键盘的问题

    先上两张图 第一张问题很明显,第二张是成功的图, 其实第一张是加了 //getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INP ...

  6. keras 自定义 custom 函数

    转自: https://kexue.fm/archives/4493/,感谢分享! Keras是一个搭积木式的深度学习框架,用它可以很方便且直观地搭建一些常见的深度学习模型.在tensorflow出来 ...

  7. windows默认共享的打开和关闭?

    windows默认共享的打开和关闭?   Windows启动时都会默认打开admin$ ipc$ 和每个盘符的共享,对于不必要的默认共享,一般都会把它取消掉,可当又需要打开此默认共享时,又该从哪里设置 ...

  8. zan-framework mysql连接

    ①根据文档内容要配置sqlmap连接池的读写白名单 http://doc.zanphp.io/zh/libs/connection_pool.html 示例代码 // demo.demo.demo_s ...

  9. March 4 2017 Week 10 Saturday

    There is more to life than increasing its speed. 生活不仅仅是匆匆赶路. I always think I have walked very slowl ...

  10. Jerry的ABAP原创技术文章合集

    我之前发过三篇和ABAP相关的文章: 1. Jerry的ABAP, Java和JavaScript乱炖 这篇文章包含我多年来在SAP成都研究院使用ABAP, Java和JavaScript工作过程中的 ...