卷积神经网络(LeNet)





一、导入所用库
import torch
from torch import nn
from d2l import torch as d2l

二、自定义重塑层
class Reshape(nn.Module):
def forward(self, x):
return x.view(-1, 1, 28, 28)

三、构建 LeNet 网络

net = nn.Sequential(
Reshape(), # 将输入 (batch, 784) → (batch, 1, 28, 28)
nn.Conv2d(1, 6, kernel_size=5, padding=2), # 卷积层1:输入通道 1 → 输出通道 6,卷积核 5×5,padding=2 保持宽高不变
nn.Sigmoid(), # 激活函数:Sigmoid
nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化1:kernel=2, stride=2,下采样一半
nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:6→16,kernel=5×5,默认无 padding → 尺寸缩小
nn.Sigmoid(), # Sigmoid 激活
nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化2
nn.Flatten(), # 展平:把多维特征图拉成一维向量
nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入 16×5×5 → 输出 120
nn.Sigmoid(), # Sigmoid 激活
nn.Linear(120, 84), # 全连接层2:120 → 84
nn.Sigmoid(), # Sigmoid 激活
nn.Linear(84, 10) # 输出层:84 → 10 类别
)

四、验证每层输出形状
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)


五、加载 Fashion-MNIST 数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

六、定义 GPU 下的准确率评估函数
def evaluate_accuracy_gpu(net, data_iter, device=None):
"""在 GPU 上评估模型在给定数据集上的准确率"""
if isinstance(net, nn.Module):
net.eval() # 切换到评估模式,关闭 dropout、batchnorm 等
if not device:
device = next(iter(net.parameters())).device
# metric[0] 累积正确预测数;metric[1] 累积样本总数
metric = d2l.Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
X, y = X.to(device), y.to(device)
y_hat = net(X)
metric.add(d2l.accuracy(y_hat, y), y.numel())
return metric[0] / metric[1]
七、定义训练函数(带 GPU 支持)
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
# 1. 权重初始化:对每个线性层和卷积层使用 Xavier 均匀分布初始化
def init_weights(m):
if type(m) in (nn.Linear, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
print('training on', device)
net.to(device) # 把模型参数搬到指定设备
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()
# 可视化工具:训练过程实时画图
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=['train loss', 'train acc', 'test acc'])
timer, num_batches = d2l.Timer(), len(train_iter)
# 2. 训练循环
for epoch in range(num_epochs):
# 累积训练损失、训练正确预测数、样本数
metric = d2l.Accumulator(3)
net.train() # 切回训练模式
for i, (X, y) in enumerate(train_iter):
timer.start()
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
with torch.no_grad():
metric.add(l * y.numel(), d2l.accuracy(y_hat, y), y.numel())
timer.stop()
# 每训练完一个 epoch,或者到达最后一个 batch 时更新可视化
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(metric[0] / metric[2], metric[1] / metric[2], None))
# 每个 epoch 结束后计算一次测试集准确率并更新图示
test_acc = evaluate_accuracy_gpu(net, test_iter, device)
animator.add(epoch + 1, (None, None, test_acc))
# 输出整体训练速度
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {device}')

八、运行训练
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())


九、总结

十、流程概览

1. 环境准备

2. 网络构建

3. 数据加载

4. 训练循环
for epoch in 1…N:
for 每个 batch (X, y):
1) 前向计算 ŷ = net(X)
2) 计算损失 L = Loss(ŷ, y)
3) 反向传播 L.backward()
4) 优化器更新参数 optimizer.step()
5) 累积训练损失 & 正确率
end-for
# 每跑完一个 epoch:
- 在测试集上评估一次准确率
- 把训练损失、训练准确率、测试准确率推到“动画器”里,实时画图
end-for

5. 评估与可视化

6. 通俗小结


卷积神经网络(LeNet)的更多相关文章
- TensorFlow+实战Google深度学习框架学习笔记(12)------Mnist识别和卷积神经网络LeNet
一.卷积神经网络的简述 卷积神经网络将一个图像变窄变长.原本[长和宽较大,高较小]变成[长和宽较小,高增加] 卷积过程需要用到卷积核[二维的滑动窗口][过滤器],每个卷积核由n*m(长*宽)个小格组成 ...
- 使用mxnet实现卷积神经网络LeNet
1.LeNet模型 LeNet是一个早期用来识别手写数字的卷积神经网络,这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当 ...
- 卷积神经网络LeNet Convolutional Neural Networks (LeNet)
Note This section assumes the reader has already read through Classifying MNIST digits using Logisti ...
- 卷积神经网络之LeNet
开局一张图,内容全靠编. 上图引用自 [卷积神经网络-进化史]从LeNet到AlexNet. 目前常用的卷积神经网络 深度学习现在是百花齐放,各种网络结构层出不穷,计划梳理下各个常用的卷积神经网络结构 ...
- 卷积神经网络详细讲解 及 Tensorflow实现
[附上个人git完整代码地址:https://github.com/Liuyubao/Tensorflow-CNN] [如有疑问,更进一步交流请留言或联系微信:523331232] Reference ...
- 经典卷积神经网络(LeNet、AlexNet、VGG、GoogleNet、ResNet)的实现(MXNet版本)
卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 其中 文章 详解卷 ...
- 卷积神经网络的一些经典网络(Lenet,AlexNet,VGG16,ResNet)
LeNet – 5网络 网络结构为: 输入图像是:32x32x1的灰度图像 卷积核:5x5,stride=1 得到Conv1:28x28x6 池化层:2x2,stride=2 (池化之后再经过激活函数 ...
- 从LeNet到SENet——卷积神经网络回顾
从LeNet到SENet——卷积神经网络回顾 从 1998 年经典的 LeNet,到 2012 年历史性的 AlexNet,之后深度学习进入了蓬勃发展阶段,百花齐放,大放异彩,出现了各式各样的不同网络 ...
- LeNet - Python中的卷积神经网络
本教程将 主要面向代码, 旨在帮助您 深入学习和卷积神经网络.由于这个意图,我 不会花很多时间讨论激活功能,池层或密集/完全连接的层 - 将来会有 很多教程在PyImageSearch博客上将 ...
- 深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 关于卷积神经网络CNN,网络和文献中 ...
随机推荐
- windows10 激活教程
1.环境 适用对象:VL版本的windows OEM版本请使用文末工具激活 1.1查询自己电脑版本 [win+R]->输入[slmgr /dlv]->查看[产品密钥通道] slmgr /d ...
- Traefik,想说爱你不容易:一场动态反向代理的心累之旅
前言:技术选型的初心 在微服务盛行.容器部署逐渐常态化的今天,"动态反向代理"显得尤为重要. Traefik 凭借其原生支持 Docker.自动生成路由.集成 Let's Encr ...
- Canvas上批量创建可视对象(DrawingVisual)管理,获取鼠标悬浮图形状态,并控制鼠标右键快捷菜单等...
近期公司有个新的定制,先简要说明下: 窗口上有个播放区域,区域上悬浮了很多可视对象(DrawingVisual),全部是动态生成的.... 现在的需求是在这些矩形框上需要添加右键快捷菜单... 需求知 ...
- K8s新手系列之K8s中的资源
K8s中资源的概念 在kubernetes中,所有的内容都抽象为资源,用户需要通过操作资源来管理kubernetes. kubernetes的本质上就是一个集群系统,用户可以在集群中部署各种服务,所谓 ...
- 小模型工具调用能力激活:以Qwen2.5 0.5B为例的Prompt工程实践
在之前的分析中,我们深入探讨了cline prompt的设计理念(Cline技术分析:prompt如何驱动大模型对本地文件实现自主变更),揭示了其在激发语言模型能力方面的潜力.现在,我们将这些理论付诸 ...
- windows快速开启【程序和功能】
程序和功能一般常用的操作是对软件进行卸载. 方式一: 1. Win+R打开运行 2. 输入appwiz.cpl命令 方式二: 1.Win+X打开快捷开关 2. F进去应用和功能 3.点击右侧程序和功能 ...
- 【代码】Python3|Requests 库怎么继承 Selenium 的 Headers (2024,Chrome)
本文使用的版本: Chrome 124 Python 12 Selenium 4.19.0 版本过旧可能会出现问题,但只要别差异太大,就可以看本文,因为本文对新老版本都有讲解. 文章目录 1 难点解析 ...
- 洛谷 P3792 由乃与大母神原型和偶像崇拜
洛谷 P3792 由乃与大母神原型和偶像崇拜 Problem 糖果屋的故事讲的就是韩赛尔和格雷特被继母赶出家里,因为没饭吃了,然后进了森林发现了一个糖果屋,里面有个女巫,专门吃小孩子 然而如果我们仔细 ...
- 刚刚 B站又血崩了?!我来告诉你真正原因
B 站又双叒叕崩了,这次是真炸裂了!6 月 12 日晚 9 点左右,我还在直播呢,突然就看到弹幕都在说 B 站炸了,我立马坐不住了,光速下波,作为一名前大厂程序员,就爱吃大厂的瓜,就爱吃同行的瓜,吃瓜 ...
- 提升PHP并行处理效率:深入解析数组排序算法及优化策略
本文由 ChatMoney团队出品 在 PHP 开发中,数组排序是一个常见的操作.随着互联网技术的不断发展,对数据处理速度和效率的要求越来越高,如何在保证排序质量的同时提高处理速度成为了一个值得探讨的 ...