运用对偶的(对应原始)感知机算法实现线性分类。

  参考书目:《统计学习方法》(李航)

  算法原理:

  代码实现:

  环境:win7 32bit + Anaconda3 +spyder

  和原始算法的实现基本框架是类似的,只是判断和权值的更新算法有点变化。

 # -*- coding: utf-8 -*-
"""
Created on Fri Nov 18 01:29:35 2016 @author: Administrator
""" import numpy as np
from matplotlib import pyplot as plt # train matrix
def get_train_data():
M1 = np.random.random((100,2))
# 将label加到最后,方便后面操作
M11 = np.column_stack((M1,np.ones(100))) M2 = np.random.random((100,2)) - 0.7
M22 = np.column_stack((M2,np.ones(100)*(-1)))
# 合并两类,并将位置索引加到最后
MA = np.vstack((M11,M22))
MA = np.column_stack((MA,range(0,200))) # 作图操作
plt.plot(M1[:,0],M1[:,1], 'ro')
plt.plot(M2[:,0],M2[:,1], 'go')
# 为了美观,根据数据点限制之后分类线的范围
min_x = np.min(M2)
max_x = np.max(M1)
# 分隔x,方便作图
x = np.linspace(min_x, max_x, 100)
# 此处返回 x 是为了之后作图方便
return MA,x # GRAM计算
def get_gram(MA):
GRAM = np.empty(shape=(200,200))
for i in range(len(MA)):
for j in range(len(MA)):
GRAM[i,j] = np.dot(MA[i,][:2], MA[j,][:2])
return GRAM # 方便在train函数中识别误分类点
def func(alpha,b,xi,yi,yN,index,GRAM):
pa1 = alpha*yN
pa2 = GRAM[:,index]
num = yi*(np.dot(pa1,pa2)+b)
return num # 训练training data
def train(MA, alpha, b, GRAM, yN):
# M 存储每次处理后依旧处于误分类的原始数据
M = []
for sample in MA:
xi = sample[0:2]
yi = sample[-2]
index = int(sample[-1])
# 如果为误分类,改变alpha,b
# n 为学习率
if func(alpha,b,xi,yi,yN,index,GRAM) <= 0:
alpha[index] += n
b += n*yi
M.append(sample)
if len(M) > 0:
# print('迭代...')
train(M, alpha, b, GRAM, yN)
return alpha,b # 作出分类线的图
def plot_classify(w,b,x, rate0):
y = (w[0]*x+b)/((-1)*w[1])
plt.plot(x,y)
plt.title('Accuracy = '+str(rate0)) # 随机生成testing data 并作图
def get_test_data():
M = np.random.random((50,2))
plt.plot(M[:,0],M[:,1],'*y')
return M
# 对传入的testing data 的单个样本进行分类
def classify(w,b,test_i):
if np.sign(np.dot(w,test_i)+b) == 1:
return 1
else:
return 0 # 测试数据,返回正确率
def test(w,b,test_data):
right_count = 0
for test_i in test_data:
classx = classify(w,b,test_i)
if classx == 1:
right_count += 1
rate = right_count/len(test_data)
return rate if __name__=="__main__":
MA,x= get_train_data()
test_data = get_test_data()
GRAM = get_gram(MA)
yN = MA[:,2]
xN = MA[:,0:2]
# 定义初始值
alpha = [0]*200
b = 0
n = 1
# 初始化最优的正确率
rate0 = 0 # print(alpha,b)
# 循环不同的学习率n,寻求最优的学习率,即最终的rate0
# w0,b0为对应的最优参数
for i in np.linspace(0.01,1,100):
n = i
alpha,b = train(MA, alpha, b, GRAM, yN)
alphap = np.column_stack((alpha*yN,alpha*yN))
w = sum(alphap*xN)
rate = test(w,b,test_data)
# print(w,b)
rate = test(w,b,test_data)
if rate > rate0:
rate0 = rate
w0 = w
b0 = b
print('Until now, the best result of the accuracy on test data is '+str(rate))
print('with w='+str(w0)+' b='+str(b0))
print('---------------------------------------------')
# 在选定最优的学习率后,作图
plot_classify(w0,b0,x,rate0)
plt.show()

  输出:

感知机的对偶形式——python3实现的更多相关文章

  1. 2. 感知机(Perceptron)基本形式和对偶形式实现

    1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...

  2. 1. 感知机原理(Perceptron)

    1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...

  3. 机器学习理论基础学习3.1--- Linear classification 线性分类之感知机PLA(Percetron Learning Algorithm)

    一.感知机(Perception) 1.1 原理: 感知机是二分类的线性模型,其输入是实例的特征向量,输出的是事例的类别,分别是+1和-1,属于判别模型. 假设训练数据集是线性可分的,感知机学习的目标 ...

  4. 机器学习笔记(一)&#183; 感知机算法 &#183; 原理篇

    这篇学习笔记强调几何直觉,同时也注重感知机算法内部的动机.限于篇幅,这里仅仅讨论了感知机的一般情形.损失函数的引入.工作原理.关于感知机的对偶形式和核感知机,会专门写另外一篇文章.关于感知机的实现代码 ...

  5. 统计学习方法与Python实现(一)——感知机

    统计学习方法与Python实现(一)——感知机 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 假设输入的实例的特征空间为x属于Rn的n维特征向量, ...

  6. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  7. 感知机算法(PLA)代码实现

    目录 1. 引言 2. 载入库和数据处理 3. 感知机的原始形式 4. 感知机的对偶形式 5. 多分类情况-one vs. rest 6. 多分类情况-one vs. one 7. sklearn实现 ...

  8. 【python与机器学习实战】感知机和支持向量机学习笔记(一)

    对<Python与机器学习实战>一书阅读的记录,对于一些难以理解的地方查阅了资料辅以理解并补充和记录,重新梳理一下感知机和SVM的算法原理,加深记忆. 1.感知机 感知机的基本概念 感知机 ...

  9. 原始感知机入门——python3实现

    运用最简单的原始(对应的有对偶)感知机算法实现线性分类. 参考书目:<统计学习方法>(李航) 算法原理: 踩到的坑:以为误分类的数据只使用一次,造成分类结果很差,在train函数内加个简单 ...

随机推荐

  1. VPN安装后报错:Reason442 & Error56

    VPN安装后一直报错,同样的32位安装包别人安装是正常,自己安装就不正常了,考虑到是自己电脑配置的问题. 经过一番努力,解决了问题,下面就本次解决过程做一个小小的总结. (1)确保VPN Servic ...

  2. fiddler监听127.0.0.1或localhost

    localhost/127.0.0.1的请求不会通过任何代理发送,fiddler也就无法截获. 解决方案 1,用 http://localhost. (locahost紧跟一个点号)2,用 http: ...

  3. ElasticSearch集群配置

    因机器有限,本文只做单机3个节点的集群测试. 1.集群测试信息 elasticsearch版本:elasticsearch-2.4.1 windowns版本:win10 2.解压elasticsear ...

  4. DEV控件中GridView中的复选框与CheckBox实现联动的全选功能

    最初的界面图如图1-1(全选框ID: cb_checkall  DEV控件名称:gcCon ): 要实现的功能如下图(1-2  1-3  1-4)及代码所示: 图1-2 图1-3 图1-4 O(∩_∩ ...

  5. 删除ORACLE的步骤

    1.关闭oracle所有的服务.可以在windows的服务管理器中关闭: 2.打开注册表:regedit 打开路径: HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlS ...

  6. P53 T5

    北京某高校可用的电话号码有以下几类:校内电话号码由4位数字,第1位数字不是0:校外电话又分为本市电话和外地电话两类,拔校外电话需先拔0,若是本市电话则再接着拔8位数字(第一位不是0),若是外地电话则拔 ...

  7. 【66测试20161115】【树】【DP_LIS】【SPFA】【同余最短路】【递推】【矩阵快速幂】

    还有3天,今天考试又崩了.状态还没有调整过来... 第一题:小L的二叉树 勤奋又善于思考的小L接触了信息学竞赛,开始的学习十分顺利.但是,小L对数据结构的掌握实在十分渣渣.所以,小L当时卡在了二叉树. ...

  8. django:field字段类型

    字段类型(Field types) AutoField 它是一个根据 ID 自增长的 IntegerField 字段.通常,你不必直接使用该字段.如果你没在别的字段上指定主 键,Django 就会自动 ...

  9. SAE使用心得1

    最近准备在新浪云端SAE上挂点自己的小网站,这样自己开发个什么东西别人能用.但是第一次接触SAE,遇到一些问题,记下来给大家看. 1.安装的svn版本不能高于 1.8,否则无法向SAE提交代码. 2. ...

  10. AWS-CDH5.5安装-软件下载

    1.下载安装介质 下载CM安装文件: [root@ip---- cm5.5.0]# wget -c -r -nd -np -k -L -A rpm http://archive-primary.clo ...