3-7softmax回归的简洁实现

import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

1.初始化模型参数

代码功能总结

  1. 定义模型

    • 定义了一个简单的神经网络模型,包含一个展平层和一个全连接层。
    • 展平层将输入的二维图像数据展平为一维张量。
    • 全连接层将展平后的特征映射到 10 个输出类别。
  2. 定义权重初始化函数
    • 定义了一个自定义的权重初始化函数 init_weights,用于初始化 nn.Linear 层的权重。
    • 使用正态分布初始化权重,标准差为 0.01。
  3. 应用权重初始化函数
    • 使用 net.apply(init_weights) 将自定义的权重初始化函数应用于模型的所有层。
    • 递归地检查每一层,如果是 nn.Linear 类型,则初始化其权重。

使用场景

这段代码通常用于在训练神经网络之前对模型的权重进行初始化。合理的权重初始化可以加速模型的收敛速度,并提高模型的性能。

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状 # nn.Linear(784, 10) 是一个全连接层,将输入的 784 维特征映射到 10 个输出类别
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m):
if type(m) == nn.Linear:
# nn.init_normal_(m.weight, std = 0.01):使用正态分布初始化权重,均值为 0,标准差为 0.01。
nn.init.normal_(m.weight, std = 0.01) #net.apply(init_weights) 是 PyTorch 中用于对模型的所有层应用自定义函数的方法
# init_weights 函数会被递归地应用于 net 中的每一层
net.apply(init_weights)
Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=784, out_features=10, bias=True)
)

2.重新审视softmax的实现

在 PyTorch 中,nn.CrossEntropyLoss 是一个常用的损失函数,用于多分类问题。它结合了 nn.LogSoftmaxnn.NLLLoss(负对数似然损失)的功能,适用于分类任务中计算模型输出与真实标签之间的损失。

1. nn.CrossEntropyLoss

  • nn.CrossEntropyLoss 是 PyTorch 提供的交叉熵损失函数。
  • 它适用于多分类问题,其中模型的输出是一个概率分布(通常是通过 softmax 函数得到的),而真实标签是一个类别索引。

2. reduction='none'

  • reduction 参数控制损失函数的输出形式。
  • 默认情况下,reduction='mean',表示对所有样本的损失值取平均。
  • 设置 reduction='none' 表示不对损失值进行任何聚合,返回每个样本的损失值。

参数解释

  • reduction

    • 'none':返回每个样本的损失值,形状与输入的 y 相同。
    • 'mean':返回所有样本的平均损失值。
    • 'sum':返回所有样本的损失值之和。
loss = nn.CrossEntropyLoss(reduction='none')

3.优化算法

在 PyTorch 中,torch.optim.SGD 是一个用于实现随机梯度下降(Stochastic Gradient Descent, SGD)优化算法的类。这段代码创建了一个 SGD 优化器实例,用于更新神经网络的参数。以下是对代码的详细解释:

代码解析

1. torch.optim.SGD

  • torch.optim.SGD 是 PyTorch 提供的随机梯度下降优化器。
  • 它用于在训练过程中更新模型的参数,以最小化损失函数。

2. net.parameters()

  • net.parameters() 是一个生成器,返回模型 net 中所有可训练的参数(如权重和偏置)。
  • 这些参数是优化器需要更新的对象。

3. lr=0.1

  • lr 是学习率(learning rate),控制参数更新的步长。
  • 学习率是一个超参数,决定了每次参数更新的幅度。
  • 在这里,学习率被设置为 0.1,表示每次参数更新的步长为 0.1

优化器的作用

优化器的作用是在训练过程中根据梯度信息更新模型的参数。SGD 优化器的具体更新规则如下: θnew=θold−lr×∇L 其中:

  • θnew 是更新后的参数。
  • θold 是更新前的参数。
  • lr 是学习率。
  • L 是损失函数 L 对参数 θ 的梯度。
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

4.训练

这段代码调用了 d2l.train_ch3 函数来训练一个神经网络模型。d2l.train_ch3 是一个封装好的训练函数,通常在 D2L(Dive into Deep Learning)库中定义,用于简化训练过程。以下是对代码的详细解释:

代码解析

1. 设置训练轮数

  • num_epochs 是一个整数,表示训练模型的总轮数(epoch)。
  • 在这个例子中,模型将训练 10 轮。

2. 调用训练函数

  • d2l.train_ch3 是一个封装好的训练函数,用于训练神经网络模型。
  • 它的参数包括:
    • net:模型网络。
    • train_iter:训练数据迭代器。
    • test_iter:测试数据迭代器。
    • loss:损失函数。
    • num_epochs:训练的总轮数。
    • trainer:优化器。

d2l.train_ch3 函数的内部逻辑

虽然我们没有看到 d2l.train_ch3 的具体实现,但根据其功能描述,它通常会执行以下步骤:

  1. 初始化动画对象

    • 用于动态绘制训练过程中的损失和准确率。
  2. 训练循环
    • 遍历每个训练轮数(num_epochs)。
    • 在每个轮数中,对训练数据进行一次完整的训练,并计算训练损失和训练准确率。
    • 在每个轮数中,对测试数据进行评估,计算测试准确率。
    • 将训练损失、训练准确率和测试准确率添加到动画中,动态绘制训练过程。
  3. 断言检查
    • 检查训练损失是否小于某个阈值(如 0.5)。
    • 检查训练准确率和测试准确率是否在合理范围内(如大于 0.7)。
num_epochs = 10  #训练的总轮数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

3-7softmax回归的简洁实现的更多相关文章

  1. 动手学深度学习8-softmax分类pytorch简洁实现

    定义和初始化模型 softamx和交叉熵损失函数 定义优化算法 训练模型 import torch from torch import nn from torch.nn import init imp ...

  2. 动手学深度学习14- pytorch Dropout 实现与原理

    方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...

  3. 动手学深度学习10- pytorch多层感知机从零实现

    多层感知机 定义模型的参数 定义激活函数 定义模型 定义损失函数 训练模型 小结 多层感知机 import torch import numpy as np import sys sys.path.a ...

  4. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  5. 小匠第一周期打卡笔记-Task01

    一.线性回归 知识点记录 线性回归输出是一个连续值,因此适用于回归问题.如预测房屋价格.气温.销售额等连续值的问题.是单层神经网络. 线性判别模型 判别模型 性质:建模预测变量和观测变量之间的关系,亦 ...

  6. L3 多层感知机

    **本小节用到的数据下载 1.涉及语句 import d2lzh1981 as d2l 数据1 : d2lzh1981 链接:https://pan.baidu.com/s/1LyaZ84Q4M75G ...

  7. pytorch和tensorflow的爱恨情仇之定义可训练的参数

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...

  8. 动手学深度学习4-线性回归的pytorch简洁实现

    导入同样导入之前的包或者模块 生成数据集 通过pytorch读取数据 定义模型 初始化模型 定义损失函数 定义优化算法 训练模型 小结 本节利用pytorch中的模块,生成一个更加简洁的代码来实现同样 ...

  9. 原生js实现简洁的返回顶部组件

    本文内容相当简单,所以没有发布到博客园首页,如果你不幸看到,那只能是我这篇文章的荣幸,谢谢你的大驾光临~(本博客返回顶部的功能就使用的是这个组件) 返回顶部组件是一种极其常见的网页功能,需求简单:页面 ...

  10. 看大众点评V9新版如何为O2O止血 带领行业下半场回归理性

    前不久,美团点评CEO王兴提出的“中国互联网进入下半场”观点一直在持续发酵,并引发了整个互联网圈对于进入下半场该如何改革,如何迎战的深刻反思.在互联网的上半场,大家依托的是人口红利,但是到了下半场,用 ...

随机推荐

  1. 实现领域驱动设计 - 使用ABP框架 - 领域逻辑 & 应用逻辑

    领域逻辑 & 应用逻辑 如前所述,领域驱动设计中的业务逻辑分为两部分(层):领域逻辑和应用逻辑: 领域逻辑由系统的核心领域规则组成,应用逻辑实现应用特定的用例 虽然定义很明确,但实现起来可能并 ...

  2. 深入掌握Map的这8个操作方法,让代码更简洁优雅

    Map 是我们经常使用的数据结构接口,它的子类 HashMap.ConcurrentHashMap 也是我们使用比较频繁的集合. 了解了 Map 接口中的方法,也就相当于知道了其子类中的可用方法,管它 ...

  3. Spring框架中的单例bean是线程安全的吗?

    1.介绍两个概念 有状态的bean:对象中有实例变量(成员变量),可以保存数据,是非线程安全的 无状态的bean:对象中没有实例变量(成员变量),不能保存数据,可以在多线程环境下共享,是线程安全的 2 ...

  4. Nginx日志拆分(linux环境下)

    1.新增shell脚本[nginx_log.sh],进行每日自动切割一次,存储在nginx文件夹下的logs下 #!/bin/bash #设置日志文件存放目录 LOG_HOME="/app/ ...

  5. 为了掌握设计模式,开发了一款Markdown 文本编辑器软件(已开源)

    设计模式实战项目:Markdown 文本编辑器软件开发(已开源) 一.项目简介 项目名称:YtyMark-java 本项目是一款基于 Java 语言 和 JavaFX 图形界面框架 开发的 Markd ...

  6. 如何在 Java 中进行内存泄漏分析?

    如何在 Java 中进行内存泄漏分析? 内存泄漏是指程序中无法访问的对象仍然被占用内存,导致内存无法回收,最终导致内存不足.程序崩溃等问题.Java 中的内存泄漏通常与垃圾回收机制的工作方式相关,虽然 ...

  7. windows下redis设置redis开机自启动

    windows系统下启动redis命令 进入redis安装目录 cd redis 输入 redis-server.exe redis.windows.conf 启动redis命令,看是否成功 可能会启 ...

  8. Windows查看端口占用、相应进程、杀死进程等[netstat]

    Windows 通过cmd或powerShell查看端口占用.相应进程.杀死进程等的命令 由于一般开发环境是在windows上,相应的一些测试必然涉及到一些端口的监听与使用.当开发使用的端口被占用后, ...

  9. 多文件,从url地址中下载文件并进行压缩

    直接上代码 Controller层 //我这里直接拿实体接收,entity.getFile()是List<对象>,对象里面存储文件相关的内容 @PostMapping("/zip ...

  10. 同余最短路&转圈背包算法学习笔记(超详细)

    一.问题引入 当你想要解决一个完全背包计数问题,但是 \(M\) 的范围太大,那么你就可以使用同余最短路. 二.算法推导过程 首先对于一个完全背包计数问题,我们要知道如果 \(x\) 这个数能凑出来, ...