[cnn]cnn训练MINST数据集demo

tips:

在文件路径进入conda

输入

jupyter nbconvert --to markdown test.ipynb

将ipynb文件转化成markdown文件

jupyter nbconvert --to html test.ipynb

jupyter nbconvert --to pdf test.ipynb

(html,pdf文件同理)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28   #图像尺寸 28*28
num_class = 10 #标签总数
num_epochs = 3 #训练总周期
batch_size = 64 #一个批次多少图片 train_dataset = datasets.MNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True,
) test_dataset = datasets.MNIST(
root='data',
train=False,
transform=transforms.ToTensor(),
download=True,
) train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = batch_size,
shuffle = True,
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入为(1,28,28)
nn.Conv2d(
in_channels=1,
out_channels=16, #要得到几个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2,
), #输出特征图为(16*28*28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
)
self.conv2 = nn.Sequential( #输入(16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接 def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
output = self.out(x)
return output, x
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001) for epoch in range(num_epochs+1):
#保留epoch的结果
train_rights = []
for batch_idx,(data,target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
net.train()
output = net(data)[0]
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right) if batch_idx %100 ==0:
net.eval()
val_rights = []
for(data,target) in test_loader:
data = data.to(device)
target = target.to(device)
output = net(data)[0]
right = accuracy(output,target)
val_rights.append(right)
#计算准确率
train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights)) print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
epoch,
batch_idx * batch_size,
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].cpu().numpy() / train_r[1],
100. * val_r[0].cpu().numpy() / val_r[1]
)
)
当前epoch:0[0/60000(0%)]	损失:2.31	训练集准确率:4.69%	测试集准确率:21.01%
当前epoch:0[6400/60000(11%)] 损失:0.51 训练集准确率:75.94% 测试集准确率:91.43%
当前epoch:0[12800/60000(21%)] 损失:0.28 训练集准确率:84.05% 测试集准确率:93.87%
当前epoch:0[19200/60000(32%)] 损失:0.15 训练集准确率:87.77% 测试集准确率:96.42%
当前epoch:0[25600/60000(43%)] 损失:0.08 训练集准确率:89.82% 测试集准确率:97.02%
当前epoch:0[32000/60000(53%)] 损失:0.14 训练集准确率:91.20% 测试集准确率:97.42%
当前epoch:0[38400/60000(64%)] 损失:0.04 训练集准确率:92.13% 测试集准确率:97.59%
当前epoch:0[44800/60000(75%)] 损失:0.08 训练集准确率:92.83% 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)] 损失:0.12 训练集准确率:93.38% 测试集准确率:97.77%
当前epoch:0[57600/60000(96%)] 损失:0.19 训练集准确率:93.81% 测试集准确率:98.24%
当前epoch:1[0/60000(0%)] 损失:0.07 训练集准确率:95.31% 测试集准确率:97.90%
当前epoch:1[6400/60000(11%)] 损失:0.08 训练集准确率:97.96% 测试集准确率:98.27%
当前epoch:1[12800/60000(21%)] 损失:0.10 训练集准确率:97.99% 测试集准确率:98.30%
当前epoch:1[19200/60000(32%)] 损失:0.02 训练集准确率:98.07% 测试集准确率:98.20%
当前epoch:1[25600/60000(43%)] 损失:0.17 训练集准确率:98.09% 测试集准确率:98.40%
当前epoch:1[32000/60000(53%)] 损失:0.12 训练集准确率:98.11% 测试集准确率:98.68%
当前epoch:1[38400/60000(64%)] 损失:0.05 训练集准确率:98.11% 测试集准确率:98.63%
当前epoch:1[44800/60000(75%)] 损失:0.10 训练集准确率:98.14% 测试集准确率:98.70%
当前epoch:1[51200/60000(85%)] 损失:0.04 训练集准确率:98.19% 测试集准确率:98.56%
当前epoch:1[57600/60000(96%)] 损失:0.03 训练集准确率:98.23% 测试集准确率:98.67%
当前epoch:2[0/60000(0%)] 损失:0.06 训练集准确率:98.44% 测试集准确率:98.32%
当前epoch:2[6400/60000(11%)] 损失:0.03 训练集准确率:98.64% 测试集准确率:98.63%
当前epoch:2[12800/60000(21%)] 损失:0.05 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)] 损失:0.01 训练集准确率:98.72% 测试集准确率:98.69%
当前epoch:2[25600/60000(43%)] 损失:0.01 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[32000/60000(53%)] 损失:0.03 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[38400/60000(64%)] 损失:0.07 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[44800/60000(75%)] 损失:0.07 训练集准确率:98.72% 测试集准确率:98.60%
当前epoch:2[51200/60000(85%)] 损失:0.03 训练集准确率:98.71% 测试集准确率:98.99%
当前epoch:2[57600/60000(96%)] 损失:0.05 训练集准确率:98.74% 测试集准确率:98.84%

[cnn]cnn训练MINST数据集demo的更多相关文章

  1. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  2. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  3. Fast RCNN 训练自己数据集 (1编译配置)

    FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...

  4. 神经网络中的Heloo,World,基于MINST数据集的LeNet

    前言 最近刚开始接触机器学习,记录下目前的一些理解,以及看到的一些好文章mark一下 1.MINST数据集 MNIST 数据集来自美国国家标准与技术研究所, National Institute of ...

  5. 使用py-faster-rcnn训练VOC2007数据集时遇到问题

    使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...

  6. 分类问题(一)MINST数据集与二元分类器

    分类问题 在机器学习中,主要有两大类问题,分别是分类和回归.下面我们先主讲分类问题. MINST 这里我们会用MINST数据集,也就是众所周知的手写数字集,机器学习中的 Hello World.sk- ...

  7. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  8. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  9. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  10. Paper Reading - CNN+CNN: Convolutional Decoders for Image Captioning

    Link of the Paper: https://arxiv.org/abs/1805.09019 Innovations: The authors propose a CNN + CNN fra ...

随机推荐

  1. SpringBoot3集成ElasticSearch

    目录 一.简介 二.环境搭建 1.下载安装包 2.服务启动 三.工程搭建 1.工程结构 2.依赖管理 3.配置文件 四.基础用法 1.实体类 2.初始化索引 3.仓储接口 4.查询语法 五.参考源码 ...

  2. OpenLDAP服务器搭建

    一.关闭防火墙和selinux [root@localhost ~]# systemctl stop firewalld.service [root@localhost ~]# systemctl d ...

  3. Web服务器部署上线的踩坑流程回顾与知新

    5月份时曾部署上线了C++的Web服务器,温故而知新,本篇文章梳理总结一下部署流程知识: 最初的解决方案:https://blog.csdn.net/BinBinCome/article/detail ...

  4. JavaWeb项目开发环境搭建

    1. 安装JDK1.8 2. 安装Tomcat8 此处安装解压版apache-tomcat-8.0.47,直接将压缩包解压到指定目录即可.例如,D:\apache-tomcat-8.0.47 3. 安 ...

  5. ffmpeg 在xp和server2003/2008/2012上修复无法定位GetNumaNodeProcessorMaskEx的问题

    问题 在给开发一个手机视频网站时需要用到ffmpeg截取视频缩略图, 把项目提交到服务器(server2003/ server2008)上时, 发现在调用命令时会出现错误"无法定位GetNu ...

  6. xv6 进程切换中的锁:MIT6.s081/6.828 lectrue12:Coordination 以及 Lab6 Thread 心得

    引言 这节课和上一节xv6进程切换是一个完整的的进程切换专题,上一节主要讨论进程切换过程中的细节,而这一节主要讨论进程切换过程中锁的使用,所以本节的两大关键词就是"Coordination& ...

  7. Mybatiplus通用3.5.1版本及其以上的代码生成器工具类

    Mybatiplus通用3.5.1版本及其以上的代码生成器工具类 package com.gton.util; import com.baomidou.mybatisplus.annotation.F ...

  8. 深度学习 YOLO v1 源码+笔记

    """ Yolo V1 by tensorflow """ import numpy as np import tensorflow._ap ...

  9. 如何在 Ubuntu上使用snap安装Docker

    1 检查系统版本 具有sudo或root用户权限 2 安装 SNAP ctrl+alt+T 打开终端 运行以下命令以安装 SNAP sudo apt update sudo apt install s ...

  10. Super Apps 超级应用们背后的道家哲学

    众所周知,Elon Musk 想将 Twitter 重新设计定位成一款"超级应用 - X"的野心已经不再是秘密.伴随着应用商店中 Twitter 标志性的蓝鸟 Logo 被 X 取 ...