3-7softmax回归的简洁实现
3-7softmax回归的简洁实现
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
1.初始化模型参数
代码功能总结
- 定义模型:
- 定义了一个简单的神经网络模型,包含一个展平层和一个全连接层。
- 展平层将输入的二维图像数据展平为一维张量。
- 全连接层将展平后的特征映射到 10 个输出类别。
- 定义权重初始化函数:
- 定义了一个自定义的权重初始化函数
init_weights
,用于初始化nn.Linear
层的权重。 - 使用正态分布初始化权重,标准差为 0.01。
- 定义了一个自定义的权重初始化函数
- 应用权重初始化函数:
- 使用
net.apply(init_weights)
将自定义的权重初始化函数应用于模型的所有层。 - 递归地检查每一层,如果是
nn.Linear
类型,则初始化其权重。
- 使用
使用场景
这段代码通常用于在训练神经网络之前对模型的权重进行初始化。合理的权重初始化可以加速模型的收敛速度,并提高模型的性能。
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
# nn.Linear(784, 10) 是一个全连接层,将输入的 784 维特征映射到 10 个输出类别
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
if type(m) == nn.Linear:
# nn.init_normal_(m.weight, std = 0.01):使用正态分布初始化权重,均值为 0,标准差为 0.01。
nn.init.normal_(m.weight, std = 0.01)
#net.apply(init_weights) 是 PyTorch 中用于对模型的所有层应用自定义函数的方法
# init_weights 函数会被递归地应用于 net 中的每一层
net.apply(init_weights)
Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=784, out_features=10, bias=True)
)
2.重新审视softmax的实现
在 PyTorch 中,nn.CrossEntropyLoss
是一个常用的损失函数,用于多分类问题。它结合了 nn.LogSoftmax
和 nn.NLLLoss
(负对数似然损失)的功能,适用于分类任务中计算模型输出与真实标签之间的损失。
1. nn.CrossEntropyLoss
nn.CrossEntropyLoss
是 PyTorch 提供的交叉熵损失函数。- 它适用于多分类问题,其中模型的输出是一个概率分布(通常是通过 softmax 函数得到的),而真实标签是一个类别索引。
2. reduction='none'
reduction
参数控制损失函数的输出形式。- 默认情况下,
reduction='mean'
,表示对所有样本的损失值取平均。 - 设置
reduction='none'
表示不对损失值进行任何聚合,返回每个样本的损失值。
参数解释
reduction
:'none'
:返回每个样本的损失值,形状与输入的y
相同。'mean'
:返回所有样本的平均损失值。'sum'
:返回所有样本的损失值之和。
loss = nn.CrossEntropyLoss(reduction='none')
3.优化算法
在 PyTorch 中,torch.optim.SGD
是一个用于实现随机梯度下降(Stochastic Gradient Descent, SGD)优化算法的类。这段代码创建了一个 SGD 优化器实例,用于更新神经网络的参数。以下是对代码的详细解释:
代码解析
1. torch.optim.SGD
torch.optim.SGD
是 PyTorch 提供的随机梯度下降优化器。- 它用于在训练过程中更新模型的参数,以最小化损失函数。
2. net.parameters()
net.parameters()
是一个生成器,返回模型net
中所有可训练的参数(如权重和偏置)。- 这些参数是优化器需要更新的对象。
3. lr=0.1
lr
是学习率(learning rate),控制参数更新的步长。- 学习率是一个超参数,决定了每次参数更新的幅度。
- 在这里,学习率被设置为
0.1
,表示每次参数更新的步长为0.1
。
优化器的作用
优化器的作用是在训练过程中根据梯度信息更新模型的参数。SGD 优化器的具体更新规则如下: θnew=θold−lr×∇L 其中:
- θnew 是更新后的参数。
- θold 是更新前的参数。
- lr 是学习率。
- ∇L 是损失函数 L 对参数 θ 的梯度。
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
4.训练
这段代码调用了 d2l.train_ch3
函数来训练一个神经网络模型。d2l.train_ch3
是一个封装好的训练函数,通常在 D2L(Dive into Deep Learning)库中定义,用于简化训练过程。以下是对代码的详细解释:
代码解析
1. 设置训练轮数
num_epochs
是一个整数,表示训练模型的总轮数(epoch)。- 在这个例子中,模型将训练 10 轮。
2. 调用训练函数
d2l.train_ch3
是一个封装好的训练函数,用于训练神经网络模型。- 它的参数包括:
net
:模型网络。train_iter
:训练数据迭代器。test_iter
:测试数据迭代器。loss
:损失函数。num_epochs
:训练的总轮数。trainer
:优化器。
d2l.train_ch3
函数的内部逻辑
虽然我们没有看到 d2l.train_ch3
的具体实现,但根据其功能描述,它通常会执行以下步骤:
- 初始化动画对象:
- 用于动态绘制训练过程中的损失和准确率。
- 训练循环:
- 遍历每个训练轮数(
num_epochs
)。 - 在每个轮数中,对训练数据进行一次完整的训练,并计算训练损失和训练准确率。
- 在每个轮数中,对测试数据进行评估,计算测试准确率。
- 将训练损失、训练准确率和测试准确率添加到动画中,动态绘制训练过程。
- 遍历每个训练轮数(
- 断言检查:
- 检查训练损失是否小于某个阈值(如 0.5)。
- 检查训练准确率和测试准确率是否在合理范围内(如大于 0.7)。
num_epochs = 10 #训练的总轮数
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
3-7softmax回归的简洁实现的更多相关文章
- 动手学深度学习8-softmax分类pytorch简洁实现
定义和初始化模型 softamx和交叉熵损失函数 定义优化算法 训练模型 import torch from torch import nn from torch.nn import init imp ...
- 动手学深度学习14- pytorch Dropout 实现与原理
方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...
- 动手学深度学习10- pytorch多层感知机从零实现
多层感知机 定义模型的参数 定义激活函数 定义模型 定义损失函数 训练模型 小结 多层感知机 import torch import numpy as np import sys sys.path.a ...
- 动手学深度学习7-从零开始完成softmax分类
获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...
- 小匠第一周期打卡笔记-Task01
一.线性回归 知识点记录 线性回归输出是一个连续值,因此适用于回归问题.如预测房屋价格.气温.销售额等连续值的问题.是单层神经网络. 线性判别模型 判别模型 性质:建模预测变量和观测变量之间的关系,亦 ...
- L3 多层感知机
**本小节用到的数据下载 1.涉及语句 import d2lzh1981 as d2l 数据1 : d2lzh1981 链接:https://pan.baidu.com/s/1LyaZ84Q4M75G ...
- pytorch和tensorflow的爱恨情仇之定义可训练的参数
pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch版本:1.6.0 tensorflow版本:1.15.0 之前我们就已 ...
- 动手学深度学习4-线性回归的pytorch简洁实现
导入同样导入之前的包或者模块 生成数据集 通过pytorch读取数据 定义模型 初始化模型 定义损失函数 定义优化算法 训练模型 小结 本节利用pytorch中的模块,生成一个更加简洁的代码来实现同样 ...
- 原生js实现简洁的返回顶部组件
本文内容相当简单,所以没有发布到博客园首页,如果你不幸看到,那只能是我这篇文章的荣幸,谢谢你的大驾光临~(本博客返回顶部的功能就使用的是这个组件) 返回顶部组件是一种极其常见的网页功能,需求简单:页面 ...
- 看大众点评V9新版如何为O2O止血 带领行业下半场回归理性
前不久,美团点评CEO王兴提出的“中国互联网进入下半场”观点一直在持续发酵,并引发了整个互联网圈对于进入下半场该如何改革,如何迎战的深刻反思.在互联网的上半场,大家依托的是人口红利,但是到了下半场,用 ...
随机推荐
- 手把手教你安装TrueNas(基础篇)
玩过蜗牛星际,体验过黑群晖系统崩掉导致里面珍藏12t大姐姐全没了(此处有哭声),我技术又菜,自己恢复是不可能恢复的,装的盗版系统,又不可能联系群晖官方售后恢复.于是乎就想要一个稳定.开 ...
- Golang 语言学习路线
学习Go语言是一个很好的选择,它具有高效的编译速度.强大的并发支持和简洁的语法.适用于初学者的Golang学习路线: 1. 学习基础: 安装Go:从官方网站下载并安装Go语言的最新版本. Hello, ...
- MySQL 事务隔离级别:社交恐惧症的四个阶段
MySQL 事务隔离级别:社交恐惧症的四个阶段 在数据库的世界里,数据们也有社交问题!事务隔离级别就是控制它们互相看到对方的程度... 什么是事务隔离? 想象一下,数据库是一个繁忙的餐厅,每个事务都是 ...
- 【Docker】命令行操作
Docker常用命令 帮助命令 docker version docker info docker --help Docker 客户端 docker 客户端非常简单 ,我们可以直接输入 docker ...
- php代码审计实战-开源项目Materialized CMS漏洞检测
一.下载Materialized CMS 链接地址:https://sourceforge.net/projects/materialized-cms/files/latest/download 二. ...
- D的SDK的设置
有点烦,被困扰.看大虾的文章一并感谢: 进入D:\Users\Public\Documents\Embarcadero\Studio\22.0\CatalogRepository\AndroidSDK ...
- 支付系统扩展:ZKmall开源商城支持跨境多币种结算的开发实践
于跨境电商平台而言,多币种支付是满足全球消费者支付需求的关键.不同国家和地区的消费者习惯使用各自的货币进行支付,如果平台不支持多币种交易,将极大地限制用户的购买意愿和支付便利性.因此,跨境电商平台必须 ...
- 🎀java-自定义日志注解
简介 创建自定义日志注解,对相关接口记录请求日志. 环境 SpringBoot 实现 注解定义 定义注解类 package com.zk.app.annotation; import com.zk.a ...
- Java 单元测试简单扫盲
前言 仔细回想起来,上次认真编写单元测试已经是两年前的事了.那时候觉得写单元测试是种负担. 为了应付代码覆盖率要求,常常依赖工具自动生成测试用例,有时需要启动Spring容器,有时又不需要(当时还分不 ...
- 在 ASP.NET Core 中编写高性能 Web API 的4个小技巧
Web API 通常用来与外部模块进行通信.发送和接收数据,作为后端开发人员,应该把写出高性能的应用作为目标. 下面 4 个技巧是我在编写 Web API 的小技巧. 1 .大量数据使用分页查询 接口 ...