手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. 【vue3】详解单向数据流,大家千万不用为了某某而某某了。

    总览 Vue3 的单向数据流 尽信官网,不如那啥. vue的版本一直在不断更新,内部实现方式也是不断的优化,官网也在不断更新. 既然一切皆在不停地发展,那么我们呢?等着官网更新还是有自己的思考? 我觉 ...

  2. 【Vue】el-table 简易表格可筛选列

    需求实现: 代码逻辑: 按钮控件: <el-popover placement="top-start"> <el-checkbox-group v-model=& ...

  3. 【Mybatis-Plus】制作逆向工具

    官方文档可参考: https://baomidou.com/pages/779a6e/#快速入门 工具需要的依赖 <?xml version="1.0"?> <p ...

  4. 【Layui】15 日期时间选择器 Laydate

    文档地址: https://www.layui.com/demo/laydate.html [基本案例] 基本日期与国际日期 <fieldset class="layui-elem-f ...

  5. 【Docker】10 容器存储

    将容器保存为一个镜像: docker commit 容器的名称 创建的镜像的名称 将镜像保存为一个tar包文件: docker save -o tar包文件名称.tar 镜像名称 可以看到Docker ...

  6. 【转载】 docker挂载volume的用户权限问题,理解docker容器的uid

    =================================================================== 在刚开始使用docker volume挂载数据卷的时候,经常出现 ...

  7. [BJOI2016] IP地址 题解

    前言 来个不一样的做法,用到了 Trie 树和主席树,并且是可爱的在线算法. 题目链接:洛谷. 题目分析 对于一个查询 \(\texttt{ip}\),只考虑所有前缀字符串规则.以时间建里横轴,匹配长 ...

  8. Elsa V3学习之脚本

    在前面的文章中,可以看到我们经常使用JS脚本来获取变量的值.在Elsa中是支持多种脚本的,最常用的基本是JS脚本和C#脚本. 本文来介绍以下这两个脚本使用. Javascript 在ELSA中的jav ...

  9. Odoo13开发环境搭建

    准备:windows10 64位系统.Python3.6.8.Pycharm2019.2.Postgresql-12.0-1.Odoo13 其它:nodejs.rtlcss.wkhtmltopdf 安 ...

  10. DDD是软件工程的第一性原理?

    本文书接上回<DDD建模后写代码的正确姿势>,关注公众号(老肖想当外语大佬)获取信息: 最新文章更新: DDD框架源码(.NET.Java双平台): 加群畅聊,建模分析.技术实现交流: 视 ...