视频教程:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5

准备数据

首先配置了环境变量,这里使用python3.9.7版本,在Anaconda下构建环境运行,并且安装pytorch

决定使用模型y=wx+b

然后根据视频在pycharm中输入如下代码

import torch
x_data=torch.Tensor([[1.0],[2.0],[3.0]])
y_data=torch.Tensor([[2.0],[4.0],[6.0]])

首先这里预期使用小批量梯度下降

而在使用numpy的时候会有广播机制存在,会自动把不能相加的矩阵直接扩充广播成相同大小,而我们的Tensor构造的时候就需要注意了,我们必须要保证我们构造的数据一开始就是矩阵,所以必须用[1.0],[2.0],[3.0]的方法来进行构造

设计模型

在pytorch中目标不再是人工求出导数,重点变成了构造计算图

构造计算图首先要知道X的维度,还需要知道输出的y是几维的,这样就可以确定w权重以及b偏值的形状

这样就可以输入x,经过w和b计算出y,再由y算出算是loss,由loss调用反向传播,拿到所有损失还需要求和求均值,否则无法反向传播

首先要把我们的模型定义成一个类,而我们在构造模型的时候使用的都是这样一个模板,必须掌握这样的一个编写方式

class LinearModel(torch.nn.Module):         //所有编写的模型都要记成Module,这里面有很多方法,要从这个模块里面把继承下来
//类里面至少要有两个函数,一个是init,属于构造函数,还有一个forward函数(必须叫这个),而module里面会根据计算图自动帮你实现反向传播过程,如果你觉得他的计算效率不高,你也可以用torch里面的一个function类来构造自己的计算方法
def __init__(self):
super(LinearModel,self).__init__() //调用父类的构造
self.Linear = torch.nn.Linear(1, 1) //torch.nn.Linear是torch里面的一个类,这里是构造一个对象,这里包含了权重和偏置 def forward(self,x):
y_pred = self.Linear(x) //这里表示实现一个可调用的对象
return y_pred
model1 = LinearModel()

这里给出class torch.nn.Linear的具体文档



这里的size表示维度,这里的bias是一个布尔类型,来决定你是否需要偏置量(默认是ture)

构造损失函数和优化器

criterion = torch.nn.MSELoss(size_average=False)               //size.average表示最后是否需要求均值
optimizer = torch.optim.SGD(model1.parameters(),lr=0.01) //parameters可以把模型中所有的参数全部找出来,lr就是学习率,而这个optimizer,他就知道需要对哪些东西做优化

训练过程

for epoch in range(100):
y_pred = model1(x_data) //先算出y
loss = criterion(y_pred,y_data) //利用criterion算出损失函数
print(epoch, loss.item()) //打印出来 optimizer.zero_grad() //先梯度归零
loss.backward() //反向传播
optimizer.step() //更新参数

打印结果

print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_text = model(x_test)
print('y_pred=', y_text.data)

一开始给我报错



原来最上面X,Y赋值少了一对中括号,导致识别不出来

然后成功跑出!



输入4,1000次迭代输出7.9994,很不错了

然后我们修改一下,加一个数据4和8,然后输入10,训练2000次



帅的嘛不谈了!

完整代码如下:

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0], [8.0]]) class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) def forward(self, x):
y_pred = self.linear(x)
return y_pred model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item()) optimizer.zero_grad()
loss.backward()
optimizer.step() print('w=', model.linear.weight.item())
print('b=', model.linear.bias.item())
x_test = torch.Tensor([[10.0]])
y_text = model(x_test)
print('y_pred=', y_text.data)

打算修改步长再试试

后记:果然学习率设置过程就是炼丹

【项目实战】用Pytorch实现线性回归的更多相关文章

  1. 【项目实战】pytorch实现逻辑斯蒂回归

    视频指导:https://www.bilibili.com/video/BV1Y7411d7Ys?p=6 一些数据集 在pytorch框架下,里面面有配套的数据集,pytorch里面有一个torchv ...

  2. 机器学习_线性回归和逻辑回归_案例实战:Python实现逻辑回归与梯度下降策略_项目实战:使用逻辑回归判断信用卡欺诈检测

    线性回归: 注:为偏置项,这一项的x的值假设为[1,1,1,1,1....] 注:为使似然函数越大,则需要最小二乘法函数越小越好 线性回归中为什么选用平方和作为误差函数?假设模型结果与测量值 误差满足 ...

  3. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...

  4. 第24月第30天 scrapy《TensorFlow机器学习项目实战》项目记录

    1.Scrapy https://www.imooc.com/learn/1017 https://github.com/pythonsite/spider/tree/master/jobboleSp ...

  5. Asp.Net Core 项目实战之权限管理系统(4) 依赖注入、仓储、服务的多项目分层实现

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

  6. 给缺少Python项目实战经验的人

    我们在学习过程中最容易犯的一个错误就是:看的多动手的少,特别是对于一些项目的开发学习就更少了! 没有一个完整的项目开发过程,是不会对整个开发流程以及理论知识有牢固的认知的,对于怎样将所学的理论知识应用 ...

  7. 【腾讯Bugly干货分享】React Native项目实战总结

    本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/577e16a7640ad7b4682c64a7 “8小时内拼工作,8小时外拼成长 ...

  8. Asp.Net Core 项目实战之权限管理系统(0) 无中生有

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

  9. Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

  10. Asp.Net Core 项目实战之权限管理系统(2) 功能及实体设计

    0 Asp.Net Core 项目实战之权限管理系统(0) 无中生有 1 Asp.Net Core 项目实战之权限管理系统(1) 使用AdminLTE搭建前端 2 Asp.Net Core 项目实战之 ...

随机推荐

  1. 干货 |《2022B2B新增长系列之企服行业橙皮书》重磅发布

    企服行业面临的宏观环境和微观环境已然发生了明显的变化.一方面,消费级互联网成为过去式,爆发式增长的时代结束.资本.媒体的目光已经悄然聚焦到以企服行业所代表的产品互联网身上,B2B企业正稳步走向C位. ...

  2. Java开发学习(十)----基于注解开发定义bean 已完成

    一.环境准备 先来准备下环境: 创建一个Maven项目 pom.xml添加Spring的依赖 <dependencies>    <dependency>        < ...

  3. Blazor快速实现扫雷(MineSweeper)

    如何快速的实现一个扫雷呢,最好的办法不是从头写,而是移植一个已经写好的! Blazor出来时间也不短了,作为一个.net开发者就用它来作吧.Blazor给我的感觉像是Angular和React的结合体 ...

  4. python获取线程返回值

    python获取线程返回值 前言 工作中的需求 将前端传过来的字符串信息通过算法转换成语音,并将语音文件返回回去 由于算法不是我写的,只需要调用即可,但是算法执行速度相当缓慢 我的优化思路是,将前端的 ...

  5. 最近公共祖先(LCA)学习笔记 | P3379 【模板】最近公共祖先(LCA)题解

    研究了LCA,写篇笔记记录一下. 讲解使用例题 P3379 [模板]最近公共祖先(LCA). 什么是LCA 最近公共祖先简称 LCA(Lowest Common Ancestor).两个节点的最近公共 ...

  6. MYSQL(进阶篇)——一篇文章带你深入掌握MYSQL

    MYSQL(进阶篇)--一篇文章带你深入掌握MYSQL 我们在上篇文章中已经学习了MYSQL的基本语法和概念 在这篇文章中我们将讲解底层结构和一些新的语法帮助你更好的运用MYSQL 温馨提醒:该文章大 ...

  7. 抖音web端 s_v_web_id 参数生成分析与实现

    本文所有教程及源码.软件仅为技术研究.不涉及计算机信息系统功能的删除.修改.增加.干扰,更不会影响计算机信息系统的正常运行.不得将代码用于非法用途,如侵立删! 抖音web端 s_v_web_id 参数 ...

  8. C#/VB.NET 将PDF转为PDF/X-1a:2001

    PDF/X-1a是一种PDF文件规范标准,在制作.使用PDF以及印刷时所需要遵循的技术条件,属于PDF/X-1标准下的一个子标准. PDF/X-1标准有由CGATS于1999年制定的PDF/X-1:1 ...

  9. Python 实现列表与二叉树相互转换并打印二叉树封装类-详细注释+完美对齐

    # Python 实现列表与二叉树相互转换并打印二叉树封装类-详细注释+完美对齐 from binarytree import build import random # https://www.cn ...

  10. 手把手教你定位线上MySQL锁超时问题,包教包会

    昨晚我正在床上睡得着着的,突然来了一条短信. 什么?线上的订单无法取消! 我赶紧登录线上系统,查看业务日志. 发现有MySQL锁超时的错误日志. 不用想,肯定有另一个事务正在修改这条订单,持有这条订单 ...