手动实现线性回归

点击查看代码
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils import data

构造一个人造数据集

点击查看代码
	def synthetic_data(w, b, num_examples):
"""生成 y = Xw + b +噪声"""
x = torch.normal(0, 1, (num_examples, len(w))) # 均值为0,方差为1 的随机数,行数为num,列数为len(x)
y = torch.matmul(x, w) + b
y += torch.normal(0, 0.1, y.shape) # 随机噪音
return x, y.reshape(-1, 1) # 将y转换成一列 true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

每次读取一个batch数据量

点击查看代码
	def data_iter(batch_size, features, labels):
num_examples = len(features) # 样本数量
indices = list(range(num_examples)) # 生成一个下标列表
random.shuffle(indices) # 将列表中顺序打乱,否则就会有序提取不好,我们要随机取样本
for i in range(0, num_examples, batch_size): # 从0开始到num_examples结束,每次拿batch_size个数据
batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)]) # 将拿出的下标拿出来,如果最后不够一个batchsize则拿到最后位置
yield features[batch_indices], labels[batch_indices] # 每次返回一个x,一个y直到完全返回 batch_size = 10 for x, y in data_iter(batch_size, features, labels):
print(x, '\n', y)
break w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True) # 生成一个均值为0方差为0.1 的两行一列的张量
b = torch.zeros(1, requires_grad=True) # 生成了一个0

定义模型

点击查看代码
	def linreg(x, w, b):
return torch.matmul(x, w) + b

损失函数 均方误差

点击查看代码
	def squared_loss(y_hat, y):
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

优化算法 小批量下降

点击查看代码
	def sgd(params, lr, batch_size):
"""小批量下降"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()

实现

点击查看代码
	lr = 0.01
num_epochs = 5
net = linreg
loss = squared_loss for epoch in range(num_epochs):
for x, y in data_iter(batch_size, features, labels):
l = loss(net(x, w, b), y) # x, y的小批量损失
l.sum().backward()
sgd([w, b], lr, batch_size)
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}')

给笔者点个赞呀!

随机推荐

  1. manim边学边做--Table

    表格是一种常见的数据展示形式,manim提供了Table模块专门用于显示表格形式的数据.表格Table和上一节介绍的矩阵Matrix都是用来显示二维数据的,不过,Table的表现力更强,比如,它可以显 ...

  2. java面试一日一题:1.6/7/8Java内存区域有什么不同吗

    问题:请讲下在JDK6 JDK7 JDK8中java内存区域有什么不同吗 分析:该问题主要考察对JVM运行时区域的了解,首先要了解最基本的内存区域划分,然后再去掌握其中的变化,再延申一点,为什么要这样 ...

  3. 解决SpringMVC/SpringBoot @RequestBody无法注入基本数据类型

    我们都知道SpringMVC使用 @RequestBody 注解可以接收请求content-type 为 application/json 格式的消息体.但是我们必须使用实体对象,Map或者直接用St ...

  4. jmeter测试udp广播(jmeter发送udp)

    jmeter测试udp广播(jmeter发送udp) jmeter测试udp广播(jmeter接收udp) 先下载安装第三方插件 下载链接:https://jmeter-plugins.org/ins ...

  5. Linux 备份命令 fsarchiver 基础使用教程

    1 安装配置 fsarchiver 使用yum安装[二者选一个即可,我使用的是下面那个]: yum install https://dl.fedoraproject.org/pub/epel/epel ...

  6. 【Git】Gitee 码云的使用

    1.注册.登陆.设置配置 以上步骤省略,不需要太多指示操作 2.配置SSH公钥: 先进入自己的用户目录下面 C:\Users\Administrator\ 然后右键空白位置[Git Bash Here ...

  7. 如何在X86_64系统上运行arm架构的docker容器——(异构/不同架构)CPU下的容器启动

    近期使用华为的人工智能集群,其中不仅要求异构加速端需要使用昇腾的硬件,更是要求CPU是arm架构的,因此就导致在本地x86电脑上难以对云端的arm版本的镜像进行软件安装和打包操作,为此我们需要在x86 ...

  8. 如何查看mongodb的索引命中率

    如何查看mongodb的索引命中率 一.背景 现在mongodb使用率很高,经常会遇到查询慢时,就会创建索引,而有时候索引命中率又不高,下面来介绍下测试环境下如何查看索引命中率 二.方案 1.首先查看 ...

  9. Linux/Go环境搭建, HelloWorld运行

    package main import "fmt" func main() { fmt.Printf("Hello,World!!!\n") } 以上是Go语言 ...

  10. 零基础学习人工智能—Python—Pytorch学习(七)

    前言 本文主要讲神经网络的下半部分. 其实就是结合之前学习的全部内容,进行一次神经网络的训练. 神经网络 下面是使用MNIST数据集进行的手写数字识别的神经网络训练和使用. MNIST 数据集,是一个 ...