Pytorch 手写数字识别 深度学习基础分享

本篇是一次内部分享,给项目开发的同事分享什么是深度学习。用最简单的手写数字识别做例子,讲解了大概的原理。
手写数字识别
展示首先数字识别项目的使用。项目实现过程:
- 训练出模型
- 准备html手写板
- flask 框架搭建简单后端



深度学习必备知识介绍
机器学习的概念
通俗解释
机器学习的关键内涵之一在于利用计算机的运算能力从大量的数据中发现一个规律,用这个规律实现预测或判断的功能。

深度学习算法分类
以算法区分深度学习应用,算法类别可分成三大类:
- 常用于图片数据进行分析处理的
卷积神经网络 - 文本分析或自然语言处理的
递归神经网络 - 常用于数据生成的
对抗神经网络
卷积神经网络(CNN)主要应用可分为图像分类、目标检测、语义分割


图片保存的本质
图片在计算机中以数字矩阵的形式存储。
https://h.markbuild.com/doc/binary-viewer-cn.html

图片的保存:

模型训练的通用步骤
模型训练的思想:

- 准备数据集
- 构建神经网络模型(面向对象中定义的一个类)
- 选择损失函数和优化器
- 训练模型
- 从模型训练得出数值
- 通过损失函数得到预测值和实际值的差距
- 通过优化器调整模型中的参数,让结果越来越准确
- 循环以上步骤
损失函数:衡量训练结果和实际偏差的函数。数值越大代表差距越大
优化器:优化模型的算法,让损失函数减小的方法
Q&A
Pytorch 手写数字识别讲解
模型训练使用pytorch框架,同样可以实现的框架还由tensorflow、keras。

数据集获取
手写识别使用的是MNIST数据集,手写数字图片。MNIST数据集由像素是28 × 28 的0~9的手写数字图片组成,一共有7万张图片,其中6万张是训练集,1万张是测试集。每个图片是黑底白字的形式。

pytorch 中提供了torchvision 包,可以通过该包可以下载数据集
import torchvision
import matplotlib.pyplot as plt
# 训练数据集
train_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=True, # 表示这是训练数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)
# 测试数据集
test_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=False, # 表示这是测试数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)
演示
模型定义
模型使用的是卷积神经网络模型。定义的神经网络模型如下:
import torch.nn as nn
# 定义卷积神经网络类
class RLS_CNN(nn.Module):
def __init__(self):
super(RLS_CNN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, # 输入、输出通道数,输出通道数可以理解为提取了几种特征
kernel_size=(3, 3), # 卷积核尺寸
stride=(1, 1), # 卷积核每次移动多少个像素
padding=1), # 原图片边缘加几个空白像素
# 输入图片尺寸为 1×28×28
# 第一次卷积,尺寸为 16×28×28
nn.MaxPool2d(kernel_size=2), # 第一次池化,尺寸为 16×14×14
nn.Conv2d(16, 32, 3, 1, 1), # 第二次卷积,尺寸为 32×14×14
nn.MaxPool2d(2), # 第二次池化,尺寸为 32×7 ×7
nn.Flatten(), # 将三维数组变成一维数组
nn.Linear(32*7*7, 16), # 变成16个卷积核,每一个卷积核是1*1,最后输出16个数字
nn.ReLU(), # 激活函数 x<0 y=0 x>0 y=x,用在反向反向传导
nn.Linear(16, 10) # 将16变成10,预测0-9之间概率值
)
def forward(self, x):
return self.net(x)
卷积神经网络模型组成

卷积神经网络通常由3个部分构成:卷积层,池化层,全连接层。各部分的功能:
- 卷积层:负责提取图像中的特征,可以输出一张图片的很多种特征。
- 池化层:用来缩小尺寸,大幅降低参数量级,降低计算量
- 全连接层:合并特征并输出结果
美颜相机的原理就是提取图片的特征,如下图片第二张模糊轮廓,第三张是突出轮廓。

卷积
卷积的功能:提取图片的多种特征信息
卷积的原理:用一个卷积核和图片的矩阵相乘,得到一个新的矩阵。新矩阵就是一个新的特征。
卷积核
卷积核也是一个矩阵,通常是33的矩阵,或者是55的矩阵。卷积运算的过程如下:

图像边缘提取
使用如下的卷积核就可以提取图像的边缘轮廓特征


调参:
卷积核矩阵由3*3一共9个参数组成,这些参数都是模型自动生成的,所谓的调参,其中一部分就是指调整卷积核矩阵的参数,让其提取的特征能够使预测更加准确
池化
池化的功能:池化就是缩小矩阵的尺寸,从而减少后续操作的参数数量。通常会在相邻的卷积层之间加入一个池化层。
池化的原理:池化的运算过程:将一个44的矩阵最大池化成22的矩阵,就是取4*4矩阵中对应区域中最大的一个数值。

池化通常有两种:
- 最大池化(max pooling):选图像区域的最大值作为该区域池化后的值。
- 平均池化(average pooling):计算图像区域的平均值作为该区域池化后的值。
全连接
全连接功能: 全连接的作用是组合特征和分类。
在前面两个步骤中从一张图片提取多种特征,并将特征矩阵进行了压缩。当数据到达全连接层时得到是一张图片的多种特征。
某一个特征并不能说整个图片是什么,否则就是盲人摸象。那么全连接层就是将多种特征组合起来形成一个完整的特征,并根据特征计算出图片是某一个类型的概率。
全连接层最终输出就是概率。比如手写数字识别,最终全连接层输出就是某一个手写数字在0~9上的概率。
tensor([[ 0.949, 3.032, 0.771, -2.173, -0.038, -0.236, 0.013, 0.614, -1.125, -2.6991]])
全连接的原理
全连接层实现的是特征组合,原理和卷积类似,也就是用一个卷积核对矩阵做运算,最后得到一个一维的数组,也就是0-9的概率。

调参:全连接的实现也需要卷积核的参与,所以卷积核矩阵也是参数的一部分,调参就包括该部分的参数。
手写数字识别的模型定义
手写数字识别的卷积神经网络,下面分析卷积+池化+全连接的过程:

Q&A
选择损失函数和优化器
损失函数功能:衡量训练结果和实际偏差的函数。数值越大代表差距越大
优化器功能:让模型不断优化,让损失函数减小的方法
手写数字识别中使用的损失函数和优化器如下:
# 交叉熵损失函数,选择一种方法计算误差值
loss_func = torch.nn.CrossEntropyLoss()
# 优化器,随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
损失函数
手写识别中选择了交叉熵损失函数,pytorch一共有19中损失函数可以使用,比较好理解的是平方差损失函数
优化器
手写识别中选了随机梯度下降算法,用来实现反向传播参数的修改。pytorch中一共有11中优化器可以使用。
模型训练
模型训练的流程:
- 定义训练的次数
- 遍历训练集,调用模型类传入图片,得到概率结果
- 通过损失函数计算损失值
- 通过优化器调整参数
- 训练完成保存模型
# 定义训练次数
cnt_epochs = 5 # 训练5个循环
# 循环训练
for cnt in range(cnt_epochs):
# 把训练集中的数据训练一遍
for imgs, labels in train_dataloader:
outputs = model(imgs) # 输出0~9预测的结果概率
loss = loss_func(outputs, labels) # 和输入做一个比较,得到一个误差
optimizer.zero_grad() # 初始化梯度,清空梯度。注意清空优化器的梯度,防止累计
loss.backward() # 方向传播计算
optimizer.step() # 累加1,执行一次
# 保存训练的结果(包括模型和参数)
torch.save(model, "my_cnn.nn")

需要注意的点:
- 训练的规律
- my_cnn.nn 模型保存的内容
Q&A
模型验证
- 模型在测试集上的准确率
- 一批模型准确率展示
总结
- 数据集非常重要。html手写识别中遇到的问题,以及如何解决。颜色,大小
- 数学知识。训练过程中遇到的数据知识:矩阵乘法
- 为什么需要GPU?如何使用GPU?
- 模型训练的过程。
卷积+池化+全连接+损失函数+优化器 - 目标检查的训练过程和手写识别有何不同?
图像分类:LeNet、AlexNet、VGG、GoogLeNet
目标检测:RCNN、Fast RCNN、Faster RCNN、YOLO、YOLOv2、SSD
Pytorch 手写数字识别 深度学习基础分享的更多相关文章
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- pytorch 手写数字识别项目 增量式训练
dataset.py ''' 准备数据集 ''' import torch from torch.utils.data import DataLoader from torchvision.datas ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- MINST手写数字识别(二)—— 卷积神经网络(CNN)
今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期.大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步. 所以,今天就来带大家写 ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- 【深度学习系列】PaddlePaddle之手写数字识别
上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
- 深度学习面试题12:LeNet(手写数字识别)
目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...
随机推荐
- 离线安装MySQL
离线安装mysql [下载地址](MySQL :: Download MySQL Community Server) 解压后依次执行如下命令 rpm -ivh mysql-community-comm ...
- 大语言模型(LLM)
大语言模型 LLM 人工智能 Artificial Intelligence 一门研究如何使计算机能够模拟和执行人类智能任务的科学和技术领域 是研究.开发用于模拟.延伸和扩展人的智能的理论.方法.技术 ...
- 深度解析Spring AI:请求与响应机制的核心逻辑
我们在前面的两个章节中基本上对Spring Boot 3版本的新变化进行了全面的回顾,以确保在接下来研究Spring AI时能够避免任何潜在的问题.今天,我们终于可以直接进入主题:Spring AI是 ...
- 分布式缓存 - 缓存服务器 - redis
如果一般的缓存可以解决问题,就不必使用分布式缓存 : 一般使用分布式缓存 都是使用 redis : 使用教程: 1. 安装包 Microsoft.Extensions.Caching.StackExc ...
- vuex 基本代码规范 js 文件
import Vue from "vue"; import Vuex from "vuex"; import { setItem, getItem } from ...
- 一、Spring Boot集成Spring Security专栏
一.Spring Boot集成Spring Security专栏 一.Spring Boot集成Spring Security之自动装配 二.实现功能及软件版本说明 使用Spring Boot集成Sp ...
- 【异常处理】Assistive Technology not found: com.sun.java.accessibility.AccessBridge
十一回来之后,工作电脑上的抓包工具Charles突然启动不起来了,双击图标后,一闪而过,就没动静了. 不知道是不是因为之前安装了什么工具.软件引起的. 打开CMD命令行,跳转到目录下启动,提示:Ass ...
- HDU-ACM 2024 Day1
T1009 数位的关系(HDU 7441) 考虑 \(l = r\) 的情况,此时只要计算一个数字,我们将其展开为一个字符串 \(S\).设 \(f_{i, j, k}\) 表示考虑了 \(S\) 的 ...
- 59 张高清大图,带你实战入门 KubeSphere DevOps
作者:运维有术星主 KubeSphere 基于 Jenkins 的 DevOps 系统是专为 Kubernetes 中的 CI/CD 工作流设计的,它提供了一站式的解决方案,帮助开发和运维团队用非常简 ...
- 云原生爱好者周刊:mist.io 开源多云管理平台
开源项目推荐 Mist Mist 是一个开源的多云管理平台,它提供了跨云和内部基础设施的可观测性,以及生命周期管理能力.同时还提供了一些功能更强大的商业组件. rga rga 是一个类似于 grep ...