Pytorch 实现简单线性回归
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
#解决内核挂掉
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
2 读取数据
data = pd.read_csv('dataset/Income1.csv')
print(type(data))
<class 'pandas.core.frame.DataFrame'>
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30 entries, 0 to 29
Data columns (total 3 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Unnamed: 0 30 non-null int64
1 Education 30 non-null float64
2 Income 30 non-null float64
dtypes: float64(2), int64(1)
memory usage: 848.0 bytes
data

查看数据类型
type(data.Education)
pandas.core.series.Series
4 图表显示数据
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 雅黑字体
plt.scatter(data.Education,data.Income)
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()

5 转换数据为 Tensor 类型
查看特征数据
data.Education
0 10.000000
1 10.401338
2 10.842809
3 11.244147
4 11.645485
5 12.086957
6 12.488294
7 12.889632
data.Education.index
RangeIndex(start=0, stop=30, step=1)
data.Education.values
array([10. , 10.40133779, 10.84280936, 11.24414716, 11.64548495,
12.08695652, 12.48829431, 12.88963211, 13.2909699 , 13.73244147,
14.13377926, 14.53511706, 14.97658863, 15.37792642, 15.77926421,
16.22073579, 16.62207358, 17.02341137, 17.46488294, 17.86622074,
18.26755853, 18.7090301 , 19.11036789, 19.51170569, 19.91304348,
20.35451505, 20.75585284, 21.15719064, 21.59866221, 22. ])
data.Education.values.reshape(-1,1)
array([[10. ],
[10.40133779],
[10.84280936],
[11.24414716],
[11.64548495],
[12.08695652],
[12.48829431],
[12.88963211],
[13.2909699 ],
[13.73244147],
[14.13377926],
[14.53511706],
[14.97658863],
[15.37792642],
[15.77926421],
[16.22073579],
[16.62207358],
[17.02341137],
[17.46488294],
[17.86622074],
[18.26755853],
[18.7090301 ],
[19.11036789],
[19.51170569],
[19.91304348],
[20.35451505],
[20.75585284],
[21.15719064],
[21.59866221],
[22. ]])
data.Education.values.reshape(-1,1).shape
查看特征数据变换后的数据类型
type(data.Education.values.reshape(-1,1))
numpy.ndarray
X = data.Education.values.reshape(-1,1).astype(np.float32)
print(type(X))
X.shape
<class 'numpy.ndarray'>
(30, 1)
X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32) ) #转换数据类型
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32) ) #转换数据类型
6 定义模型
model = nn.Linear(1,1) #w@input+b 等价于model(input)
定义均方损失函数
loss_fn = nn.MSELoss() #定义均方损失函数
定义优化器
opt = torch.optim.SGD(model.parameters(),lr=0.00001)
7 模型训练
for epoch in range(200):
for x, y in zip(X,Y):
y_pred = model(x) #使用模型预测
loss = loss_fn(y,y_pred) #根据预测计算损失
opt.zero_grad() #进行梯度清零
loss.backward() #求解梯度
opt.step() #优化模型参数
8 输出权重和偏置
model.weight
model.bias
Tensor 类型数据带梯度转换为numpy需要先去梯度
type(model.weight.detach().numpy())
numpy.ndarray
model(X).data.numpy()
预测值类型
type(model(X).data.numpy())
numpy.ndarray
model(X).data.numpy().shape
(30, 1)
plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).data.numpy())
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" data = pd.read_csv('dataset/Income1.csv')
print(type(data)) data.info() data from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 雅黑字体
plt.scatter(data.Education,data.Income)
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show() X = torch.from_numpy(data.Education.values.reshape(-1,1).astype(np.float32) ) #转换数据类型
Y = torch.from_numpy(data.Income.values.reshape(-1,1).astype(np.float32) ) #转换数据类型 model = nn.Linear(1,1) #w@input+b 等价于model(input)
loss_fn = nn.MSELoss() #定义均方损失函数
opt = torch.optim.SGD(model.parameters(),lr=0.00001) for epoch in range(200):
for x, y in zip(X,Y):
y_pred = model(x) #使用模型预测
loss = loss_fn(y,y_pred) #根据预测计算损失
opt.zero_grad() #进行梯度清零
loss.backward() #求解梯度
opt.step() #优化模型参数
print(f'epoch {epoch + 1}, loss {loss.sum():f}') model.weight
model.bias type(model.weight.detach().numpy()) plt.scatter(data.Education,data.Income)
plt.plot(X.numpy(),model(X).data.numpy())
plt.xlabel("受教育年限")
plt.ylabel("工资")
plt.show()
Pytorch 实现简单线性回归的更多相关文章
- SPSS数据分析—简单线性回归
和相关分析一样,回归分析也可以描述两个变量间的关系,但二者也有所区别,相关分析可以通过相关系数大小描述变量间的紧密程度,而回归分析更进一步,不仅可以描述变量间的紧密程度,还可以定量的描述当一个变量变化 ...
- sklearn学习笔记之简单线性回归
简单线性回归 线性回归是数据挖掘中的基础算法之一,从某种意义上来说,在学习函数的时候已经开始接触线性回归了,只不过那时候并没有涉及到误差项.线性回归的思想其实就是解一组方程,得到回归函数,不过在出现误 ...
- 机器学习与Tensorflow(1)——机器学习基本概念、tensorflow实现简单线性回归
一.机器学习基本概念 1.训练集和测试集 训练集(training set/data)/训练样例(training examples): 用来进行训练,也就是产生模型或者算法的数据集 测试集(test ...
- day-12 python实现简单线性回归和多元线性回归算法
1.问题引入 在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...
- 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE
前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...
- 机器学习——Day 2 简单线性回归
写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...
- Python回归分析五部曲(一)—简单线性回归
回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...
- R 语言中的简单线性回归
... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...
- 简单线性回归(梯度下降法) python实现
grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...
随机推荐
- Python中管理数据库
前言:Python中是利用MySQL模块和数据库之间建立联系. MySQLdb 是用于Python链接Mysql数据库的接口,它实现了 Python 数据库 API 规范 V2.0,基于 MySQL ...
- Python之pytesseract模块-实现OCR
在给PC端应用做自动化测试时,某些情况下无法定位界面上的控件,但我们又想获得界面上的文字,则可以通过截图后从图片上去获取该文字信息.那么,Python中有没有对应的工具来实现OCR呢?答案是有的,它叫 ...
- 为什么Class实例可以不是全局唯一的——自定义类加载器
为什么Class实例可以不是全局唯一的 通过定义两个类加载器加载同一字节码文件来证明Class实例为什么不是全局唯一的 1.将一个名为Demo(没有后缀)的字节码文件放在D盘根目录 2.定义两个类加载 ...
- xxs攻击
1 XSS是一种经常出现在web应用中的计算机安全漏洞,它允许恶意web用户将代码植入到提供给其它用户使用的页面中.比如这些代码包括HTML代码和客户端脚本.攻击者利用XSS漏洞旁路掉访问控制--例如 ...
- 一文读懂Redis
目录结构如下: 简介 Redis是一个高性能的key-value数据库.Redis对数据的操作都是原子性的. 优缺点 优点: 基于内存操作,内存读写速度快. Redis是单线程的,避免线程切换开销及多 ...
- python模块--datetime
datatime.date类 构造器 返回值类型 说明 (year, month, day) date 类方法/属性 .max date datetime.date(9999, 12, 3 ...
- lombok时运行编译无法找到get/set方法 看这篇就够了
今天项目突然运行的时候报错,提示找不到get和set方法,这个时候我就检查了项目,在编译器(idea)是没有报错的.说明编译没问题,只是运行过不去. 后面就开始用我的方法解决这个问题,一步一步排查. ...
- spark相关介绍-提取hive表(一)
本文环境说明 centos服务器 jupyter的scala核spylon-kernel spark-2.4.0 scala-2.11.12 hadoop-2.6.0 本文主要内容 spark读取hi ...
- jquery监听动态添加的input的change事件
使用下面方法在监听普通的input的change事件正常 $('#pp').on('change', 'input.videos_poster_input', function () { consol ...
- 探究java的intern方法
本文主要解释java的intern方法的作用和原理,同时会解释一下经常问的String面试题. 首先先说一下结论,后面会实际操作,验证一下结论.intern方法在不同的Java版本中的实现是不一样的. ...