LinerProgression
手动实现线性回归
点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data
构造一个人造数据集
点击查看代码
def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
每次读取一个batch数据量
点击查看代码
def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回
batch_size = 10
for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0
定义模型
点击查看代码
def linreg(x, w, b):
return torch.matmul(x, w) + b
损失函数 均方误差
点击查看代码
def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
优化算法 小批量下降
点击查看代码
def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
实现
点击查看代码
lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')
给笔者点个赞呀!
随机推荐
- oeasy教您玩转vim - 41 - # 各寄存器
各寄存器 回忆上节课内容 上次是复制粘贴 y就是把东西yank到寄存器里,就是复制 d就是把东西delete到寄存器里,就是剪切 yank也可以配合motion 不管是yank.delete都是把 ...
- TIER 1: Responder
TIER 1: Responder Active Directory Active Directory(AD)是由微软开发的目录服务,用于在网络环境中管理和组织用户.计算机.应用程序和其他资源.它是 ...
- Oracle 存储过程学习总结
创建/更新存储过程 基础基础用法 创建/修改无参存储过程 CREATE OR REPLACE PROCEDURE procedure_name [IS|AS] --声明全局变量(可选) BEGIN - ...
- 用.Net实现GraphRag:从零开始构建智能知识图谱
近来,大模型技术日新月异,使得与其相关的研发项目也层出不穷.其中一个备受关注的技术便是RAG(Retrieval Augmented Generation).今天,我要跟大家分享一个出色的项目:Gra ...
- 如何对jar包修改并重新发布在本机
本人苦于jieba不能如何识别伊利丹·怒风,召唤者坎西恩这种名字,对jieba-analysis进行了解包和打包 步骤1:找到对应jar 步骤2:在cmd中输入jar -xvf xxx.jar解压包, ...
- RHCA rh442 009 磁盘算法 RAID相关 磁盘压力测试
磁盘 一个数据在磁盘A位置,一个数据在磁盘B位置,他们如果隔着很远.这对磁盘来说性能很差 (机械盘,磁头来回移动) 一个数据写进来,他会把数据放到缓存中,经过磁盘调度算法来调度,最后写到硬盘 io读写 ...
- 万字干货:从消息流平台Serverless之路,看Serverless标准演进
摘要:如今,Serverless化已经成为消息流平台发展的新趋势,而如何更好地基于Serverless化的消息流平台进行应用设计和开发,则成为了一个值得思考的问题. 本文分享自华为云社区<900 ...
- Windows安装虚拟机软件-VirtualBox
1.VirtualBox简介 VirtualBox号称是最强的开源免费虚拟机软件,它不仅具有丰富的特色,而且性能也很优异. 它简单易用,可虚拟的系统包括Windows.Mac OS X.Linux.O ...
- 【Uni-App】组件笔记
官网文档地址: https://uniapp.dcloud.io/component/README 组件是视图层的基本组成单元. 组件是一个单独且可复用的功能模块的封装. 每个组件,包括如下几个部分: ...
- 【SpringMVC】03 使用注解
第一步还是配置web.xml,使用分发器统一处理请求和加载容器文件 <?xml version="1.0" encoding="UTF-8"?> & ...