记录下学习使用sklearn,将使用sklearn实现机器学习大部分内容

基于scikit-learn机器学习(第2版)这本书,和scikit-learn中文社区

简单线性回归

首先,最简单的线性回归也有几个地方要注意

  1. fit的时候,对于X,要求是n*m的类型,y要是n*1的类型
  2. sklearn会将得到的系数存储起来,分别在coef_中和intercept_中,intercept_是偏移,也就是b,coef_是k,或者向量中的W

来看具体例子

from sklearn.linear_model import LinearRegression
model = LinearRegression()
import numpy as np
# 线性模型为y = kx + b
X = np.array([1,2,3,4,5,6]).reshape(-1,1)
y = [1,2.1,2.9,4.1,5.2,6.1]
# y = np.array([1,2.1,2.9,4.1,5.2,6.1]) 这样写也是可以的
# y = np.array([1,2.1,2.9,4.1,5.2,6.1]).reshape(-1,1) 这样写也是可以的
# y = np.array([1,2.1,2.9,4.1,5.2,6.1]).reshape(1,-1) 这样写就不行了
model.fit(X,y)
print(model.coef_)
print(model.intercept_) /*
[[1.02857143]]
[-0.03333333]
*/

上面y的长度总是要和X保持一致的

y = [1,2.1,2.9,4.1,5.2,6.1]
print("y:")
print(y)
y = np.array([1,2.1,2.9,4.1,5.2,6.1]) # 这样写也是可以的
print("y:")
print(y)
print(y.shape)
y = np.array([1,2.1,2.9,4.1,5.2,6.1]).reshape(-1,1) # 这样写也是可以的
print("y:")
print(y)
print(y.shape)
y = np.array([1,2.1,2.9,4.1,5.2,6.1]).reshape(1,-1) # 这样写就不行了
print("y:")
print(y)
print(y.shape) /*
y:
[1, 2.1, 2.9, 4.1, 5.2, 6.1]
y:
[1. 2.1 2.9 4.1 5.2 6.1]
(6,)
y:
[[1. ]
[2.1]
[2.9]
[4.1]
[5.2]
[6.1]]
(6, 1)
y:
[[1. 2.1 2.9 4.1 5.2 6.1]]
(1, 6)
*/

再来看这个例子

# 最小二乘法用于sklearn中的线性回归,引入它。
from sklearn import linear_model
reg = linear_model.LinearRegression() def foo(x1,x2): # w0 = 5, w1 = 2, w2 = 3
return 2 * x1 + 3 * x2 + 5 """生成测试数据 X,y
X 10行2列
y 10行1列
"""
X = [[i,(i+1)/2] for i in range(10)]
y = [foo(i,(i+1)/2) for i in range(10)] # 根据参数拟合直线
reg.fit(X,y) # 输出 w1,w2 = [2.8, 1.4]
print(reg.coef_) # 输出 w0 = 5.8
print(reg.intercept_) """
拟合直线: y = 2.8 * x1 + 1.4 * x2 + 5.8
""" # 用生成的直线进行预测
print(reg.predict(X))
[2.8 1.4]
5.799999999999997
X:
[[0, 0.5], [1, 1.0], [2, 1.5], [3, 2.0], [4, 2.5], [5, 3.0], [6, 3.5], [7, 4.0], [8, 4.5], [9, 5.0]]
y:
[6.5, 10.0, 13.5, 17.0, 20.5, 24.0, 27.5, 31.0, 34.5, 38.0]
predict:
[ 6.5 10. 13.5 17. 20.5 24. 27.5 31. 34.5 38. ]

这个例子来源:https://blog.csdn.net/weixin_43899202/article/details/104155313

看完之后第一反应是,明明 y = 2 * x1 + 3 * x2 + 5,怎么就变成 y = 2.8 * x1 + 1.4 * x2 + 5.8

这是因为我们对X赋值的时候都是等比例缩放,我们现在再用两个测试集去验证一下就知道了

# 验证集1
X_test = np.array([[0,0],[1,1],[2,2],[3,3],[4,4],[5,5]])
y_test1 = np.array([5,10,15,20,25,30])
score1 = reg.score(X_test,y_test1)
print("score1:")
print(score1)
# 验证集2
X_test = np.array([[0,0],[1,1],[2,2],[3,3],[4,4],[5,5]])
y_test2 = np.array([5.8,10,14.2,18.4,22.6,26.8])
score2 = reg.score(X_test,y_test2)
print("score2")
print(score2)
score1:
0.9546514285714286
score2
1.0

最后,再使用predict方法

X_predict = np.array([[0,0],[0,1],[1,1],[1,2],[2,2]])
print(reg.predict(X))
[ 6.5 10.  13.5 17.  20.5 24.  27.5 31.  34.5 38. ]

这样,学习使用了线性模型、fit方法、predict方法、score方法,以及线性模型的参数coef_intercept_

sklearn机器学习实战-简单线性回归的更多相关文章

  1. 使用sklearn机器学习库实现线性回归

    import numpy as np  # 导入科学技术框架import matplotlib.pyplot as plt  # 导入画图工具from sklearn.linear_model imp ...

  2. sklearn机器学习实战-KNN

    KNN分类 KNN是惰性学习模型,也被称为基于实例的学习模型 简单线性回归是勤奋学习模型,训练阶段耗费计算资源,但是预测阶段代价不高 首先工作是把label的内容进行二值化(如果多分类任务,则考虑On ...

  3. scikit-learn机器学习(一)简单线性回归

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt ## 设置字符集,防止中文乱码 import ma ...

  4. 机器学习实战 | SKLearn最全应用指南

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/41 本文地址:http://www.showmeai.tech/article-det ...

  5. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  6. sklearn学习笔记之简单线性回归

    简单线性回归 线性回归是数据挖掘中的基础算法之一,从某种意义上来说,在学习函数的时候已经开始接触线性回归了,只不过那时候并没有涉及到误差项.线性回归的思想其实就是解一组方程,得到回归函数,不过在出现误 ...

  7. 机器学习与Tensorflow(1)——机器学习基本概念、tensorflow实现简单线性回归

    一.机器学习基本概念 1.训练集和测试集 训练集(training set/data)/训练样例(training examples): 用来进行训练,也就是产生模型或者算法的数据集 测试集(test ...

  8. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  9. Python线性回归算法【解析解,sklearn机器学习库】

    一.概述 参考博客:https://www.cnblogs.com/yszd/p/8529704.html 二.代码实现[解析解] import numpy as np import matplotl ...

随机推荐

  1. 区分构造函数注入和 setter 注入?

    构造函数注入 setter 注入 没有部分注入 有部分注入 不会覆盖 setter 属性 会覆盖 setter 属性 任意修改都会创建一个新实例 任意修改不会创建一个新实例 适用于设置很多属性 适用于 ...

  2. 运筹学之"简单平均预测法"和"加权滑动平均预测法"和"确定平滑系数"

    1.简单滑动平均预测法就是将所有的售价加起来除以总数 665/5=133 2.加权滑动平均预测法:需要将售价分别乘以权之和,并除以权之和 1771/13≈136.23 二.某木材公司销售房架构件,其中 ...

  3. 6_比例积分控制器_PI控制

  4. 2018 百度web前端面试

    面试前 正式入职一年半左右,实习半年,勉强两年经验吧,然后很惊喜收到了百度的面试邀约,约得两点钟面试,然后本人一点钟就到了,通电话之后,面试官很热情,说正在吃饭吃完饭就去找我,让我去坐着等一会,然后一 ...

  5. css3中user-select的用法详解

    css3中user-select的用法详解 user-select属性是css3新增的属性,用于设置用户是否能够选中文本.可用于除替换元素外的所有元素,以下是user-select的主要用法和注意事项 ...

  6. java继承时能包括静态的变量和方法吗?举例说明!

    子类继承了超类定义的所有实例变量和方法包括静态的变量和方法(马克-to-win见下例),并且为它自己增添了独特的元素.子类只能有一个超类.Java不支持多超类的继承. 子类拥有超类的所有成员,但它不能 ...

  7. 类其中的变量为final时的用法

    类其中的变量为final时的用法:   类当中final变量没有初始缺省值,必须在构造函数中赋值或直接当时赋值.否则报错. public class Test {     final int i;   ...

  8. IDEA修改代码后不用重新启动项目即可刷新

    1.File--Settings--Build 2.Build,Execution,Deplyment--Compiler 3.选中打勾 "Build project automatical ...

  9. java集合总览

    在编写java程序中,我们最常用的除了八种基本数据类型,String对象外还有一个集合类,在我们的的程序中到处充斥着集合类的身影!java中集合大家族的成员实在是太丰富了,有常用的ArrayList. ...

  10. spring总览

    Spring 概述 1. 什么是spring? Spring 是个java企业级应用的开源开发框架.Spring主要用来开发Java应用,但是有些扩展是针对构建J2EE平台的web应用.Spring  ...