03_利用pytorch解决线性回归问题
03_利用pytorch解决线性回归问题
一、引言
上一篇文章我们利用numpy解决了线性回归问题,我们能感觉到他的麻烦之处,很多数学性的方法都需要我们自己亲手去实现,这对于数学不好的同学来说,简直就是灾难,让你数学又好并且码代码能力又强,臣妾做不到呀!因此我们说到,可以利用torch这个框架简化其中的很多操作,接下来就让我们提前体验下torch的强大,由于直接上代码,有些代码可能你可能看不懂,但没关系,能看懂torch的强大就行。
由于上一篇已经详细的介绍了线性回归模型的流程,我们这里就不再次罗嗦了,直接把线性回归的5个步骤贴过来:
- 初始化未知变量\(w=b=0\)
- 得到损失函数\(loss = \frac{1}{N}\sum_{i=1}^N{(\hat{y_i}-y_i)}^2\)
- 利用梯度下降算法更新得到\(w', b'\)
- 重复步骤3,利用\(w', b'\)得到新的更优的\(w', b'\),直至\(w', b'\)收敛
- 最后得到函数模型\(f=w'*x+b'\)
但是这里为了体现torch的强大,我们使用了神经网络全连接层的概念,但是用的是一层网络,相信不会难倒你。有人很好奇,为什么我这样做,我想说:如果不使用神经网络,我为什么不用sklearn框架做个线性回归给你体验下框架的强大呢?三行代码一套搞定。而且一层神经网络可以看做是感知机模型,而感知机模型无非就是在线性回归模型的基础上加了一个sgn函数。
二、利用torch解决线性回归问题
2.1 定义x和y
在下面的代码中,我们定义的x和y都是ndarray数据的格式,所以在之后的处理之中需要通过torch的from_numpy()方法把ndarray格式的数据转成tensor类型。
其中x为一维数组,例如:[1,2,3,4,5,……]
其中y假设\(y=2*x+3\)
import numpy as np
# torch里要求数据类型必须是float
x = np.arange(1, 12, dtype=np.float32).reshape(-1, 1)
y = 2 * x + 3
2.2 自定制线性回归模型类
由于torch内部封装了线性回归算法,我们只需要继承它给我们提供的模型类即可,然后通过Python中类的继承做出一些灵活的改动。(如果你对继承不熟悉,强烈推荐回炉重造Python)下面给出代码:
import torch
import torch.nn as nn
# 继承nn.module,实现前向传播,线性回归直接可以看做是全连接层
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__() # 继承父类方法
self.linear = nn.Linear(input_dim, output_dim) # 定义全连接层,其中input_dim和output_dim是输入和输出数据的维数
# 定义前向传播算法
def forward(self, inp):
out = self.linear(inp) # 输入x后,通过全连接层得到输入出结果out
return out # 返回被全连接层处理后的结果
# 定义线性回归模型
regression_model = LinearRegressionModel(1, 1) # x和y都是一维的
2.3 指定gpu或者cpu
# 可以通过to()或者cuda()使用GPU进行模型的训练,需要将模型和数据都转换到GPU上,也可以指定具体的GPU,如.cuda(1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
regression_model.to(device)
2.4 设置参数
epochs = 1000 # 训练次数
learning_rate = 0.01 # 学习速率
optimizer = torch.optim.SGD(regression_model.parameters(), learning_rate) # 优化器(未来会详细介绍),这里使用随机梯度下降算法(SGD)
criterion = nn.MSELoss() # 使用均方误差定义损失函数
2.5 训练
for epoch in range(epochs):
# 数据类型转换
inputs = torch.from_numpy(x).to(device) # 由于x是ndarray数组,需要转换成tensor类型,如果用gpu训练,则会通过to函数把数据传入gpu
labels = torch.from_numpy(y).to(device)
# 训练
optimizer.zero_grad() # 每次求偏导都会清零,否则会进行叠加
outputs = regression_model(inputs) # 把输入传入定义的线性回归模型中,进行前向传播,得到预测结果
loss = criterion(outputs, labels) # 通过均方误差评估预测误差
loss.backward() # 反向传播
optimizer.step() # 更新权重参数
# 每50次循环打印一次结果
if epoch % 50 == 0:
print("epoch:", epoch, "loss:", loss.item())
predict = regression_model(torch.from_numpy(x).requires_grad_()).data.numpy() # 通过训练好的模型预测结果
2.6 保存模型
torch.save(regression_model.state_dict(), "model.pk1") # 保存模型
result = regression_model.load_state_dict(torch.load("model.pk1")) # 加载模型
三、代码汇总
# author : 'nickchen121';
# date: 14/4/2021 20:11
import numpy as np
# torch里要求数据类型必须是float
x = np.arange(1, 12, dtype=np.float32).reshape(-1, 1)
y = 2 * x + 3
import torch
import torch.nn as nn
# 继承nn.module,实现前向传播,线性回归直接可以看做是全连接层
class LinearRegressionModel(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__() # 继承父类方法
self.linear = nn.Linear(input_dim, output_dim) # 定义全连接层,其中input_dim和output_dim是输入和输出数据的维数
# 定义前向传播算法
def forward(self, inp):
out = self.linear(inp) # 输入x后,通过全连接层得到输入出结果out
return out # 返回被全连接层处理后的结果
# 定义线性回归模型
regression_model = LinearRegressionModel(1, 1) # x和y都是一维的
# 可以通过to()或者cuda()使用GPU进行模型的训练,需要将模型和数据都转换到GPU上,也可以指定具体的GPU,如.cuda(1)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
regression_model.to(device)
epochs = 1000 # 训练次数
learning_rate = 0.01 # 学习速率
optimizer = torch.optim.SGD(regression_model.parameters(), learning_rate) # 优化器,这里使用随机梯度下降算法(SGD)
criterion = nn.MSELoss() # 使用均方误差定义损失函数
for epoch in range(epochs):
# 数据类型转换
inputs = torch.from_numpy(x).to(device) # 由于x是ndarray数组,需要转换成tensor类型,如果用gpu训练,则会通过to函数把数据传入gpu
labels = torch.from_numpy(y).to(device)
# 训练
optimizer.zero_grad() # 每次求偏导都会清零,否则会进行叠加
outputs = regression_model(inputs) # 把输入传入定义的线性回归模型中,进行前向传播
loss = criterion(outputs, labels) # 通过均方误差评估预测误差
loss.backward() # 反向传播
optimizer.step() # 更新权重参数
# 每50次循环打印一次结果
if epoch % 50 == 0:
print("epoch:", epoch, "loss:", loss.item())
predict = regression_model(torch.from_numpy(x).requires_grad_()).data.numpy() # 通过训练好的模型预测结果
# torch.save(regression_model.state_dict(), "model.pk1") # 保存模型
# result = regression_model.load_state_dict(torch.load("model.pk1")) # 加载模型
四、总结
本篇文章从torch的角度去解决了线性回归问题,细节你可能不懂,但也可以发现它是非常简单的,全程没有让你去实现优化器、去实现全连接层、去实现反向传播,在这里你就不需要去实现一个数学公式。你需要做的仅仅是成为一个优秀的调包侠,并且努力成为一个伟大的调参师即可。
至于为什么直接上代码,而不是先讲解torch的基础,一定是有深意的。站得高看得远,先带你领略下torch的用途及其强大,然后我们再慢慢的一步一步筑基。
03_利用pytorch解决线性回归问题的更多相关文章
- 02_利用numpy解决线性回归问题
02_利用numpy解决线性回归问题 目录 一.引言 二.线性回归简单介绍 2.1 线性回归三要素 2.2 损失函数 2.3 梯度下降 三.解决线性回归问题的五个步骤 四.利用Numpy实战解决线性回 ...
- 机器学习中梯度下降法原理及用其解决线性回归问题的C语言实现
本文讲梯度下降(Gradient Descent)前先看看利用梯度下降法进行监督学习(例如分类.回归等)的一般步骤: 1, 定义损失函数(Loss Function) 2, 信息流forward pr ...
- 利用闭包解决for循环里onclick事件不能捕捉实时i值问题
问题描述 我们都知道,如果我们对于一组元素(相同的标签)同时进行onclick事件处理的时候(在需要获取到索引的时候),一般是写一个for循环,但是onclick是一个异步调用的,所以会带来一个问题, ...
- 利用Readability解决网页正文提取问题
分享: 利用Readability解决网页正文提取问题 做数据抓取和分析的各位亲们, 有没有遇到下面的难题呢? - 如何从各式各样的网页中提取正文!? 虽然可以用SS为各种网站写脚本做解析, 但是 ...
- 利用gulp解决微信浏览器缓存问题
做了好多项目,这次终于要解决微信浏览器缓存这个令人头疼的问题了.每次上传新的文件,在微信浏览器中访问时,总要先清除微信的缓存,实在麻烦,在网上搜罗了很多解决办法,终于找到了方法:利用gulp解决缓存问 ...
- 利用Json_encode解决中文问题
利用Json_encode解决中文问题 public function return_json($data=array()){ echo json_encode($data ...
- 利用Filter解决跨域请求的问题
1.为什么出现跨域. 很简单的一句解释,A系统中使用ajax调用B系统中的接口,此时就是一个典型的跨域问题,此时浏览器会出现以下错误信息,此处使用的是chrome浏览器. 错误信息如下: jquery ...
- 利用NSProxy解决NSTimer内存泄漏问题
之前写过一篇利用RunTime解决由NSTimer导致的内存泄漏的文章,最近和同事讨论觉得这样写有点复杂,然后发现有NSProxy这么好用的根类,根类,根类,没错NSProxy与NSObject一样是 ...
- 利用dynamic解决匿名对象不能赋值的问题
原文:利用dynamic解决匿名对象不能赋值的问题 关于匿名对象 匿名对象是.Net Framework 3.0提供的新类型,例如: }; 就是一个匿名类,搭配Linq,可以很灵活的在代码中组合数据, ...
随机推荐
- WPF 数据绑定实例一
前言: 数据绑定的基本步骤: (1)先声明一个类及其属性 (2)初始化类赋值 (3)在C#代码中把控件DataContext=对象: (4)在界面设计里,控件给要绑定的属性{Binding 绑定类的属 ...
- 微信小程序:如何删除所有的console.log?
使用vscode正则匹配,手动去除 1.用vscode打开微信小程序项目 2.Edit-----replace in Files 1. console.log()加了分号 console\.log\( ...
- Django框架的forms组件与一些补充
目录 一.多对多的三种创建方式 1. 全自动 2. 纯手撸(了解) 3. 半自动(强烈推荐) 二.forms组件 1. 如何使用forms组件 2. 使用forms组件校验数据 3. 使用forms组 ...
- Linux基本命令——系统管理和磁盘管理
转: Linux基本命令--系统管理和磁盘管理 Linux命令--系统管理和磁盘管理 一.系统管理 1.1 时间相关指令 <1> 查看当前日历: cal <2> 显示或设置时间 ...
- docker的安装和基本的docker命令、镜像和容器的操作
1.yum 包更新到最新 yum update 2.安装需要的软件包, yum-util 提供yum-config-manager功能,另外两个是devicemapper驱动依赖的 yum insta ...
- dapr学习:dapr介绍
该部分主要是给出学习dapr的入门,描述dapr全貌告诉你dapr是啥以及介绍dapr的主要功能与组件 该部分分为两章: 第一章:介绍dapr 第二章:调试dapr的解决方案项目 1. 介绍dapr ...
- LeetCode-二叉搜索树的范围和
二叉搜索树的范围和 LeetCode-938 首先需要仔细理解题目的意思:找出所有节点值在L和R之间的数的和. 这里采用递归来完成,主要需要注意二叉搜索树的性质. /** * 给定二叉搜索树的根结点 ...
- CCF(公共钥匙盒):思维+模拟
公共钥匙盒 201709-2 这题的思路一开始不是很清晰,一开始想用贪心去做.但是发现按照题目的思路不对.所以这里采用的是类似于多项式的加减的处理. #include<iostream> ...
- 鸿蒙的js开发模式19:鸿蒙手机下载python服务器端文件的实现
目录:1.承接上篇鸿蒙客户端上传文件2.域名通过内网穿透工具3.python服务器端代码4.鸿蒙手机的界面和业务逻辑5.<鸿蒙的js开发模式>系列文章合集 1.承接上篇鸿蒙客户端上传文件, ...
- 越来越受欢迎的Vue想学么,90后小姐姐今儿来教你
摘要:Vue的相关技术原理成为了前端岗位面试中的必考知识点,掌握 Vue 对于前端工程师来说更像是一门"必修课". 本文原作者为尹婷,擅长前端组件库研发和微信机器人. 我们发现, ...