手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. oeasy教您玩转vim - 26 - 缩进设置

    ​ 缩进设置 回忆上节课内容 这次了解了颜色的细节 设置 256 色模式 :set t_Co=256 然后确定了具体的各种颜色 还可以生成网页 :TOhtml 还有什么好玩的么? ​ 缩进设置 ​ 在 ...

  2. oeasy教您玩转vim - 50 - # 命令行效率

    ​ 命令行效率 回忆上节课内容 总结 我们上次研究的是范围命令执行方法 批量控制缩进 :20,40> 批量执行普通模式下的命令 :4,10normal A; 直接切换到全屏命令模式 ex-mod ...

  3. iOS开发基础142-广告归因

    IDFA IDFA是苹果为iOS设备提供的一个唯一标识符,专门用于广告跟踪和相关的营销用途.与之对应的,在Android平台的是谷歌广告ID(Google Advertising ID). IDFA的 ...

  4. Jmeter函数助手22-V

    V函数用于执行变量名.嵌套函数.类似eval函数 Name of variable (may include variable and function references):必填,填入变量名称或者 ...

  5. 【XML】Extensible Markup Language 可扩展标记语言

    Extensible Markup Language 可扩展标记语言[XML] 视频资料参考自:https://www.bilibili.com/video/BV1B441117Lu?p=186 其他 ...

  6. vim跳转到上次和下次光标位置

    在vim的命令模式下: ctrl + i    下次光标位置; ctrl + o   上次光标位置. =====================================

  7. uview-ui toast 二次封装

    开发用到uview 的toast 很常用的内容使用却很繁琐 所以做了简单封装方便使用 前后对比: this.$refs.uToast.show({ type: 'success', title: '成 ...

  8. 题解:CF780B The Meeting Place Cannot Be Changed

    这道题一看就是 二分 板子题. 当然由于精度原因,最好由原来的二分模板转换成这个. while ((w - t) > 0.000001) { mid = (t + w) / 2.0 ; if ( ...

  9. 白鲸开源CEO郭炜荣获「2024中国数智化转型升级先锋人物」称号

    2024年7月24日,由数据猿主办,IDC协办,新华社中国经济信息社.上海大数据联盟.上海市数商协会.上海超级计算中心作为支持单位,举办"数智新质·力拓未来 2024企业数智化转型升级发展论 ...

  10. Go进程内存占用那些事(二)

    0x01 最简单的Go程序 package main import ( "fmt" "time" ) func main() { fmt.Println(&qu ...