第一次,调了很久。它本来已经很OK了,同时适用CPU和GPU,且可正常运行的。

为了用于性能测试,主要改了三点:

一,每一批次显示处理时间。

二,本地加载测试数据。

三,兼容LINUX和WIN

本地加载测试数据时,要注意是用将两个pt文件,放在processed目录下,raw目录不要即可。

训练数据的定义目录是在当前目录 data/MNIST/processed目录下。

我自己弄了个下载:

http://u.163.com/2FUm6N1L  提取码: XJpmqUoR

只能下载20次,过了可在此留言。

import os
import timeit
import torch                     # pytorch 最基本模块
import torch.nn as nn            # pytorch中最重要的模块,封装了神经网络相关的函数
import torch.nn.functional as F  # 提供了一些常用的函数,如softmax
import torch.optim as optim      # 优化模块,封装了求解模型的一些优化器,如Adam SGD
from torch.optim import lr_scheduler # 学习率调整器,在训练过程中合理变动学习率
from torchvision import transforms  #pytorch 视觉库中提供了一些数据变换的接口
from torchvision import datasets  #pytorch 视觉库提供了加载数据集的接口

DATA_DIR = os.path.join(os.getcwd(),"data")
# 预设网络超参数 (所谓超参数就是可以人为设定的参数

BATCH_SIZE= 64 # 由于使用批量训练的方法,需要定义每批的训练的样本数目

EPOCHS=3      # 总共训练迭代的次数

# 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

learning_rate = 0.1  # 设定初始的学习率

# 加载训练集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(DATA_DIR, train=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5,), std=(0.5,)) # 数据规范化到正态分布
                    ])),
    batch_size=BATCH_SIZE, shuffle=True) # 指明批量大小,打乱,这是处于后续训练的需要。

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(DATA_DIR, train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ])),
    batch_size=BATCH_SIZE, shuffle=True)

# 设计模型
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # 提取特征层
        self.features = nn.Sequential(
            # 卷积层
            # 输入图像通道为 1,因为我们使用的是黑白图,单通道的
            # 输出通道为32(代表使用32个卷积核),一个卷积核产生一个单通道的特征图
            # 卷积核kernel_size的尺寸为 3 * 3,stride 代表每次卷积核的移动像素个数为1
            # padding 填充,为1代表在图像长宽都多了两个像素
            nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size=3, stride=1, padding=1),

            # 批量归一化,跟上一层的out_channels大小相等,以下的通道规律也是必须要对应好的
            nn.BatchNorm2d(num_features = 32),

            # 激活函数,inplace=true代表直接进行运算
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            # 最大池化层
            # kernel_size 为2 * 2的滑动窗口
            # stride为2,表示每次滑动距离为2个像素
            # 经过这一步,图像的大小变为1/4,即 28 * 28 -》 14 * 14
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 14 * 14 -》 7 * 7
        )
         # 分类层
        self.classifier = nn.Sequential(
            # Dropout层
            # p = 0.5 代表该层的每个权重有0.5的可能性为0
            nn.Dropout(p = 0.5),
            # 这里是通道数64 * 图像大小7 * 7,然后输入到512个神经元中
            nn.Linear(64 * 7 * 7, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.5),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        # 经过特征提取层
        x = self.features(x)
        # 输出结果必须展平成一维向量
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x  

# 初始化模型
ConvModel = ConvNet().to(DEVICE)
# 定义交叉熵损失函数
criterion = nn.CrossEntropyLoss().to(DEVICE)
# 定义模型优化器
optimizer = torch.optim.Adam(ConvModel.parameters(), lr = learning_rate)
# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

def train(num_epochs,_model, _device, _train_loader, _optimizer, _lr_scheduler):
    _model.train()
    _lr_scheduler.step()
    for epoch in range(num_epochs):
        start = end = 0
        # 从迭代器抽取图片和标签
        for i, (images, labels) in enumerate(_train_loader):
            if (i + 1) % 100 == 1:
                start = timeit.default_timer()
            samples = images.to(_device)
            labels = labels.to(_device)
            #此时样本是一批图片,在CNN的输入中,我们需要将其变为四维,
            # reshape第一个-1 代表自动计算批量图片的数目n
            # 最后reshape得到的结果就是n张图片,每一张图片都是单通道的28 * 28,得到四维张量
            output = _model(samples.reshape(-1, 1, 28, 28))

            # 计算损失函数值
            loss = criterion(output, labels)

            # 优化器内部参数梯度必须变为0
            optimizer.zero_grad()

            # 损失值后向传播
            loss.backward()

            # 更新模型参数
            optimizer.step()

            if (i + 1) % 100 == 0:
                end = timeit.default_timer()
                print("Epoch:{}/{}, Time:{}s, step:{}, loss:{:.4f}".format(epoch+1, num_epochs, end-start, i + 1, loss.item()))

def test(_test_loader, _model, _device):
    _model.eval() # 设置模型进入预测模式 evaluation
    loss = 0
    correct = 0

    with torch.no_grad(): #如果不需要 backward更新梯度,那么就要禁用梯度计算,减少内存和计算资源浪费。
        for data, target in _test_loader:
            data, target = data.to(_device), target.to(_device)
            output = ConvModel(data.reshape(-1, 1, 28, 28))
            loss += criterion(output, target).item() # 添加损失值
            pred = output.data.max(1, keepdim=True)[1] # 找到概率最大的下标,为输出值
            correct += pred.eq(target.data.view_as(pred)).cpu().sum() # .cpu()是将参数迁移到cpu上来。

    loss /= len(_test_loader.dataset)

    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        loss, correct, len(_test_loader.dataset),
        100. * correct / len(_test_loader.dataset)))

for epoch in range(1, EPOCHS + 1):
    train(epoch, ConvModel, DEVICE, train_loader, optimizer, exp_lr_scheduler)
    test(test_loader,ConvModel, DEVICE)
    test(train_loader,ConvModel, DEVICE)

一套兼容win和Linux的PyTorch训练MNIST的算法代码(CNN)的更多相关文章

  1. php中路径斜杠的应用,兼容win与linux

    更多内容推荐微信公众号,欢迎关注: PHP中斜杠的运用 兼容win和linux 使用常量:DIRECTORY_SEPARATOR如:"www".DIRECTORY_SEPARATO ...

  2. 跨平台设置NODE_ENV(兼容win和linux)

    通过NODE_ENV可以来设置环境变量(默认值为development).一般我们通过检查这个值来分别对开发环境和生产环境下做不同的处理.可以在命令行中通过下面的方式设置这个值: linux & ...

  3. 用Pytorch训练MNIST分类模型

    本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...

  4. Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    我用过的编辑器不少,真不少- 但却没有哪款让我特别心仪的,直到我遇到了 Sublime Text 2 !如果说“神器”是我能给予一款软件最高的评价,那么我很乐意为它封上这么一个称号.它小巧绿色且速度非 ...

  5. [转载]Sublime Text 2 - 性感无比的代码编辑器!程序员必备神器!跨平台支持Win/Mac/Linux

    代码编辑器或者文本编辑器,对于程序员来说,就像剑与战士一样,谁都想拥有一把可以随心驾驭且锋利无比的宝剑,而每一位程序员,同样会去追求最适合自己的强大.灵活的编辑器,相信你和我一样,都不会例外. 我用过 ...

  6. Java文件夹操作,判断多级路径是否存在,不存在就创建(包括windows和linux下的路径字符分析),兼容Windows和Linux

    兼容windows和linux. 分析: 在windows下路径有以下表示方式: (标准)D:\test\1.txt (不标准,参考linux)D:/test/1.txt 然后在java中,尤其使用F ...

  7. paip兼容windows与linux的java类根目录路径的方法

    paip兼容windows与linux的java类根目录路径的方法 1.只有 pathx.class.getResource("")或者pathx.class.getResourc ...

  8. redhat 安装配置samba实现win共享linux主机目录

    [转]http://blog.chinaunix.net/uid-26642180-id-3135941.html redhat 安装配置samba实现win共享linux主机目录 2012-03-1 ...

  9. Win和Linux查看端口和杀死进程

    title: Win和Linux查看端口和杀死进程 date: 2017-7-30 tags: null categories: Linux --- 本文介绍Windows和Linux下查看端口和杀死 ...

随机推荐

  1. 超级简单,把PuppyLinux安装到U盘

    先说说使用感受:上网全是乱码!不支持中文 下载最新版puppylinux,从官网下载 现在U盘引导程序制作工具Unetbootin 打开下载的UNetbootin,进行下面的操作: 制作完毕后,修改U ...

  2. linux服务器之间文件传输

    有时候我们会遇到,把一个服务器上的文件夹,传到另一个服务器 我们需要先把文件夹打包成 tar.gz,这种格式在任何linux版本上都能压缩/解压 #解压命令 tar -zxvf xxx.tar.gz ...

  3. 002 spring boot框架,引入mybatis-generator插件,自动生成Mapper和Entity

    1.创建一个springboot项目 2.创建项目的文件结构以及jdk的版本 3.选择项目所需要的依赖 点击next,直到项目构建完成. 4.项目初步结构 5.POM文件 <?xml versi ...

  4. Python之路【第三十篇】:django 模型层-多表关系

    多表操作 文件为 ---->  orm2 数据库表关系之关联字段与外键约束 一对多Book id title price publish email addr 1 php 100 人民出版社 1 ...

  5. LeetCode 5073. 进击的骑士(Java)BFS

    题目:5073. 进击的骑士 一个坐标可以从 -infinity 延伸到 +infinity 的 无限大的 棋盘上,你的 骑士 驻扎在坐标为 [0, 0] 的方格里. 骑士的走法和中国象棋中的马相似, ...

  6. XXE任意文件读取(当xml解析内容有输出时)

    利用XXE漏洞读取文件 参考:https://www.jianshu.com/p/4fc721398e97 首先找到登录源码如下: 由题目可以利用XXE漏洞读取文件 先登录用Burp Suite抓包: ...

  7. 关于 exynos 4412 按键中断 异步通知

    以下是驱动测试代码: //内核的驱动代码 #include <linux/init.h> #include <linux/module.h> //for module_init ...

  8. IIS7 URL重写如何针对二级域名重写

    二级域名与站点主域名是绑在同一目录下,想实现访问二级域名重写至站点下的某个目录.  如:  访问so.abc.cn 实际访问的是站点根目录下的search目录下的文件 相当于so.abc.cn绑定至s ...

  9. C# 多线程与高并发处理并且具备暂停、继续、停止功能

    --近期有一个需要运用多线程的项目,会有并发概率,所以写了一份代码,可能有写地方还不完善,后续有需求在改 1 /// <summary> /// 并发对象 /// </summary ...

  10. C#设计模式之11:命令模式

    C#设计模式之11:命令模式 命令模式 命令模式用来解决一些复杂业务逻辑的时候会很有用,比如,你的一个方法中到处充斥着if else 这种结构的时候,用命令模式来解决这种问题就会让事情变得简单很多. ...