机器学习基石笔记:Homework #4 Regularization&Validation相关习题
原文地址:https://www.jianshu.com/p/3f7d4aa6a7cf
问题描述



程序实现
# coding: utf-8
import numpy as np
import math
import matplotlib.pyplot as plt
def sign(x):
if(x>=0):
return 1
else:
return -1
def read_data(dataFile):
with open(dataFile,'r') as f:
lines=f.readlines()
data_list=[]
for line in lines:
line=line.strip().split()
data_list.append([1.0] + [float(l) for l in line])
dataArray=np.array(data_list)
num_data=dataArray.shape[0]
num_dim=dataArray.shape[1]-1
dataX=dataArray[:,:-1].reshape((num_data,num_dim))
dataY=dataArray[:,-1].reshape((num_data,1))
return dataX,dataY
def w_reg(dataX,dataY,namuta):
num_dim=dataX.shape[1]
dataX_T=np.transpose(dataX)
tmp=np.dot(np.linalg.inv(np.dot(dataX_T,dataX)+namuta*np.eye(num_dim)),dataX_T)
return np.dot(tmp,dataY)
def pred(wREG,dataX):
pred=np.dot(dataX,wREG)
num_data=dataX.shape[0]
for i in range(num_data):
pred[i][0]=sign(pred[i][0])
return pred
def zero_one_cost(pred,dataY):
return np.sum(pred!=dataY)/dataY.shape[0]
if __name__=="__main__":
# train
dataX,dataY=read_data("hw4_train.dat")
print("\n13")
wREG=w_reg(dataX,dataY,namuta=10)
Ein=zero_one_cost(pred(wREG,dataX),dataY)
print("the Ein on the train set: ",Ein)
# test
testX,testY=read_data("hw4_test.dat")
Eout=zero_one_cost(pred(wREG,testX),testY)
print("the Eout on the test set: ",Eout)
l=[2,1,0,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]
print("\n14")
Ein_list=[]
Eout_list=[]
for i in l:
namuta=math.pow(10,i)
wREG=w_reg(dataX,dataY,namuta)
Ein_list.append(zero_one_cost(pred(wREG,dataX),dataY))
Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
id_in=Ein_list.index(min(Ein_list))
plt.figure()
plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
plt.xlabel("namuta")
plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
plt.ylabel("Ein")
plt.savefig("14.png")
print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
print("the Eout on such namuta: ", Eout_list[id_in])
print("\n15")
id_out = Eout_list.index(min(Eout_list))
plt.figure()
plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eout_list)
plt.xlabel("namuta")
plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
plt.ylabel("Eout")
plt.savefig("15.png")
print("the namuta with the minimun Eout: ", math.pow(10, l[id_out]))
trainX=dataX[:120]
trainY=dataY[:120]
validX=dataX[120:]
validY=dataY[120:]
# validation
print("\n16")
Ein_list.clear()
Eout_list.clear()
Eval_list=[]
for i in l:
namuta=math.pow(10,i)
wREG=w_reg(trainX,trainY,namuta)
Ein_list.append(zero_one_cost(pred(wREG,trainX),trainY))
Eout_list.append(zero_one_cost(pred(wREG,testX),testY))
Eval_list.append(zero_one_cost(pred(wREG,validX),validY))
id_in=Ein_list.index(min(Ein_list))
plt.figure()
plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Ein_list)
plt.xlabel("namuta")
plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
plt.ylabel("Ein")
plt.savefig("16.png")
print("the namuta with the minimun Ein: ",math.pow(10,l[id_in]))
print("the Eout on such namuta: ", Eout_list[id_in])
print("\n17")
id_val=Eval_list.index(min(Eval_list))
plt.figure()
plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
plt.xlabel("namuta")
plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
plt.ylabel("Eval")
plt.savefig("17.png")
print("the namuta with the minimun Eval: ",math.pow(10,l[id_val]))
print("the Eout on such namuta: ", Eout_list[id_val])
print("\n18")
wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
Ein=zero_one_cost(pred(wREG,dataX),dataY)
Eout = zero_one_cost(pred(wREG, testX), testY)
print("Ein: ",Ein)
print("Eout: ",Eout)
# 5-fold cross validation
print("\n19")
Eval_list.clear()
splX=np.split(dataX,5,axis=0)
splY=np.split(dataY,5,axis=0)
for j in l:
Eval = 0
namuta=math.pow(10,j)
for i in range(5):
li=[a for a in range(5)]
li.pop(i)
trainX=np.concatenate([splX[k] for k in li],axis=0)
trainY=np.concatenate([splY[k] for k in li],axis=0)
wREG=w_reg(trainX,trainY,namuta)
Eval+=zero_one_cost(pred(wREG,splX[i]),splY[i])/5
Eval_list.append(Eval)
id_val=Eval_list.index(min(Eval_list))
plt.figure()
plt.plot(np.power(np.full(shape=(len(l),),fill_value=10,dtype=np.int32),l),Eval_list)
plt.xlabel("namuta")
plt.xlim((math.pow(10,l[0]),math.pow(10,l[-1])))
plt.ylabel("Ecv")
plt.savefig("19.png")
print("the namuta with the minimun Ecv: ",math.pow(10,l[id_val]))
print("\n20")
wREG=w_reg(dataX,dataY,namuta=math.pow(10,l[id_val]))
Ein=zero_one_cost(pred(wREG,dataX),dataY)
Eout = zero_one_cost(pred(wREG, testX), testY)
print("Ein: ",Ein)
print("Eout: ",Eout)
运行结果
13

14


15


16


17


18

19


20

机器学习基石笔记:Homework #4 Regularization&Validation相关习题的更多相关文章
- 机器学习基石笔记:14 Regularization
一.正则化的假设集合 通过从高次多项式的H退回到低次多项式的H来降低模型复杂度, 以降低过拟合的可能性, 如何退回? 通过加约束条件: 如果加了严格的约束条件, 没有必要从H10退回到H2, 直接使用 ...
- 机器学习基石笔记:Homework #1 PLA&PA相关习题
原文地址:http://www.jianshu.com/p/5b4a64874650 问题描述 程序实现 # coding: utf-8 import numpy as np import matpl ...
- 机器学习基石笔记:Homework #2 decision stump相关习题
原文地址:http://www.jianshu.com/p/4bc01760ac20 问题描述 程序实现 17-18 # coding: utf-8 import numpy as np import ...
- 机器学习基石笔记:Homework #3 LinReg&LogReg相关习题
原文地址:http://www.jianshu.com/p/311141f2047d 问题描述 程序实现 13-15 # coding: utf-8 import numpy as np import ...
- 机器学习基石笔记:15 Validation
一.模型选择问题 如何选择? 视觉上 NO 不是所有资料都能可视化;人脑模型复杂度也得算上. 通过Ein NO 容易过拟合;泛化能力差. 通过Etest NO 能保证好的泛化,不过往往没法提前获得测试 ...
- 机器学习基石:Homework #0 SVD相关&常用矩阵求导公式
- 机器学习基石笔记:13 Hazard of Overfitting
泛化能力差和过拟合: 引起过拟合的原因: 1)过度VC维(模型复杂度高)------确定性噪声: 2)随机噪声: 3)有限的样本数量N. 具体实验来看模型复杂度Qf/确定性噪声.随机噪声sigma2. ...
- 【原】Coursera—Andrew Ng机器学习—课程笔记 Lecture 7 Regularization 正则化
Lecture7 Regularization 正则化 7.1 过拟合问题 The Problem of Overfitting7.2 代价函数 Cost Function7.3 正则化线性回归 R ...
- 林轩田机器学习基石笔记1—The Learning Problem
机器学习分为四步: When Can Machine Learn? Why Can Machine Learn? How Can Machine Learn? How Can Machine Lear ...
随机推荐
- python join函数
join()函数 语法: 'sep'.join(seq) 参数说明sep:分隔符.可以为空seq:要连接的元素序列.字符串.元组.字典上面的语法即:以sep作为分隔符,将seq所有的元素合并成一个新 ...
- 2017《Java技术》预备作业 计科1502任秀兴
阅读邹欣老师的博客,谈谈你期望的师生关系是什么样的? 我认为,学生和老师的关系,应该亦师亦友.可以以一种朋友的身份去进行教学,是我们理想中的课堂. 在生活中,老师和我们应该多沟通,成为朋友,在有感情的 ...
- 大数据学习之BigData常用算法和数据结构
大数据学习之BigData常用算法和数据结构 1.Bloom Filter 由一个很长的二进制向量和一系列hash函数组成 优点:可以减少IO操作,省空间 缺点:不支持删除,有 ...
- C++——运行时类型识别RTTI
1.实现方式 typeid运算符,返回表达式的类型 dynamic_cast运算符,基类的指针或引用安全地转换成派生类的指针或引用 2.适用于:使用基类的指针或引用执行派生类的操作,且该操作不是虚函数 ...
- HDU 1875 畅通工程再续 (Prim)
题目链接:HDU 1875 Problem Description 相信大家都听说一个"百岛湖"的地方吧,百岛湖的居民生活在不同的小岛中,当他们想去其他的小岛时都要通过划小船来实现 ...
- git和svn的比较
当前的市场上主流的两种项目开发版本控制软件就是Git和SVN,那么这二者到底有什么区别呢? 在我们公司,其实两个都用,跟对个人体验,我觉得两者差不多,都是进行代码的版本管理. 我觉得1.由于我是实习生 ...
- python作业/练习/实战:3、实现商品管理的一个程序
作业要求 实现一个商品管理的一个程序,运行程序有三个选项,输入1添加商品:输入2删除商品:输入3 查看商品信息1.添加商品: 商品名称:xx 商品如果已经存在,提示商品已存在 商品价格:xx数量只能为 ...
- nginx支持http2协议
1.http2协议 HTTP 2.0 的主要目标是改进传输性能,实现低延迟和高吞吐量.从另一方面看,HTTP 的高层协议语义并不会因为这次版本升级而受影响.所有HTTP 首部.值,以及它们的使用场景都 ...
- Gradle教程
Ant和Maven共享在Java市场上相当大的成功.ANT是在2000年发布了第一个版本的工具,它是基于程序编程思想的发展. 后来,人们在 Apache-Ivy的帮助下,网络接受插件和依赖管理的能力有 ...
- android中的Serveice组件
创建 配置 Service: 1.定义一个继承了Service类的子类 2.在 AndroidManifest.xml清单文件中对开发的Service进行配置 Service和Activity很相似, ...