Pytorch 实现线性回归
Pytorch 实现线性回归
import torch
from torch.utils import data
from torch import nn
# 合成数据
def synthetic_data(w, b, num_examples):
"""y = Xw + b + zs"""
X = torch.normal(0, 1, (num_examples, len(w)))
y = torch.matmul(X, w) + b
y += torch.normal(0, 0.01, y.shape)
return X, y.reshape((-1, 1))
# 用于合成数据的模板
true_w = torch.tensor([2, -3.4, 2])
true_b = 4.2
# 合成1000个数据
features, labels = synthetic_data(true_w, true_b, 1000)
# 随机批量加载数据
def load_array(data_arrays, batch_size, is_train=True):
dataset = data.TensorDataset(*data_arrays)
return data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 10
data_iter = load_array((features, labels), batch_size)
# 初始化线性网络,3输入1输出
net = nn.Sequential(nn.Linear(3, 1))
# 均方误差损失函数
loss = nn.MSELoss()
# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)
# 开始迭代
num_epochs = 3
for epoch in range(num_epochs):
for X, y in data_iter:
l = loss(net(X), y)
trainer.zero_grad()
l.backward()
trainer.step()
l = loss(net(features), labels)
print(f'epoch {epoch + 1}, loss {l:f}')
Pytorch 实现线性回归的更多相关文章
- 03_利用pytorch解决线性回归问题
03_利用pytorch解决线性回归问题 目录 一.引言 二.利用torch解决线性回归问题 2.1 定义x和y 2.2 自定制线性回归模型类 2.3 指定gpu或者cpu 2.4 设置参数 2.5 ...
- 从头学pytorch(三) 线性回归
关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...
- 【项目实战】用Pytorch实现线性回归
视频教程:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5 准备数据 首先配置了环境变量,这里使用python3.9.7版本,在Anaconda下构建环境 ...
- 用Pytorch训练线性回归模型
假定我们要拟合的线性方程是:\(y=2x+1\) \(x\):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] \(y\):[1, 3, 5, 7, ...
- 【动手学pytorch】线性回归
代码及解释 错题整理
- 神经网络架构PYTORCH-宏观分析
基本概念和功能: PyTorch是一个能够提供两种高级功能的python开发包,这两种高级功能分别是: 使用GPU做加速的矢量计算 具有自动重放功能的深度神经网络从细的粒度来分,PyTorch是一个包 ...
- pytorch从入门到放弃(目录)
目录 前置基础 Pytorch从入门到放弃 推荐阅读 前置基础 Python从入门到放弃(目录) 人工智能(目录) Pytorch从入门到放弃 01_pytorch和tensorflow的区别 02_ ...
- Task2.设立计算图并自动计算
1.numpy和pytorch实现梯度下降法 import numpy as np # N is batch size; N, D_in, H, D_out = 64, 1000, 100, 10 # ...
- 动手学习pytorch——(1)线性回归
最近参加了伯禹教育的动手学习深度学习项目,现在对第一章(线性回归)部分进行一个总结. 这里从线性回归模型之从零开始的实现和使用pytorch的简洁两个部分进行总结. 损失函数,选取平方函数来评估误差, ...
随机推荐
- hutool包里的ObjectUtil.isNull和ObjectUtil.isEmpty的区别
大家都用过hutool包把,包路径为:cn.hutool.core.util,最近再使用的过程中一直没高明白ObjectUtil.isEmpty和ObjectUtil.isNull两者到底有那些区别, ...
- 罗杨老师带你了解谷歌编程之夏(GSoC)活动全流程
罗杨老师带你了解谷歌编程之夏(GSoC)活动全流程 为了帮助同学们更好地参与开源,Casbin 决定做一期访谈节目,由小编作为一名开源初学者,用访谈的形式与北京大学工学博士.Casbin作者.Npca ...
- 什么是 Git
概述 Git 是一个开源的分布式版本控制系统,用于敏捷高效地处理任何或小或大的项目. Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件. ...
- uoj450 【集训队作业2018】复读机(生成函数,单位根反演)
uoj450 [集训队作业2018]复读机(生成函数,单位根反演) uoj 题解时间 首先直接搞出单个复读机的生成函数 $ \sum\limits_{ i = 0 }^{ k } [ d | i ] ...
- JWT jti和kid属性的说明
jti chaim=== JWT ID " jti"(JWT ID)声明为JWT提供了唯一的标识符. 标识符值的分配方式必须确保将相同值偶然分配给不同数据对象的可能性可以忽略不计: ...
- kafka producer 源码总结
kafka producer可以总体上分为两个部分: producer调用send方法,将消息存放到内存中 sender线程轮询的从内存中将消息通过NIO发送到网络中 1 调用send方法 其实在调用 ...
- 学习GlusterFS(六)
一.GlusterFS概述 分布式文件系统由来 在介绍之前我们先来看下文件系统及典型的NFS文件系统. 计算机通过文件系统管理,存储数据的.而现在数据信息时代中人们可获取数据成指数倍的增长,单纯通过增 ...
- 滑动窗口法——Leetcode例题
滑动窗口法--Leetcode例题(连更未完结) 1. 方法简介 滑动窗口法可以理解为一种特殊的双指针法,通常用来解决数组和字符串连续几个元素满足特殊性质问题(对于字符串来说就是子串).滑动窗口法的显 ...
- 树莓派系统安装(ubuntu版本)无需屏幕
0.前提 所需物品:一个手机.一台电脑.一个树莓派.一张tf卡和一个读卡器.所需软件:Win32DiskImager.putty还需要ubuntu系统镜像源.这些我都放在百度网盘上了链接:https: ...
- VC 下如何正确的创建及管理项目
讲解 VC 下如何正确的创建及管理项目 本文讲解 Visual C++ 的项目文件组成,以及如何正确的创建及管理项目. 本文所设计的内容是初学者必须要掌握的.不能正确的管理项目,就不能进一步写有规模的 ...