手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. 题解:B3646 数列前缀和 3

    分析 板子题,线段树维护矩阵区间积,除了难写没什么思维难度. 所以直接放代码吧. Code #include<bits/stdc++.h> #define int long long us ...

  2. C#从6.0~9.0都更新了什么?

    一.C#6中新增的功能 get 只读属性 简洁的语法来创建不可变类型,仅有get访问器: public string FirstName { get; } public string LastName ...

  3. GIS前沿技术

    无论是初步接触到GIS的学生,还是对GIS已经有一定的了解的从业者,肯定都非常关心两个问题:GIS有没有发展前景,GIS有哪些应用价值? 关于这两个问题,笔者的答案是GIS作为一门融合了空间数据采集. ...

  4. Docker 基于Dockerfile创建镜像实践

    需求描述 简单说,就是创建一个服务型的镜像,即运行基于该镜像创建的容器时,基于该容器自动开启一个服务.具体来说,是创建一个部署了nginx,uwsgi,python,django项目代码的镜像,运行基 ...

  5. MySQL 跨服务器关联查询

    如果您需要在 MySQL 中关联查询位于不同服务器的表(跨服务器关联查询),您可以考虑使用 MySQL 的联机查询(Federated MySQL).联机查询允许您在一个服务器上访问和查询另一个服务器 ...

  6. Net8将Serilog日志推送ES,附视频

    这是一个Serilog的实践Demo,包括了区别记录存放,AOP 日志记录,EF 执行记录,并且将日志推送到Elastic Search. 说在前面的话 自从AI出来之后,学习的曲线瞬间变缓了,学习的 ...

  7. 【Java】讲讲StreamAPI

    预设场景: 从Mybatis调用Mapper得到的用户集合 List<UserDTO> userList = new ArrayList<>(); 常用的几种API用法示例: ...

  8. 【SpringCloud】Nacos集群部署(Centos平台)

    一.前提环境准备 Nacos 下载 https://github.com/alibaba/nacos/releases 或者使用其它博主备份的 https://blog.csdn.net/weixin ...

  9. 【RabbitMQ】14 集群搭建

    多服务器单实例 -- 参考博客: https://www.cnblogs.com/lixioayi/articles/9993658.html 首先要找到cookie文件,所有实例要保持cookie一 ...

  10. 【IDEA】DEBUG调试问题

    不要将断点打在方法的声明上: 会有一个菱形标志,在标记之后运行DEBUG模式会跑不起来 查看所有的断点标记: 在这里直接找到所有标记位置,弄掉就会跑起来了