Pytorch 实现简单线性回归
问题描述:
  使用 pytorch 实现一个简单的线性回归。
      

            受教育年薪与收入数据集
单变量线性回归
  单变量线性回归算法(比如,$x$ 代表学历,$f(x)$ 代表收入): 
    $f(x) = w*x + b $

  我们使用 $f(x)$ 这个函数来映射输入特征和输出值。
目标:
  预测函数 $f(x)$ 与真实值之间的整体误差最小
损失函数: 
  使用均方差作为作为成本函数。
  也就是预测值和真实值之间差的平方取均值。
成本函数与损失函数: 
  优化的目标( $y$ 代表实际的收入):
  找到合适的 $w$ 和 $b$ ,使得 $(f(x) - y)^{2}$越小越好
  注意:现在求解的是参数 $w$ 和 $b$。

过程
1 导入实验所需要的包
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
#解决内核挂掉
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

2 读取数据

data = pd.read_csv('dataset/Income1.csv')
print(type(data))
<class 'pandas.core.frame.DataFrame'>
3 查看数据信息
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Unnamed: 0 30 non-null int64
1 Education 30 non-null float64
2 Income 30 non-null float64
dtypes: float64(2), int64(1)
memory usage: 848.0 bytes
  查看数据
data

      

   查看数据类型

type(data.Education)
pandas.core.series.Series

4 图表显示数据

from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 雅黑字体
plt.scatter(data.Education,data.Income)
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()

      

5 转换数据为 Tensor 类型

查看特征数据

data.Education
0     10.000000
1 10.401338
2 10.842809
3 11.244147
4 11.645485
5 12.086957
6 12.488294
7 12.889632
查看特征数据 index
data.Education.index
RangeIndex(start=0, stop=30, step=1)
查看特征数据 value
data.Education.values
array([10.        , 10.40133779, 10.84280936, 11.24414716, 11.64548495,
12.08695652, 12.48829431, 12.88963211, 13.2909699 , 13.73244147,
14.13377926, 14.53511706, 14.97658863, 15.37792642, 15.77926421,
16.22073579, 16.62207358, 17.02341137, 17.46488294, 17.86622074,
18.26755853, 18.7090301 , 19.11036789, 19.51170569, 19.91304348,
20.35451505, 20.75585284, 21.15719064, 21.59866221, 22. ])
特征数据变换形状
data.Education.values.reshape(-1,1)
array([[10.        ],
[10.40133779],
[10.84280936],
[11.24414716],
[11.64548495],
[12.08695652],
[12.48829431],
[12.88963211],
[13.2909699 ],
[13.73244147],
[14.13377926],
[14.53511706],
[14.97658863],
[15.37792642],
[15.77926421],
[16.22073579],
[16.62207358],
[17.02341137],
[17.46488294],
[17.86622074],
[18.26755853],
[18.7090301 ],
[19.11036789],
[19.51170569],
[19.91304348],
[20.35451505],
[20.75585284],
[21.15719064],
[21.59866221],
[22. ]])
查看特征数据变换后的形状
data.Education.values.reshape(-1,1).shape

查看特征数据变换后的数据类型

type(data.Education.values.reshape(-1,1))
numpy.ndarray
修改特征数据变换后的数据类型
X = data.Education.values.reshape(-1,1).astype(np.float32)
print(type(X))
X.shape
<class 'numpy.ndarray'>
(30, 1)
特征数据和标签转换为Tensor
X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32) ) #转换数据类型
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32) ) #转换数据类型

6 定义模型

定义线性回归模型:
model = nn.Linear(1,1)    #w@input+b   等价于model(input)

定义均方损失函数

loss_fn = nn.MSELoss()    #定义均方损失函数

定义优化器

opt = torch.optim.SGD(model.parameters(),lr=0.00001)  

7 模型训练

for epoch in range(200):
for x, y in zip(X,Y):
y_pred = model(x) #使用模型预测
loss = loss_fn(y,y_pred) #根据预测计算损失
opt.zero_grad() #进行梯度清零
loss.backward() #求解梯度
opt.step() #优化模型参数

8 输出权重和偏置

model.weight
model.bias

  Tensor 类型数据带梯度转换为numpy需要先去梯度

type(model.weight.detach().numpy())
numpy.ndarray
9 获取预测值 y_pred
model(X).data.numpy()

预测值类型

type(model(X).data.numpy())
numpy.ndarray
预测值size
model(X).data.numpy().shape
(30, 1)
10 绘制回归曲线
plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).data.numpy())
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()
 完整代码:

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" data = pd.read_csv('dataset/Income1.csv')
print(type(data)) data.info() data from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 雅黑字体
plt.scatter(data.Education,data.Income)
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show() X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32) ) #转换数据类型
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32) ) #转换数据类型 model = nn.Linear(1,1) #w@input+b 等价于model(input)
loss_fn = nn.MSELoss() #定义均方损失函数
opt = torch.optim.SGD(model.parameters(),lr=0.00001) for epoch in range(200):
for x, y in zip(X,Y):
y_pred = model(x) #使用模型预测
loss = loss_fn(y,y_pred) #根据预测计算损失
opt.zero_grad() #进行梯度清零
loss.backward() #求解梯度
opt.step() #优化模型参数
print(f'epoch {epoch + 1}, loss {loss.sum():f}') model.weight
model.bias type(model.weight.detach().numpy()) plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).data.numpy())
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()
看完点个关注呗!!(总结不易)
 

Pytorch 实现简单线性回归的更多相关文章

  1. SPSS数据分析—简单线性回归

    和相关分析一样,回归分析也可以描述两个变量间的关系,但二者也有所区别,相关分析可以通过相关系数大小描述变量间的紧密程度,而回归分析更进一步,不仅可以描述变量间的紧密程度,还可以定量的描述当一个变量变化 ...

  2. sklearn学习笔记之简单线性回归

    简单线性回归 线性回归是数据挖掘中的基础算法之一,从某种意义上来说,在学习函数的时候已经开始接触线性回归了,只不过那时候并没有涉及到误差项.线性回归的思想其实就是解一组方程,得到回归函数,不过在出现误 ...

  3. 机器学习与Tensorflow(1)——机器学习基本概念、tensorflow实现简单线性回归

    一.机器学习基本概念 1.训练集和测试集 训练集(training set/data)/训练样例(training examples): 用来进行训练,也就是产生模型或者算法的数据集 测试集(test ...

  4. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  5. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  6. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  7. Python回归分析五部曲(一)—简单线性回归

    回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...

  8. R 语言中的简单线性回归

    ... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...

  9. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

随机推荐

  1. 小程序生成商品分享二维码海报解决方案和实现方式JAVA

    使用技术:  Graphics , 七牛云 , 微信sdk(github上非常出名的wxjava,地址https://github.com/Wechat-Group/WxJava/)直接上干货代码,每 ...

  2. Python语法之函数、引用和装饰器

    所谓函数,就是把具有独立功能的代码块组织成为一个小模块,在需要的时候调用 函数是带名字的代码块,用于完成具体的工作 需要在程序中多次执行同一项任务时,你无需反复编写完成该任务的代码,而只需调用该 任务 ...

  3. 硕盟type-c转接头HDMI+VGA+USB3.0+PD3.0四合一多功能扩展坞

    硕盟SM-T54是一款 TYPE C转HDMI+VGA+USB3.0+PD3.0四合一多功能扩展坞,支持四口同时使用,您可以将含有USB 3.1协议的电脑主机,通过此产品连接到具有HDMI或VGA的显 ...

  4. Expression 表达式动态生成

    http://blog.csdn.net/duan1311/article/details/51769119 以上是拼装和调用GroupBy的方法,是不是很简单,只要传入分组列与合计列就OK了! 下面 ...

  5. vim中字符串的替换

    vi/vim 中可以使用 :s 命令来替换字符串 :s/vivian/sky/ 替换当前行第一个 vivian 为 sky :s/vivian/sky/g 替换当前行所有 vivian 为 sky : ...

  6. python中字符串的各种方法

     图片来源见水印,一个学python的公众号

  7. 最新版微软视窗(Windows)作业系统下载(2020-08-19)

    为了更好的使用WSL(Windows Subsystem For Linux),不得不用最新的windows 10 2004版了,这个版本的WSL已经是第二版了,即WSL2.下面给出下载地址 系统发布 ...

  8. TP5指定讲师页面文章上下篇

    控制器代码 // 查询上下篇 $courseIds = model('course') ->where([ 'isdel' => 0, 'teacherid' => $teacher ...

  9. centos7.6,nginx1.18,php-7.4.6,mysql-5.7.30 安装

    #1.下载,来自各官网 nginx-1.18.0.tar.gz php-7.4.6.tar.gz mysql-5.7.30-linux-glibc2.12-x86_64.tar.gz #下载到本地再传 ...

  10. ecshop调用指定广告的方法

    在include/lib_goods.php文件下面新增:function getads($cat,$num){ $time = gmtime();$sql = "SELECT * FROM ...