知识点

  • scikit-learn 对于线性回归提供了比较多的类库,这些类库都可以用来做线性回归分析。
  • 我们也可以使用scikit-learn的线性回归函数,而不是从头开始实现这些算法。 我们将scikit-learn的线性回归算法应用于编程作业1.1的数据,并看看它的表现。
  • 一般来说,只要觉得数据有线性关系,LinearRegression类是我们的首选。如果发现拟合或者预测的不好,再考虑用其他的线性回归库。如果是学习线性回归,推荐先从这个类开始第一步的研究。
  • LinearRegression 的使用非常简单,主要分为两步:
    1. 使用 fit(x_train,y_train) 对训练集x, y进行训练。
    2. 使用 predict(x_test) 训练得到的估计器对输入为 x_test 的集合进行预测。( (x_test) 可以是测试集,也可以是需要预测的数据)

过程

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt # 导入数据
path = 'D:\BaiduNetdiskDownload\data_sets\ex1data1.txt' # pd.read_csv 将 TXT 文件读入并转化为数据框形式
# names 添加列名
# header 用指定的行来作为标题(表头),若原来无标题则设为 none
# 用到 Pandas 里面的 head( ) 函数读取数据(只能读取前五行)
data = pd.read_csv(path,header=None,names=['Population','Profit'])
data.head()


# 在训练集中插入一列1(其实是x0=1),方便我们可以使用向量化的解决方案来计算代价和梯度。
data.insert(0, 'Ones', 1) # set X(training set), y(target variable)
# 设置训练集X,和目标变量y的值
cols = data.shape[1] # 获取列数
X = data.iloc[:,0:cols-1] # 输入向量X为前cols-1列
y = data.iloc[:,cols-1:cols] # 目标变量y为最后一列 # 代价函数是应该是 numpy 矩阵,所以我们需要转换X和Y,然后才能使用它们。 我们还需要初始化 theta 。
X = np.array(X.values)
y = np.array(y.values)
theta = np.array([0,0])

核心代码:

from sklearn import linear_model

# 需要导入LinearRegression类,并将之实例化,并采用fit()方法已验证这些训练数据。
model = linear_model.LinearRegression()
model.fit(X, y) # fit(X, y)对训练集X, y进行训练

scikit-learn model的预测表现:

x = np.array(X[:, 1])
f = model.predict(X).flatten() # .flatten() 默认按行的方向降维 fig, ax = plt.subplots(figsize=(8,5))
ax.plot(x, f, 'r', label='Prediction')
ax.scatter(data.Population, data.Profit, label='Traning Data')
ax.legend(loc=2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()

参考资料

python_sklearn机器学习算法系列之LinearRegression线性回归

吴恩达机器学习作业Python实现(一):线性回归

scikit-learn 线性回归算法库小结

编程作业1.1——sklearn机器学习算法系列之LinearRegression线性回归的更多相关文章

  1. sklearn机器学习算法--线性模型

    线性模型 用于回归的线性模型 线性回归(普通最小二乘法) 岭回归 lasso 用于分类的线性模型 用于多分类的线性模型 1.线性回归 LinearRegression,模型简单,不同调节参数 #2.导 ...

  2. sklearn机器学习算法--K近邻

    K近邻 构建模型只需要保存训练数据集即可.想要对新数据点做出预测,算法会在训练数据集中找到最近的数据点,也就是它的“最近邻”. 1.K近邻分类 #第三步导入K近邻模型并实例化KN对象 from skl ...

  3. 机器学习作业(五)机器学习算法的选择与优化——Matlab实现

    题目下载[传送门] 第1步:读取数据文件,并可视化: % Load from ex5data1: % You will have X, y, Xval, yval, Xtest, ytest in y ...

  4. 机器学习算法系列:FM分解机

    在线性回归中,是假设每个特征之间独立的,也即是线性回归模型是无法捕获特征之间的关系.为了捕捉特征之间的关系,便有了FM分解机的出现了.FM分解机是在线性回归的基础上加上了交叉特征,通过学习交叉特征的权 ...

  5. 如何用Python实现常见机器学习算法-1

    最近在GitHub上学习了有关python实现常见机器学习算法 目录 一.线性回归 1.代价函数 2.梯度下降算法 3.均值归一化 4.最终运行结果 5.使用scikit-learn库中的线性模型实现 ...

  6. java数据结构和算法编程作业系列篇-数组

    /** * 编程作业 2.1 向highArray.java程序(清单2.3)的HighArray类添加一个名为getMax()的方法,它返回 数组中最大关键字的值,当数组为空时返回-1.向main( ...

  7. Andrew Ng机器学习编程作业:Logistic Regression

    编程作业文件: machine-learning-ex2 1. Logistic Regression (逻辑回归) 有之前学生的数据,建立逻辑回归模型预测,根据两次考试结果预测一个学生是否有资格被大 ...

  8. Stanford coursera Andrew Ng 机器学习课程编程作业(Exercise 2)及总结

    Exercise 1:Linear Regression---实现一个线性回归 关于如何实现一个线性回归,请参考:http://www.cnblogs.com/hapjin/p/6079012.htm ...

  9. stanford coursera 机器学习编程作业 exercise 3(逻辑回归实现多分类问题)

    本作业使用逻辑回归(logistic regression)和神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于逻辑回归的一个编程练习,可参考:http://www.cnb ...

随机推荐

  1. python基础笔记:判断与循环

    判断: #根据身高为1.75,体重为65的小明的bmi输出小明的身材 h=1.75 w=65 bmi=w/(h*h) if bmi<18.5: print('过轻') elif bmi<= ...

  2. idea中使用maven运行wordcount代码

    1.创建maven项目 pom文件: <?xml version="1.0" encoding="UTF-8"?> <project xmln ...

  3. NumPy 排序、查找、计数

    章节 Numpy 介绍 Numpy 安装 NumPy ndarray NumPy 数据类型 NumPy 数组创建 NumPy 基于已有数据创建数组 NumPy 基于数值区间创建数组 NumPy 数组切 ...

  4. python 虚拟环境的安装

    方式一 1. pip install virtualenv 2. virtualenv 虚拟环境的名字 3. mac上 source + 虚拟环境的目录/bin/activate win上 直接进入虚 ...

  5. 使用Oracle VM VirtualBox安装CentOS 7.6操作系统

    使用Oracle VM VirtualBox安装CentOS 7.6操作系统                                                               ...

  6. Profiling Top Kagglers: Bestfitting, Currently #1 in the World

    We have a new #1 on our leaderboard – a competitor who surprisingly joined the platform just two yea ...

  7. Arduino - -- 串口双向通信

    需要用到Arduino UNO的串口双向通信功能,以下源码: int val; void setup() {   Serial.begin(9600); // opensserial port, se ...

  8. HttpServlet中文乱码问题

    客户端提交数据给服务器端(Requset) 如果数据中带有中文的话,有可能会出现乱码情况,那么可以参照以下方法解决. 如果是GET方式 1.代码转码 String username = request ...

  9. POJ 1164:The Castle

    The Castle Time Limit: 1000MS   Memory Limit: 10000K Total Submissions: 6677   Accepted: 3767 Descri ...

  10. C++ 模板练习1

    //特定的模板友元关系 #include "stdafx.h" #include <iostream> using namespace std; template< ...