Pytorch搭建卷积神经网络用于MNIST分类
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch import nn, optim
from torch.nn import functional as F EPOCH = 1000
BATCH_SIZE = 128
LR = 0.001
DOWNLOAD_MNIST = False train_data = datasets.MNIST(
root='./mnist',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
#plot one example
print(train_data.train_data.size())
print(train_data.train_labels.size())
plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[0])
plt.show() train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) test_data = datasets.MNIST(
root='./mnist',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),#0-255 -> 0-1
download=DOWNLOAD_MNIST
)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE,\
shuffle=True, num_workers=2) x, label = iter(test_loader).next() #这个iter能把一个batch_size提取出来
print("x:",x.shape, 'label:',label.shape) class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=10, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(10, 20, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out1 = nn.Linear(20 * 7 * 7, 512) # fully connected layer, output 10 classes
self.out2 = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out1(x)
output = F.relu(output)
output = self.out2(output)
return output # return x for visualization DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cnn = CNN().to(DEVICE)
print(cnn) # net architecture
optimizer = optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
criteon = nn.CrossEntropyLoss().to(DEVICE) # the target label is not one-hotted def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(batch_idx+1)%30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item())) def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += torch.eq(pred, target).float().sum().item()
test_loss += criterion(output, target) # test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
# pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
# correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) # training and testing
for epoch in range(EPOCH):
train(cnn, DEVICE, train_loader, optimizer, epoch)
test(cnn, DEVICE, test_loader)
最后能得到99%的准确率
Pytorch搭建卷积神经网络用于MNIST分类的更多相关文章
- Tensorflow框架初尝试————搭建卷积神经网络做MNIST问题
Tensorflow是一个非常好用的deep learning框架 学完了cs231n,大概就可以写一个CNN做一下MNIST了 tensorflow具体原理可以参见它的官方文档 然后CNN的原理可以 ...
- Pytorch搭建简单神经网络 Task2
1>建立数据集(并绘制图像) # -*- coding: utf-8 -*- #demo.py import torch import torch.nn.functional as F # 主要 ...
- TensorFlow——CNN卷积神经网络处理Mnist数据集
CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...
- 3层-CNN卷积神经网络预测MNIST数字
3层-CNN卷积神经网络预测MNIST数字 本文创建一个简单的三层卷积网络来预测 MNIST 数字.这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成. MNIST 由 ...
- 使用MXNet远程编写卷积神经网络用于多标签分类
最近试试深度学习能做点什么事情.MXNet是一个与Tensorflow类似的开源深度学习框架,在GPU显存利用率上效率高,比起Tensorflow显著节约显存,并且天生支持分布式深度学习,单机多卡.多 ...
- TensorFlow系列专题(十四): 手把手带你搭建卷积神经网络实现冰山图像分类
目录: 冰山图片识别背景 数据介绍 数据预处理 模型搭建 结果分析 总结 一.冰山图片识别背景 这里我们要解决的任务是来自于Kaggle上的一道赛题(https://www.kaggle.com/c/ ...
- 动手学习Pytorch(6)--卷积神经网络基础
卷积神经网络基础 本节我们介绍卷积神经网络的基础概念,主要是卷积层和池化层,并解释填充.步幅.输入通道和输出通道的含义. 二维卷积层 本节介绍的是最常见的二维卷积层,常用于处理图像数据. 二维 ...
- 写给程序员的机器学习入门 (八) - 卷积神经网络 (CNN) - 图片分类和验证码识别
这一篇将会介绍卷积神经网络 (CNN),CNN 模型非常适合用来进行图片相关的学习,例如图片分类和验证码识别,也可以配合其他模型实现 OCR. 使用 Python 处理图片 在具体介绍 CNN 之前, ...
- 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)
1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...
随机推荐
- UVa 10047 自行车 状态记录广搜
每个格子(x,y,drection,color) #include<iostream> #include<cstdio> #include<cstring> #in ...
- ZROI 19.08.05模拟赛
传送门 写在前面:为了保护正睿题目版权,这里不放题面,只写题解. A \(21pts:\) 随便枚举,随便爆搜就好了. \(65pts:\) 比较显然的dp,设\(f_{i,j,k}\)表示在子树\( ...
- Vue文件路径引入
- SonarQube 7.7默认数据库连接方法
SonarQube7.7默认数据库为H2 embbed数据库 连接字符串:jdbc:h2:tcp://localhost:9092/sonar 用户名密码都为空
- PIXI如何绘制离屏canvas到舞台上
有个方法是toDataURL(),原生的,先转换成图片再绘制. 但是pixi提供了一个BaseTexture,其构造函数的参数可以是一个canvas 因此可以直接使用如下代码绘制canvas //微信 ...
- 【bzoj4562】[Haoi2016]食物链
*题目描述: 如图所示为某生态系统的食物网示意图,据图回答第1小题 现在给你n个物种和m条能量流动关系,求其中的食物链条数. 物种的名称为从1到n编号 M条能量流动关系形如 a1 b1 a2 b2 a ...
- vue 渲染是出现 Do not use built-in or reserved HTML elements as component id 的警告
情况1.是因为组件命名和引入不一致造成的. 命名组件(nav) export default { name: 'nav', data () { return { } } 引入组件(Navigation ...
- nginx报错 nginx: [alert] kill(25903, 1) failed (3: No such process)
当nginx 中报错 时 nginx报错 nginx: [alert] kill(25903, 1) failed (3: No such process) 通过在nginx/sbin,目录下 运行命 ...
- BZOJ 3203 Luogu P3299 [SDOI2013]保护出题人 (凸包、斜率优化、二分)
惊了,我怎么这么菜啊.. 题目链接: (bzoj)https://www.lydsy.com/JudgeOnline/problem.php?id=3203 (luogu)https://www.lu ...
- sun.misc.BASE64Decoder 替代
加密解密经常用到sun.misc.BASE64Decoder处理,编译时会提示: sun.misc.BASE64Decoder是内部专用 API, 可能会在未来发行版中删除 解决办法: Java8以后 ...