[cnn]cnn训练MINST数据集demo
[cnn]cnn训练MINST数据集demo
tips:
在文件路径进入conda
输入
jupyter nbconvert --to markdown test.ipynb
将ipynb文件转化成markdown文件
jupyter nbconvert --to html test.ipynb
jupyter nbconvert --to pdf test.ipynb
(html,pdf文件同理)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28 #图像尺寸 28*28
num_class = 10 #标签总数
num_epochs = 3 #训练总周期
batch_size = 64 #一个批次多少图片
train_dataset = datasets.MNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
test_dataset = datasets.MNIST(
root='data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = batch_size,
shuffle = True,
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入为(1,28,28)
nn.Conv2d(
in_channels=1,
out_channels=16, #要得到几个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2,
), #输出特征图为(16*28*28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
)
self.conv2 = nn.Sequential( #输入(16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
output = self.out(x)
return output, x
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001)
for epoch in range(num_epochs+1):
#保留epoch的结果
train_rights = []
for batch_idx,(data,target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
net.train()
output = net(data)[0]
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right)
if batch_idx %100 ==0:
net.eval()
val_rights = []
for(data,target) in test_loader:
data = data.to(device)
target = target.to(device)
output = net(data)[0]
right = accuracy(output,target)
val_rights.append(right)
#计算准确率
train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights))
print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
epoch,
batch_idx * batch_size,
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].cpu().numpy() / train_r[1],
100. * val_r[0].cpu().numpy() / val_r[1]
)
)
当前epoch:0[0/60000(0%)] 损失:2.31 训练集准确率:4.69% 测试集准确率:21.01%
当前epoch:0[6400/60000(11%)] 损失:0.51 训练集准确率:75.94% 测试集准确率:91.43%
当前epoch:0[12800/60000(21%)] 损失:0.28 训练集准确率:84.05% 测试集准确率:93.87%
当前epoch:0[19200/60000(32%)] 损失:0.15 训练集准确率:87.77% 测试集准确率:96.42%
当前epoch:0[25600/60000(43%)] 损失:0.08 训练集准确率:89.82% 测试集准确率:97.02%
当前epoch:0[32000/60000(53%)] 损失:0.14 训练集准确率:91.20% 测试集准确率:97.42%
当前epoch:0[38400/60000(64%)] 损失:0.04 训练集准确率:92.13% 测试集准确率:97.59%
当前epoch:0[44800/60000(75%)] 损失:0.08 训练集准确率:92.83% 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)] 损失:0.12 训练集准确率:93.38% 测试集准确率:97.77%
当前epoch:0[57600/60000(96%)] 损失:0.19 训练集准确率:93.81% 测试集准确率:98.24%
当前epoch:1[0/60000(0%)] 损失:0.07 训练集准确率:95.31% 测试集准确率:97.90%
当前epoch:1[6400/60000(11%)] 损失:0.08 训练集准确率:97.96% 测试集准确率:98.27%
当前epoch:1[12800/60000(21%)] 损失:0.10 训练集准确率:97.99% 测试集准确率:98.30%
当前epoch:1[19200/60000(32%)] 损失:0.02 训练集准确率:98.07% 测试集准确率:98.20%
当前epoch:1[25600/60000(43%)] 损失:0.17 训练集准确率:98.09% 测试集准确率:98.40%
当前epoch:1[32000/60000(53%)] 损失:0.12 训练集准确率:98.11% 测试集准确率:98.68%
当前epoch:1[38400/60000(64%)] 损失:0.05 训练集准确率:98.11% 测试集准确率:98.63%
当前epoch:1[44800/60000(75%)] 损失:0.10 训练集准确率:98.14% 测试集准确率:98.70%
当前epoch:1[51200/60000(85%)] 损失:0.04 训练集准确率:98.19% 测试集准确率:98.56%
当前epoch:1[57600/60000(96%)] 损失:0.03 训练集准确率:98.23% 测试集准确率:98.67%
当前epoch:2[0/60000(0%)] 损失:0.06 训练集准确率:98.44% 测试集准确率:98.32%
当前epoch:2[6400/60000(11%)] 损失:0.03 训练集准确率:98.64% 测试集准确率:98.63%
当前epoch:2[12800/60000(21%)] 损失:0.05 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)] 损失:0.01 训练集准确率:98.72% 测试集准确率:98.69%
当前epoch:2[25600/60000(43%)] 损失:0.01 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[32000/60000(53%)] 损失:0.03 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[38400/60000(64%)] 损失:0.07 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[44800/60000(75%)] 损失:0.07 训练集准确率:98.72% 测试集准确率:98.60%
当前epoch:2[51200/60000(85%)] 损失:0.03 训练集准确率:98.71% 测试集准确率:98.99%
当前epoch:2[57600/60000(96%)] 损失:0.05 训练集准确率:98.74% 测试集准确率:98.84%
[cnn]cnn训练MINST数据集demo的更多相关文章
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- Fast RCNN 训练自己数据集 (1编译配置)
FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...
- 神经网络中的Heloo,World,基于MINST数据集的LeNet
前言 最近刚开始接触机器学习,记录下目前的一些理解,以及看到的一些好文章mark一下 1.MINST数据集 MNIST 数据集来自美国国家标准与技术研究所, National Institute of ...
- 使用py-faster-rcnn训练VOC2007数据集时遇到问题
使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...
- 分类问题(一)MINST数据集与二元分类器
分类问题 在机器学习中,主要有两大类问题,分别是分类和回归.下面我们先主讲分类问题. MINST 这里我们会用MINST数据集,也就是众所周知的手写数字集,机器学习中的 Hello World.sk- ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- Paper Reading - CNN+CNN: Convolutional Decoders for Image Captioning
Link of the Paper: https://arxiv.org/abs/1805.09019 Innovations: The authors propose a CNN + CNN fra ...
随机推荐
- 【LaTeX】语法(更新中)
目录 长度 空行 空格 超链接 数学公式 段落中(隐式) 单独成段(显式) 居中,左对齐,右对齐 居中 左对齐 右对齐 参考文献配置 TODO 参考资料 中文支持参考环境配置中的 内容,在这里不做重复 ...
- AI绘画Stable Diffusion实战操作: 62个咒语调教-时尚杂志封面
今天来给大家分享,如何用sd简单的咒语输出好看的图片的教程,今天做的是时尚杂志专题,话不多说直入主题. 还不会StableDiffusion的基本操作,推荐看看这篇保姆级教程: AI绘画:Stable ...
- Avalonia开发(一)环境搭建
一.介绍 开源 GitHub:https://github.com/AvaloniaUI/Avalonia/ 多平台支持,包括Windows.mac OS.Linux.iOS.Android.Sams ...
- dedebiz数据重置
TRUNCATE biz_addonarticle;TRUNCATE biz_addonimages;TRUNCATE biz_addoninfos;TRUNCATE biz_addonshop;TR ...
- 中国这么多 Java 开发者,应该诞生出生态级应用开发框架
1.必须要有,不然就永远不会有 应用开发框架,虽然没有芯片.操作系统.数据库.编程语言这些重要.但是最终呈现在用户面前的,总是有软件部分.而软件系统开发,一般都需要应用开发框架,它是软件系统的基础性部 ...
- DBeaver Ultimate 22.1.0 连接数据库(MySQL+Mongo+Clickhouse)
前言 继续书接上文 Docker Compose V2 安装常用数据库MySQL+Mongo,部署安装好之后我本来是找了一个web端的在线连接数据库的工具,但是使用过程中并不丝滑,最终还是选择了使用 ...
- 14.9 Socket 高效文件传输
网络上的文件传输功能也是很有必要实现一下的,网络传输文件的过程通常分为客户端和服务器端两部分.客户端可以选择上传或下载文件,将文件分块并逐块发送到服务器,或者从服务器分块地接收文件.服务器端接收来自客 ...
- 【XXE实战】——浅看两道CTF题
[XXE实战]--浅看两道CTF题 上一条帖子[XXE漏洞]原理及实践演示对XXE的一些原理进行了浅析,于是写了两道CTF题巩固一下,顺便也记录一下第一次写出来CTF.两道题都是在BUU上找的:[NC ...
- Building Bridges 题解
Building Bridges 题目大意 连接两根柱子 \(i,j\) 的代价是 \((h_i-h_j)^2+\sum\limits_{k=j+1}^{i-1}w_k\),连接具有传递性,求将 \( ...
- 游戏客户端开发中对MVC模式的思考
话说在前头,我分析MVC模式是为了确定自己要做的独立游戏的结构出来,并不适用于大型商业游戏的开发. MVC模式的概述 关于MVC模式,Model用于存储数据,View层用于显示数据,Controlle ...