[cnn]cnn训练MINST数据集demo

tips:

在文件路径进入conda

输入

jupyter nbconvert --to markdown test.ipynb

将ipynb文件转化成markdown文件

jupyter nbconvert --to html test.ipynb

jupyter nbconvert --to pdf test.ipynb

(html,pdf文件同理)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28   #图像尺寸 28*28
num_class = 10 #标签总数
num_epochs = 3 #训练总周期
batch_size = 64 #一个批次多少图片 train_dataset = datasets.MNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True,
) test_dataset = datasets.MNIST(
root='data',
train=False,
transform=transforms.ToTensor(),
download=True,
) train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = batch_size,
shuffle = True,
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入为(1,28,28)
nn.Conv2d(
in_channels=1,
out_channels=16, #要得到几个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2,
), #输出特征图为(16*28*28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
)
self.conv2 = nn.Sequential( #输入(16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接 def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
output = self.out(x)
return output, x
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001) for epoch in range(num_epochs+1):
#保留epoch的结果
train_rights = []
for batch_idx,(data,target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
net.train()
output = net(data)[0]
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right) if batch_idx %100 ==0:
net.eval()
val_rights = []
for(data,target) in test_loader:
data = data.to(device)
target = target.to(device)
output = net(data)[0]
right = accuracy(output,target)
val_rights.append(right)
#计算准确率
train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights)) print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
epoch,
batch_idx * batch_size,
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].cpu().numpy() / train_r[1],
100. * val_r[0].cpu().numpy() / val_r[1]
)
)
当前epoch:0[0/60000(0%)]	损失:2.31	训练集准确率:4.69%	测试集准确率:21.01%
当前epoch:0[6400/60000(11%)] 损失:0.51 训练集准确率:75.94% 测试集准确率:91.43%
当前epoch:0[12800/60000(21%)] 损失:0.28 训练集准确率:84.05% 测试集准确率:93.87%
当前epoch:0[19200/60000(32%)] 损失:0.15 训练集准确率:87.77% 测试集准确率:96.42%
当前epoch:0[25600/60000(43%)] 损失:0.08 训练集准确率:89.82% 测试集准确率:97.02%
当前epoch:0[32000/60000(53%)] 损失:0.14 训练集准确率:91.20% 测试集准确率:97.42%
当前epoch:0[38400/60000(64%)] 损失:0.04 训练集准确率:92.13% 测试集准确率:97.59%
当前epoch:0[44800/60000(75%)] 损失:0.08 训练集准确率:92.83% 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)] 损失:0.12 训练集准确率:93.38% 测试集准确率:97.77%
当前epoch:0[57600/60000(96%)] 损失:0.19 训练集准确率:93.81% 测试集准确率:98.24%
当前epoch:1[0/60000(0%)] 损失:0.07 训练集准确率:95.31% 测试集准确率:97.90%
当前epoch:1[6400/60000(11%)] 损失:0.08 训练集准确率:97.96% 测试集准确率:98.27%
当前epoch:1[12800/60000(21%)] 损失:0.10 训练集准确率:97.99% 测试集准确率:98.30%
当前epoch:1[19200/60000(32%)] 损失:0.02 训练集准确率:98.07% 测试集准确率:98.20%
当前epoch:1[25600/60000(43%)] 损失:0.17 训练集准确率:98.09% 测试集准确率:98.40%
当前epoch:1[32000/60000(53%)] 损失:0.12 训练集准确率:98.11% 测试集准确率:98.68%
当前epoch:1[38400/60000(64%)] 损失:0.05 训练集准确率:98.11% 测试集准确率:98.63%
当前epoch:1[44800/60000(75%)] 损失:0.10 训练集准确率:98.14% 测试集准确率:98.70%
当前epoch:1[51200/60000(85%)] 损失:0.04 训练集准确率:98.19% 测试集准确率:98.56%
当前epoch:1[57600/60000(96%)] 损失:0.03 训练集准确率:98.23% 测试集准确率:98.67%
当前epoch:2[0/60000(0%)] 损失:0.06 训练集准确率:98.44% 测试集准确率:98.32%
当前epoch:2[6400/60000(11%)] 损失:0.03 训练集准确率:98.64% 测试集准确率:98.63%
当前epoch:2[12800/60000(21%)] 损失:0.05 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)] 损失:0.01 训练集准确率:98.72% 测试集准确率:98.69%
当前epoch:2[25600/60000(43%)] 损失:0.01 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[32000/60000(53%)] 损失:0.03 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[38400/60000(64%)] 损失:0.07 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[44800/60000(75%)] 损失:0.07 训练集准确率:98.72% 测试集准确率:98.60%
当前epoch:2[51200/60000(85%)] 损失:0.03 训练集准确率:98.71% 测试集准确率:98.99%
当前epoch:2[57600/60000(96%)] 损失:0.05 训练集准确率:98.74% 测试集准确率:98.84%

[cnn]cnn训练MINST数据集demo的更多相关文章

  1. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. Fast RCNN 训练自己数据集 (1编译配置)

    FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...

  4. 神经网络中的Heloo,World,基于MINST数据集的LeNet

    前言 最近刚开始接触机器学习,记录下目前的一些理解,以及看到的一些好文章mark一下 1.MINST数据集 MNIST 数据集来自美国国家标准与技术研究所, National Institute of ...

  5. 使用py-faster-rcnn训练VOC2007数据集时遇到问题

    使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...

  6. 分类问题(一)MINST数据集与二元分类器

    分类问题 在机器学习中,主要有两大类问题,分别是分类和回归.下面我们先主讲分类问题. MINST 这里我们会用MINST数据集,也就是众所周知的手写数字集,机器学习中的 Hello World.sk- ...

  7. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  8. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  9. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  10. Paper Reading - CNN+CNN: Convolutional Decoders for Image Captioning

    Link of the Paper: https://arxiv.org/abs/1805.09019 Innovations: The authors propose a CNN + CNN fra ...

随机推荐

  1. 优化Redis缓存淘汰机制解决性能测试中报错率逐渐攀升问题

    在某个查询场景的性能测试过程中,遇到了一个问题:测试过程中报错率逐渐攀升.进一步检查后发现,在查询业务所在应用的后台日志和平台应用的后台日志中,都出现了用户登录相关的报错信息.经过排查分析,发现了问题 ...

  2. WPF-封装自定义雷达图控件

    源码地址:https://gitee.com/LiuShuiRuoBing/code_blog 雷达图用于表示不同内容的占比关系,在项目中有广泛的应用,但是目前未曾有封装良好的雷达图控件,鉴于最近项目 ...

  3. 《CTFshow-Web入门》02. Web 11~20

    @ 目录 web11 题解 原理 web12 题解 web13 题解 web14 题解 web15 题解 web16 题解 原理 web17 题解 web18 题解 原理 web19 题解 web20 ...

  4. 《最新出炉》系列入门篇-Python+Playwright自动化测试-15-playwright处理浏览器多窗口切换

    1.简介 浏览器多窗口的切换问题相比大家不会陌生吧,之前宏哥在java+selenium系列文章中就有介绍过.大致步骤就是:使用selenium进行浏览器的多个窗口切换测试,如果我们打开了多个网页,进 ...

  5. 淘宝商品详情 API的使用说明

    淘宝平台提供了 API 接口可以用于获取淘宝商品详情信息.通过 API 接口,我们可以获取到商品的基本信息.价格.评论及评价等详细信息.以下是使用说明: 获取淘宝API账号 在获取淘宝商品详情 API ...

  6. 第1章 Git概述

    第1章 Git概述 Git 是一个免费的.开源的分布式版本控制系统,可以快速高效地处理从小型到大型的各种项目. Git 易于学习,占地面积小,性能极快. 它具有廉价的本地库,方便的暂存区域和多个工作流 ...

  7. Go语言中JSON的反序列化规则

    Unmarshal 解析 func Unmarshal(data []byte, v any) error Unmarshal 解析 JSON 编码的数据,并将结果存储在 v 指向的值中.如果 v 为 ...

  8. 各快 100 倍?4G、5G、6G 相差这么多吗

    二狗子今天晚上有点 emo,为什么呢? 原来是二狗子心心念很久的一个手游上线了,二狗子兴冲冲地下载了 40 多分钟,终于下载完了游戏.结果打开游戏一看,发现游戏内部的更新写着预计 30 分钟完成更新. ...

  9. LVS+keepalived配置高可用架构和负载均衡机制(1)

    一.基础知识 1. 四层负载均衡(基于IP+端口的负载均衡) 所谓四层负载均衡,也就是主要通过报文中的目标ip地址和端口,再加上负载均衡设备设置的服务器选择方式(分发策略,轮询),决定最终选择的内部服 ...

  10. 算法修养--广度优先搜索BFS

    广度优先算法(BFS) 广度优先算法(Breadth-First Search)是在图和树领域的搜索方法,其核心思想是从一个起始点开始,访问其所有的临近节点,然后再按照相同的方式访问这些临近节点的节点 ...