import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from torch.nn import functional as F EPOCH = 1000
BATCH_SIZE = 128
LR = 0.001
DOWNLOAD_MNIST = False train_data = datasets.MNIST(
root='./mnist',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
#plot one example
print(train_data.train_data.size())
print(train_data.train_labels.size())
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) test_data = datasets.MNIST(
root='./mnist',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) x, label = iter(test_loader).next() #这个iter能把一个batch_size提取出来
print("x:",x.shape, 'label:',label.shape) class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=10, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(10, 20, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out1 = nn.Linear(20 * 7 * 7, 512) # fully connected layer, output 10 classes
self.out2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out1(x)
output = F.relu(output)
output = self.out2(output)
return output # return x for visualization DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = CNN().to(DEVICE)
print(cnn) # net architecture
optimizer = optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
criteon = nn.CrossEntropyLoss().to(DEVICE) # the target label is not one-hotted def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(batch_idx+1)%30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += torch.eq(pred, target).float().sum().item()
test_loss += criterion(output, target) # test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
# pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
# correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) # training and testing
for epoch in range(EPOCH):
train(cnn, DEVICE, train_loader, optimizer, epoch)
test(cnn, DEVICE, test_loader)

最后能得到99%的准确率

Pytorch搭建卷积神经网络用于MNIST分类的更多相关文章

  1. Tensorflow框架初尝试————搭建卷积神经网络做MNIST问题

    Tensorflow是一个非常好用的deep learning框架 学完了cs231n,大概就可以写一个CNN做一下MNIST了 tensorflow具体原理可以参见它的官方文档 然后CNN的原理可以 ...

  2. Pytorch搭建简单神经网络 Task2

    1>建立数据集(并绘制图像) # -*- coding: utf-8 -*- #demo.py import torch import torch.nn.functional as F # 主要 ...

  3. TensorFlow——CNN卷积神经网络处理Mnist数据集

    CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...

  4. 3层-CNN卷积神经网络预测MNIST数字

    3层-CNN卷积神经网络预测MNIST数字 本文创建一个简单的三层卷积网络来预测 MNIST 数字.这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成. MNIST 由 ...

  5. 使用MXNet远程编写卷积神经网络用于多标签分类

    最近试试深度学习能做点什么事情.MXNet是一个与Tensorflow类似的开源深度学习框架,在GPU显存利用率上效率高,比起Tensorflow显著节约显存,并且天生支持分布式深度学习,单机多卡.多 ...

  6. TensorFlow系列专题(十四): 手把手带你搭建卷积神经网络实现冰山图像分类

    目录: 冰山图片识别背景 数据介绍 数据预处理 模型搭建 结果分析 总结 一.冰山图片识别背景 这里我们要解决的任务是来自于Kaggle上的一道赛题(https://www.kaggle.com/c/ ...

  7. 动手学习Pytorch(6)--卷积神经网络基础

    卷积神经网络基础 本节我们介绍卷积神经网络的基础概念,主要是卷积层和池化层,并解释填充.步幅.输入通道和输出通道的含义.   二维卷积层 本节介绍的是最常见的二维卷积层,常用于处理图像数据.   二维 ...

  8. 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别

    这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...

  9. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

随机推荐

  1. php随机获取数组里面的值

    srand() 函数播下随机数发生器种子,array_rand() 函数从数组中随机选出一个或多个元素,并返回.第二个参数用来确定要选出几个元素.如果选出的元素不止一个,则返回包含随机键名的数组,否则 ...

  2. ARM仿真器之驱动黄色惊叹号

    JLink CDC UART PORT 黄色惊叹号 Windows 无法验证此设备所需的驱动程序的数字签名.某软件或硬件最近有所更改,可能安装了签名错误或损毁的文件,或者安装的文件可能是来路不明的恶意 ...

  3. JS 验证码的实现

    转自:https://github.com/ace0109/verifyCode 正要做一个验证码,网上找到这个还不错: gVerify.js: !(function(window, document ...

  4. vue-awesome-swiper 轮播图使用

    最近在做vue 的轮播图的问题,项目中也遇到一些问题,查了 swiper 官网资料, 还有vue-awesome-swiper的文案,最后把怎么使用这个插件简单的说下,啥东西都需要自己实践下,还是老规 ...

  5. 理解 es7 async/await

    简介 JavaScript ES7 中的 async / await 让多个异步 promise 协同工作起来更容易.如果要按一定顺序从多个数据库或者 API 异步获取数据,你可能会以一堆乱七八糟的 ...

  6. es6 Object.assign(target, ...sources)

    Object.assign() 方法用于将所有可枚举属性(对象属性)的值从一个或多个源对象复制到目标对象.它将返回目标对象. 语法 Object.assign(target, ...sources) ...

  7. U盘安装linux(CentOS Kali ubuntu)无法挂载_实测

    最新的Kali linux官方下载地址: https://www.kali.org/downloads/ 刻录软件官网:http://rufus.ie/ 1.打开. 2.插入U盘. 3.自动下载插件. ...

  8. 第六章 Linux文件与目录管理

    http://www.92csz.com/study/linux/6.htm 绝对路径:路径的写法一定由根目录”/”写起 相对路径:路径的写法不是由根目录”/”写起 mkdir 创建一个目录.mkdi ...

  9. 对TypeScript进行研究

    1.npm install -g typescript 在编辑器,将下面的代码输入到greeter.ts文件里: function greeter(person) { return "Hel ...

  10. weblogic 10c and 12c 打补丁

    https://pan.baidu.com/s/17IaK1SYwHxwt-CRb0zDqXw