手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. 面试题-python 什么是装饰器(decorator )?

    前言 python装饰器本质上就是一个函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外的功能,装饰器的返回值也是一个函数对象.很多python初学者学到面向对象类和方法是一道大坎,那么py ...

  2. Python 在PDF中添加、替换、或删除图片

    PDF文件中的图片可以丰富文档内容,提升用户的阅读体验.除了在PDF中添加图片外,有时也需要替换或删除其中的图片,以改进视觉效果或更新信息.本文将提供以下三个示例,介绍如何使用Python 操作PDF ...

  3. perf 性能分析工具

    perf 性能分析工具 perf topperf recordperf reportperf listperf stat perf top -p <pid> 例如查看redis进程的内核调 ...

  4. 【DataBase】MySQL 06 条件查询 & 排序查询

    视频参考自:P28 - P42 https://www.bilibili.com/video/BV1xW411u7ax 条件查询概述 # 进阶2 条件查询 -- 语法:SELECT 查询列表 FROM ...

  5. 【IDEA】转大小写快速操作

    需求场景: 快速修改一些字符全部变成大写,或者小写 例如修改SQL语句,部分字段大写,部分字段小写,需要统一 快捷键: [Ctrl + Shift + U] 演示案例: SELECT ( (SELEC ...

  6. 使用Transformer模型实现四足机器人控制—— 跨模态Transformer: LocoTransformer —— LEARNING VISION-GUIDED QUADRUPEDAL LOCOMOTION END-TO-END WITH CROSS-MODAL TRANSFORMERS

    论文: LEARNING VISION-GUIDED QUADRUPEDAL LOCOMOTION END-TO-END WITH CROSS-MODAL TRANSFORMERS 发表于ICLR20 ...

  7. 中国2023年GDP增速5.2%

    在中美贸易战和三年全球疫情的大背景下,我国的经济依旧保持强有力的增速,这表明了经济发展不断转好,一切恢复到疫情和贸易战之前也是有待期望的.

  8. 绑定国内主机IP的域名网站必须要备案

    买了个域名: http://devilmaycry812839668.top/ 然后绑定了国内的一个云主机,刚搭了个web server,一个网页都没有(短期内页没考虑做网页): 今天看了下web s ...

  9. mindspore.ops.Pow()等算子不能处理float64类型的数据

    原文地址: https://gitee.com/mindspore/mindspore/issues/I3ZG99 Software Environment: -- MindSpore r1.2 GP ...

  10. Singleton bean creation not allowed while singletons of this factory are in destruction

    1.背景 一直都是正常运行的程序,检查日志发现有一条报错如下: org.springframework.beans.factory.BeanCreationNotAllowedException: E ...