用pytorch做手写数字识别,识别l率达97.8%
pytorch做手写数字识别
效果如下:
工程目录如下
第一步 数据获取
下载MNIST库,这个库在网上,执行下面代码自动下载到当前data文件夹下
from torchvision.datasets import MNIST
import torchvision mnist = MNIST(root='./data',train=True,download=True) print(mnist)
print(mnist[0])
print(len(mnist))
img = mnist[0][0]
img.show()
dataset.py文件,读取数据并做预处理
'''
准备数据集
''' import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision def mnist_dataset(train): func = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.1307,),std=(0.3081,))
]) #1.准备Mnist数据集
return MNIST(root='./data',train=train,download=False,transform=func) def get_dataloader(train = True):
mnist = mnist_dataset(train)
return DataLoader(mnist,batch_size=128,shuffle=True) if __name__ == '__main__':
for (images,labels) in get_dataloader():
print(images.size())
print(labels.size())
break
models.py文件,定义训练的模型类
'''
定义模型
''' import torch.nn as nn
import torch.nn.functional as F class MnistModel(nn.Module): def __init__(self):
super(MnistModel,self).__init__()
self.fc1 = nn.Linear(1*28*28,100)
self.fc2 = nn.Linear(100,10) def forward(self,image):
image_viewd = image.view(-1,1*28*28) #[batch_size,1*28*28]
fc1_out = self.fc1(image_viewd) #[batch_size,100]
fc1_out_relu = F.relu(fc1_out) #[batch_size,100]
out = self.fc2(fc1_out_relu) #[batch_size,10] return F.log_softmax(out,dim=-1) #带权损失计算交叉熵
cong.py文件,定义一些常亮,设置使用cpu还是GPU
'''
项目配置
''' import torch train_batch_size = 128
test_batch_size = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train.py文件,模型训练文件,保存模型
"""
进行模型的训练
"""
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
import numpy as np
import torch
import os
from test import eval #1. 实例化模型,优化器,损失函数
model = MnistModel().to(conf.device)
optimizer = optim.Adam(model.parameters(),lr=1e-3) #2. 进行循环,进行训练
def train(epoch):
train_dataloader = get_dataloader(train=True)
bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
total_loss = []
for idx,(input,target) in bar:
input = input.to(conf.device)
target = target.to(conf.device)
#梯度置为0
optimizer.zero_grad()
#计算得到预测值
output = model(input)
#得到损失
loss = F.nll_loss(output,target)
#反向传播,计算损失
loss.backward()
total_loss.append(loss.item())
#参数的更新
optimizer.step()
#打印数据
if idx%10 ==0 :
bar.set_description_str("epcoh:{} idx:{},loss:{:.6f}".format(epoch,idx,np.mean(total_loss)))
torch.save(model.state_dict(),"./models/model.pkl")
torch.save(optimizer.state_dict(),"./models/optimizer.pkl") if __name__ == '__main__':
for i in range(10):
train(i)
eval()
test.py文件,模型测试文件,测试模型准确率
'''
进行模型评估
''' from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
import numpy as np
import torch
import os def eval():
#实例化模型,优化器,损失函数
model = MnistModel().to(conf.device) if os.path.exists("./models/model.pkl"):
model.load_state_dict(torch.load("./models/model.pkl")) test_dataloader = get_dataloader(train=False)
total_loss = []
total_acc = []
with torch.no_grad():
for input, target in test_dataloader: # 2. 进行循环,进行训练
input = input.to(conf.device)
target = target.to(conf.device)
# 计算得到预测值
output = model(input)
# 得到损失
loss = F.nll_loss(output, target)
# 反向传播,计算损失
total_loss.append(loss.item()) # 计算准确率
###计算预测值
pred = output.max(dim=-1)[-1]
total_acc.append(pred.eq(target).float().mean().item())
print("test loss:{},test acc:{}".format(np.mean(total_loss), np.mean(total_acc))) # if __name__ == '__main__':
# # for i in range(10):
# # train(i)
# eval()
用pytorch做手写数字识别,识别l率达97.8%的更多相关文章
- 【转】机器学习教程 十四-利用tensorflow做手写数字识别
模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...
- 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...
- opencv实现KNN手写数字的识别
人工智能是当下很热门的话题,手写识别是一个典型的应用.为了进一步了解这个领域,我阅读了大量的论文,并借助opencv完成了对28x28的数字图片(预处理后的二值图像)的识别任务. 预处理一张图片: 首 ...
- pytorch CNN 手写数字识别
一个被放弃的入门级的例子终于被我实现了,虽然还不太完美,但还是想记录下 1.预处理 相比较从库里下载数据集(关键是经常失败,格式也看不懂),更喜欢直接拿图片,从网上找了半天,最后从CSDN上下载了一个 ...
- caffe+opencv3.3dnn模块 完成手写数字图片识别
最近由于项目需要用到caffe,学习了下caffe的用法,在使用过程中也是遇到了些问题,通过上网搜索和问老师的方法解决了,在此记录下过程,方便以后查看,也希望能为和我一样的新手们提供帮助. 顺带附上老 ...
- 用tensorflow求手写数字的识别准确率 (简单版)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 mnist = in ...
- 吴裕雄 python神经网络 手写数字图片识别(5)
import kerasimport matplotlib.pyplot as pltfrom keras.models import Sequentialfrom keras.layers impo ...
- 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) from keras.datasets import mnist fro ...
- 吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别
import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...
随机推荐
- 关于《自动化测试实战宝典:Robot Framework + Python从小工到专家》
受新冠疫情影响,笔者被“困”在湖北老家七十余天,于4月1号(愚人节)这天,终于返回到广州.当前国内疫情基本已趋于平稳,但全球疫情整体势态仍在持续疯涨,累计确诊病例已近80万人.祈祷这场全球性灾难能尽早 ...
- 模块 configparser 配置文件生成修改
此模块用于生成和修改常见配置文档 1.来看一个好多软件的常见配置文件格式如下***.ini [DEFAULT] ServerAliveInterval = 45 Compression = yes C ...
- JavaScript 异步、栈、事件循环、任务队列
概览 我们经常会听到引擎和runtime,它们的区别是什么呢? 引擎:解释并编译代码,让它变成能交给机器运行的代码(runnable commands). runtime:就是运行环境,它提供一些对外 ...
- javax.el.PropertyNotFoundException: 类型[cn.cqsw.pojo.Course]上找不到属性[CourseId]
今天在JSP利用EL表达式取值报了 "javax.el.PropertyNotFoundException” 1 Caused by: org.apache.jasper.JasperExc ...
- 四、【Docker笔记】Docker容器
容器是Docker的另一个核心概念,容器就是镜像的一个运行实例,只是它具有一个可写的文件层,而镜像是一个只读的文件. 一.创建容器 1.新建容器 我们可以使用 docker create 命令来创建一 ...
- IE不支持sessionStorage问题
IE8及以上版本是支持的,如果你的项目在IE8及以上打开报错: 那是因为:页面要放在服务器上才能有效!!!!!!!!!!!!!!!!!!!!!!!!
- python--爬虫(XPath与BeautifulSoup4)
获取页面内容除使用正则意外,还可以使用XPath,其原理是将html代码转换为xml格式,然后使用XPath查找html节点或元素. 选取节点 XPath使用路径表达式来选取XML文档中的节点或节点集 ...
- Spring(一):Spring入门程序和IoC初步理解
本文是按照狂神说的教学视频学习的笔记,强力推荐,教学深入浅出一遍就懂!b站搜索狂神说或点击下面链接 https://space.bilibili.com/95256449?spm_id_from=33 ...
- 天天写order by,你知道Mysql底层执行原理吗?
前言 文章首发于微信公众号[码猿技术专栏]. 在实际的开发中一定会碰到根据某个字段进行排序后来显示结果的需求,但是你真的理解order by在 Mysql 底层是如何执行的吗? 假设你要查询城市是苏州 ...
- Python 1基础语法四(数字类型、输入输出汇总和命令行参数)
一.数字(Number)类型 python中数字有四种类型:整数.布尔型.浮点数和复数. int (整数), 如 1, 只有一种整数类型 int,表示为长整型,没有 python2 中的 Long. ...