感知机的对偶形式——python3实现
运用对偶的(对应原始)感知机算法实现线性分类。
参考书目:《统计学习方法》(李航)
算法原理:



代码实现:
环境:win7 32bit + Anaconda3 +spyder
和原始算法的实现基本框架是类似的,只是判断和权值的更新算法有点变化。
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 18 01:29:35 2016 @author: Administrator
""" import numpy as np
from matplotlib import pyplot as plt # train matrix
def get_train_data():
M1 = np.random.random((100,2))
# 将label加到最后,方便后面操作
M11 = np.column_stack((M1,np.ones(100))) M2 = np.random.random((100,2)) - 0.7
M22 = np.column_stack((M2,np.ones(100)*(-1)))
# 合并两类,并将位置索引加到最后
MA = np.vstack((M11,M22))
MA = np.column_stack((MA,range(0,200))) # 作图操作
plt.plot(M1[:,0],M1[:,1], 'ro')
plt.plot(M2[:,0],M2[:,1], 'go')
# 为了美观,根据数据点限制之后分类线的范围
min_x = np.min(M2)
max_x = np.max(M1)
# 分隔x,方便作图
x = np.linspace(min_x, max_x, 100)
# 此处返回 x 是为了之后作图方便
return MA,x # GRAM计算
def get_gram(MA):
GRAM = np.empty(shape=(200,200))
for i in range(len(MA)):
for j in range(len(MA)):
GRAM[i,j] = np.dot(MA[i,][:2], MA[j,][:2])
return GRAM # 方便在train函数中识别误分类点
def func(alpha,b,xi,yi,yN,index,GRAM):
pa1 = alpha*yN
pa2 = GRAM[:,index]
num = yi*(np.dot(pa1,pa2)+b)
return num # 训练training data
def train(MA, alpha, b, GRAM, yN):
# M 存储每次处理后依旧处于误分类的原始数据
M = []
for sample in MA:
xi = sample[0:2]
yi = sample[-2]
index = int(sample[-1])
# 如果为误分类,改变alpha,b
# n 为学习率
if func(alpha,b,xi,yi,yN,index,GRAM) <= 0:
alpha[index] += n
b += n*yi
M.append(sample)
if len(M) > 0:
# print('迭代...')
train(M, alpha, b, GRAM, yN)
return alpha,b # 作出分类线的图
def plot_classify(w,b,x, rate0):
y = (w[0]*x+b)/((-1)*w[1])
plt.plot(x,y)
plt.title('Accuracy = '+str(rate0)) # 随机生成testing data 并作图
def get_test_data():
M = np.random.random((50,2))
plt.plot(M[:,0],M[:,1],'*y')
return M
# 对传入的testing data 的单个样本进行分类
def classify(w,b,test_i):
if np.sign(np.dot(w,test_i)+b) == 1:
return 1
else:
return 0 # 测试数据,返回正确率
def test(w,b,test_data):
right_count = 0
for test_i in test_data:
classx = classify(w,b,test_i)
if classx == 1:
right_count += 1
rate = right_count/len(test_data)
return rate if __name__=="__main__":
MA,x= get_train_data()
test_data = get_test_data()
GRAM = get_gram(MA)
yN = MA[:,2]
xN = MA[:,0:2]
# 定义初始值
alpha = [0]*200
b = 0
n = 1
# 初始化最优的正确率
rate0 = 0 # print(alpha,b)
# 循环不同的学习率n,寻求最优的学习率,即最终的rate0
# w0,b0为对应的最优参数
for i in np.linspace(0.01,1,100):
n = i
alpha,b = train(MA, alpha, b, GRAM, yN)
alphap = np.column_stack((alpha*yN,alpha*yN))
w = sum(alphap*xN)
rate = test(w,b,test_data)
# print(w,b)
rate = test(w,b,test_data)
if rate > rate0:
rate0 = rate
w0 = w
b0 = b
print('Until now, the best result of the accuracy on test data is '+str(rate))
print('with w='+str(w0)+' b='+str(b0))
print('---------------------------------------------')
# 在选定最优的学习率后,作图
plot_classify(w0,b0,x,rate0)
plt.show()
输出:


感知机的对偶形式——python3实现的更多相关文章
- 2. 感知机(Perceptron)基本形式和对偶形式实现
1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...
- 1. 感知机原理(Perceptron)
1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...
- 机器学习理论基础学习3.1--- Linear classification 线性分类之感知机PLA(Percetron Learning Algorithm)
一.感知机(Perception) 1.1 原理: 感知机是二分类的线性模型,其输入是实例的特征向量,输出的是事例的类别,分别是+1和-1,属于判别模型. 假设训练数据集是线性可分的,感知机学习的目标 ...
- 机器学习笔记(一)· 感知机算法 · 原理篇
这篇学习笔记强调几何直觉,同时也注重感知机算法内部的动机.限于篇幅,这里仅仅讨论了感知机的一般情形.损失函数的引入.工作原理.关于感知机的对偶形式和核感知机,会专门写另外一篇文章.关于感知机的实现代码 ...
- 统计学习方法与Python实现(一)——感知机
统计学习方法与Python实现(一)——感知机 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 假设输入的实例的特征空间为x属于Rn的n维特征向量, ...
- 吴裕雄 python 机器学习——人工神经网络与原始感知机模型
import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...
- 感知机算法(PLA)代码实现
目录 1. 引言 2. 载入库和数据处理 3. 感知机的原始形式 4. 感知机的对偶形式 5. 多分类情况-one vs. rest 6. 多分类情况-one vs. one 7. sklearn实现 ...
- 【python与机器学习实战】感知机和支持向量机学习笔记(一)
对<Python与机器学习实战>一书阅读的记录,对于一些难以理解的地方查阅了资料辅以理解并补充和记录,重新梳理一下感知机和SVM的算法原理,加深记忆. 1.感知机 感知机的基本概念 感知机 ...
- 原始感知机入门——python3实现
运用最简单的原始(对应的有对偶)感知机算法实现线性分类. 参考书目:<统计学习方法>(李航) 算法原理: 踩到的坑:以为误分类的数据只使用一次,造成分类结果很差,在train函数内加个简单 ...
随机推荐
- VPN安装后报错:Reason442 & Error56
VPN安装后一直报错,同样的32位安装包别人安装是正常,自己安装就不正常了,考虑到是自己电脑配置的问题. 经过一番努力,解决了问题,下面就本次解决过程做一个小小的总结. (1)确保VPN Servic ...
- fiddler监听127.0.0.1或localhost
localhost/127.0.0.1的请求不会通过任何代理发送,fiddler也就无法截获. 解决方案 1,用 http://localhost. (locahost紧跟一个点号)2,用 http: ...
- ElasticSearch集群配置
因机器有限,本文只做单机3个节点的集群测试. 1.集群测试信息 elasticsearch版本:elasticsearch-2.4.1 windowns版本:win10 2.解压elasticsear ...
- DEV控件中GridView中的复选框与CheckBox实现联动的全选功能
最初的界面图如图1-1(全选框ID: cb_checkall DEV控件名称:gcCon ): 要实现的功能如下图(1-2 1-3 1-4)及代码所示: 图1-2 图1-3 图1-4 O(∩_∩ ...
- 删除ORACLE的步骤
1.关闭oracle所有的服务.可以在windows的服务管理器中关闭: 2.打开注册表:regedit 打开路径: HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlS ...
- P53 T5
北京某高校可用的电话号码有以下几类:校内电话号码由4位数字,第1位数字不是0:校外电话又分为本市电话和外地电话两类,拔校外电话需先拔0,若是本市电话则再接着拔8位数字(第一位不是0),若是外地电话则拔 ...
- 【66测试20161115】【树】【DP_LIS】【SPFA】【同余最短路】【递推】【矩阵快速幂】
还有3天,今天考试又崩了.状态还没有调整过来... 第一题:小L的二叉树 勤奋又善于思考的小L接触了信息学竞赛,开始的学习十分顺利.但是,小L对数据结构的掌握实在十分渣渣.所以,小L当时卡在了二叉树. ...
- django:field字段类型
字段类型(Field types) AutoField 它是一个根据 ID 自增长的 IntegerField 字段.通常,你不必直接使用该字段.如果你没在别的字段上指定主 键,Django 就会自动 ...
- SAE使用心得1
最近准备在新浪云端SAE上挂点自己的小网站,这样自己开发个什么东西别人能用.但是第一次接触SAE,遇到一些问题,记下来给大家看. 1.安装的svn版本不能高于 1.8,否则无法向SAE提交代码. 2. ...
- AWS-CDH5.5安装-软件下载
1.下载安装介质 下载CM安装文件: [root@ip---- cm5.5.0]# wget -c -r -nd -np -k -L -A rpm http://archive-primary.clo ...