import numpy as np
import matplotlib.pyplot as plt from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn import datasets, linear_model,discriminant_analysis def load_data():
# 使用 scikit-learn 自带的 iris 数据集
iris=datasets.load_iris()
X_train=iris.data
y_train=iris.target
return train_test_split(X_train, y_train,test_size=0.25,random_state=0,stratify=y_train) #线性判断分析LinearDiscriminantAnalysis
def test_LinearDiscriminantAnalysis(*data):
X_train,X_test,y_train,y_test=data
lda = discriminant_analysis.LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)
print('Coefficients:%s, intercept %s'%(lda.coef_,lda.intercept_))
print('Score: %.2f' % lda.score(X_test, y_test)) # 产生用于分类的数据集
X_train,X_test,y_train,y_test=load_data()
# 调用 test_LinearDiscriminantAnalysis
test_LinearDiscriminantAnalysis(X_train,X_test,y_train,y_test)

def plot_LDA(converted_X,y):
'''
绘制经过 LDA 转换后的数据
:param converted_X: 经过 LDA转换后的样本集
:param y: 样本集的标记
'''
fig=plt.figure()
ax=Axes3D(fig)
colors='rgb'
markers='o*s'
for target,color,marker in zip([0,1,2],colors,markers):
pos=(y==target).ravel()
X=converted_X[pos,:]
ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,label="Label %d"%target)
ax.legend(loc="best")
fig.suptitle("Iris After LDA")
plt.show() def run_plot_LDA():
'''
执行 plot_LDA 。其中数据集来自于 load_data() 函数
'''
X_train,X_test,y_train,y_test=load_data()
X=np.vstack((X_train,X_test))
Y=np.vstack((y_train.reshape(y_train.size,1),y_test.reshape(y_test.size,1)))
lda = discriminant_analysis.LinearDiscriminantAnalysis()
lda.fit(X, Y)
converted_X=np.dot(X,np.transpose(lda.coef_))+lda.intercept_
plot_LDA(converted_X,Y) # 调用 run_plot_LDA
run_plot_LDA()

def test_LinearDiscriminantAnalysis_solver(*data):
'''
测试 LinearDiscriminantAnalysis 的预测性能随 solver 参数的影响
'''
X_train,X_test,y_train,y_test=data
solvers=['svd','lsqr','eigen']
for solver in solvers:
if(solver=='svd'):
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver)
else:
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver,shrinkage=None)
lda.fit(X_train, y_train)
print('Score at solver=%s: %.2f' %(solver, lda.score(X_test, y_test))) # 调用 test_LinearDiscriminantAnalysis_solver
test_LinearDiscriminantAnalysis_solver(X_train,X_test,y_train,y_test)

def test_LinearDiscriminantAnalysis_shrinkage(*data):
'''
测试 LinearDiscriminantAnalysis 的预测性能随 shrinkage 参数的影响
'''
X_train,X_test,y_train,y_test=data
shrinkages=np.linspace(0.0,1.0,num=20)
scores=[]
for shrinkage in shrinkages:
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver='lsqr',shrinkage=shrinkage)
lda.fit(X_train, y_train)
scores.append(lda.score(X_test, y_test))
## 绘图
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.plot(shrinkages,scores)
ax.set_xlabel(r"shrinkage")
ax.set_ylabel(r"score")
ax.set_ylim(0,1.05)
ax.set_title("LinearDiscriminantAnalysis")
plt.show()
# 调用 test_LinearDiscr
test_LinearDiscriminantAnalysis_shrinkage(X_train,X_test,y_train,y_test)

吴裕雄 python 机器学习——线性判断分析LinearDiscriminantAnalysis的更多相关文章

  1. 吴裕雄 python 机器学习——主成份分析PCA降维

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

  2. 吴裕雄--天生自然 人工智能机器学习实战代码:线性判断分析LINEARDISCRIMINANTANALYSIS

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

  3. 吴裕雄 python 机器学习——支持向量机线性分类LinearSVC模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets, linear_model,svm fr ...

  4. 吴裕雄 python 机器学习——局部线性嵌入LLE降维模型

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

  5. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  6. 吴裕雄 python 机器学习——分类决策树模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_s ...

  7. 吴裕雄 python 机器学习——回归决策树模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_s ...

  8. 吴裕雄 python 机器学习——逻辑回归

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

  9. 吴裕雄 python 机器学习——ElasticNet回归

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

随机推荐

  1. HttpWebResponse Post 前端控件数据,后台如何接收?

    MVC视图页: @{ Layout = null; } <!DOCTYPE html> <html> <head> <title>test</ti ...

  2. Android Studio搭建系统App开发环境

    一.前言 在Android的体系中开发普通app使用Android Studio这一利器会非常的方便.但是开发系统app可能就会有些吃力,不过经过一些配置仍然会 很简单.我们知道系统app因为涉及到一 ...

  3. mysql 设置初始密码

    mysqladmin -uroot password "123" 设置初始密码 由于原密码为空,因此-p可以不用 mysqladmin -uroot -p"123&quo ...

  4. vs2010直接调用av_register_all crash问题

    需要做一个视频导出的功能,通过ffmpeg来实现,vs2010里面引用ffmpeg库的 dll 和 lib 文件 第一步av_register_all就直接crash了, 查了近半天的时间,都快崩溃了 ...

  5. vue2.9.5 引入vue-strap时报错

    1.vue2.9.5 引入vue-strap时出错 2.组件中引入vue-strap的具体代码如下: 1 import DatePicker from 'vue-strap/src/Datepicke ...

  6. 7.2.4 else与if配对

    规则是,如果没有花括号,else与离它最近的if匹配,除非最近的if被花括号括起来. 注意:要缩进"语句","语句"可以是一条简单语句或复合语句. 记住,编译器 ...

  7. 软件测试:1.Describe An Error

    软件测试:1.Describe An Error 要求: 1.简要描述你最近完成项目里的一个error: 2.说明原因,错误影响,及你怎样发现的: 或许因为刚开学的缘故,近期我并没有完成大的项目,多少 ...

  8. 通过jedis远程访问redis服务器

    一.jedis简介 类似于mysql数据库,一般开发都需要通过代码去访问redis服务器,对于主流的开发语言,redis提供了访问的客户端接口. https://redis.io/clients 而对 ...

  9. stm32 uart 中断 蜜汁bug

    在项目中,使用stm32f103,配置uart1接收RXNE中断,使用DMA来进行UART1的发送. 初始化代码如下: void uart_init(u32 bound) { GPIO_InitTyp ...

  10. 2018-2019-2 20175328 《Java程序设计》第八周学习总结

    2018-2019-2 20175328 <Java程序设计>第八周学习总结 主要内容 泛型 泛型推出的主要目的是可以建立具有类型安全的集合框架,如链表.散列映射等数据结构. 1.泛型类声 ...