原文地址:http://www.jianshu.com/p/4bc01760ac20

问题描述



程序实现

17-18

# coding: utf-8

import numpy as np
import matplotlib.pyplot as plt def sign(n):
if(n>0):
return 1
else:
return -1 def gen_data():
data_X=np.random.uniform(-1,1,(20,1))# [-1,1)
data_Y=np.zeros((20,1))
idArray=np.random.permutation([i for i in range(20)])
for i in range(20):
if(i<20*0.2):
data_Y[idArray[i]][0]=-sign(data_X[idArray[i]][0])
else:
data_Y[idArray[i]][0] = sign(data_X[idArray[i]][0])
data=np.concatenate((data_X,data_Y),axis=1)
return data def decision_stump(dataArray):
minErrors=20
min_s_theta_list=[]
num_data=dataArray.shape[0]
data=dataArray.tolist()
data.sort(key=lambda x:x[0])
for s in [-1.0,1.0]:
for i in range(num_data):
if(i==num_data-1):
theta=(data[i][0]+1.0)/2
else:
theta=(data[i][0]+data[i+1][0])/2
errors=0
for i in range(20):
pred=s*sign(data[i][0]-theta)
if(pred!=data[i][1]):
errors+=1
if(minErrors>errors):
minErrors=errors
min_s_theta_list=[]
elif(minErrors<errors):
continue
min_s_theta_list.append((s, theta))
i=np.random.randint(low=0,high=len(min_s_theta_list))
min_s,min_theta=min_s_theta_list[i]
return minErrors,min_s,min_theta def computeEinEout(minErrors,min_s,min_theta):
Ein=minErrors/20
Eout=0.5+0.3*min_s*(abs(min_theta)-1)
return Ein,Eout if __name__=="__main__":
Ein_list=[]
Eout_list=[]
for i in range(5000):
dataArray=gen_data()
minErrors,min_s,min_theta=decision_stump(dataArray)
Ein,Eout=computeEinEout(minErrors,min_s,min_theta)
Ein_list.append(Ein)
Eout_list.append(Eout) # show results
# 17 & 18
print("the average Ein: ",sum(Ein_list)/5000)
print("the average Eout: ",sum(Eout_list)/5000) plt.figure(figsize=(16,6))
plt.subplot(121)
plt.hist(Ein_list)
plt.xlabel("Ein")
plt.ylabel("frequency")
plt.subplot(122)
plt.hist(Eout_list)
plt.xlabel("Eout")
plt.ylabel("frequency")
plt.savefig("EinEout.png")

19-20

# coding: utf-8

import numpy as np

def read_data(dataFile):
with open(dataFile, 'r') as file:
data_list = []
for line in file.readlines():
line = line.strip().split()
data_list.append([float(l) for l in line])
data_array = np.array(data_list)
return data_array def predict(s,theta,dataX):
num_data=dataX.shape[0]
res=s*np.sign(dataX-theta)
return res def decision_stump(dataArray):
min_s_theta_list=[]
num_data=dataArray.shape[0]
minErrors=num_data
data=dataArray.tolist()
data.sort(key=lambda x:x[0])
dataArray=np.array(data)
dataX=dataArray[:,0].reshape(num_data,1)
dataY=dataArray[:,1].reshape(num_data,1)
for s in [-1.0,1.0]:
for i in range(num_data):
if(i==num_data-1):
theta=(dataX[i][0]*2+1)/2
else:
theta=(dataX[i][0]+dataX[i+1][0])/2
pred=predict(s,theta,dataX)
errors=np.sum(pred!=dataY)
if(minErrors>errors):
minErrors=errors
min_s_theta_list=[]
elif(minErrors<errors):
continue
min_s_theta_list.append((s, theta))
i=np.random.randint(low=0,high=len(min_s_theta_list))
min_s,min_theta=min_s_theta_list[i]
return minErrors,min_s,min_theta def best_of_best(candidate):
candidate.sort(key=lambda x:x[1])
counts=0
for i in range(len(candidate)):
if(candidate[i][1]!=candidate[0][1]):
break
counts+=1
i=np.random.randint(low=0,high=counts)
return candidate[i][0],candidate[i][1],candidate[i][2],candidate[i][3] if __name__=="__main__":
data_array=read_data("hw2_train.dat")
num_data=data_array.shape[0]
num_dim=data_array.shape[1]-1
candidate=[]
dataY=data_array[:,-1].reshape(num_data,1)
for i in range(num_dim):
dataX=data_array[:,i].reshape(num_data,1)
min_errors,min_s,min_theta=decision_stump(np.concatenate((dataX,dataY),axis=1))
candidate.append([i,min_errors,min_s,min_theta])
min_id,min_errors,min_s,min_theta=best_of_best(candidate)
print("the optimal decision stump:\n","s: ",min_s,"\ntheta: ",min_theta)
print("the Ein of the optimal decision stump:\n",min_errors/num_data) test_array=read_data("hw2_test.dat")
num_test=test_array.shape[0]
testY=test_array[:,-1].reshape(num_test,1)
num_dim=test_array.shape[1]-1
testX=test_array[:,min_id].reshape(num_test,1)
pred=predict(min_s,min_theta,testX)
print("the Eout of the optimal decision stump by Etest:\n",np.sum(pred!=testY)/num_test)

运行结果

17-18



19-20

机器学习基石笔记:Homework #2 decision stump相关习题的更多相关文章

  1. 机器学习基石笔记:Homework #1 PLA&PA相关习题

    原文地址:http://www.jianshu.com/p/5b4a64874650 问题描述 程序实现 # coding: utf-8 import numpy as np import matpl ...

  2. 机器学习基石笔记:Homework #4 Regularization&Validation相关习题

    原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf 问题描述 程序实现 # coding: utf-8 import numpy as np import math ...

  3. 机器学习基石笔记:Homework #3 LinReg&LogReg相关习题

    原文地址:http://www.jianshu.com/p/311141f2047d 问题描述 程序实现 13-15 # coding: utf-8 import numpy as np import ...

  4. 机器学习基石:Homework #0 SVD相关&常用矩阵求导公式

  5. 林轩田机器学习基石笔记1—The Learning Problem

    机器学习分为四步: When Can Machine Learn? Why Can Machine Learn? How Can Machine Learn? How Can Machine Lear ...

  6. 机器学习基石笔记:01 The Learning Problem

    原文地址:https://www.jianshu.com/p/bd7cb6c78e5e 什么时候适合用机器学习算法? 存在某种规则/模式,能够使性能提升,比如准确率: 这种规则难以程序化定义,人难以给 ...

  7. 机器学习基石笔记:04 Feasibility of Learning

    原文地址:https://www.jianshu.com/p/f2f4d509060e 机器学习是设计算法\(A\),在假设集合\(H\)里,根据给定数据集\(D\),选出与实际模式\(f\)最为相近 ...

  8. 机器学习基石笔记:03 Types of Learning

    原文地址:https://www.jianshu.com/p/86b2a9cef742 一.学习的分类 根据输出空间\(Y\):分类(二分类.多分类).回归.结构化(监督学习+输出空间有结构): 根据 ...

  9. 机器学习技法笔记:09 Decision Tree

    Roadmap Decision Tree Hypothesis Decision Tree Algorithm Decision Tree Heuristics in C&RT Decisi ...

随机推荐

  1. 理解B+树算法和Innodb索引

    一.innodb存储引擎索引概述: innodb存储引擎支持两种常见的索引:B+树索引和哈希索引. innodb支持哈希索引是自适应的,innodb会根据表的使用情况自动生成哈希索引. B+树索引就是 ...

  2. K:找寻数组中第n大的数组元素的三个算法

    相关介绍:  给定一个数组,找出该数组中第n大的元素的值.其中,1<=n<=length.例如,给定一个数组A={2,3,6,5,7,9,8,1,4},当n=1时,返回9.解决该问题的算法 ...

  3. Zookeeper Curator API 使用

    0. 原生 ZOOKEEPER JAVA API  http://www.cnblogs.com/rocky-fang/p/9030438.html 1. 概述 Curator采用cache封装对事件 ...

  4. 利用PIE实现全球云分布图的效果

    1.问题背景: 最近项目中获得了一份全球云分布图,客户要求把云显示在全球地图上,出现云的效果,如下图所示: [全球云分布图] [世界地图云示意图] 2.解决思路 咨询专业的业务人员,建议我测试下试试地 ...

  5. RGB与INT类型的转换

    开发时遇到的问题,设置图层样式时颜色的返回值是uint,一时不知改怎么转换为C#常用的RGB值了. 一番百度,结果如下: RGB = R + G * 256 + B * 256 * 256 因此可得到 ...

  6. MUI框架-08-窗口管理-创建子页面

    MUI框架-08-窗口管理-创建子页面 之前写过这一篇,不知道为什么被删了,我就大概写了,抱歉 创建子页面是为了,页面切换时,外面的页面不动,让 MUI 写出来的页面更接近原生 app 官方文档:ht ...

  7. Android链接蓝牙电子称

    蓝牙一直是我内心屏蔽的一个模块哈哈哈哈!然而今天我不得不正视它了,我百度了看了好多因为需要设备匹配所以设备不在没办法测试,几天之后设备到了.因为没有接触过,看到返回的打印出来的菱形方块就以为是错了.于 ...

  8. 微信小程序开发7-JavaScript脚本

    1.小程序的主要开发语言是 JavaScript ,开发者使用 JavaScript 来开发业务逻辑以及调用小程序的 API 来完成业务需求. 2.ECMAScript 在大部分开发者看来,ECMAS ...

  9. GPG error: http://extras.ubuntu.com trusty Release: The following signatures couldn't be verified because the public key is not available: NO_PUBKEY F60F4B3D7FA2AF80

    今天在更新运行apt-get update的时候出现了如下的错误: W: GPG error: http://extras.ubuntu.com trusty Release: The followi ...

  10. MessageFormat使用记录

    1.日志里面需要记录入参,之前一般使用StringUtils.formt()方法,但是如果入参含有空值,就会报错.这个时候可以使用MessageFormat方法.用法 format(String pa ...