简单线性回归问题的优化(SGD)R语言
本编博客继续分享简单的机器学习的R语言实现。
今天是关于简单的线性回归方程问题的优化问题

常用方法,我们会考虑随机梯度递降,好处是,我们不需要遍历数据集中的所有元素,这样可以大幅度的减少运算量。
具体的算法参考下面:
首先我们先定义我们需要的参数的Notation


上述算法中,为了避免过拟合,我们采用了L2的正则化,在更新步骤中,我们会发现,这个正则项目,对参数更新的影响
下面是代码部分:
## Load Library
library(ggplot2)
library(reshape2)
library(mvtnorm) ## Function for reading the data
read_data <- function(fname, sc) {
data <- read.csv(file=fname,head=TRUE,sep=",")
nr = dim(data)[1]
nc = dim(data)[2]
x = data[1:nr,1:(nc-1)]
y = data[1:nr,nc]
if (isTRUE(sc)) {
x = scale(x) ## Scale x
y = scale(y) ## Scale y
}
return (list("x" = x, "y" = y))
}
我们定义了一个读取数据的方程,这里,我们会把数据集给scale一下,可以保证进一步提高运算速度
## Matrix Product Function
predict_func <- function(Phi, w){
return(Phi%*%w)
} ## Function to compute the cost function
train_obj_func <- function (Phi, w, label, lambda){
# Cost funtion including the L2 norm regulaztion
return(.5 * mean((predict_func(Phi, w) - label)^2) + .5 * lambda * t(w) %*% w)
} ## Return the errors for each iteration
get_errors <- function(data, label, W) {
n_weights = dim(W)[1]
Phi <- cbind('X0' = 1, data)
errors = matrix(,nrow=n_weights, ncol=2)
for (tau in 1:n_weights) {
errors[tau,1] = tau
errors[tau,2] = train_obj_func(Phi, W[tau,],label, 0) ## Get the errors, set the lambda to 0
}
return(errors)
}
同时,我们定义了计算矩阵乘法,计算目标函数以及求误差的方程。
sgd_train <- function(train_x, train_y, lambda, eta, epsilon, max_epoch) {
## Prepare the traindata
## Attach the 1 for X0
Phi <- as.matrix(cbind('X0'=1, train.data))
## Calculate the max iteration time for the SGD
train_len = dim(train_x)[1]
tau_max = max_epoch * train_len
W <- matrix(,nrow=tau_max, ncol=ncol(Phi))
set.seed(1234)
## Random Generate the start parameter
W[1,] <- runif(ncol(Phi))
tau = 1 # counter
## Create a dateframe to store the value of cost function for each iteration
obj_func_val <-matrix(,nrow=tau_max, ncol=1)
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
while (tau <= tau_max){
# check termination criteria
if (obj_func_val[tau,1]<=epsilon) {break}
# shuffle data:
train_index <- sample(1:train_len, train_len, replace = FALSE)
# loop over each datapoint
for (i in train_index) {
# increment the counter
tau <- tau + 1
if (tau > tau_max) {break}
# make the weight update
y_pred <- predict_func(Phi[i,], W[tau-1,])
W[tau,] <- sgd_update_weight(W[tau-1,], Phi[i,], train_y[i], y_pred, lambda, eta)
# keep track of the objective funtion
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
}
}
# resulting values for the training objective function as well as the weights
return(list('vals'=obj_func_val,'W'=W))
}
# updating the weight vector
sgd_update_weight <- function(W_prev, x, y_true, y_pred, lambda, eta) {
## Computer the Gradient
grad = - (y_true-y_pred) * x
## Here I just combine the regularisation term with prev w
return(W_prev*(1-eta * lambda) - eta * grad)
}
根据上述我们写的计算更新目标函数和参数的方法,讲算法用R实现
下面是实验部分
## Load the train data and train label
train.data <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$x
train.label <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$y
## Load the testdata and test label
test.data <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$x
test.label <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$y # Set MAX EPOCH max_epoch = 18 ## Implement SGD with Ridge regression
options(warn=-1) ## Initilize
## Set the related parameters
epsilon = .001 ## Terimation threshold
eta = .01 ## Learning Rate
lambda= 0.5 ## Regularisation parmater ## Run SGD
## Cost function values
train_res2 = sgd_train(train.data, train.label, lambda, eta, epsilon, max_epoch)
## Calulate the errors
## To be mentioned here, we will only visulisation for the train error to check the converge result
errors2 = get_errors(train.data, train.label, train_res2$W)
接着,我们把SGD的error plot给绘制出来
## Visulastion for SGD
plot(train_res2$val, main="SGD", type="l", col="blue",
xlab="iteration", ylab="training objective function")

这里我们的方程比较简单,可以看到,目标函数很快就收敛了。
简单线性回归问题的优化(SGD)R语言的更多相关文章
- 【数据分析】线性回归与逻辑回归(R语言实现)
文章来源:公众号-智能化IT系统. 回归模型有多种,一般在数据分析中用的比较常用的有线性回归和逻辑回归.其描述的是一组因变量和自变量之间的关系,通过特定的方程来模拟.这么做的目的也是为了预测,但有时也 ...
- 一个简单文本分类任务-EM算法-R语言
一.问题介绍 概率分布模型中,有时只含有可观测变量,如单硬币投掷模型,对于每个测试样例,硬币最终是正面还是反面是可以观测的.而有时还含有不可观测变量,如三硬币投掷模型.问题这样描述,首先投掷硬币A,如 ...
- R语言
什么是R语言编程? R语言是一种用于统计分析和为此目的创建图形的编程语言.不是数据类型,它具有用于计算的数据对象.它用于数据挖掘,回归分析,概率估计等领域,使用其中可用的许多软件包. R语言中的不同数 ...
- R语言:用简单的文本处理方法优化我们的读书体验
博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html 前言 延续之前的用R语言读琅琊榜小说,继续讲一下利用R语言做一些简单的文本处理.分词的事情.其实 ...
- R语言-简单线性回归图-方法
目标:利用R语言统计描绘50组实验对比结果 第一步:导入.csv文件 X <- read.table("D:abc11.csv",header = TRUE, sep = & ...
- R 语言中的简单线性回归
... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...
- R语言解读一元线性回归模型
转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...
- R语言解读多元线性回归模型
转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...
- 多元线性回归公式推导及R语言实现
多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...
随机推荐
- [C#]“正在终止线程”的问题
在C#中启用线程后,如果试图使用Abort方法来终止线程,那么必定会抛出“正在终止线程”的异常,一开始我也想过如何来避免这种异常出现,花了不少气力,但最后发现全是徒劳. 原因是一个正在运行的线程被终止 ...
- hdu-6058 Kanade's sum
题意:略 思路:要我们求每个区间第K大数之和,其实可以转换为求多少个区间的第K大数是X,然后我们在求和就好了. 那么我们可以从小到大枚举所有可能成为第K大的数.为什么从小到大呢? 因为从小到大我们就略 ...
- jquery库google加载
加载js库的时候可以加载google CDN,可以同时加载多个jquery库<script src="http://www.google.com/jsapi">< ...
- linux_关闭防火墙
centos6版本 永久关闭 chkconfig iptables off 查看状态 chkconfig iptables --list 此时关闭开机重新启动 service iptables sto ...
- fabric 安装
fabric 是一个python的库,fabric可以通过ssh批量管理服务器. 第一步安装依赖包 安装fabric依赖及pip yum install -y python-pip gcc pytho ...
- Oracle中根据当前时间和活动类型去数据库查询活动id
活动类型默认是1,代表邀请好友 select * from t_invite_activityinfo twhere sysdate >= t.begintime and sysdate< ...
- application.properties /application.yml官网查看配置;springboot application.properties 官网查看,info yml 查看;springboot.yml查看info;springboot.yml查看Actuator监控中心info
官网查看: https://docs.spring.io/spring-boot/docs/current-SNAPSHOT/reference/htmlsingle/#appendix 查看info ...
- js实现百度搜索框滑动固定顶部
现在很多主流系统例如百度.有道.爱奇艺等的搜索框都有一个特点,滑动到刚好看不到搜索框时,固定搜索框到顶部,这也算是一个对用户友好型的操 作. 在看了百度的js和css后自己摸索出来实现效果,还是学艺不 ...
- _编程语言_C++_Lambda函数与表达式
C++11提供了对匿名函数的支持,称为Lambda表达式函数 Lambda 表达式把函数看作对象.Lambda 表达式可以像对象一样使用,比如可以将它们赋给变量和作为参数传递,还可以像函数一样对其求值 ...
- Shell编程-08-Shell中的循环语句
目录 while语句 until语句 for语句 select语句 循环中断控制 循环语句总结 循环语句常用于重复执行一条命令或一组命令等,直到达到结束条件后,则终止执行.在Shell中常见的 ...