运用对偶的(对应原始)感知机算法实现线性分类。

  参考书目:《统计学习方法》(李航)

  算法原理:

  代码实现:

  环境:win7 32bit + Anaconda3 +spyder

  和原始算法的实现基本框架是类似的,只是判断和权值的更新算法有点变化。

 # -*- coding: utf-8 -*-
"""
Created on Fri Nov 18 01:29:35 2016 @author: Administrator
""" import numpy as np
from matplotlib import pyplot as plt # train matrix
def get_train_data():
M1 = np.random.random((100,2))
# 将label加到最后,方便后面操作
M11 = np.column_stack((M1,np.ones(100))) M2 = np.random.random((100,2)) - 0.7
M22 = np.column_stack((M2,np.ones(100)*(-1)))
# 合并两类,并将位置索引加到最后
MA = np.vstack((M11,M22))
MA = np.column_stack((MA,range(0,200))) # 作图操作
plt.plot(M1[:,0],M1[:,1], 'ro')
plt.plot(M2[:,0],M2[:,1], 'go')
# 为了美观,根据数据点限制之后分类线的范围
min_x = np.min(M2)
max_x = np.max(M1)
# 分隔x,方便作图
x = np.linspace(min_x, max_x, 100)
# 此处返回 x 是为了之后作图方便
return MA,x # GRAM计算
def get_gram(MA):
GRAM = np.empty(shape=(200,200))
for i in range(len(MA)):
for j in range(len(MA)):
GRAM[i,j] = np.dot(MA[i,][:2], MA[j,][:2])
return GRAM # 方便在train函数中识别误分类点
def func(alpha,b,xi,yi,yN,index,GRAM):
pa1 = alpha*yN
pa2 = GRAM[:,index]
num = yi*(np.dot(pa1,pa2)+b)
return num # 训练training data
def train(MA, alpha, b, GRAM, yN):
# M 存储每次处理后依旧处于误分类的原始数据
M = []
for sample in MA:
xi = sample[0:2]
yi = sample[-2]
index = int(sample[-1])
# 如果为误分类,改变alpha,b
# n 为学习率
if func(alpha,b,xi,yi,yN,index,GRAM) <= 0:
alpha[index] += n
b += n*yi
M.append(sample)
if len(M) > 0:
# print('迭代...')
train(M, alpha, b, GRAM, yN)
return alpha,b # 作出分类线的图
def plot_classify(w,b,x, rate0):
y = (w[0]*x+b)/((-1)*w[1])
plt.plot(x,y)
plt.title('Accuracy = '+str(rate0)) # 随机生成testing data 并作图
def get_test_data():
M = np.random.random((50,2))
plt.plot(M[:,0],M[:,1],'*y')
return M
# 对传入的testing data 的单个样本进行分类
def classify(w,b,test_i):
if np.sign(np.dot(w,test_i)+b) == 1:
return 1
else:
return 0 # 测试数据,返回正确率
def test(w,b,test_data):
right_count = 0
for test_i in test_data:
classx = classify(w,b,test_i)
if classx == 1:
right_count += 1
rate = right_count/len(test_data)
return rate if __name__=="__main__":
MA,x= get_train_data()
test_data = get_test_data()
GRAM = get_gram(MA)
yN = MA[:,2]
xN = MA[:,0:2]
# 定义初始值
alpha = [0]*200
b = 0
n = 1
# 初始化最优的正确率
rate0 = 0 # print(alpha,b)
# 循环不同的学习率n,寻求最优的学习率,即最终的rate0
# w0,b0为对应的最优参数
for i in np.linspace(0.01,1,100):
n = i
alpha,b = train(MA, alpha, b, GRAM, yN)
alphap = np.column_stack((alpha*yN,alpha*yN))
w = sum(alphap*xN)
rate = test(w,b,test_data)
# print(w,b)
rate = test(w,b,test_data)
if rate > rate0:
rate0 = rate
w0 = w
b0 = b
print('Until now, the best result of the accuracy on test data is '+str(rate))
print('with w='+str(w0)+' b='+str(b0))
print('---------------------------------------------')
# 在选定最优的学习率后,作图
plot_classify(w0,b0,x,rate0)
plt.show()

  输出:

感知机的对偶形式——python3实现的更多相关文章

  1. 2. 感知机(Perceptron)基本形式和对偶形式实现

    1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...

  2. 1. 感知机原理(Perceptron)

    1. 感知机原理(Perceptron) 2. 感知机(Perceptron)基本形式和对偶形式实现 3. 支持向量机(SVM)拉格朗日对偶性(KKT) 4. 支持向量机(SVM)原理 5. 支持向量 ...

  3. 机器学习理论基础学习3.1--- Linear classification 线性分类之感知机PLA(Percetron Learning Algorithm)

    一.感知机(Perception) 1.1 原理: 感知机是二分类的线性模型,其输入是实例的特征向量,输出的是事例的类别,分别是+1和-1,属于判别模型. 假设训练数据集是线性可分的,感知机学习的目标 ...

  4. 机器学习笔记(一)&#183; 感知机算法 &#183; 原理篇

    这篇学习笔记强调几何直觉,同时也注重感知机算法内部的动机.限于篇幅,这里仅仅讨论了感知机的一般情形.损失函数的引入.工作原理.关于感知机的对偶形式和核感知机,会专门写另外一篇文章.关于感知机的实现代码 ...

  5. 统计学习方法与Python实现(一)——感知机

    统计学习方法与Python实现(一)——感知机 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 假设输入的实例的特征空间为x属于Rn的n维特征向量, ...

  6. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  7. 感知机算法(PLA)代码实现

    目录 1. 引言 2. 载入库和数据处理 3. 感知机的原始形式 4. 感知机的对偶形式 5. 多分类情况-one vs. rest 6. 多分类情况-one vs. one 7. sklearn实现 ...

  8. 【python与机器学习实战】感知机和支持向量机学习笔记(一)

    对<Python与机器学习实战>一书阅读的记录,对于一些难以理解的地方查阅了资料辅以理解并补充和记录,重新梳理一下感知机和SVM的算法原理,加深记忆. 1.感知机 感知机的基本概念 感知机 ...

  9. 原始感知机入门——python3实现

    运用最简单的原始(对应的有对偶)感知机算法实现线性分类. 参考书目:<统计学习方法>(李航) 算法原理: 踩到的坑:以为误分类的数据只使用一次,造成分类结果很差,在train函数内加个简单 ...

随机推荐

  1. 使用caffe时遇到的问题

    1.Error: (unix time) try if you are using GNU date 问题所在: 在训练train.txt图片列表位置和生成的lmbd数据不符. 解决方案: 修改tra ...

  2. onthink 数据库连接配置

    define('UC_DB_DSN', 'mysql://root:@127.0.0.1:3306/app'); // 数据库连接,使用Model方式调用API必须配置此项 /* 数据库配置 */ ' ...

  3. 周爱民:真正的架构师是没有title的(图灵访谈)

    周爱民,现任豌豆荚架构师,国内软件开发界资深软件工程师.从1996年起开始涉足商业软件开发,历任部门经理.区域总经理.高级软件工程师.平台架构师等职,有18年的软件开发与架构.项目管理及团队建设经验, ...

  4. java 选择排序

    import java.util.Scanner; public class SelectionSort { public static void sort(int[] a, int n){ if(n ...

  5. Socket通讯

    复习贴,资料大多来自百科.看了一遍理解了一遍,把绕口的话按语义给改了`_>` 对于一个网络连接来说,套接字是平等的,并没有差别,不因为在服务器端或在客户端而产生不同级别.不管是Socket还是S ...

  6. jboss设置图片上传大小

    <http-listener name="default" socket-binding="http" max-post-size="10485 ...

  7. [转]搭建高可用mongodb集群(二)—— 副本集

    在上一篇文章<搭建高可用MongoDB集群(一)——配置MongoDB> 提到了几个问题还没有解决. 主节点挂了能否自动切换连接?目前需要手工切换. 主节点的读写压力过大如何解决? 从节点 ...

  8. 浏览器主页被hao123贱贱的篡改的一种方式

    今天打开一个PDF文件(正经文件,不要想歪了),误点了一个“编辑”按钮,出来发现浏览器主页被篡改了,包括chrome和IE.通过一个网址"www.qquuu8.com"跳转到hao ...

  9. 【PCB】电子元件封装大全及封装常识

    电子元件封装大全及封装常识 电子元件封装大全及封装常识 一.什么叫封装封装,就是指把硅片上的电路管脚,用导线接引到外部接头处,以便与其它器件连接.封装形式是指安装半导体集成电路芯片用的外壳.它不仅起着 ...

  10. [转载]Bison-Flex 笔记

    FLEX 什么是FLEX?它是一个自动化工具,可以按照定义好的规则自动生成一个C函数yylex(),也成为扫描器(Scanner).这个C函数把文本串作为输入,按照定义好的规则分析文本串中的字符,找到 ...