[cnn]cnn训练MINST数据集demo
[cnn]cnn训练MINST数据集demo
tips:
在文件路径进入conda
输入
jupyter nbconvert --to markdown test.ipynb
将ipynb文件转化成markdown文件
jupyter nbconvert --to html test.ipynb
jupyter nbconvert --to pdf test.ipynb
(html,pdf文件同理)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
input_size = 28 #图像尺寸 28*28
num_class = 10 #标签总数
num_epochs = 3 #训练总周期
batch_size = 64 #一个批次多少图片
train_dataset = datasets.MNIST(
root='data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
test_dataset = datasets.MNIST(
root='data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = batch_size,
shuffle = True,
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = batch_size,
shuffle = True,
)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入为(1,28,28)
nn.Conv2d(
in_channels=1,
out_channels=16, #要得到几个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2,
), #输出特征图为(16*28*28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), #池化(2x2) 输出为(16,14,14)
)
self.conv2 = nn.Sequential( #输入(16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #flatten操作 输出为(batch_size,32*7*7)
output = self.out(x)
return output, x
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
net = CNN().to(device)
criterion = nn.CrossEntropyLoss() #损失函数
#优化器
optimizer = optim.Adam(net.parameters(),lr = 0.001)
for epoch in range(num_epochs+1):
#保留epoch的结果
train_rights = []
for batch_idx,(data,target) in enumerate(train_loader):
data = data.to(device)
target = target.to(device)
net.train()
output = net(data)[0]
loss = criterion(output,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output,target)
train_rights.append(right)
if batch_idx %100 ==0:
net.eval()
val_rights = []
for(data,target) in test_loader:
data = data.to(device)
target = target.to(device)
output = net(data)[0]
right = accuracy(output,target)
val_rights.append(right)
#计算准确率
train_r = (sum([i[0] for i in train_rights]),sum(i[1] for i in train_rights))
val_r = (sum([i[0] for i in val_rights]),sum(i[1] for i in val_rights))
print('当前epoch:{}[{}/{}({:.0f}%)]\t损失:{:.2f}\t训练集准确率:{:.2f}%\t测试集准确率:{:.2f}%'.format(
epoch,
batch_idx * batch_size,
len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].cpu().numpy() / train_r[1],
100. * val_r[0].cpu().numpy() / val_r[1]
)
)
当前epoch:0[0/60000(0%)] 损失:2.31 训练集准确率:4.69% 测试集准确率:21.01%
当前epoch:0[6400/60000(11%)] 损失:0.51 训练集准确率:75.94% 测试集准确率:91.43%
当前epoch:0[12800/60000(21%)] 损失:0.28 训练集准确率:84.05% 测试集准确率:93.87%
当前epoch:0[19200/60000(32%)] 损失:0.15 训练集准确率:87.77% 测试集准确率:96.42%
当前epoch:0[25600/60000(43%)] 损失:0.08 训练集准确率:89.82% 测试集准确率:97.02%
当前epoch:0[32000/60000(53%)] 损失:0.14 训练集准确率:91.20% 测试集准确率:97.42%
当前epoch:0[38400/60000(64%)] 损失:0.04 训练集准确率:92.13% 测试集准确率:97.59%
当前epoch:0[44800/60000(75%)] 损失:0.08 训练集准确率:92.83% 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)] 损失:0.12 训练集准确率:93.38% 测试集准确率:97.77%
当前epoch:0[57600/60000(96%)] 损失:0.19 训练集准确率:93.81% 测试集准确率:98.24%
当前epoch:1[0/60000(0%)] 损失:0.07 训练集准确率:95.31% 测试集准确率:97.90%
当前epoch:1[6400/60000(11%)] 损失:0.08 训练集准确率:97.96% 测试集准确率:98.27%
当前epoch:1[12800/60000(21%)] 损失:0.10 训练集准确率:97.99% 测试集准确率:98.30%
当前epoch:1[19200/60000(32%)] 损失:0.02 训练集准确率:98.07% 测试集准确率:98.20%
当前epoch:1[25600/60000(43%)] 损失:0.17 训练集准确率:98.09% 测试集准确率:98.40%
当前epoch:1[32000/60000(53%)] 损失:0.12 训练集准确率:98.11% 测试集准确率:98.68%
当前epoch:1[38400/60000(64%)] 损失:0.05 训练集准确率:98.11% 测试集准确率:98.63%
当前epoch:1[44800/60000(75%)] 损失:0.10 训练集准确率:98.14% 测试集准确率:98.70%
当前epoch:1[51200/60000(85%)] 损失:0.04 训练集准确率:98.19% 测试集准确率:98.56%
当前epoch:1[57600/60000(96%)] 损失:0.03 训练集准确率:98.23% 测试集准确率:98.67%
当前epoch:2[0/60000(0%)] 损失:0.06 训练集准确率:98.44% 测试集准确率:98.32%
当前epoch:2[6400/60000(11%)] 损失:0.03 训练集准确率:98.64% 测试集准确率:98.63%
当前epoch:2[12800/60000(21%)] 损失:0.05 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)] 损失:0.01 训练集准确率:98.72% 测试集准确率:98.69%
当前epoch:2[25600/60000(43%)] 损失:0.01 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[32000/60000(53%)] 损失:0.03 训练集准确率:98.70% 测试集准确率:98.76%
当前epoch:2[38400/60000(64%)] 损失:0.07 训练集准确率:98.70% 测试集准确率:98.62%
当前epoch:2[44800/60000(75%)] 损失:0.07 训练集准确率:98.72% 测试集准确率:98.60%
当前epoch:2[51200/60000(85%)] 损失:0.03 训练集准确率:98.71% 测试集准确率:98.99%
当前epoch:2[57600/60000(96%)] 损失:0.05 训练集准确率:98.74% 测试集准确率:98.84%
[cnn]cnn训练MINST数据集demo的更多相关文章
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)
版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...
- Fast RCNN 训练自己数据集 (1编译配置)
FastRCNN 训练自己数据集 (1编译配置) 转载请注明出处,楼燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/ https:/ ...
- 神经网络中的Heloo,World,基于MINST数据集的LeNet
前言 最近刚开始接触机器学习,记录下目前的一些理解,以及看到的一些好文章mark一下 1.MINST数据集 MNIST 数据集来自美国国家标准与技术研究所, National Institute of ...
- 使用py-faster-rcnn训练VOC2007数据集时遇到问题
使用py-faster-rcnn训练VOC2007数据集时遇到如下问题: 1. KeyError: 'chair' File "/home/sai/py-faster-rcnn/tools/ ...
- 分类问题(一)MINST数据集与二元分类器
分类问题 在机器学习中,主要有两大类问题,分别是分类和回归.下面我们先主讲分类问题. MINST 这里我们会用MINST数据集,也就是众所周知的手写数字集,机器学习中的 Hello World.sk- ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- 使用caffe训练mnist数据集 - caffe教程实战(一)
个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...
- 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集
上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...
- Paper Reading - CNN+CNN: Convolutional Decoders for Image Captioning
Link of the Paper: https://arxiv.org/abs/1805.09019 Innovations: The authors propose a CNN + CNN fra ...
随机推荐
- Docker容器怎么安装Vim编辑器
在现代软件开发和系统管理中,Docker已经成为一个不可或缺的工具.它允许我们轻松地创建.部署和运行应用程序,以及构建可移植的容器化环境.然而,在Docker容器中安装特定的工具可能会有一些挑战, ...
- TCP 粘包
TCP(Transmission Control Protocol,传输控制协议)是一种传输层协议. TCP提供了以下主要功能: 可靠性:TCP使用确认.重传和校验等机制来确保数据的可靠传输.它能够检 ...
- 基于bert-base-chinese训练bert模型(最后附上整体代码)
目录: 一.bert-base-chinese模型下载 二.数据集的介绍 三.完成类的代码 四.写训练方法 五.总源码及源码参考出处 一.bert-base-chinese模型下载 对于已经预训练好的 ...
- Socket.io入门
Socket.io入门 根据官方文档socket.io使用必须客户端根服务端一致,socket.io不兼容webSocket或者其他模块,因为socket.io在连接时做了自定义处理, 所以不同的长连 ...
- Vue.js 官方脚手架 create-vue 是怎么实现的?
Vue.js 官方脚手架 create-vue 是怎么实现的? 摘要 本文共分为四个部分,系统解析了vue.js 官方脚手架 create-vue 的实现细节. 第一部分主要是一些准备工作,如源码下载 ...
- SSM-Mybatis笔记
目录 Mybatis-9.28 1.简介 1.1.什么是Mybatis 1.2.持久化 1.3.持久层 1.4 为什么需要Mybatis? 2.第一个Mybatis程序 2.1.搭建环境 2.2.创建 ...
- 关于tiptop gp5.2采购模块,价格变更的随笔
采购价格变更要看具体环节,你可以把他当作是三张表,采购价格表.收货价格表.入库价格表,这些还好处理,如果已抛砖到财务端生成账款再要求改价格就更复杂,会产生更多张表了,改起来也就更复杂. 用apmt91 ...
- VoIP==Voice over Internet Protocol
基于IP的语音传输(英语:Voice over Internet Protocol,缩写为VoIP)是一种语音通话技术,经由网际协议(IP)来达成语音通话与多媒体会议,也就是经由互联网来进行通信.其他 ...
- SOA认知和方法论
1 前言 1.1 架构分类 在软件设计领域,企业架构通常被划分为如下五种分类: 如何理解架构分类依据及其彼此之间的关系?业务是企业赖以生存之本,因此业务架构是基础.是灵魂,其他一切均是对业务架构的支撑 ...
- .then()方法的意思和用法
then()方法是异步执行. 意思是:就是当.then()前的方法执行完后再执行then()内部的程序,这样就避免了,数据没获取到等的问题. 语法:promise.then(onCompleted, ...