本篇是一次内部分享,给项目开发的同事分享什么是深度学习。用最简单的手写数字识别做例子,讲解了大概的原理。

手写数字识别

展示首先数字识别项目的使用。项目实现过程:

  1. 训练出模型
  2. 准备html手写板
  3. flask 框架搭建简单后端





深度学习必备知识介绍

机器学习的概念

通俗解释

机器学习的关键内涵之一在于利用计算机的运算能力从大量的数据中发现一个规律,用这个规律实现预测或判断的功能。

深度学习算法分类

以算法区分深度学习应用,算法类别可分成三大类:

  • 常用于图片数据进行分析处理的卷积神经网络
  • 文本分析或自然语言处理的递归神经网络
  • 常用于数据生成的对抗神经网络

卷积神经网络(CNN)主要应用可分为图像分类、目标检测、语义分割

图片保存的本质

图片在计算机中以数字矩阵的形式存储。

https://h.markbuild.com/doc/binary-viewer-cn.html

图片的保存:

模型训练的通用步骤

模型训练的思想:

  1. 准备数据集
  2. 构建神经网络模型(面向对象中定义的一个类)
  3. 选择损失函数和优化器
  4. 训练模型
    • 从模型训练得出数值
    • 通过损失函数得到预测值和实际值的差距
    • 通过优化器调整模型中的参数,让结果越来越准确
    • 循环以上步骤

损失函数:衡量训练结果和实际偏差的函数。数值越大代表差距越大

优化器:优化模型的算法,让损失函数减小的方法

Q&A


Pytorch 手写数字识别讲解

模型训练使用pytorch框架,同样可以实现的框架还由tensorflow、keras。

数据集获取

手写识别使用的是MNIST数据集,手写数字图片。MNIST数据集由像素是28 × 28 的0~9的手写数字图片组成,一共有7万张图片,其中6万张是训练集,1万张是测试集。每个图片是黑底白字的形式。

pytorch 中提供了torchvision 包,可以通过该包可以下载数据集

import torchvision
import matplotlib.pyplot as plt # 训练数据集
train_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=True, # 表示这是训练数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
) # 测试数据集
test_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=False, # 表示这是测试数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)

演示


模型定义

模型使用的是卷积神经网络模型。定义的神经网络模型如下:

import torch.nn as nn

# 定义卷积神经网络类
class RLS_CNN(nn.Module):
def __init__(self):
super(RLS_CNN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, # 输入、输出通道数,输出通道数可以理解为提取了几种特征
kernel_size=(3, 3), # 卷积核尺寸
stride=(1, 1), # 卷积核每次移动多少个像素
padding=1), # 原图片边缘加几个空白像素
# 输入图片尺寸为 1×28×28
# 第一次卷积,尺寸为 16×28×28
nn.MaxPool2d(kernel_size=2), # 第一次池化,尺寸为 16×14×14
nn.Conv2d(16, 32, 3, 1, 1), # 第二次卷积,尺寸为 32×14×14
nn.MaxPool2d(2), # 第二次池化,尺寸为 32×7 ×7
nn.Flatten(), # 将三维数组变成一维数组
nn.Linear(32*7*7, 16), # 变成16个卷积核,每一个卷积核是1*1,最后输出16个数字
nn.ReLU(), # 激活函数 x<0 y=0 x>0 y=x,用在反向反向传导
nn.Linear(16, 10) # 将16变成10,预测0-9之间概率值
) def forward(self, x):
return self.net(x)

卷积神经网络模型组成

卷积神经网络通常由3个部分构成:卷积层,池化层,全连接层。各部分的功能:

  • 卷积层:负责提取图像中的特征,可以输出一张图片的很多种特征。
  • 池化层:用来缩小尺寸,大幅降低参数量级,降低计算量
  • 全连接层:合并特征并输出结果

美颜相机的原理就是提取图片的特征,如下图片第二张模糊轮廓,第三张是突出轮廓。

卷积

卷积的功能:提取图片的多种特征信息

卷积的原理:用一个卷积核和图片的矩阵相乘,得到一个新的矩阵。新矩阵就是一个新的特征。

卷积核

卷积核也是一个矩阵,通常是33的矩阵,或者是55的矩阵。卷积运算的过程如下:

图像边缘提取

使用如下的卷积核就可以提取图像的边缘轮廓特征

调参

卷积核矩阵由3*3一共9个参数组成,这些参数都是模型自动生成的,所谓的调参,其中一部分就是指调整卷积核矩阵的参数,让其提取的特征能够使预测更加准确

池化

池化的功能:池化就是缩小矩阵的尺寸,从而减少后续操作的参数数量。通常会在相邻的卷积层之间加入一个池化层。

池化的原理:池化的运算过程:将一个44的矩阵最大池化成22的矩阵,就是取4*4矩阵中对应区域中最大的一个数值。

池化通常有两种:

  • 最大池化(max pooling):选图像区域的最大值作为该区域池化后的值。
  • 平均池化(average pooling):计算图像区域的平均值作为该区域池化后的值。

全连接

全连接功能: 全连接的作用是组合特征分类

在前面两个步骤中从一张图片提取多种特征,并将特征矩阵进行了压缩。当数据到达全连接层时得到是一张图片的多种特征。

某一个特征并不能说整个图片是什么,否则就是盲人摸象。那么全连接层就是将多种特征组合起来形成一个完整的特征,并根据特征计算出图片是某一个类型的概率。

全连接层最终输出就是概率。比如手写数字识别,最终全连接层输出就是某一个手写数字在0~9上的概率。

tensor([[ 0.949,  3.032,  0.771, -2.173, -0.038, -0.236,  0.013,  0.614, -1.125, -2.6991]])

全连接的原理

全连接层实现的是特征组合,原理和卷积类似,也就是用一个卷积核对矩阵做运算,最后得到一个一维的数组,也就是0-9的概率。

调参:全连接的实现也需要卷积核的参与,所以卷积核矩阵也是参数的一部分,调参就包括该部分的参数。

手写数字识别的模型定义

手写数字识别的卷积神经网络,下面分析卷积+池化+全连接的过程:

Q&A


选择损失函数和优化器

损失函数功能:衡量训练结果和实际偏差的函数。数值越大代表差距越大

优化器功能:让模型不断优化,让损失函数减小的方法

手写数字识别中使用的损失函数和优化器如下:

# 交叉熵损失函数,选择一种方法计算误差值
loss_func = torch.nn.CrossEntropyLoss() # 优化器,随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)

损失函数

手写识别中选择了交叉熵损失函数,pytorch一共有19中损失函数可以使用,比较好理解的是平方差损失函数

优化器

手写识别中选了随机梯度下降算法,用来实现反向传播参数的修改。pytorch中一共有11中优化器可以使用。

模型训练

模型训练的流程:

  1. 定义训练的次数
  2. 遍历训练集,调用模型类传入图片,得到概率结果
  3. 通过损失函数计算损失值
  4. 通过优化器调整参数
  5. 训练完成保存模型
# 定义训练次数
cnt_epochs = 5 # 训练5个循环 # 循环训练
for cnt in range(cnt_epochs):
# 把训练集中的数据训练一遍
for imgs, labels in train_dataloader:
outputs = model(imgs) # 输出0~9预测的结果概率
loss = loss_func(outputs, labels) # 和输入做一个比较,得到一个误差
optimizer.zero_grad() # 初始化梯度,清空梯度。注意清空优化器的梯度,防止累计
loss.backward() # 方向传播计算
optimizer.step() # 累加1,执行一次 # 保存训练的结果(包括模型和参数)
torch.save(model, "my_cnn.nn")

需要注意的点:

  • 训练的规律
  • my_cnn.nn 模型保存的内容

Q&A

模型验证

  • 模型在测试集上的准确率
  • 一批模型准确率展示

总结

  1. 数据集非常重要。html手写识别中遇到的问题,以及如何解决。颜色,大小
  2. 数学知识。训练过程中遇到的数据知识:矩阵乘法
  3. 为什么需要GPU?如何使用GPU?
  4. 模型训练的过程。卷积 + 池化 + 全连接 + 损失函数 + 优化器
  5. 目标检查的训练过程和手写识别有何不同?

    图像分类:LeNet、AlexNet、VGG、GoogLeNet

    目标检测:RCNN、Fast RCNN、Faster RCNN、YOLO、YOLOv2、SSD

Pytorch 手写数字识别 深度学习基础分享的更多相关文章

  1. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  2. pytorch 手写数字识别项目 增量式训练

    dataset.py ''' 准备数据集 ''' import torch from torch.utils.data import DataLoader from torchvision.datas ...

  3. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  4. MINST手写数字识别(二)—— 卷积神经网络(CNN)

    今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期.大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步. 所以,今天就来带大家写 ...

  5. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  6. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  7. MindSpore手写数字识别初体验,深度学习也没那么神秘嘛

    摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...

  8. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

  9. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  10. 深度学习面试题12:LeNet(手写数字识别)

    目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...

随机推荐

  1. SQL server temporal table 学习笔记

    refer: https://blog.csdn.net/Hehuyi_In/article/details/89670462 https://docs.microsoft.com/en-us/sql ...

  2. HDLC报文简单分析

    最近在学习HDLC协议,从刚开始的一窍不通到现在的懵懵懂懂,下面分享一段报文解析,给初学者一点点经验的分析. 报文:7E A0 57 03 02 B8 4B 5B E6 E7 00 C4 01 C1 ...

  3. word在原有的方框里打勾

    按住键盘上的ALT键不放,然后在小键盘区输入"9745"这几个数字,最后松开 ALT 键,自动变成框框中带勾符号.

  4. 【赵渝强老师】大数据分析引擎:Presto

    一.什么是Presto? 背景知识:Hive的缺点和Presto的背景 Hive使用MapReduce作为底层计算框架,是专为批处理设计的.但随着数据越来越多,使用Hive进行一个简单的数据查询可能要 ...

  5. LeetCode 1388. Pizza With 3n Slices(3n 块披萨)(DP)

    给你一个披萨,它由 3n 块不同大小的部分组成,现在你和你的朋友们需要按照如下规则来分披萨: 你挑选 任意 一块披萨.Alice 将会挑选你所选择的披萨逆时针方向的下一块披萨.Bob 将会挑选你所选择 ...

  6. (系列五).net8 中使用Dapper搭建底层仓储连接数据库(附源码)

    说明 该文章是属于OverallAuth2.0系列文章,每周更新一篇该系列文章(从0到1完成系统开发). 该系统文章,我会尽量说的非常详细,做到不管新手.老手都能看懂. 说明:OverallAuth2 ...

  7. 逆向WeChat(七)

    上篇介绍了如何通过嗅探MojoIPC抓包小程序的HTTPS数据. 本篇逆向微信客户端本地数据库相关事宜. 本篇在博客园地址https://www.cnblogs.com/bbqzsl/p/184235 ...

  8. Windows 10 LTSC 2019(1809) WSL 安装 CentOS 7

    1.安装WSL    通过控制面板--程序和功能--启用或关闭WIndows功能,勾选"适用于Linux的Windows子系统".    或者通过管理员权限打开 PowerShel ...

  9. 14 Positional Encoding (为什么 Self-Attention 需要位置编码)

    博客配套视频链接: https://space.bilibili.com/383551518?spm_id_from=333.1007.0.0 b 站直接看 配套 github 链接:https:// ...

  10. 通过maven动态配置spring boot配置文件

    一.引入maven插件的jar包 <plugin> <groupId>org.apache.maven.plugins</groupId> <artifact ...