手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. 七天.NET 8操作SQLite入门到实战 - (3)第七天Blazor学生管理页面编写和接口对接

    前言 本章节的主要内容是完善Blazor学生管理页面的编写和接口对接. 七天.NET 8 操作 SQLite 入门到实战详细教程 第一天 SQLite 简介 第二天 在 Windows 上配置 SQL ...

  2. .NET 权限工作流框架 TOP 榜

    前言 .NET权限管理及快速开发框架.最好用的权限工作流系统. 基于经典领域驱动设计的权限管理及快速开发框架,源于Martin Fowler企业级应用开发思想及最新技术组合(SqlSugar.EF.Q ...

  3. BI 工具助力企业解锁数字化工厂,开启工业智能新视界

    背景 在 2022 年公布的<"十四五"数字经济发展规划>中,政府不断增加对制造业数字化转型的政策支持力度,积极倡导制造企业采用最新技术,提升自动化.数字化和智能化水平 ...

  4. 对比python学julia(第二章)--(第二节)勾股树—分形之美

    2.1.问题描述 二话不说,先上图: 图一.勾股定理图形                                                          图二.勾股树         ...

  5. 【REGX】正则表达式 选中空白行

    参考地址: https://www.cnblogs.com/peijyStudy/p/13201576.html VScode并列替换不够智能,我需要等行粘贴,结果SHIFT+ALT复制内容粘贴上去就 ...

  6. 一直让 PHP 程序员懵逼的同步阻塞异步非阻塞,终于搞明白了

    大家好,我是码农先森. 经常听到身边写 Java.Go 的朋友提到程序异步.非阻塞.线程.协程,让系统性能提高到百万.千万并发,使我甚是惊讶属实羡慕.对于常年写 PHP 的我来说,最初听到这几个词时, ...

  7. ApacheCon Asia 2022 精彩回顾 | 如何让更多人从大数据中获益?

    点亮 ️ Star · 照亮开源之路 GitHub:https://github.com/apache/dolphinscheduler 在 ApacheCon Asia 2022 Meetup上,有 ...

  8. 【Playwright+Python】系列教程(八)鉴权Authentication的使用

    写在前面 还是有些絮叨的感觉,官方翻译和某些博主写那个玩楞,基本都是软件直接翻译后的产物. 读起来生硬不说,甚至有的时候不到是什么意思,真的是实在不敢恭维. 到底是什么意思? 就是你已经登陆过一次,在 ...

  9. Floyd判联通(传递闭包) & poj1049 sorting it all out

    Floyd判联通(传递闭包) Floyd传递闭包顾名思义就是把判最短路的代码替换成了判是否连通的代码,它可以用来判断图中两点是否连通.板子大概是这个样的: for(int k=1; k<=n; ...

  10. C#数据结构与算法实战入门指南

    前言 在编程领域,数据结构与算法是构建高效.可靠和可扩展软件系统的基石.它们对于提升程序性能.优化资源利用以及解决复杂问题具有至关重要的作用.今天大姚分享一些非常不错的C#数据结构与算法实战教程,希望 ...