卷积神经网络CNN实战:MINST手写数字识别——网络定义
本文基于python-pytorch框架,编写CNN网络,并采用CNN手写数字数据集训练、测试网络。
网络的构建
以LeNet-5 网络为例
类定义
首先先了解一下网络的最基本框架
- 一般而言,首先创建一个类
class,创建时,继承nn.Module父类,注意,在该类的构造函数中__init__中,显示的调用其父类的构造函数super(...).__init__() - 网络的结构,例如卷积层、线性层等,一般在其构造函数中定义。
- 对于一些不带参数的网络结构,也可以在forward方法中直接调用,而不定义,但不推荐。
- 每一个网络类必须显示的定义
forward方法,编写程序时需要在该函数中编写运算,实现对输入张量(tensor)的运算,并最后给予返回值;通常forward函数的返回值也为张量(tensor)。
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
''' Definition of Network Structure '''
def forward(self, x):
# x = f(x)
return x
卷积模块/特征提取器
self.features = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2), # 28x28x1 -> 28x28x6
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1), # 14x14x6 -> 10x10x16
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
卷积层
nn.Conv2d 是 PyTorch 中用于定义二维卷积层的一个模块。其参数配置的含义如下:
in_channels (1): 输入图像的通道数。对于单通道的灰度图像,值为 1;对于 RGB 图像,则为 3。
out_channels (6): 卷积层输出的通道数,即卷积核的数量。每个卷积核会产生一个输出通道。这里设置为 6,意味着该卷积层会生成 6 个特征图(feature maps)。
kernel_size (5): 卷积核的尺寸。这里的
5表示使用 5x5 的卷积核。这是一个正方形的卷积核大小,但也可以设置成不相等的高度和宽度,比如(5, 3)。stride (1): 卷积操作的步幅。步幅决定了卷积核在输入图像上滑动的速度。步幅为 1 表示卷积核每次滑动一个像素。
padding (2): 输入图像的边界填充。填充用于控制输出特征图的空间尺寸,通常用于保持特征图的尺寸不变或减少尺寸。这里的
2表示在每一边填充 2 个像素。
综上,nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2) 的设置表示从单通道的输入图像中提取 6 个特征图,每个特征图通过一个 5x5 大小的卷积核生成,卷积操作的步幅为 1,并且在输入图像的每一边填充 2 个像素以控制输出尺寸。
激活函数
nn.Tanh() 是 PyTorch 中的一个激活函数层,它实现了双曲正切函数(tanh)。激活函数在神经网络中用于引入非线性特性,从而使网络能够学习更复杂的模式和特征。
双曲正切函数的数学形式是:
\]
该函数的输出范围是 ([-1, 1]),具有以下特点:
非线性:tanh 是一个非线性函数,使得神经网络能够处理更复杂的任务。
输出范围:输出值范围在 -1 到 1 之间,零中心化,使得均值接近零,这可以帮助加快训练过程和提高模型的收敛速度。
梯度:tanh 函数的导数在输入接近 0 时最大(为 1),而在输入较大或较小时梯度逐渐变小,这意味着它在大输入值时可能会遇到梯度消失的问题。
在 PyTorch 中,nn.Tanh() 可以作为模型的一个层来应用于网络的前向传播中:
nn.MaxPool2d 是 PyTorch 中的一个池化层,用于对二维输入数据进行最大池化操作。池化操作常用于卷积神经网络(CNN)中,以减少特征图的空间尺寸,减少计算量和过拟合的风险,并提高模型的鲁棒性。
池化层
nn.MaxPool2d(kernel_size=2, stride=2) 的参数含义如下:
kernel_size (2): 池化窗口的尺寸。这里的
2表示池化窗口是 2x2 的正方形区域。池化操作将在每个 2x2 的区域内选取最大值。stride (2): 池化窗口在输入特征图上滑动的步幅。步幅为 2 表示池化窗口每次滑动 2 个像素。这意味着池化操作会将特征图的空间尺寸缩小为原来的一半。
具体功能:
- 最大池化:在池化窗口(2x2 区域)内选择最大值,并用该最大值替代整个窗口区域中的所有值。这样,池化层能够保留特征图中的重要信息,同时减少空间尺寸。
作用:
降维:通过池化操作减少特征图的空间尺寸(宽度和高度),从而减少计算量和内存消耗。
提高鲁棒性:池化操作可以使网络对位置的微小变化更具鲁棒性,因为它只保留局部区域的最大值。
防止过拟合:通过减少特征图的尺寸,可以降低模型的复杂性,从而减少过拟合的风险。
解释:
对于一个 4x4 的输入特征图,使用 2x2 的池化窗口和步幅为 2,池化操作会将特征图缩小为 2x2 的尺寸,每个池化窗口选择区域内的最大值。例如,在 2x2 的池化窗口内,[[1, 2], [5, 6]] 会变成 6,依此类推。
总之,nn.MaxPool2d 是卷积神经网络中常用的池化层,用于减少特征图的尺寸和计算复杂度,同时保留重要的特征信息。
线性层/分类器
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), # 全连接层1
nn.Tanh(),
nn.Linear(120, 84), # 全连接层2
nn.Tanh(),
nn.Linear(84, 10) # 输出层
)
展平函数
nn.Flatten() 是 PyTorch 中用于将多维张量展平成一维张量的模块。这个操作通常在卷积层(Convolutional Layers)和线性层(Linear Layers)之间使用,以便将卷积层输出的多维特征图转换成适合于线性层处理的一维特征向量。
nn.Flatten() 的主要作用是将输入张量从多维转换为一维。例如,对于形状为 (N, C, H, W) 的输入张量,使用 nn.Flatten() 后,输出的张量将变为形状为 (N, C * H * W) 的一维张量。
为什么在卷积层和线性层之间使用 nn.Flatten()
卷积层的输出是多维的:卷积层生成的输出通常是一个四维张量,表示批量的特征图,其中包含多个通道的二维特征图。为了将这些特征图传递到线性层,必须将其展平成一维张量,因为线性层要求输入为一维特征向量。
线性层的输入是一维的:线性层(也称全连接层)只能处理一维的输入数据。通过
nn.Flatten(),可以将卷积层的多维输出展平为一维,从而可以将其作为线性层的输入。连接卷积和线性层:卷积层通常用于提取特征,而线性层则用于对这些特征进行分类或回归等任务。在这些任务中,线性层处理的是扁平化的特征向量,因此需要将卷积层的输出展平。
总结
nn.Flatten() 在卷积神经网络(CNN)的前向传播中充当了重要的角色,它将卷积层的多维特征图展平为线性层所需的一维特征向量。这使得卷积层提取的复杂特征可以被线性层进一步处理,从而完成分类、回归等任务。
nn.Linear 的基本概念
nn.Linear 是 PyTorch 中的一个模块,用于实现线性变换,也称为全连接层(Fully Connected Layer,FC Layer)。它将输入的特征通过一个线性变换映射到输出特征。
参数解释
输入特征的数量 (
in_features):- 在
nn.Linear(16 * 5 * 5, 120)中,第一个参数16 * 5 * 5表示输入特征的数量。 16通常表示通道数,5 * 5是特征图的高度和宽度。这里的计算表示输入的特征总数为16 * 5 * 5 = 400。- 卷积层的输出经过
Flatten层处理后,变为一维向量。
- 在
输出特征的数量 (
out_features):- 第二个参数
120表示输出特征的数量。在这个例子中,模型将输出一个长度为 120 的一维向量。
- 第二个参数
线性层的功能
线性层的功能可以用数学公式表示为:
\]
- 其中 \(y\) 是输出,\(A\) 是权重矩阵,\(x\) 是输入,\(b\) 是偏差项。
在使用 nn.Linear(16 * 5 * 5, 120) 时,PyTorch 会自动创建一个形状为 (120, 400) 的权重矩阵和一个形状为 (120,) 的偏差向量。权重和偏差会在训练过程中进行学习和优化。
总结
nn.Linear(16 * 5 * 5, 120)用于定义一个线性层,它将输入的 400 维特征向量映射到 120 维输出向量。- 此层通常用于将卷积层提取的特征连接到分类器或其他层,以形成完整的神经网络架构。
代码汇总
import torch.nn as nn
# LeNet-5
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2), # 28x28x1 -> 28x28x6
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2), # 池化层
nn.Conv2d(6, 16, kernel_size=5, stride=1), # 14x14x6 -> 10x10x16
nn.Tanh(),
nn.MaxPool2d(kernel_size=2, stride=2) # 池化层
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), # 全连接层1
nn.Tanh(),
nn.Linear(120, 84), # 全连接层2
nn.Tanh(),
nn.Linear(84, 10) # 输出层
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
卷积神经网络CNN实战:MINST手写数字识别——网络定义的更多相关文章
- 卷积神经网络应用于tensorflow手写数字识别(第三版)
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...
- MINST手写数字识别(三)—— 使用antirectifier替换ReLU激活函数
这是一个来自官网的示例:https://github.com/keras-team/keras/blob/master/examples/antirectifier.py 与之前的MINST手写数字识 ...
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- MINST手写数字识别(一)—— 全连接网络
这是一个简单快速入门教程——用Keras搭建神经网络实现手写数字识别,它大部分基于Keras的源代码示例 minst_mlp.py. 1.安装依赖库 首先,你需要安装最近版本的Python,再加上一些 ...
- MINST手写数字识别(二)—— 卷积神经网络(CNN)
今天我们的主角是keras,其简洁性和易用性简直出乎David 9我的预期.大家都知道keras是在TensorFlow上又包装了一层,向简洁易用的深度学习又迈出了坚实的一步. 所以,今天就来带大家写 ...
- 【TensorFlow-windows】(四) CNN(卷积神经网络)进行手写数字识别(mnist)
主要内容: 1.基于CNN的mnist手写数字识别(详细代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3-4.2.0-Windows-x86_64. ...
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
随机推荐
- 如何用matlab求隐式函数的导数
如何用matlab求隐式函数的导数 隐函数求导的例子 假设有一个圆 \(x^2+y^2=5\) , 要求在某个点上的切线的斜率. 我们可以把式\(x^2+y^2=5\)中的每一项对\(x\)求导, 可 ...
- 配置h5py、netCDF4库的方法:Anaconda环境
本文介绍基于Anaconda环境,下载并安装Python中h5py与netCDF4这两个模块的方法. 在Python语言中,h5py与netCDF4这两个模块是与遥感图像处理.地学分析等GIS ...
- SQL注入攻击及防御
SQL注入攻击及防御 1.项目实验环境 目标靶机OWASP_Broken_Web_App_VM_1.2: https://sourceforge.net/projects/owaspbwa/files ...
- 有手就会的 Java 处理压缩文件
@ 目录 前言 背景 第一步:编写代码 1.1 请求层 1.2 业务处理层 1.3 新增配置 第二步:解压缩处理 2.1 引入依赖 2.2 解压缩工具类 总结 前言 请各大网友尊重本人原创知识分享,谨 ...
- NKCTF 2023 Misc
NKCTF 2023 Misc hard-misc base32 --> N0wayBack公众号回复:NKCTF2023我来了! 得到flag:NKCTF{wtk2023Oo0oImcoM1N ...
- jQuery -- 手稿
- window10设置开机自启动exe的三种方式(亲测有效)
拷贝文件到自启动位置 路径地址:C:\ProgramData\Microsoft\Windows\Start Menu\Programs\StartUp 通过组策略设置脚本随服务器启动 开始-> ...
- 开源免费又好用的中式数据报表:UReport2是一款高性能的架构在Spring之上纯Java报表引擎,通过迭代单元格可以实现任意复杂的中国式报表。
北润乾.南帆软,数加发力在云端. uReport 身何安?中式报表真开源. 报表江湖之中,uReport安身立命的产品品类定位是什么? 说来很简单,uReport的价值在于填补了这样一个市场空白:开源 ...
- [oeasy]python0115_西里尔字符集_Cyrillic_俄文字符编码_KOI_8859系列
各语言字符编码 回忆上次内容 上次回顾了 非ascii的拉丁字符编码的进化过程 0-127 是 ascii 的领域 西欧.北欧语言 大多使用 拉丁字符 由iso组织 制定iso-8859-1 ...
- [rCore学习笔记 016]实现应用程序
写在前面 本随笔是非常菜的菜鸡写的.如有问题请及时提出. 可以联系:1160712160@qq.com GitHhub:https://github.com/WindDevil (目前啥也没有 设计方 ...