Pytorch学习系列(一)至(四)均摘自《深度学习框架PyTorch入门与实践》陈云

目录:

1.程序的主要功能

2.文件组织架构

3. 关于`__init__.py`

4.数据处理

5.模型定义

6.工具函数

7.配置文件

8.main.py

9.使用

1.程序的主要功能:

模型定义
    数据加载
    训练和测试

2.文件组织架构:

```
├── checkpoints/
├── data/
│   ├── __init__.py
│   ├── dataset.py
│   └── get_data.sh
├── models/
│   ├── __init__.py
│   ├── AlexNet.py
│   ├── BasicModule.py
│   └── ResNet34.py
└── utils/
│   ├── __init__.py
│   └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.md

```

其中:

- `checkpoints/`: 用于保存训练好的模型,可使程序在异常退出后仍能重新载入模型,恢复训练
- `data/`:数据相关操作,包括数据预处理、dataset实现等
- `models/`:模型定义,可以有多个模型,例如上面的AlexNet和ResNet34,一个模型对应一个文件
- `utils/`:可能用到的工具函数,在本次实验中主要是封装了可视化工具
- `config.py`:配置文件,所有可配置的变量都集中在此,并提供默认值
- `main.py`:主文件,训练和测试程序的入口,可通过不同的命令来指定不同的操作和参数
- `requirements.txt`:程序依赖的第三方库

- `README.md`:提供程序的必要说明

3. 关于`__init__.py`

可以看到,几乎每个文件夹下都有`__init__.py`,一个目录如果包含了`__init__.py` 文件,那么它就变成了一个包(package)。

`__init__.py`可以为空,也可以定义包的属性和方法,但其必须存在,其它程序才能从这个目录中导入相应的模块或函数。

例如在`data/`文件夹下有`__init__.py`,则在`main.py` 中就可以`from data.dataset import DogCat`。而如果在`__init__.py`中写入`from .dataset import DogCat`,则在main.py中就可以直接写为:`from data import DogCat`,或者`import data; dataset = data.DogCat`,相比于`from data.dataset import DogCat`更加便捷。

4.数据处理

数据的相关处理主要保存在`data/dataset.py`中。

关于数据加载的相关操作,其基本原理就是使用`Dataset`提供数据集的封装,再使用`Dataloader`实现数据并行加载。

Kaggle提供的数据包括训练集和测试集,而我们在实际使用中,还需专门从训练集中取出一部分作为验证集。对于这三类数据集,其相应操作也不太一样,而如果专门写三个`Dataset`,则稍显复杂和冗余,因此这里通过加一些判断来区分。对于训练集,我们希望做一些数据增强处理,如随机裁剪、随机翻转、加噪声等,而验证集和测试集则不需要。下面看`dataset.py`的代码:

#coding:utf8
    import os
    from PIL import  Image
    from torch.utils import data
    import numpy as np
    from torchvision import  transforms as T
     
     
    class DogCat(data.Dataset):
        
        def __init__(self,root,transforms=None,train=True,test=False):
            '''
            主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
            '''
            self.test = test
            imgs = [os.path.join(root,img) for img in os.listdir(root)]
     
            # test1: data/test1/8973.jpg
            # train: data/train/cat.10004.jpg
            if self.test:
                imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2].split('/')[-1]))
            else:
                imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2]))
                
            imgs_num = len(imgs)
     
            if self.test:
                self.imgs = imgs
            elif train:
                self.imgs = imgs[:int(0.7*imgs_num)]
            else :
                self.imgs = imgs[int(0.7*imgs_num):]
                
        
            if transforms is None:
                normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
                                         std = [0.229, 0.224, 0.225])
     
                if self.test or not train:
                    self.transforms = T.Compose([
                        T.Scale(224),
                        T.CenterCrop(224),
                        T.ToTensor(),
                        normalize
                        ])
                else :
                    self.transforms = T.Compose([
                        T.Scale(256),
                        T.RandomSizedCrop(224),
                        T.RandomHorizontalFlip(),
                        T.ToTensor(),
                        normalize
                        ])
                    
            
        def __getitem__(self,index):
            '''
            一次返回一张图片的数据
            '''
            img_path = self.imgs[index]
            if self.test: label = int(self.imgs[index].split('.')[-2].split('/')[-1])
            else: label = 1 if 'dog' in img_path.split('/')[-1] else 0
            data = Image.open(img_path)
            data = self.transforms(data)
            return data, label
        
        def __len__(self):
            return len(self.imgs)

5.模型定义

模型的定义主要保存在`models/`目录下,其中`BasicModule`是对`nn.Module`的简易封装,提供快速加载和保存模型的接口。

#coding:utf8
    import torch as t
    import time
     
     
    class BasicModule(t.nn.Module):
        '''
        封装了nn.Module,主要是提供了save和load两个方法
        '''
     
        def __init__(self):
            super(BasicModule,self).__init__()
            self.model_name=str(type(self))# 默认名字
     
        def load(self, path):
            '''
            可加载指定路径的模型
            '''
            self.load_state_dict(t.load(path))
     
        def save(self, name=None):
            '''
            保存模型,默认使用“模型名字+时间”作为文件名
            '''
            if name is None:
                prefix = 'checkpoints/' + self.model_name + '_'
                name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
            t.save(self.state_dict(), name)
            return name
     
     
    class Flat(t.nn.Module):
        '''
        把输入reshape成(batch_size,dim_length)
        '''
     
        def __init__(self):
            super(Flat, self).__init__()
            #self.size = size
     
        def forward(self, x):
            return x.view(x.size(0), -1)

6.工具函数

在项目中,我们可能会用到一些helper方法,这些方法可以统一放在`utils/`文件夹下,需要使用时再引入。在本例中主要是封装了可视化工具visdom的一些操作,其代码如下,在本次实验中只会用到`plot`方法,用来统计损失信息。

#coding:utf8
    import visdom
    import time
    import numpy as np
     
    class Visualizer(object):
        '''
        封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`
        调用原生的visdom接口
        '''
     
        def __init__(self, env='default', **kwargs):
            self.vis = visdom.Visdom(env=env, **kwargs)
            
            # 画的第几个数,相当于横座标
            # 保存(’loss',23) 即loss的第23个点
            self.index = {}
            self.log_text = ''
        def reinit(self,env='default',**kwargs):
            '''
            修改visdom的配置
            '''
            self.vis = visdom.Visdom(env=env,**kwargs)
            return self
     
        def plot_many(self, d):
            '''
            一次plot多个
            @params d: dict (name,value) i.e. ('loss',0.11)
            '''
            for k, v in d.items():
                self.plot(k, v)
     
        def img_many(self, d):
            for k, v in d.items():
                self.img(k, v)
     
        def plot(self, name, y,**kwargs):
            '''
            self.plot('loss',1.00)
            '''
            x = self.index.get(name, 0)
            self.vis.line(Y=np.array([y]), X=np.array([x]),
                          win=name,
                          opts=dict(title=name),
                          update=None if x == 0 else 'append',
                          **kwargs
                          )
            self.index[name] = x + 1
     
        def img(self, name, img_,**kwargs):
            '''
            self.img('input_img',t.Tensor(64,64))
            self.img('input_imgs',t.Tensor(3,64,64))
            self.img('input_imgs',t.Tensor(100,1,64,64))
            self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10)
            !!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!!
            '''
            self.vis.images(img_.cpu().numpy(),
                           win=name,
                           opts=dict(title=name),
                           **kwargs
                           )
     
     
        def log(self,info,win='log_text'):
            '''
            self.log({'loss':1,'lr':0.0001})
            '''
     
            self.log_text += ('[{time}] {info} <br>'.format(
                                time=time.strftime('%m%d_%H%M%S'),\
                                info=info))
            self.vis.text(self.log_text,win)   
     
        def __getattr__(self, name):
            return getattr(self.vis, name)

7.配置文件

在模型定义、数据处理和训练等过程都有很多变量,这些变量应提供默认值,并统一放置在配置文件中,这样在后期调试、修改代码或迁移程序时会比较方便,在这里我们将所有可配置项放在`config.py`中。

#coding:utf8
    import warnings
    class DefaultConfig(object):
        env = 'default' # visdom 环境
        model = 'ResNet34' # 使用的模型,名字必须与models/__init__.py中的名字一致
        
        train_data_root = './data/train/' # 训练集存放路径
        test_data_root = './data/test1' # 测试集存放路径
        load_model_path = 'checkpoints/model.pth' # 加载预训练的模型的路径,为None代表不加载
     
        batch_size = 128 # batch size
        use_gpu = True # user GPU or not
        num_workers = 4 # how many workers for loading data
        print_freq = 20 # print info every N batch
     
        debug_file = '/tmp/debug' # if os.path.exists(debug_file): enter ipdb
        result_file = 'result.csv'
          
        max_epoch = 10
        lr = 0.1 # initial learning rate
        lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
        weight_decay = 1e-4 # 损失函数
     
     
     
    def parse(self,kwargs):
            '''
            根据字典kwargs 更新 config参数
            '''
            for k,v in kwargs.items():
                if not hasattr(self,k):
                    warnings.warn("Warning: opt has not attribut %s" %k)
                setattr(self,k,v)
     
            print('user config:')
            for k,v in self.__class__.__dict__.items():
                if not k.startswith('__'):
                    print(k,getattr(self,k))
     
     
    DefaultConfig.parse = parse
    opt =DefaultConfig()
    # opt.parse = parse

8.main.py

在讲解主程序`main.py`之前,我们先来看看2017年3月谷歌开源的一个命令行工具`fire`[^3] ,通过`pip install fire`即可安装。下面来看看`fire`的基础用法,假设`example.py`文件内容如下:

import fire
     
    def add(x, y):
      return x + y
      
    def mul(**kwargs):
        a = kwargs['a']
        b = kwargs['b']
        return a * b
     
    if __name__ == '__main__':
      fire.Fire()

python example.py add 1 2 # 执行add(1, 2)
    python example.py mul --a=1 --b=2 # 执行mul(a=1, b=2), kwargs={'a':1, 'b':2}
    python example.py add --x=1 --y==2 # 执行add(x=1, y=2)

在主程序`main.py`中,主要包含四个函数,其中三个需要命令行执行,`main.py`的代码组织结构如下:

def train(**kwargs):
        '''
        训练
        '''
        pass
        
    def val(model, dataloader):
        '''
        计算模型在验证集上的准确率等信息,用以辅助训练
        '''
        pass
     
    def test(**kwargs):
        '''
        测试(inference)
        '''
        pass
     
    def help():
        '''
        打印帮助的信息
        '''
        print('help')
     
    if __name__=='__main__':
        import fire
        fire.Fire()

训练

训练的主要步骤如下:

- 定义网络
- 定义数据
- 定义损失函数和优化器
- 计算重要指标
- 开始训练
  - 训练网络
  - 可视化各种指标
  - 计算在验证集上的指标

def train(**kwargs):
        opt.parse(kwargs)
        vis = Visualizer(opt.env)
     
        # step1: configure model
        model = getattr(models, opt.model)()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        if opt.use_gpu: model.cuda()
     
        # step2: data
        train_data = DogCat(opt.train_data_root,train=True)
        val_data = DogCat(opt.train_data_root,train=False)
        train_dataloader = DataLoader(train_data,opt.batch_size,
                            shuffle=True,num_workers=opt.num_workers)
        val_dataloader = DataLoader(val_data,opt.batch_size,
                            shuffle=False,num_workers=opt.num_workers)
        
        # step3: criterion and optimizer
        criterion = t.nn.CrossEntropyLoss()
        lr = opt.lr
        optimizer = t.optim.Adam(model.parameters(),lr = lr,weight_decay = opt.weight_decay)
            
        # step4: meters
        loss_meter = meter.AverageValueMeter()
        confusion_matrix = meter.ConfusionMeter(2)
        previous_loss = 1e100
     
        # train
        for epoch in range(opt.max_epoch):
            
            loss_meter.reset()
            confusion_matrix.reset()
     
            for ii,(data,label) in enumerate(train_dataloader):
     
                # train model
                input = Variable(data)
                target = Variable(label)
                if opt.use_gpu:
                    input = input.cuda()
                    target = target.cuda()
     
                optimizer.zero_grad()
                score = model(input)
                loss = criterion(score,target)
                loss.backward()
                optimizer.step()
                
                
                # meters update and visualize
                loss_meter.add(loss.data[0])
                confusion_matrix.add(score.data, target.data)
     
                if ii%opt.print_freq==opt.print_freq-1:
                    vis.plot('loss', loss_meter.value()[0])
                    
                    # 进入debug模式
                    if os.path.exists(opt.debug_file):
                        import ipdb;
                        ipdb.set_trace()
     
     
            model.save()
     
            # validate and visualize
            val_cm,val_accuracy = val(model,val_dataloader)
     
            vis.plot('val_accuracy',val_accuracy)
            vis.log("epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}".format(
                        epoch = epoch,loss = loss_meter.value()[0],val_cm = str(val_cm.value()),train_cm=str(confusion_matrix.value()),lr=lr))
            
            # update learning rate
            if loss_meter.value()[0] > previous_loss:          
                lr = lr * opt.lr_decay
                # 第二种降低学习率的方法:不会有moment等信息的丢失
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            
     
            previous_loss = loss_meter.value()[0]

验证

验证相对来说比较简单,但要注意需将模型置于验证模式(`model.eval()`),验证完成后还需要将其置回为训练模式(`model.train()`),这两句代码会影响`BatchNorm`和`Dropout`等层的运行模式。验证模型准确率的代码如下。

def val(model,dataloader):
        '''
        计算模型在验证集上的准确率等信息
        '''
        model.eval()
        confusion_matrix = meter.ConfusionMeter(2)
        for ii, data in enumerate(dataloader):
            input, label = data
            val_input = Variable(input, volatile=True)
            val_label = Variable(label.type(t.LongTensor), volatile=True)
            if opt.use_gpu:
                val_input = val_input.cuda()
                val_label = val_label.cuda()
            score = model(val_input)
            confusion_matrix.add(score.data.squeeze(), label.type(t.LongTensor))
     
        model.train()
        cm_value = confusion_matrix.value()
        accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
        return confusion_matrix, accuracy

测试

测试时,需要计算每个样本属于狗的概率,并将结果保存成csv文件。测试的代码与验证比较相似,但需要自己加载模型和数据。

def test(**kwargs):
        opt.parse(kwargs)
        import ipdb;
        ipdb.set_trace()
        # configure model
        model = getattr(models, opt.model)().eval()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        if opt.use_gpu: model.cuda()
     
        # data
        train_data = DogCat(opt.test_data_root,test=True)
        test_dataloader = DataLoader(train_data,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
        results = []
        for ii,(data,path) in enumerate(test_dataloader):
            input = t.autograd.Variable(data,volatile = True)
            if opt.use_gpu: input = input.cuda()
            score = model(input)
            probability = t.nn.functional.softmax(score)[:,0].data.tolist()
            # label = score.max(dim = 1)[1].data.tolist()
            
            batch_results = [(path_,probability_) for path_,probability_ in zip(path,probability) ]
     
            results += batch_results
        write_csv(results,opt.result_file)
     
        return results

帮助函数

为了方便他人使用, 程序中还应当提供一个帮助函数,用于说明函数是如何使用。程序的命令行接口中有众多参数,如果手动用字符串表示不仅复杂,而且后期修改config文件时,还需要修改对应的帮助信息,十分不便。这里使用了Python标准库中的inspect方法,可以自动获取config的源代码。help的代码如下:

def help():
        '''
        打印帮助的信息: python file.py help
        '''
        
        print('''
        usage : python file.py <function> [--args=value]
        <function> := train | test | help
        example:
                python {0} train --env='env0701' --lr=0.01
                python {0} test --dataset='path/to/dataset/root/'
                python {0} help
        avaiable args:'''.format(__file__))
     
        from inspect import getsource
        source = (getsource(opt.__class__))
        print(source)

9.使用

正如`help`函数的打印信息所述,可以通过命令行参数指定变量名.下面是三个使用例子,fire会将包含`-`的命令行参数自动转层下划线`_`,也会将非数值的值转成字符串。所以`--train-data-root=data/train`和`--train_data_root='data/train'`是等价的。

```
# 训练模型
python main.py train
        --train-data-root=data/train/
        --load-model-path='checkpoints/resnet34_16:53:00.pth'
        --lr=0.005
        --batch-size=32
        --model='ResNet34'  
        --max-epoch = 20

# 测试模型
python main.py test
       --test-data-root=data/test1
       --load-model-path='checkpoints/resnet34_00:23:05.pth'
       --batch-size=128
       --model='ResNet34'
       --num-workers=12

# 打印帮助信息
python main.py help
---------------------  
作者:寻找如意  
来源:CSDN  
原文:https://blog.csdn.net/qq_34447388/article/details/79541824  
版权声明:本文为博主原创文章,转载请附上博文链接!

Pytorch学习--编程实战:猫和狗二分类的更多相关文章

  1. 1.keras实现-->自己训练卷积模型实现猫狗二分类(CNN)

    原数据集:包含 25000张猫狗图像,两个类别各有12500 新数据集:猫.狗 (照片大小不一样) 训练集:各1000个样本 验证集:各500个样本 测试集:各500个样本 1= 狗,0= 猫 # 将 ...

  2. java并发编程实战:第十二章---并发程序的测试

    并发程序中潜在错误的发生并不具有确定性,而是随机的. 安全性测试:通常会采用测试不变性条件的形式,即判断某个类的行为是否与其规范保持一致 活跃性测试:进展测试和无进展测试两方面,这些都是很难量化的(性 ...

  3. 《Java并发编程实战》第十二章 测试并发程序 读书笔记

    并发测试分为两类:安全性测试(无论错误的行为不会发生)而活性测试(会发生). 安全測试 - 通常採用測试不变性条件的形式,即推断某个类的行为是否与其它规范保持一致. 活跃性測试 - 包含进展測试和无进 ...

  4. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

  5. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  6. 【Java并发编程实战】----- AQS(二):获取锁、释放锁

    上篇博客稍微介绍了一下AQS,下面我们来关注下AQS的所获取和锁释放. AQS锁获取 AQS包含如下几个方法: acquire(int arg):以独占模式获取对象,忽略中断. acquireInte ...

  7. Java多线程编程实战指南(核心篇)读书笔记(二)

    (尊重劳动成果,转载请注明出处:http://blog.csdn.net/qq_25827845/article/details/76651408冷血之心的博客) 博主准备恶补一番Java高并发编程相 ...

  8. java并发编程实战学习(3)--基础构建模块

    转自:java并发编程实战 5.3阻塞队列和生产者-消费者模式 BlockingQueue阻塞队列提供可阻塞的put和take方法,以及支持定时的offer和poll方法.如果队列已经满了,那么put ...

  9. Shell高级编程视频教程-跟着老男孩一步步学习Shell高级编程实战视频教程

    Shell高级编程视频教程-跟着老男孩一步步学习Shell高级编程实战视频教程 教程简介: 本教程共71节,主要介绍了shell的相关知识教程,如shell编程需要的基础知识储备.shell脚本概念介 ...

随机推荐

  1. Java 布尔运算

    章节 Java 基础 Java 简介 Java 环境搭建 Java 基本语法 Java 注释 Java 变量 Java 数据类型 Java 字符串 Java 类型转换 Java 运算符 Java 字符 ...

  2. 关于SSM中mybatis向oracle添加语句采用序列自增的问题

    在SSM向oracle数据库中插入语句时,报错如下: ### Error updating database.  Cause: java.sql.SQLException: 不支持的特性 ### SQ ...

  3. 利用QRCoder生成二维码

    1.项目添加QRCoder.dll 和System.Drawing.dll的引用 2.创建二维码公共处理类(QRCoderHelper.cs) /// <summary> /// 二维码公 ...

  4. java的形参与实参的区别以及java的方法

    package com.lv.study; public class Demo05 { public static void main(String[] args) { //我想要用什么分隔符进行分隔 ...

  5. jdk的配置和安装

    1.Jdk的安装和配置 一.安装JDK与配置环境与检验配置成功: 1.进入java.com网站,然后按照以下步骤进行 =>=>等会出现java茶杯双击,一次一次的按下一步,最后会在同一个j ...

  6. Pytorch_torch.nn.MSELoss

    Pytorch_torch.nn.MSELoss 均方损失函数作用主要是求预测实例与真实实例之间的loss loss(xi,yi)=(xi−yi)2 函数需要输入两个tensor,类型统一设置为flo ...

  7. URAL_1146/uva_108 最大子矩阵 DP 降维

    题意很简单,给定一个N*N的大矩阵,求其中数值和最大的子矩阵. 一开始找不到怎么DP,没有最优子结构啊,后来聪哥给了我思路,化成一维,变成最大连续和即可.为了转化成一维,必须枚举子矩阵的宽度,通过预处 ...

  8. 如何保障Assignment写作效率?

    有没有因为开学要交的Assignment而日夜赶工.身心俱疲啊?写Assignment确实是个体力+脑力活,要一直保持旺盛的精力并不容易.精神和身体的疲劳会慢慢分散你的注意力,进而影响效率和写作质量. ...

  9. 51Nod大数加法(两个数正负都可)

    很多大数的问题都运用模拟的思想,但是这个说一样也一样,但是难度较大,很麻烦,我自己谢写了100多行的代码,感觉很对,但就是WA.其实个人感觉C和C++没有大数类,是对人思想和算法的考验,但是有时候做不 ...

  10. Unity3d中渲染到RenderTexture的原理,几种方式以及一些问题

    超级搬运工 http://blog.csdn.net/leonwei/article/details/54972653 ---------------------------------------- ...