知识点

  • scikit-learn 对于线性回归提供了比较多的类库,这些类库都可以用来做线性回归分析。
  • 我们也可以使用scikit-learn的线性回归函数,而不是从头开始实现这些算法。 我们将scikit-learn的线性回归算法应用于编程作业1.1的数据,并看看它的表现。
  • 一般来说,只要觉得数据有线性关系,LinearRegression类是我们的首选。如果发现拟合或者预测的不好,再考虑用其他的线性回归库。如果是学习线性回归,推荐先从这个类开始第一步的研究。
  • LinearRegression 的使用非常简单,主要分为两步:
    1. 使用 fit(x_train,y_train) 对训练集x, y进行训练。
    2. 使用 predict(x_test) 训练得到的估计器对输入为 x_test 的集合进行预测。( (x_test) 可以是测试集,也可以是需要预测的数据)

过程

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt # 导入数据
path = 'D:\BaiduNetdiskDownload\data_sets\ex1data1.txt' # pd.read_csv 将 TXT 文件读入并转化为数据框形式
# names 添加列名
# header 用指定的行来作为标题(表头),若原来无标题则设为 none
# 用到 Pandas 里面的 head( ) 函数读取数据(只能读取前五行)
data = pd.read_csv(path,header=None,names=['Population','Profit'])
data.head()


# 在训练集中插入一列1(其实是x0=1),方便我们可以使用向量化的解决方案来计算代价和梯度。
data.insert(0, 'Ones', 1) # set X(training set), y(target variable)
# 设置训练集X,和目标变量y的值
cols = data.shape[1] # 获取列数
X = data.iloc[:,0:cols-1] # 输入向量X为前cols-1列
y = data.iloc[:,cols-1:cols] # 目标变量y为最后一列 # 代价函数是应该是 numpy 矩阵,所以我们需要转换X和Y,然后才能使用它们。 我们还需要初始化 theta 。
X = np.array(X.values)
y = np.array(y.values)
theta = np.array([0,0])

核心代码:

from sklearn import linear_model

# 需要导入LinearRegression类,并将之实例化,并采用fit()方法已验证这些训练数据。
model = linear_model.LinearRegression()
model.fit(X, y) # fit(X, y)对训练集X, y进行训练

scikit-learn model的预测表现:

x = np.array(X[:, 1])
f = model.predict(X).flatten() # .flatten() 默认按行的方向降维 fig, ax = plt.subplots(figsize=(8,5))
ax.plot(x, f, 'r', label='Prediction')
ax.scatter(data.Population, data.Profit, label='Traning Data')
ax.legend(loc=2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()

参考资料

python_sklearn机器学习算法系列之LinearRegression线性回归

吴恩达机器学习作业Python实现(一):线性回归

scikit-learn 线性回归算法库小结

编程作业1.1——sklearn机器学习算法系列之LinearRegression线性回归的更多相关文章

  1. sklearn机器学习算法--线性模型

    线性模型 用于回归的线性模型 线性回归(普通最小二乘法) 岭回归 lasso 用于分类的线性模型 用于多分类的线性模型 1.线性回归 LinearRegression,模型简单,不同调节参数 #2.导 ...

  2. sklearn机器学习算法--K近邻

    K近邻 构建模型只需要保存训练数据集即可.想要对新数据点做出预测,算法会在训练数据集中找到最近的数据点,也就是它的“最近邻”. 1.K近邻分类 #第三步导入K近邻模型并实例化KN对象 from skl ...

  3. 机器学习作业(五)机器学习算法的选择与优化——Matlab实现

    题目下载[传送门] 第1步:读取数据文件,并可视化: % Load from ex5data1: % You will have X, y, Xval, yval, Xtest, ytest in y ...

  4. 机器学习算法系列:FM分解机

    在线性回归中,是假设每个特征之间独立的,也即是线性回归模型是无法捕获特征之间的关系.为了捕捉特征之间的关系,便有了FM分解机的出现了.FM分解机是在线性回归的基础上加上了交叉特征,通过学习交叉特征的权 ...

  5. 如何用Python实现常见机器学习算法-1

    最近在GitHub上学习了有关python实现常见机器学习算法 目录 一.线性回归 1.代价函数 2.梯度下降算法 3.均值归一化 4.最终运行结果 5.使用scikit-learn库中的线性模型实现 ...

  6. java数据结构和算法编程作业系列篇-数组

    /** * 编程作业 2.1 向highArray.java程序(清单2.3)的HighArray类添加一个名为getMax()的方法,它返回 数组中最大关键字的值,当数组为空时返回-1.向main( ...

  7. Andrew Ng机器学习编程作业:Logistic Regression

    编程作业文件: machine-learning-ex2 1. Logistic Regression (逻辑回归) 有之前学生的数据,建立逻辑回归模型预测,根据两次考试结果预测一个学生是否有资格被大 ...

  8. Stanford coursera Andrew Ng 机器学习课程编程作业(Exercise 2)及总结

    Exercise 1:Linear Regression---实现一个线性回归 关于如何实现一个线性回归,请参考:http://www.cnblogs.com/hapjin/p/6079012.htm ...

  9. stanford coursera 机器学习编程作业 exercise 3(逻辑回归实现多分类问题)

    本作业使用逻辑回归(logistic regression)和神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于逻辑回归的一个编程练习,可参考:http://www.cnb ...

随机推荐

  1. 067-PHP使用匿名函数

    <?php $func=function ($x,$y){ //匿名函数与变量绑定 return $x+$y; }; echo '5+6='.$func(5,6); //使用匿名函数 echo ...

  2. win10,64位操作系统安装mysql-8.0.16经验总结(图文详细,保证一次安装成功)

    文章目录 1.mysql下载 2.解压及配置文件 3.启动MySQL数据库 4.登录 MySQL 5.配置系统环境变量 6.mysql-8.0.16修改初始密码 机器配置: win10,64位: my ...

  3. 深度解析高分Essay写作组成部分

    很多留学生在国外呆了好几个学期对于怎么写Essay还是不太了解,今天小编大致总结了以下一些技巧,相信同学们可以尝试着自己完成一篇Essay写作!一起来看看吧! 与你的导师沟通 你的导师对于你的Essa ...

  4. ACM-AK吧!少年

    题目描述:AK吧!少年 AK is an ACM competition finished all of the problems. AC This problem is one of the ste ...

  5. css文本强制两行超出就显示省略号,不显示省略号

    1. 强制一行的情况很简单 overflow:hidden; //超出的隐藏 text-overflow:ellipsis; //省略号 white-space:nowrap; //强制一行显示 2. ...

  6. java处理浮点数小数点后几位

    转载:https://blog.csdn.net/xue_feitian/article/details/6556275 第一种方法: double f = 123.2315455458; BigDe ...

  7. pppd调试心得.md

    描述 pppd是用于驱动3g模块的一种方式,其本质是和运营商APN协商,建立连接 其与运营商之间使用ppp协议,而用户在应用层使用系统提供的socket即可,从而忽略底层使用的时何种接口的设备,避免因 ...

  8. 百度地图API提供Geocoder类进行地址解析

    根据地址描述获得坐标百度地图API提供Geocoder类进行地址解析,您可以通过Geocoder.getPoint()方法来将一段地址描述转换为一个坐标. // 创建地址解析器实例var myGeo ...

  9. java课程之团队开发冲刺阶段2.6

    总结昨天进度: 1.总体的思路已经完成,代码也差不多了,只剩下对闹钟activity的设置 遇到的困难: 1.在设置震动的时候,对方法有点不太理解,所以使用的时候产生了错误,没有达到预期的效果 今天的 ...

  10. 安装swoole redis异步 hiredis swoole扩展加载失败 或者不显示问题 解决办法

    当前办法仅供参考 贴上报错 找了好久 根据网上办法也试了 没解决 最后 仔细读问题 觉得可能是 hiredis路径问题 终于解决了 解决办法: 进入你的安装包目录然后执行下面 mkdir /usr/l ...