3-6softmax回归从0开始实现
3-6softmax回归从0开始实现
import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# print(len(train_iter), len(test_iter))
1.初始化模型参数
num_inputs = 784
num_outputs = 10
w = torch.normal(0,0.01,size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs,requires_grad=True)
2.定义softmax操作
torch.tensor
是 PyTorch 中用于创建张量的函数。- 这里创建了一个形状为
(2, 3)
的二维张量x
,其中包含两行三列的浮点数。
x.sum(0, keepdim=True)
x.sum(0, keepdim=True)
是对张量x
沿着第 0 维(即行)进行求和操作。- 参数
0
表示沿着第 0 维(行)进行操作,即对每一列的元素求和。 - 参数
keepdim=True
表示在求和后保留原张量的维度。如果不设置keepdim=True
,结果的维度会减少。
x = torch.tensor([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
x.sum(0, keepdim=True), x.sum(1,keepdim=True)
(tensor([[5., 7., 9.]]),
tensor([[ 6.],
[15.]]))
#定义 softmax 函数
def softmax(x):
x_exp = torch.exp(x)
partition = x_exp.sum(1, keepdim = True)
return x_exp / partition # 这里应用了广播机制
x = torch.normal(0, 1, (2,5))
x_prob = softmax(x)
x_prob, x_prob.sum(1) #打印结果
(tensor([[0.0908, 0.1063, 0.1892, 0.1219, 0.4918],
[0.0741, 0.0314, 0.1805, 0.6420, 0.0720]]),
tensor([1., 1.]))
3.定义模型
定义softmax操作后,我们可以实现softmax回归模型。下面的代码定义了输入如何通过网络映射到输出。注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。
def net(x):
return softmax(torch.matmul(x.reshape((-1, w.shape[0])),w) + b)
4.定义损失函数
y_hat[[0, 1], y]
- 这里的
y_hat[[0, 1], y]
是一种高级索引操作,用于从y_hat
中选择特定的元素。 [[0, 1], y]
是一个索引组合,其中:[0, 1]
表示从y_hat
的第 0 行和第 1 行中选择元素。y
是一个张量[0, 2]
,表示从每行中选择第 0 列和第 2 列的元素。
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0,1],y] # 输出(0,0) 和 (1,2)
tensor([0.1000, 0.5000])
# 实现交叉熵损失函数
def cross_entropy(y_hat, y):
return -torch.log(y_hat[range(len(y_hat)), y])
cross_entropy(y_hat,y)
tensor([2.3026, 0.6931])
5.分类精度
y_hat
是模型的预测输出。y
是真实标签。- 函数的目标是计算预测数量正确的。
代码解析
1. 处理多分类问题
len(y_hat.shape) > 1
:检查y_hat
是否是一个二维张量(即是否有多个类别预测概率)。y_hat.shape[1] > 1
:进一步确认y_hat
的第二维(列数)大于 1,表示这是一个多分类问题。y_hat.argmax(axis=1)
:如果满足上述条件,说明y_hat
是一个概率分布(例如,softmax 输出)。argmax(axis=1)
会沿着第二维(列)找到每行的最大值的索引,这些索引即为预测的类别标签。
2. 比较预测值和真实值
y_hat.type(y.dtype)
:将y_hat
的数据类型转换为与y
相同的数据类型。这一步是为了确保比较操作能够正常进行。y_hat == y
:逐元素比较y_hat
和y
,返回一个布尔张量cmp
,其中True
表示预测正确,False
表示预测错误。
3. 计算预测正确的数量
cmp.type(y.dtype)
:将布尔张量cmp
转换为与y
相同的数据类型(通常是整数类型)。True
会被转换为 1,False
会被转换为 0。cmp.sum()
:对转换后的张量求和,计算预测正确的数量。float(cmp.sum())
:将结果转换为浮点数,以便后续计算准确率时可以使用浮点除法。
def accuracy(y_hat, y):
'''计算预测正确的数量'''
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
accuracy(y_hat, y) / len(y)
0.5
def evaluate_accuracy(net, data_iter):
'''计算在指定数据集上模型的精度'''
if isinstance(net, torch.nn.Module):
net.eval() # 将模型设置为评估模式
metric = Accumulator(2) # 正确预测数、预测总数
with torch.no_grad():
for x,y in data_iter:
# print("x shape:", x.shape) # 打印 x 的形状
# y_hat = net(x)
# print("y_hat shape:", y_hat.shape) # 打印 y_hat 的形状
# print("y shape:", y.shape) # 打印 y 的形状
metric.add(accuracy(net(x), y), y.numel()) #返回张量 y 中所有元素的总数
return metric[0] / metric[1]
class Accumulator:
'''在n个变量上累加'''
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
evaluate_accuracy(net, test_iter)
0.0858
6.训练
def train_epoch_ch3(net, train_iter, loss, updater):
'''训练模型一个迭代周期'''
# 将模型设置为训练模式
if isinstance(net, torch.nn.Module):
net.train()
# 训练损失总和、训练准确度综合、样本数
metric = Accumulator(3)
for x, y in train_iter:
# 计算梯度并更新参数
y_hat = net(x)
l = loss(y_hat, y)
if isinstance(updater, torch.optim.Optimizer):
# 使用pytorch内置的优化器和损失函数
updater.zero_grad()
l.mean().backward()
updater.step()
else:
# 使用定制的优化器和损失函数
l.sum().backward()
updater(x.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y),y.numel())
# 返回训练损失和训练精度
return metric[0] / metric[2], metric[1] / metric[2]
绘制代码功能总结
- 初始化 (
__init__
方法):- 设置图形的布局(如大小、子图数量等)。
- 初始化用于存储数据点的列表。
- 定义一个 lambda 函数来配置轴的属性(如标签、范围等)。
- 添加数据点 (
add
方法):- 接收新的数据点
x
和y
,并将它们添加到存储结构中。 - 动态更新图形,绘制最新的数据点。
- 配置图形的轴属性,并显示更新后的图形。
- 接收新的数据点
使用场景
这个类通常用于动态绘制训练过程中的数据(如损失值、准确率等),以便实时观察模型的性能变化。通过调用 add
方法,可以逐次添加新的数据点并更新图形。
class Animator:
""" 在动画中绘制数据"""
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim = None, xscale='linear', yscale = 'linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize = (3.5, 2.5)):
"""
初始化 Animator 类的实例。
参数:
- xlabel: x轴标签
- ylabel: y轴标签
- legend: 图例列表
- xlim: x轴范围
- ylim: y轴范围
- xscale: x轴刻度类型(如 'linear' 或 'log')
- yscale: y轴刻度类型
- fmts: 每条线的格式字符串列表
- nrows: 子图的行数
- ncols: 子图的列数
- figsize: 图形的大小
"""
# 增量地绘制多条线
if legend is None:
legend = []
d2l.use_svg_display() # 使用 d2l 库的 SVG 显示功能
# 创建图形和子图
self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize = figsize)
if nrows * ncols == 1: # 如果只有一个子图,将其转换为列表以便统一处理
self.axes = [self.axes,]
# 使用lambda函数捕获参数,# 定义一个 lambda 函数,用于配置子图的轴
self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel,
xlim, ylim, xscale, yscale, legend)
# 初始化 x 和 y 数据为空,fmts 是每条线的格式
self.x, self.y, self.fmts = None, None, fmts
def add(self, x, y):
"""
向图表中添加数据点。
参数:
- x: x轴数据点(可以是单个值或列表)
- y: y轴数据点(可以是单个值或列表)
"""
# 向图表中添加多个数据点
if not hasattr(y, "__len__"):
y = [y]
n = len(y)
# 如果 x 是单个值,将其复制为与 y 同样长度的列表
if not hasattr(x, "__len__"):
x = [x] * n
# 如果是第一次调用 add 方法,初始化 x 和 y 的存储结构
if not self.x:
self.x = [[] for _ in range(n)]
if not self.y:
self.y = [[] for _ in range(n)]
# 遍历 x 和 y 的值,将它们添加到对应的列表中
for i, (a,b) in enumerate(zip(x, y)):
if a is not None and b is not None:
self.x[i].append(a)
self.y[i].append(b)
# 清除当前子图的内容
self.axes[0].cla()
# 绘制每条线
for x,y, fmt in zip(self.x, self.y, self.fmts):
self.axes[0].plot(x, y, fmt)
# 配置轴的属性
self.config_axes()
# 显示图形
display.display(self.fig)
# 清除之前的输出,避免重复显示
display.clear_output(wait=True)
训练模型代码功能总结
- 初始化动画对象:
- 使用
Animator
类初始化一个动画对象,设置 x 轴为训练轮数(epoch
),y 轴范围为[0.3, 0.9]
,并定义图例为['train loss', 'train acc', 'test acc']
。
- 使用
- 训练过程:
- 遍历每个训练轮数(
num_epochs
)。 - 在每个轮数中,调用
train_epoch_ch3
函数对训练数据进行一次完整的训练,返回训练损失和训练准确率。 - 在每个轮数中,调用
evaluate_accuracy
函数对测试数据进行评估,返回测试准确率。 - 将当前轮数的训练损失、训练准确率和测试准确率添加到动画中,用于动态绘制训练过程。
- 遍历每个训练轮数(
- 断言检查:
- 检查训练损失是否小于 0.5,确保模型训练效果良好。
- 检查训练准确率和测试准确率是否在合理范围内(小于等于 1 且大于 0.7),确保模型性能符合预期。
使用场景
这个函数通常用于训练一个简单的神经网络模型(如在第 3 章中介绍的线性网络)。它通过动态绘制训练损失、训练准确率和测试准确率,帮助用户直观地观察模型的训练过程和性能变化。
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
"""
训练模型(定义见第3章)
参数:
- net: 模型网络
- train_iter: 训练数据迭代器
- test_iter: 测试数据迭代器
- loss: 损失函数
- num_epochs: 训练的总轮数
- updater: 参数更新函数(如优化器)
"""
# 初始化一个动画对象,用于绘制训练过程中的损失和准确率
animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
# 遍历每个训练轮数
for epoch in range(num_epochs):
# 在当前轮数中,对训练数据进行一次完整的训练,并返回训练损失和训练准确率
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
# 在当前轮数中,对测试数据进行评估,返回测试准确率
test_acc = evaluate_accuracy(net, test_iter)
# 将当前轮数的训练损失、训练准确率和测试准确率添加到动画中,用于绘制图表
animator.add(epoch + 1, train_metrics + (test_acc,))
# 从训练指标中提取训练损失和训练准确率
train_loss, train_acc = train_metrics
# 断言训练损失小于 0.5,确保模型训练效果良好
assert train_loss < 0.5, train_loss
# 断言训练准确率在合理范围内(小于等于 1 且大于 0.7)
assert train_acc <= 1 and train_acc > 0.7, train_acc
# 断言测试准确率在合理范围内(小于等于 1 且大于 0.7)
assert test_acc <= 1 and test_acc > 0.7, test_acc
lr = 0.1
def updater(batch_size):
return d2l.sgd([w, b], lr, batch_size)
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
7.预测
现在训练已经完成,我们的模型已经准备好对图像进行分类预测。给定一系列图像,我们将比较它们的实际
标签(文本输出的第一行)和模型预测(文本输出的第二行)
def predict_ch3(net, test_iter, n=6):
"""
预测标签
参数:
- net: 训练好的模型网络
- test_iter: 测试数据迭代器
- n: 要显示的预测样本数量,默认为 6
"""
# 从测试数据迭代器中获取第一批数据
for x, y in test_iter:
break
# 获取真实标签的文本描述
trues = d2l.get_fashion_mnist_labels(y)
# 使用模型进行预测,并获取预测标签的文本描述
preds = d2l.get_fashion_mnist_labels(net(x).argmax(axis=1))
# 生成标题,包含真实标签和预测标签
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
# 显示前 n 个图像及其标题
# 将图像数据重塑为 (n, 28, 28) 的形状,以便显示
d2l.show_images(x[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)
3-6softmax回归从0开始实现的更多相关文章
- 用R语言的quantreg包进行分位数回归
什么是分位数回归 分位数回归(Quantile Regression)是计量经济学的研究前沿方向之一,它利用解释变量的多个分位数(例如四分位.十分位.百分位等)来得到被解释变量的条件分布的相应的分位数 ...
- Hyperledger Fabric 1.0.1至Hyperledger Fabric 1.0.5所升级的内容及修复的问题
基础更新 各版本每次迭代都会有一些基础更新内容,如文档修改覆盖.测试用例完善.用户体验改进及删除冗余无效代码等… 下面分类介绍的是一些版本迭代的重要更新内容,因个人实操和理解有限,部分更新并未明确,如 ...
- 斯坦福机器学习视频笔记 Week3 逻辑回归与正则化 Logistic Regression and Regularization
我们将讨论逻辑回归. 逻辑回归是一种将数据分类为离散结果的方法. 例如,我们可以使用逻辑回归将电子邮件分类为垃圾邮件或非垃圾邮件. 在本模块中,我们介绍分类的概念,逻辑回归的损失函数(cost fun ...
- 机器学习---逻辑回归(二)(Machine Learning Logistic Regression II)
在<机器学习---逻辑回归(一)(Machine Learning Logistic Regression I)>一文中,我们讨论了如何用逻辑回归解决二分类问题以及逻辑回归算法的本质.现在 ...
- 资深人士剖析微软开源.NET事件:战略重心已经从PC转移到云端
本文是雷锋网对我的访谈整理的文章,源地址是 http://www.leiphone.com/news/201411/6KaGhD7PDABnvrRf.html 2014年11月13日,微软表示开源.N ...
- 《BI那点儿事》数据挖掘各类算法——准确性验证
准确性验证示例1:——基于三国志11数据库 数据准备: 挖掘模型:依次为:Naive Bayes 算法.聚类分析算法.决策树算法.神经网络算法.逻辑回归算法.关联算法提升图: 依次排名为: 1. 神经 ...
- html5 svg动画
http://www.zhangxinxu.com/sp/svg/ 以上是svg的一个线上编辑器,也可以adobe Illustrator制作生成. 我们通过以上编辑器可以获得以下代码. 例: < ...
- RFID 基础/分类/编码/调制/传输
不同频段的RFID产品会有不同的特性,本文详细介绍了无源的感应器在不同工作频率产品的特性以及主要的应用. 目前定义RFID产品的工作频率有低频.高频和甚高频的频率范围内的符合不同标准的不同的产品,而且 ...
- 论文笔记之:Instance-aware Semantic Segmentation via Multi-task Network Cascades
Instance-aware Semantic Segmentation via Multi-task Network Cascades Jifeng Dai Kaiming He Jian Sun ...
- 创建理想的SEQUENCE和自增长的trigger
SEQUENCE CREATE SEQUENCE TEST_SEQ START 1 --从1开始,第一个一定是NEXTVAL,因为第一个CURRVAL不好使,返回值会是1,第一个NEXTVAL相当于从 ...
随机推荐
- Cursor 老改坏代码?六哥这几招超管用!
大家好,我是六哥!最近不少小伙伴和我吐槽,在使用Cursor时,AI老是把代码改坏,让人头疼不已.我自己也用了大几十个小时Cursor,今天就来给大家分享一些实用小窍门,教大家如何巧妙规避这类问题. ...
- 抽象类的注意事项、abstract关键字的冲突--java进阶day02
1.注意事项 1.抽象类不允许创建对象 2.抽象类存在构造方法 3.抽象类中可以存在普通成员方法 4.抽象类的子类存在两种处理方式 第一种不多解释,主要讲第二种,子类继承了抽象类,相当于子类里面有了抽 ...
- 记载火狐浏览器下的一次新手级的js解密工作
警告:该随笔内容仅用于合法范围下的学习,不得用于任何商业和非法用途,不得未经授权转载,否则后果自负. 首先是需要解密的网站:https://www.aqistudy.cn/historydata/mo ...
- 痞子衡嵌入式:恩智浦i.MX RT1xxx上特色外设XBAR那些事(1)- 初识
大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是恩智浦i.MX RT1xxx系列上的XBAR外设. 得益于 Arm Cortex-M 内核的普及,现如今 MCU 厂商遍地开花,只要能取 ...
- 查缺补漏——01-BFS
01bfs 解决的是一类特殊的最段路问题. 在学习它的过程中,我更加深刻地学习到了泛化路径和 bfs. 01-BFS 是什么 首先明确,01-BFS 是一种图论算法.它解决的事最短路径问题.最短路径算 ...
- L2-3 锦标赛
先画图理解 具体就是先存入每个左右子树的lose,然后存入根的lose和win 然后往下建树,左右的win也可以交换 可以学习这样的完全二叉树存储结构 #include <bits/stdc++ ...
- java的打包(JAR、War)
一.Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing i ...
- Eclipse 中 JAVA AWT相关包不提示问题(解决)
原因: 由于在2021年7月15日 OpenJDK管理委员会全票通过批准成立由Phil Race担任初始负责人的 Client Libraries Group(客户端类库工作组). 新的工作组将继续赞 ...
- 一个 CTO 的深度思考
今天和一些同事聊了一会,以下是我的观点 我的观点,成年人只能筛选,不能培养 在组织中,应该永远向有结果的人看齐.不能当他站出来讲话的时候,大家还要讨论讨论,他虽然拿到结果了,但是他就是有一点点小问题. ...
- Git错误,fatal: refusing to merge unrelated histories
错误:fatal: refusing to merge unrelated histories 中文意思就是拒绝合并不相关的历史, 解决 出现这个问题的最主要原因还是在于本地仓库和远程仓库实际上是独立 ...