import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # 配置GPU或CPU设置
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 超参数设置
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001 # 下载 MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255
download=True) test_dataset = torchvision.datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor())# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255 # 训练数据加载,按照batch_size大小加载,并随机打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 测试数据加载,按照batch_size大小加载
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Convolutional neural network (two convolutional layers) 2层卷积
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7 * 7 * 32, num_classes) def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out model = ConvNet(num_classes).to(device)
print(model) # ConvNet(
# (layer1): Sequential(
# (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
# (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
# (layer2): Sequential(
# (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
# (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))
# (fc): Linear(in_features=1568, out_features=10, bias=True)) # 损失函数与优化器设置
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器设置 ,并传入CNN模型参数和相应的学习率
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 训练CNN模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device) # 前向传播
outputs = model(images)
# 计算损失 loss
loss = criterion(outputs, labels) # 反向传播与优化
# 清空上一步的残余更新参数值
optimizer.zero_grad()
# 反向传播
loss.backward()
# 将参数更新值施加到RNN model的parameters上
optimizer.step()
# 每迭代一定步骤,打印结果值
if (i + 1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item())) # 测试模型
# model.train model.eval 在测试模型时在前面使用:model.eval() ; 在训练模型时会在前面加上:model.train()
# 让model变成测试模式,是针对model 在训练时和评价时不同的 Batch Normalization 和 Dropout 方法模式
# eval()时,让model变成测试模式, pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值,
# 不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大。
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # 保存已经训练好的模型
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

  

Convolutional neural network (CNN) - Pytorch版的更多相关文章

  1. 卷积神经网络(Convolutional Neural Network, CNN)简析

    目录 1 神经网络 2 卷积神经网络 2.1 局部感知 2.2 参数共享 2.3 多卷积核 2.4 Down-pooling 2.5 多层卷积 3 ImageNet-2010网络结构 4 DeepID ...

  2. Recurrent neural network (RNN) - Pytorch版

    import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms # ...

  3. 斯坦福大学卷积神经网络教程UFLDL Tutorial - Convolutional Neural Network

    Convolutional Neural Network Overview A Convolutional Neural Network (CNN) is comprised of one or mo ...

  4. 卷积神经网络(Convolutional Neural Network,CNN)

    全连接神经网络(Fully connected neural network)处理图像最大的问题在于全连接层的参数太多.参数增多除了导致计算速度减慢,还很容易导致过拟合问题.所以需要一个更合理的神经网 ...

  5. 深度学习FPGA实现基础知识10(Deep Learning(深度学习)卷积神经网络(Convolutional Neural Network,CNN))

    需求说明:深度学习FPGA实现知识储备 来自:http://blog.csdn.net/stdcoutzyx/article/details/41596663 说明:图文并茂,言简意赅. 自今年七月份 ...

  6. 【转载】 卷积神经网络(Convolutional Neural Network,CNN)

    作者:wuliytTaotao 出处:https://www.cnblogs.com/wuliytTaotao/ 本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,欢迎 ...

  7. CNN(Convolutional Neural Network)

    CNN(Convolutional Neural Network) 卷积神经网络(简称CNN)最早可以追溯到20世纪60年代,Hubel等人通过对猫视觉皮层细胞的研究表明,大脑对外界获取的信息由多层的 ...

  8. 论文阅读(Weilin Huang——【TIP2016】Text-Attentional Convolutional Neural Network for Scene Text Detection)

    Weilin Huang--[TIP2015]Text-Attentional Convolutional Neural Network for Scene Text Detection) 目录 作者 ...

  9. 论文笔记:(CVPR2019)Relation-Shape Convolutional Neural Network for Point Cloud Analysis

    目录 摘要 一.引言 二.相关工作 基于视图和体素的方法 点云上的深度学习 相关性学习 三.形状意识表示学习 3.1关系-形状卷积 建模 经典CNN的局限性 变换:从关系中学习 通道提升映射 3.2性 ...

随机推荐

  1. Kafka与ActiveMQ区别

    Kafka 是LinkedIn 开发的一个高性能.分布式的消息系统,广泛用于日志收集.流式数据处理.在线和离线消息分发等场景.虽然不是作为传统的MQ来设计,在大部分情况,Kafaka 也可以代替原先A ...

  2. centos7--zabbix3.4微信报警

    1.申请企业微信 1.1 注册企业微信的地址 https://qy.weixin.qq.com/ 1.2 按照提示进行填写 1.3 完善个人信息: 1.4 创建应用 根据提示创建应用: 1.5 筛出重 ...

  3. 浅析HTTP/2的多路复用

    HTTP/2有三大特性:头部压缩.Server Push.多路复用.前两个特性意思比较明确,也好理解,唯有多路复用不太好理解,尤其是和HTTP1.1进行对比的时候,这个问题我想了很长时间,也对比了很长 ...

  4. JVM命令行参数

    root@ubuntu-blade2:/sdf/jdk# javaUsage: java [-options] class [args...] (to execute a class) or java ...

  5. 《sicp》模块化程序设计 笔记

    <sicp>模块化程序设计 2.2.3 序列作为一种约定界面 学习笔记 这节中,讲述了一种模块化的程序设计思想,也就是将程序设计为如同信号处理过程一样,采用级联的方式将程序各个部分组合在一 ...

  6. 肿瘤免疫疗法 | 细胞治疗和PD1/PDL1 | Tumor immunotherapy | cell therapy

    人类肿瘤治疗史上的里程碑无疑一定有一座是肿瘤免疫疗法的. 而肿瘤免疫疗法的主要两大领域,细胞治疗以及以PD1/PDL1为代表的免疫检查点抑制剂都在飞速发展. 目前,已经有5种抗PD1/PDL1抗体药物 ...

  7. mycat启动报Unable to start JVM: No such file or directory (2)【转】

    mycat启动失败,查看日志 /mycat/logs/wrapper.log发现如下信息 1  STATUS | wrapper  | 2017/11/22 16:15:17 | --> Wra ...

  8. could not find 'gopls

    安装go tools 安装以上后用vim打开go代码,使用函数跳转时会出现: E718: Funcref requiredvim-go: could not find 'gopls'. Run :Go ...

  9. Android 关于selector中item顺序的问题

    selector的item从上到下是按照匹配原则来改变状态的,一旦匹配到某个item的状态,就不会继续往下匹配了. https://blog.csdn.net/l403040463/article/d ...

  10. Git - 高级合并

    Git - 高级合并https://git-scm.com/book/zh/v2/Git-%E5%B7%A5%E5%85%B7-%E9%AB%98%E7%BA%A7%E5%90%88%E5%B9%B6 ...