机器学习技法笔记:Homework #8 kNN&RBF&k-Means相关习题
原文地址:https://www.jianshu.com/p/1db700f866ee
问题描述



程序实现
# kNN_RBFN.py
# coding:utf-8
import numpy as np
import matplotlib.pyplot as plt
def ReadData(dataFile):
with open(dataFile, 'r') as f:
lines = f.readlines()
data_list = []
for line in lines:
line = line.strip().split()
data_list.append([float(l) for l in line])
dataArray = np.array(data_list)
return dataArray
def sign(n):
if(n>=0):
return 1
else:
return -1
def kNN(k,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
distArray=np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1)
id_list=np.argsort(distArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY
def GetZeroOneError(predY,dataY):
return (predY!=dataY).sum()/dataY.shape[0]
def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.ylim(0,1)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return
def RBFNetwork(k,gamma,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
gaussianDistArray=np.exp(-gamma*np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1))
id_list=np.argsort(gaussianDistArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY
if __name__=="__main__":
dataArray=ReadData("hw8_train.dat")
testArray=ReadData("hw8_test.dat")
k_list=[1,3,5,7,9]
ein_list=[]
eout_list=[]
for k in k_list:
predY=kNN(k,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=kNN(k,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1]))
# 12
plot_bar_chart(k_list,ein_list,nameX="k",nameY="Ein(gk-nbor)",saveName="12.png")
# 14
plot_bar_chart(k_list,eout_list,nameX='k',nameY="Eout(gk-bor)",saveName="14.png")
gamma_list=[-3,-1,0,1,2]
ein_list=[]
eout_list=[]
for gamma in gamma_list:
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1]))
# 16
plot_bar_chart(X=gamma_list,Y=ein_list,nameX="log10(gamma)",nameY="Ein(guniform)",saveName="16.png")
# 18
plot_bar_chart(X=gamma_list,Y=eout_list,nameX="log10(gamma)",nameY="Eout(guniform)",saveName="18.png")
# kMeans.py
# coding:utf-8
from numpy import random
from kNN_RBFN import *
def kMeans(t,k,dataArray):
num_data=dataArray.shape[0]
random.seed(t)
centreIDList=random.randint(0,num_data,k).tolist()
nowCentreArray=dataArray[centreIDList,:]
tmpCentreArray=np.array(nowCentreArray)
ein=1000000
nowEin=ein-1
dict={}
while(nowEin<ein):
ein=nowEin
dict = {}
for n in range(num_data):
distArray=np.sum((nowCentreArray-dataArray[n,:])**2,axis=1)
minID=np.argmin(distArray)
tmpCentreArray[minID]=(tmpCentreArray[minID]+dataArray[n,:])/2
try:
dict[minID].append(dataArray[n,:])
except:
dict[minID]=[]
dict[minID].append(dataArray[n,:])
nowCentreArray=np.array(tmpCentreArray)
nowEin=GetEin(nowCentreArray,dict)
return nowCentreArray,dict
def GetEin(nowCentreArray,dict):
k=nowCentreArray.shape[0]
ein=0
for i in range(k):
if i not in dict.keys():
continue
data=np.array(dict[i])
ein+=np.average(np.sum((data-nowCentreArray[i])**2,axis=1))
return ein
def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return
if __name__=="__main__":
dataArray=ReadData("hw8_nolabel_train.dat")
k_list=[2,4,6,8,10]
ein_list=[]
for k in k_list:
ein=0
for t in range(500):
nowCentreArray,dict=kMeans(t,k,dataArray)
ein+=GetEin(nowCentreArray,dict)
ein_list.append(ein/500)
plot_bar_chart(k_list,ein_list,nameX="k",nameY="the average Ein over 500 experiments",saveName="20.png")
运行结果





机器学习技法笔记:Homework #8 kNN&RBF&k-Means相关习题的更多相关文章
- 机器学习技法笔记(2)-Linear SVM
从这一节开始学习机器学习技法课程中的SVM, 这一节主要介绍标准形式的SVM: Linear SVM 引入SVM 首先回顾Percentron Learning Algrithm(感知器算法PLA)是 ...
- 机器学习十大算法之KNN(K最近邻,k-NearestNeighbor)算法
机器学习十大算法之KNN算法 前段时间一直在搞tkinter,机器学习荒废了一阵子.如今想重新写一个,发现遇到不少问题,不过最终还是解决了.希望与大家共同进步. 闲话少说,进入正题. KNN算法也称最 ...
- 机器学习技法笔记:Homework #6 AdaBoost&Kernel Ridge Regression相关习题
原文地址:http://www.jianshu.com/p/9bf9e2add795 AdaBoost 问题描述 程序实现 # coding:utf-8 import math import nump ...
- 机器学习技法笔记:Homework #5 特征变换&Soft-Margin SVM相关习题
原文地址:https://www.jianshu.com/p/6bf801bdc644 特征变换 问题描述 程序实现 # coding: utf-8 import numpy as np from c ...
- 机器学习技法笔记:Homework #7 Decision Tree&Random Forest相关习题
原文地址:https://www.jianshu.com/p/7ff6fd6fc99f 问题描述 程序实现 13-15 # coding:utf-8 # decision_tree.py import ...
- 机器学习技法笔记:14 Radial Basis Function Network
Roadmap RBF Network Hypothesis RBF Network Learning k-Means Algorithm k-Means and RBF Network in Act ...
- 机器学习技法笔记:08 Adaptive Boosting
Roadmap Motivation of Boosting Diversity by Re-weighting Adaptive Boosting Algorithm Adaptive Boosti ...
- 机器学习技法笔记:15 Matrix Factorization
Roadmap Linear Network Hypothesis Basic Matrix Factorization Stochastic Gradient Descent Summary of ...
- 机器学习技法笔记:16 Finale
Roadmap Feature Exploitation Techniques Error Optimization Techniques Overfitting Elimination Techni ...
随机推荐
- js千位符 | js 千位分隔符 | js 金额格式化
js 千位分隔符 千位分隔符,其实就是数字中的逗号.依西方的习惯,人们在数字中加进一个符号,以免因数字位数太多而难以看出它的值.所以人们在数字中,每隔三位数加进一个逗号,也就是千位分隔符,以便更加容易 ...
- POJ3641 Pseudoprime numbers (幂取模板子)
给你两个数字p,a.如果p是素数,并且ap mod p = a,输出“yes”,否则输出“no”. 很简单的板子题.核心算法是幂取模(算法详见<算法竞赛入门经典>315页). 幂取模板子: ...
- List、Map、Set三个接口存取元素时,各有什么特点
List接口以特定索引来存取元素,可以有重复元素 Set接口不可以存放重复元素(使用equals方法区分是否重复) Map接口保存的是键值对(key-value-pair)映射,映射关系可以是一对一或 ...
- javascript 中的函数
/* 第二天 */ 函数 函数是js里最有趣的东西了,函数实际上就是对象,每个函数Function类型的实例,函数名实际上是指向函数对象的指针.不带圆括号的函数时访问函数的指针,带圆括号的是调 ...
- Javascript基础一(介绍)
Javascript的发展历史: JavaScript在设计之初只是为了做表单验证.但是现如今,JavaScript已经成为了一门功能全面的编程语言,已经是WEB中不可缺少的一部分,如今的JavaSc ...
- ionic3 多级联动城市选择插件 ion-multi-picker
1.效果演示 2.npm安装扩展包依赖 ion-multi-picker 组件 npm install ion-multi-picker --save 3.在app.module.ts中导入插件模块 ...
- 2018-12-2-C#-Span-入门
title author date CreateTime categories C# Span 入门 lindexi 2018-12-02 11:32:46 +0800 2018-06-18 11:1 ...
- teb教程4
障碍物避障以及机器人足迹模型 简介:障碍物避障的实现,以及必要参数的设置对于机器人足迹模型和其对应的影响 1.障碍物避障是怎样工作的 1.1 惩罚项 障碍物避障作为整个路径优化的一部分.显然,优化是找 ...
- CICS FILE OPEN
CEMT I CECD V FILE() GROUP() CEDA check error log in JESYSMSG FILE OPEN/CLOSE STATUS CICS ACTION res ...
- 【学术篇】The Xuanku Inversion Magic学习笔记
退役之前写的 然后因为退役就咕咕咕了... 后来发现数学考试能用的到个鬼就发布出来了QwQ 主要是方便自己没登录的时候查阅... 显然子集什么的是没有学会的QwQ 所以学OI的话不要看本文!!!!!& ...