分享一下 线性回归中 欠拟合 和 过拟合 是怎么回事~
为了解决欠拟合的情 经常要提高线性的次数建立模型拟合曲线, 次数过高会导致过拟合,次数不够会欠拟合。
再建立高次函数时候,要利用多项式特征生成器 生成训练数据。
下面把整个流程展示一下
模拟了一个预测蛋糕价格的从欠拟合到过拟合的过程 git: https://github.com/linyi0604/MachineLearning 在做线性回归预测时候,为了提高模型的泛化能力,经常采用多次线性函数建立模型 f = k*x + b 一次函数
f = a*x^2 + b*x + w 二次函数
f = a*x^3 + b*x^2 + c*x + w 三次函数
。。。 泛化:
对未训练过的数据样本进行预测。 欠拟合:
由于对训练样本的拟合程度不够,导致模型的泛化能力不足。 过拟合:
训练样本拟合非常好,并且学习到了不希望学习到的特征,导致模型的泛化能力不足。 在建立超过一次函数的线性回归模型之前,要对默认特征生成多项式特征再输入给模型
  poly2 = PolynomialFeatures(degree=2)    # 2次多项式特征生成器
  x_train_poly2 = poly2.fit_transform(x_train)

下面模拟 根据蛋糕的直径大小 预测蛋糕价格

 from sklearn.linear_model import LinearRegression
import numpy as np
import matplotlib.pyplot as plt '''
在做线性回归预测时候,
为了提高模型的泛化能力,经常采用多次线性函数建立模型 f = k*x + b 一次函数
f = a*x^2 + b*x + w 二次函数
f = a*x^3 + b*x^2 + c*x + w 三次函数
。。。 泛化:
对未训练过的数据样本进行预测。 欠拟合:
由于对训练样本的拟合程度不够,导致模型的泛化能力不足。 过拟合:
训练样本拟合非常好,并且学习到了不希望学习到的特征,导致模型的泛化能力不足。 在建立超过一次函数的线性回归模型之前,要对默认特征生成多项式特征再输入给模型 下面模拟 根据蛋糕的直径大小 预测蛋糕价格 ''' # 样本的训练数据,特征和目标值
x_train = [[6], [8], [10], [14], [18]]
y_train = [[7], [9], [13], [17.5], [18]] # 一次线性回归的学习与预测
# 线性回归模型 学习
regressor = LinearRegression()
regressor.fit(x_train, y_train)
# 画出一次线性回归的拟合曲线
xx = np.linspace(0, 25, 100) # 0到16均匀采集100个点做x轴
xx = xx.reshape(xx.shape[0], 1)
yy = regressor.predict(xx) # 计算每个点对应的y
plt.scatter(x_train, y_train) # 画出训练数据的点
plt1, = plt.plot(xx, yy, label="degree=1")
plt.axis([0, 25, 0, 25])
plt.xlabel("Diameter")
plt.ylabel("Price")
plt.legend(handles=[plt1])
plt.show()

一次线性函数拟合曲线的结果,是欠拟合的情况:

下面进行建立2次线性回归模型进行预测:

 # 2次线性回归进行预测
poly2 = PolynomialFeatures(degree=2) # 2次多项式特征生成器
x_train_poly2 = poly2.fit_transform(x_train)
# 建立模型预测
regressor_poly2 = LinearRegression()
regressor_poly2.fit(x_train_poly2, y_train)
# 画出2次线性回归的图
xx_poly2 = poly2.transform(xx)
yy_poly2 = regressor_poly2.predict(xx_poly2)
plt.scatter(x_train, y_train)
plt1, = plt.plot(xx, yy, label="Degree1")
plt2, = plt.plot(xx, yy_poly2, label="Degree2")
plt.axis([0, 25, 0, 25])
plt.xlabel("Diameter")
plt.ylabel("Price")
plt.legend(handles=[plt1, plt2])
plt.show()
# 输出二次回归模型的预测样本评分
print("二次线性模型在训练数据上得分:", regressor_poly2.score(x_train_poly2, y_train)) # 0.9816421639597427

二次线性回归模型拟合的曲线:

拟合程度明显比1次线性拟合的要好

下面进行4次线性回归模型:

 # 进行四次线性回归模型拟合
poly4 = PolynomialFeatures(degree=4) # 4次多项式特征生成器
x_train_poly4 = poly4.fit_transform(x_train)
# 建立模型预测
regressor_poly4 = LinearRegression()
regressor_poly4.fit(x_train_poly4, y_train)
# 画出2次线性回归的图
xx_poly4 = poly4.transform(xx)
yy_poly4 = regressor_poly4.predict(xx_poly4)
plt.scatter(x_train, y_train)
plt1, = plt.plot(xx, yy, label="Degree1")
plt2, = plt.plot(xx, yy_poly2, label="Degree2")
plt4, = plt.plot(xx, yy_poly4, label="Degree2")
plt.axis([0, 25, 0, 25])
plt.xlabel("Diameter")
plt.ylabel("Price")
plt.legend(handles=[plt1, plt2, plt4])
plt.show()
# 输出二次回归模型的预测样本评分
print("四次线性训练数据上得分:", regressor_poly4.score(x_train_poly4, y_train)) # 1.0

四次线性模型预测准确率为百分之百, 但是看一下拟合曲线,明显存在不合逻辑的预测曲线,

在样本点之外的情况,可能预测的非常不准确,这种情况为过拟合

之前我们一直在展示在训练集合上获得的模型评分,次数越高的模型,训练拟合越好。

下面查看一组测试数据进行预测的得分情况:

 # 准备测试数据
x_test = [[6], [8], [11], [16]]
y_test = [[8], [12], [15], [18]]
print("一次线性模型在测试集合上得分:", regressor.score(x_test, y_test)) # 0.809726797707665
x_test_poly2 = poly2.transform(x_test)
print("二次线性模型在测试集合上得分:", regressor_poly2.score(x_test_poly2, y_test)) # 0.8675443656345054
x_test_poly4 = poly4.transform(x_test)
print("四次线性模型在测试集合上得分:", regressor_poly4.score(x_test_poly4, y_test)) # 0.8095880795746723

会发现,二次模型在预测集合上表现最好,四次模型表现反而不好!

这就是由于对训练数据学习的太过分,学习到了不重要的东西,反而导致预测不准确。


机器学习之路:python 多项式特征生成PolynomialFeatures 欠拟合与过拟合的更多相关文章

  1. 机器学习之路: python k近邻分类器 KNeighborsClassifier 鸢尾花分类预测

    使用python语言 学习k近邻分类器的api 欢迎来到我的git查看源代码: https://github.com/linyi0604/MachineLearning from sklearn.da ...

  2. 机器学习之路--Python

    常用数据结构 1.list 列表 有序集合 classmates = ['Michael', 'Bob', 'Tracy'] len(classmates) classmates[0] len(cla ...

  3. 机器学习之路: python 回归树 DecisionTreeRegressor 预测波士顿房价

    python3 学习api的使用 git: https://github.com/linyi0604/MachineLearning 代码: from sklearn.datasets import ...

  4. 机器学习之路: python 线性回归LinearRegression, 随机参数回归SGDRegressor 预测波士顿房价

    python3学习使用api 线性回归,和 随机参数回归 git: https://github.com/linyi0604/MachineLearning from sklearn.datasets ...

  5. 机器学习之路: python 决策树分类DecisionTreeClassifier 预测泰坦尼克号乘客是否幸存

    使用python3 学习了决策树分类器的api 涉及到 特征的提取,数据类型保留,分类类型抽取出来新的类型 需要网上下载数据集,我把他们下载到了本地, 可以到我的git下载代码和数据集: https: ...

  6. 机器学习基础:(Python)训练集测试集分割与交叉验证

    在上一篇关于Python中的线性回归的文章之后,我想再写一篇关于训练测试分割和交叉验证的文章.在数据科学和数据分析领域中,这两个概念经常被用作防止或最小化过度拟合的工具.我会解释当使用统计模型时,通常 ...

  7. 一个完整的机器学习项目在Python中演练(三)

    大家往往会选择一本数据科学相关书籍或者完成一门在线课程来学习和掌握机器学习.但是,实际情况往往是,学完之后反而并不清楚这些技术怎样才能被用在实际的项目流程中.就像你的脑海中已经有了一块块"拼 ...

  8. 机器学习算法与Python实践之(四)支持向量机(SVM)实现

    机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...

  9. 机器学习算法与Python实践之(三)支持向量机(SVM)进阶

    机器学习算法与Python实践之(三)支持向量机(SVM)进阶 机器学习算法与Python实践之(三)支持向量机(SVM)进阶 zouxy09@qq.com http://blog.csdn.net/ ...

随机推荐

  1. C++利用cin输入时检测回车的方法

    今天做TJU的OJ ,其中一道题是先读入一个字符串,再读入一个整数,循环往复,直到字符串是空,也就是说回车键结束循环. 但是cin对空格和回车都不敏感,都不影响继续读入数据,所以需要一种新的方式检测回 ...

  2. Linux/Unix系统编程手册 第一章:历史和标准

    Unix的开发不受控于某一个厂商或者组织,是由诸多商业和非商业团体共同贡献进行演化的.这导致两个结果:一是Unix集多种特性于一身,二是由于参与者众多,随着时间推移,Unix实现方式逐渐趋于分裂. 由 ...

  3. java四舍五入BigDecimal和js保留小数点两位

    java四舍五入BigDecimal保留两位小数的实现方法: // 四舍五入保留两位小数System.out.println("四舍五入取整:(3.856)="      + ne ...

  4. RobotFramework安装扩展库包Selenium2Library(三)

    Robot Framework扩展库包 http://robotframework.org/#libraries 一,自动化测试web端 1,pip安装SeleniumLibrary pip inst ...

  5. MySQL 联合查询

    联合查询:将多次查询(多条select语句), 在记录上进行拼接(字段不会增加) 基本语法:多条select语句构成: 每一条select语句获取的字段数必须严格一致(但是字段类型无关) 语法 Sel ...

  6. 数据库-mysql数据操作

    一:mysql 数据的插入 语法 以下为向MySQL数据表插入数据通用的 INSERT INTO SQL语法: INSERT INTO table_name ( field1, field2,...f ...

  7. python基础--hashlib模块

    hashlib模块用于加密操作,代替了md5和sha模块, 主要提供SHA1, SHA224, SHA256, SHA384, SHA512 ,MD5 算法. # -*- coding:utf-8 - ...

  8. No.15 selenium for python JavaScript

    JS处理滚动条 一.上下滚动 1.滚动条回到顶部: js="var q=document.documentElement.scrollTop=10000" driver.execu ...

  9. python随笔(二)

    range(2,10):不包括10 range(2,10,3):步长为3 range(10,2,-1):从10到2,步长-1.

  10. DedeCMS栏目页调用当前栏目名和上级栏目名

    在构建网页的时候,如果不想逐个写栏目列表页的标题,即列表页标题形式为:{field:seotitle/}_{dede:global.cfg_webname/},其中{field:seotitle/}为 ...