Pytorch 实现线性回归

import torch
from torch.utils import data
from torch import nn # 合成数据
def synthetic_data(w, b, num_examples):
"""y = Xw + b + zs"""
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1)) # 用于合成数据的模板
true_w = torch.tensor([2, -3.4, 2])
true_b = 4.2 # 合成1000个数据
features, labels = synthetic_data(true_w, true_b, 1000) # 随机批量加载数据
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train) batch_size = 10
data_iter = load_array((features, labels), batch_size) # 初始化线性网络,3输入1输出
net = nn.Sequential(nn.Linear(3, 1))
# 均方误差损失函数
loss = nn.MSELoss()
# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03) # 开始迭代
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X), y)
trainer.zero_grad()
l.backward()
trainer.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')

Pytorch 实现线性回归的更多相关文章

  1. 03_利用pytorch解决线性回归问题

    03_利用pytorch解决线性回归问题 目录 一.引言 二.利用torch解决线性回归问题 2.1 定义x和y 2.2 自定制线性回归模型类 2.3 指定gpu或者cpu 2.4 设置参数 2.5 ...

  2. 从头学pytorch(三) 线性回归

    关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...

  3. 【项目实战】用Pytorch实现线性回归

    视频教程:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5 准备数据 首先配置了环境变量,这里使用python3.9.7版本,在Anaconda下构建环境 ...

  4. 用Pytorch训练线性回归模型

    假定我们要拟合的线性方程是:\(y=2x+1\) \(x\):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] \(y\):[1, 3, 5, 7, ...

  5. 【动手学pytorch】线性回归

    代码及解释 错题整理

  6. 神经网络架构PYTORCH-宏观分析

    基本概念和功能: PyTorch是一个能够提供两种高级功能的python开发包,这两种高级功能分别是: 使用GPU做加速的矢量计算 具有自动重放功能的深度神经网络从细的粒度来分,PyTorch是一个包 ...

  7. pytorch从入门到放弃(目录)

    目录 前置基础 Pytorch从入门到放弃 推荐阅读 前置基础 Python从入门到放弃(目录) 人工智能(目录) Pytorch从入门到放弃 01_pytorch和tensorflow的区别 02_ ...

  8. Task2.设立计算图并自动计算

    1.numpy和pytorch实现梯度下降法 import numpy as np # N is batch size; N, D_in, H, D_out = 64, 1000, 100, 10 # ...

  9. 动手学习pytorch——(1)线性回归

    最近参加了伯禹教育的动手学习深度学习项目,现在对第一章(线性回归)部分进行一个总结. 这里从线性回归模型之从零开始的实现和使用pytorch的简洁两个部分进行总结. 损失函数,选取平方函数来评估误差, ...

随机推荐

  1. HTML-置换元素

    我们都知道,行内元素不能够定义宽度和高度,但 img,input,button等标签作为行内元素却可以定义宽高,为什么呢?这就牵扯到了置换元素和非置换元素. 置换元素: w3c官方解释:"A ...

  2. 去掉一个Vector集合中重复的元素 ?

    Vector newVector = new Vector(); For (int i=0;i<vector.size();i++) { Object obj = vector.get(i); ...

  3. Kafka 分区数可以增加或减少吗?为什么?

    我们可以使用 bin/kafka-topics.sh 命令对 Kafka 增加 Kafka 的分区数据,但是 Kafka 不支持减少分区数. Kafka 分区数据不支持减少是由很多原因的,比如减少的分 ...

  4. Oracle入门基础(十三)一一java调用oracle存储过程

    package demo; import java.sql.CallableStatement; import java.sql.Connection; import java.sql.ResultS ...

  5. Error 和 Exception 有什么区别?

    Error 表示系统级的错误和程序不必处理的异常,是恢复不是不可能但很困难的情 况下的一种严重问题:比如内存溢出,不可能指望程序能处理这样的情况: Exception 表示需要捕捉或者需要程序进行处理 ...

  6. 列举 Spring Framework 的优点?

    由于 Spring Frameworks 的分层架构,用户可以自由选择自己需要的组件. Spring Framework 支持 POJO(Plain Old Java Object) 编程,从而具备持 ...

  7. 你能保证 GC 执行吗?

    不能,虽然你可以调用 System.gc() 或者 Runtime.gc(),但是没有办法保证 GC 的执行.

  8. spring 中有多少种 IOC 容器?

    BeanFactory - BeanFactory 就像一个包含 bean 集合的工厂类.它会在客户端 要求时实例化 bean.ApplicationContext - ApplicationCont ...

  9. 学习Apache(四)

    介绍 Apache HTTP 服务器被设计为一个功能强大,并且灵活的 web 服务器, 可以在很多平台与环境中工作.不同平台和不同的环境往往需要不同 的特性,或可能以不同的方式实现相同的特性最有效率. ...

  10. 基于CrawlSpider全栈数据爬取

    CrawlSpider就是爬虫类Spider的一个子类 使用流程 创建一个基于CrawlSpider的一个爬虫文件 :scrapy genspider -t crawl spider_name www ...