LR梯度下降法MSE演练
同步进行一波网上代码搬砖, 先来个入门的线性回归模型训练, 基于梯度下降法来, 优化用 MSE 来做. 理论部分就不讲了, 网上一大堆, 我自己也是理解好多年了, 什么 偏导数, 梯度(多远函数一阶偏导数组成的向量) , 方向导数, 反方向(梯度下降) 这些基本的高数知识, 假设大家是非常清楚原理的.
如不清楚原理, 那就没有办法了, 只能自己补, 毕竟 ML 这块, 如果不清楚其数学原理, 只会有框架和导包, 那得是多门的无聊和无趣呀. 这里是搬运代码的, 当然, 我肯定是有改动的,基于我的经验, 做个小笔记, 方便自己后面遇到时, 直接抄呀.
01 采样数据
这里呢, 假设已知一个线性模型, 就假设已经基本训练好了一个, 比如是这样的.
\(y=1.477x + 0.089\)
现在为了更好模拟真实样本的观测误差, 给模型添加一个误差变量 \(\epsilon\) (读作 \epsilon) , 然后想要搞成这样的.
\(y =1.477x + 0.089 + \epsilon, \epsilon-N(0, 0.01^2)\)
现在来随机采样 100次, 得到 n=100 的样本训练数据集
import numpy as np
def data_sample(times=100):
"""数据集采样times次, 返回一个二维数组"""
for i in range(times):
# 随机采样输入 x, 一个数值 (均匀分布)
x = np.random.uniform(-10, 10)
# 采样高斯噪声(随机误差),正态分布
epsilon = np.random.normal(0, 0.01)
# 得到模型输出
y = 1.447 * x + 0.089 + epsilon
# 用生成器来生成或存储样本点
yield [x, y]
# test
# 将数据转为 np.array 的二维数组
data = np.array(list(data_sample()))
data 是这样的, 2D数组, 一共100行记录, 每一行表示一个样本点 (x, y).
array([[ 5.25161007, 7.6922458 ],
[ 9.00034456, 13.11119931],
[ 9.47485633, 13.80426132],
[ -4.3644416 , -6.2183884 ],
[ -3.35345323, -4.76625711],
[ -5.10494006, -7.30976062],
.....
[ -6.78834597, -9.73362456]]
02 计算误差 MSE
计算每个点 (xi, yi) 处的预测值 与 真实值 之差的平方 并 累加, 从而得到整个训练集上的均方误差损失值.
# y = w * x + b
def get_mse(w, b, points):
"""计算训练集的 MSE"""
# points 是一个二维数组, 每行表示一个样本
# 每个样本, 第一个数是 x, 第二个数是 y
loss = 0
for i in range(0,len(X)):
x = points[i, 0]
y = points[i, 1]
# 计算每个点的误差平方, 并进行累加
loss += (y - (w * x + b)) ** 2
# 用 总损失 / 总样本数 = 均方误差 mse
return loss / len(points)
样本是一个二维数组, 或者矩阵. 每一行, 表示一个样本, 每一列表示该样本的某个子特征
03 计算梯度
关于梯度, 即多元函数的偏导数向量, 这个方向是, 多元函数的最大导数方向 (变化率最大) 方向 (向量方向), 于是, 反方向, 则是函数变化率最小, 即极值点的地方呀, 就咱需要的, 所以称为, 梯度下降法嘛, 从数学上就非常好理解哦.
def step_gradient(b_current, w_current, points, lr):
# 计算误差函数在所有点的导数, 并更新 w, b
b_gradient = 0
w_gradinet = 0
n = len(points) # 样本数
for i in range(n):
# x, y 都是一个数值
x = points[i, 0]
y = points[i, 1]
# 损失函数对 b 的导数 g_b = 2/n * (wx+b-y) 数学推导的
b_gradient += (n/2) * ((w_current * x + b) - y)
# 损失函数对 w 的导数 g_w = 2/n (wx+b-y) x
w_gradinet += (n/2) * x * ((w_current * x + b) - y)
# 根据梯度下降法, 更新 w, b
new_w = w_current - (lr * b_gradient)
new_b = b_current - (lr * b_gradient)
return [new_w, new_b]
04 更新梯度 Epoch
根据第三步, 在算出误差函数在 w, b 的梯度后, 就可以通过 梯度下降法来更新 w,b 的值. 我们把对数据集的所有样本训练一次称为一个 Epoch, 共循环迭代 num_iterations 个 Epoch.
def gradient_descent(points, w, b, lr, max_iter):
"""梯度下降 Epoch"""
for step in range(max_iter):
# 计算梯度并更新一次
w, b = step_gradient(b, w, np.array(points),lr)
# 计算当前的 均方差 mse
loss = get_mes(w, b, points)
if step % 50 == 0:
# 每隔50次打印一次当前信息
print(f"iteration: {step} loss: {loss}, w:{w}, b:{b}")
# 返回最后一次的 w,b
return [w, b]
05 主函数
def main():
# 加载训练数据, 即通过真实模型添加高斯噪声得到的
lr = 0.01 # 学习率
init_b = 0
init_w = 0
max_iter = 500 # 最大Epoch=100次
# 用梯度下降法进行训练
w, b = gradient_descent(data, init_w, init_b, lr, max_iter)
# 计算出最优的均方差 loss
loss = get_mse(w, b, dataa)
print(f"Final loss: {loss}, w:{w}, b:{b}")
# 运行主函数
main()
iteration: 0 loss: 52624.8637745707, w:-37.451784525811654, b:-37.451784525811654
iteration: 50 loss: 8.751081967754209e+134, w:-5.0141110054193505e+66, b:-5.0141110054193505e+66
iteration: 100 loss: 1.7286223665339186e+265, w:-7.047143783692584e+131, b:-7.047143783692584e+131
iteration: 150 loss: inf, w:-9.904494626138306e+196, b:-9.904494626138306e+196
iteration: 200 loss: inf, w:-1.3920393397706614e+262, b:-1.3920393397706614e+262
iteration: 250 loss: nan, w:nan, b:nan
iteration: 300 loss: nan, w:nan, b:nan
iteration: 350 loss: nan, w:nan, b:nan
iteration: 400 loss: nan, w:nan, b:nan
iteration: 450 loss: nan, w:nan, b:nan
************************************************************
Final loss: nan, w:nan, b:nan
可以看到, 在 Epoch 100多次, 后, 就已经收敛了. 当然正常来说, 应该给 loss 设置一个阈值的, 不然后面都 inf 了, 我还在 epoch, 这就有问题了. 这里就不改了, 总是习惯留下一些不完美, 这样才会记得更深. 其目的也是在与数理 ML 整个训练过程, 用入门级的 线性回归和 梯度下降法来整.
LR梯度下降法MSE演练的更多相关文章
- pytorch梯度下降法讲解(非常详细)
pytorch随机梯度下降法1.梯度.偏微分以及梯度的区别和联系(1)导数是指一元函数对于自变量求导得到的数值,它是一个标量,反映了函数的变化趋势:(2)偏微分是多元函数对各个自变量求导得到的,它反映 ...
- 梯度下降法实现(Python语言描述)
原文地址:传送门 import numpy as np import matplotlib.pyplot as plt %matplotlib inline plt.style.use(['ggplo ...
- 快速用梯度下降法实现一个Logistic Regression 分类器
前阵子听说一个面试题:你实现一个logistic Regression需要多少分钟?搞数据挖掘的人都会觉得实现这个简单的分类器分分钟就搞定了吧? 因为我做数据挖掘的时候,从来都是顺手用用工具的,尤其是 ...
- (3)梯度下降法Gradient Descent
梯度下降法 不是一个机器学习算法 是一种基于搜索的最优化方法 作用:最小化一个损失函数 梯度上升法:最大化一个效用函数 举个栗子 直线方程:导数代表斜率 曲线方程:导数代表切线斜率 导数可以代表方向, ...
- 线性回归(最小二乘法、批量梯度下降法、随机梯度下降法、局部加权线性回归) C++
We turn next to the task of finding a weight vector w which minimizes the chosen function E(w). Beca ...
- 梯度下降法及一元线性回归的python实现
梯度下降法及一元线性回归的python实现 一.梯度下降法形象解释 设想我们处在一座山的半山腰的位置,现在我们需要找到一条最快的下山路径,请问应该怎么走?根据生活经验,我们会用一种十分贪心的策略,即在 ...
- 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)
在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...
- optim.SDG 或者其他、实现随机梯度下降法
optim.SDG 或者其他.实现随机梯度下降法 待办 实现随机梯度下降算法的参数优化方式 另外还有class torch.optim.ASGD(params, lr=0.01, lambd=0.00 ...
- [Machine Learning] 梯度下降法的三种形式BGD、SGD以及MBGD
在应用机器学习算法时,我们通常采用梯度下降法来对采用的算法进行训练.其实,常用的梯度下降法还具体包含有三种不同的形式,它们也各自有着不同的优缺点. 下面我们以线性回归算法来对三种梯度下降法进行比较. ...
- 机器学习基础——梯度下降法(Gradient Descent)
机器学习基础--梯度下降法(Gradient Descent) 看了coursea的机器学习课,知道了梯度下降法.一开始只是对其做了下简单的了解.随着内容的深入,发现梯度下降法在很多算法中都用的到,除 ...
随机推荐
- Flink Watermark 不止可以用时间戳衡量
https://mp.weixin.qq.com/s/L5PqtcmffCIq_CnUs0WS3g
- P11620 [Ynoi Easy Round 2025] TEST_34
由子序列和最值异或可以想到线性基 发现其实线性基满足结合律 考虑线段树进行维护 那么显然的一个想法就是把1操作直接上tag 但是发现上tag其实会丢失线性基的性质 于是差分 将区间修改变为单点修改 考 ...
- 当我老丈人都安装上DeepSeek的时候,我就知道AI元年真的来了!
关注公众号回复1 获取一线.总监.高管<管理秘籍> 春节期间DeepSeek引爆了朋友圈,甚至连我老丈人都安装了APP,这与两年前OpenAI横空出世很不一样,DeepSeek似乎真的实现 ...
- 【Matlab】求解复合材料层合板刚度矩阵及柔度矩阵
1. matlab文件结构 2. main.m代码 clc clear; warning off; %% %铺层角度数组 angles=[0 90 0]; % ° %单层厚度 ply_thicknes ...
- 前端解析excel表格实现
1. 背景:在做react项目时,遇到一个解析excel的需求变更,把从原来后端解析变更为前端解析. 1.1 由于后端解析excel文件有安全隐患,因为项目中后端不允许上传文件,当然后端解析对前端来说 ...
- Django实战项目-学习任务系统-用户注册
接着上期代码框架,开发第2个功能,用户注册,在原有用户模型基础上,增加一个学生用户属性表,用来关联学生用户的各种属性值,这个属性表是参考网络小说里系统属性值设计的,方便直观了解用户的能力高低,等级以及 ...
- bs4库爬取天气预报
Python不仅用于网站开发,数据分析,图像处理,也常用于爬虫技术方向,最近学习了解下,爬虫技术入门一般先使用bs4库,爬取天气预报简单尝试下. 第一步:首先选定目标网站地址 网上查询,天气预报准确率 ...
- IvorySQL 4.0 之 Invisible Column 功能解析
前言 随着数据库应用场景的多样化,用户对数据管理的灵活性和隐私性提出了更高要求.IvorySQL 作为一款基于 PostgreSQL 并兼容 Oracle 的开源数据库,始终致力于在功能上保持领先和创 ...
- 抓包分析:wireshark抓不到TLS1.3数据包中证书的解决方案
近日工作中遇到需要分析使用TLS1.3协议进行通信的数据包的情况,但使用wireshark进行分析发现不能抓到服务端证书,感到诧异遂设法解决 这篇博客给出解决方案,和简单的原理分析 解决方案: 第一步 ...
- 如何学习 ROS+PX4
博客地址:https://www.cnblogs.com/zylyehuo/ 参考 https://www.bilibili.com/video/BV1vx4y1Y7Tu?spm_id_from=33 ...