LinerProgression
手动实现线性回归
点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data
构造一个人造数据集
点击查看代码
def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
每次读取一个batch数据量
点击查看代码
def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回
batch_size = 10
for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0
定义模型
点击查看代码
def linreg(x, w, b):
return torch.matmul(x, w) + b
损失函数 均方误差
点击查看代码
def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
优化算法 小批量下降
点击查看代码
def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
实现
点击查看代码
lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss
for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')
给笔者点个赞呀!
随机推荐
- oeasy教您玩转vim - 28 - 水平移动
水平移动 回忆上节课内容 根据扩展名我们可以设置某些特定类型文件的配置 相关文件类型的设置放在相应的文件夹里 文件类型缩进文件夹 /usr/share/vim/vim81/indent/ 文件类型 ...
- pytorch问题记录
1.找不到fused 2.找不到torch_extensions 网上的教程一般都是linux系统的,Windows这个是在C盘目录下 3.c++开发环境中找不到vcvars64.bat 解决方法:重 ...
- 移植自淘宝店家的,硬件SPI通讯3.5寸TFT,LCD屏幕。MSPM0G3507
适用MSPM0G3507 LP开发板 3.5寸TFTLCD屏,SPI通讯 项目是CCStheia的 特点:硬件SPI,速度更快,可以在syscfg中自行修改引脚 蓝奏云: https://wwo.la ...
- ComfyUI插件:ComfyUI Impact 节点(二)
前言: 学习ComfyUI是一场持久战,而 ComfyUI Impact 是一个庞大的模块节点库,内置许多非常实用且强大的功能节点 ,例如检测器.细节强化器.预览桥.通配符.Hook.图片发送器.图片 ...
- vue3 + ts 中出现 类型“typeof import(".........../node_modules/vue/dist/vue")”的参数不能赋给类型“Component<any, any, any, ComputedOptions, MethodOptions>”的参数。
错误示例截图 解决方法 修改shims-vue.d.ts中的内容 declare module "*.vue" { import { defineComponent } from ...
- 自然语言处理:通过API调用各大公司的机器翻译开放平台
国内大公司做机器翻译做的比较好的有讯飞和百度,这里给出这两个公司机器翻译的开放平台API的介绍: 讯飞开放平台: 链接:https://www.xfyun.cn/doc/nlp/xftrans_new ...
- bazel编译报错:absl/base/policy_checks.h:79:2: error: #error "C++ versions less than C++14 are not supported."
使用bazel编译一个软件时报错,报错的信息为: absl/base/policy_checks.h:79:2: error: #error "C++ versions less than ...
- 宝塔环境安装redis
参考: http://www.bt.cn/Help/Find?id=92 步骤: 1. 在安装宝塔时 PHP 版本选 7.0: 2. 安装 redis:wget http://125.88.182.1 ...
- 记一次 .NET某智慧出行系统 CPU爆高分析
一:背景 1. 讲故事 前些天有位朋友找到我,说他们的系统出现了CPU 100%的情况,让你帮忙看一下怎么回事?dump也拿到了,本想着这种情况让他多抓几个,既然有了就拿现有的分析吧. 二:WinDb ...
- Leetcode: 1484. Groups Sold Products By The Date
题目要求如下: 输入的数据为 要求按照日期查询出每日销售数量及相应产品的名称,并按照字符顺序进行排序. 下面是实现的代码: import pandas as pd def categorize_pro ...