博客地址:https://www.cnblogs.com/zylyehuo/









一、导入所用库

import torch
from torch import nn
from d2l import torch as d2l

二、自定义重塑层

class Reshape(nn.Module):
def forward(self, x):
return x.view(-1, 1, 28, 28)

三、构建 LeNet 网络

net = nn.Sequential(
Reshape(), # 将输入 (batch, 784) → (batch, 1, 28, 28)
nn.Conv2d(1, 6, kernel_size=5, padding=2), # 卷积层1:输入通道 1 → 输出通道 6,卷积核 5×5,padding=2 保持宽高不变
nn.Sigmoid(), # 激活函数:Sigmoid
nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化1:kernel=2, stride=2,下采样一半
nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:6→16,kernel=5×5,默认无 padding → 尺寸缩小
nn.Sigmoid(), # Sigmoid 激活
nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化2
nn.Flatten(), # 展平:把多维特征图拉成一维向量
nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入 16×5×5 → 输出 120
nn.Sigmoid(), # Sigmoid 激活
nn.Linear(120, 84), # 全连接层2:120 → 84
nn.Sigmoid(), # Sigmoid 激活
nn.Linear(84, 10) # 输出层:84 → 10 类别
)

四、验证每层输出形状

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
X = layer(X)
print(layer.__class__.__name__, 'output shape:\t', X.shape)



五、加载 Fashion-MNIST 数据

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

六、定义 GPU 下的准确率评估函数

def evaluate_accuracy_gpu(net, data_iter, device=None):
"""在 GPU 上评估模型在给定数据集上的准确率"""
if isinstance(net, nn.Module):
net.eval() # 切换到评估模式,关闭 dropout、batchnorm 等
if not device:
device = next(iter(net.parameters())).device
# metric[0] 累积正确预测数;metric[1] 累积样本总数
metric = d2l.Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
X, y = X.to(device), y.to(device)
y_hat = net(X)
metric.add(d2l.accuracy(y_hat, y), y.numel())
return metric[0] / metric[1]

七、定义训练函数(带 GPU 支持)

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
# 1. 权重初始化:对每个线性层和卷积层使用 Xavier 均匀分布初始化
def init_weights(m):
if type(m) in (nn.Linear, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights) print('training on', device)
net.to(device) # 把模型参数搬到指定设备
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss() # 可视化工具:训练过程实时画图
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
legend=['train loss', 'train acc', 'test acc']) timer, num_batches = d2l.Timer(), len(train_iter)
# 2. 训练循环
for epoch in range(num_epochs):
# 累积训练损失、训练正确预测数、样本数
metric = d2l.Accumulator(3)
net.train() # 切回训练模式
for i, (X, y) in enumerate(train_iter):
timer.start()
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
with torch.no_grad():
metric.add(l * y.numel(), d2l.accuracy(y_hat, y), y.numel())
timer.stop()
# 每训练完一个 epoch,或者到达最后一个 batch 时更新可视化
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,
(metric[0] / metric[2], metric[1] / metric[2], None))
# 每个 epoch 结束后计算一次测试集准确率并更新图示
test_acc = evaluate_accuracy_gpu(net, test_iter, device)
animator.add(epoch + 1, (None, None, test_acc))
# 输出整体训练速度
print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {device}')

八、运行训练

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())



九、总结

十、流程概览

1. 环境准备

2. 网络构建

3. 数据加载

4. 训练循环

for epoch in 1…N:
for 每个 batch (X, y):
1) 前向计算 ŷ = net(X)
2) 计算损失 L = Loss(ŷ, y)
3) 反向传播 L.backward()
4) 优化器更新参数 optimizer.step()
5) 累积训练损失 & 正确率
end-for # 每跑完一个 epoch:
- 在测试集上评估一次准确率
- 把训练损失、训练准确率、测试准确率推到“动画器”里,实时画图
end-for

5. 评估与可视化

6. 通俗小结



卷积神经网络(LeNet)的更多相关文章

  1. TensorFlow+实战Google深度学习框架学习笔记(12)------Mnist识别和卷积神经网络LeNet

    一.卷积神经网络的简述 卷积神经网络将一个图像变窄变长.原本[长和宽较大,高较小]变成[长和宽较小,高增加] 卷积过程需要用到卷积核[二维的滑动窗口][过滤器],每个卷积核由n*m(长*宽)个小格组成 ...

  2. 使用mxnet实现卷积神经网络LeNet

    1.LeNet模型 LeNet是一个早期用来识别手写数字的卷积神经网络,这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当 ...

  3. 卷积神经网络LeNet Convolutional Neural Networks (LeNet)

    Note This section assumes the reader has already read through Classifying MNIST digits using Logisti ...

  4. 卷积神经网络之LeNet

    开局一张图,内容全靠编. 上图引用自 [卷积神经网络-进化史]从LeNet到AlexNet. 目前常用的卷积神经网络 深度学习现在是百花齐放,各种网络结构层出不穷,计划梳理下各个常用的卷积神经网络结构 ...

  5. 卷积神经网络详细讲解 及 Tensorflow实现

    [附上个人git完整代码地址:https://github.com/Liuyubao/Tensorflow-CNN] [如有疑问,更进一步交流请留言或联系微信:523331232] Reference ...

  6. 经典卷积神经网络(LeNet、AlexNet、VGG、GoogleNet、ResNet)的实现(MXNet版本)

    卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现. 其中 文章 详解卷 ...

  7. 卷积神经网络的一些经典网络(Lenet,AlexNet,VGG16,ResNet)

    LeNet – 5网络 网络结构为: 输入图像是:32x32x1的灰度图像 卷积核:5x5,stride=1 得到Conv1:28x28x6 池化层:2x2,stride=2 (池化之后再经过激活函数 ...

  8. 从LeNet到SENet——卷积神经网络回顾

    从LeNet到SENet——卷积神经网络回顾 从 1998 年经典的 LeNet,到 2012 年历史性的 AlexNet,之后深度学习进入了蓬勃发展阶段,百花齐放,大放异彩,出现了各式各样的不同网络 ...

  9. LeNet - Python中的卷积神经网络

    本教程将  主要面向代码,  旨在帮助您 深入学习和卷积神经网络.由于这个意图,我  不会花很多时间讨论激活功能,池层或密集/完全连接的层 - 将来会有  很多教程在PyImageSearch博客上将 ...

  10. 深度学习方法(五):卷积神经网络CNN经典模型整理Lenet,Alexnet,Googlenet,VGG,Deep Residual Learning

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 关于卷积神经网络CNN,网络和文献中 ...

随机推荐

  1. windows10 激活教程

    1.环境 适用对象:VL版本的windows OEM版本请使用文末工具激活 1.1查询自己电脑版本 [win+R]->输入[slmgr /dlv]->查看[产品密钥通道] slmgr /d ...

  2. Traefik,想说爱你不容易:一场动态反向代理的心累之旅

    前言:技术选型的初心 在微服务盛行.容器部署逐渐常态化的今天,"动态反向代理"显得尤为重要. Traefik 凭借其原生支持 Docker.自动生成路由.集成 Let's Encr ...

  3. Canvas上批量创建可视对象(DrawingVisual)管理,获取鼠标悬浮图形状态,并控制鼠标右键快捷菜单等...

    近期公司有个新的定制,先简要说明下: 窗口上有个播放区域,区域上悬浮了很多可视对象(DrawingVisual),全部是动态生成的.... 现在的需求是在这些矩形框上需要添加右键快捷菜单... 需求知 ...

  4. K8s新手系列之K8s中的资源

    K8s中资源的概念 在kubernetes中,所有的内容都抽象为资源,用户需要通过操作资源来管理kubernetes. kubernetes的本质上就是一个集群系统,用户可以在集群中部署各种服务,所谓 ...

  5. 小模型工具调用能力激活:以Qwen2.5 0.5B为例的Prompt工程实践

    在之前的分析中,我们深入探讨了cline prompt的设计理念(Cline技术分析:prompt如何驱动大模型对本地文件实现自主变更),揭示了其在激发语言模型能力方面的潜力.现在,我们将这些理论付诸 ...

  6. windows快速开启【程序和功能】

    程序和功能一般常用的操作是对软件进行卸载. 方式一: 1. Win+R打开运行 2. 输入appwiz.cpl命令 方式二: 1.Win+X打开快捷开关 2. F进去应用和功能 3.点击右侧程序和功能 ...

  7. 【代码】Python3|Requests 库怎么继承 Selenium 的 Headers (2024,Chrome)

    本文使用的版本: Chrome 124 Python 12 Selenium 4.19.0 版本过旧可能会出现问题,但只要别差异太大,就可以看本文,因为本文对新老版本都有讲解. 文章目录 1 难点解析 ...

  8. 洛谷 P3792 由乃与大母神原型和偶像崇拜

    洛谷 P3792 由乃与大母神原型和偶像崇拜 Problem 糖果屋的故事讲的就是韩赛尔和格雷特被继母赶出家里,因为没饭吃了,然后进了森林发现了一个糖果屋,里面有个女巫,专门吃小孩子 然而如果我们仔细 ...

  9. 刚刚 B站又血崩了?!我来告诉你真正原因

    B 站又双叒叕崩了,这次是真炸裂了!6 月 12 日晚 9 点左右,我还在直播呢,突然就看到弹幕都在说 B 站炸了,我立马坐不住了,光速下波,作为一名前大厂程序员,就爱吃大厂的瓜,就爱吃同行的瓜,吃瓜 ...

  10. 提升PHP并行处理效率:深入解析数组排序算法及优化策略

    本文由 ChatMoney团队出品 在 PHP 开发中,数组排序是一个常见的操作.随着互联网技术的不断发展,对数据处理速度和效率的要求越来越高,如何在保证排序质量的同时提高处理速度成为了一个值得探讨的 ...