本编博客继续分享简单的机器学习的R语言实现。

今天是关于简单的线性回归方程问题的优化问题

常用方法,我们会考虑随机梯度递降,好处是,我们不需要遍历数据集中的所有元素,这样可以大幅度的减少运算量。

具体的算法参考下面:

首先我们先定义我们需要的参数的Notation

上述算法中,为了避免过拟合,我们采用了L2的正则化,在更新步骤中,我们会发现,这个正则项目,对参数更新的影响

下面是代码部分:

## Load Library
library(ggplot2)
library(reshape2)
library(mvtnorm) ## Function for reading the data
read_data <- function(fname, sc) {
data <- read.csv(file=fname,head=TRUE,sep=",")
nr = dim(data)[1]
nc = dim(data)[2]
x = data[1:nr,1:(nc-1)]
y = data[1:nr,nc]
if (isTRUE(sc)) {
x = scale(x) ## Scale x
y = scale(y) ## Scale y
}
return (list("x" = x, "y" = y))
}

我们定义了一个读取数据的方程,这里,我们会把数据集给scale一下,可以保证进一步提高运算速度

## Matrix Product Function
predict_func <- function(Phi, w){
return(Phi%*%w)
} ## Function to compute the cost function
train_obj_func <- function (Phi, w, label, lambda){
# Cost funtion including the L2 norm regulaztion
return(.5 * mean((predict_func(Phi, w) - label)^2) + .5 * lambda * t(w) %*% w)
} ## Return the errors for each iteration
get_errors <- function(data, label, W) {
n_weights = dim(W)[1]
Phi <- cbind('X0' = 1, data)
errors = matrix(,nrow=n_weights, ncol=2)
for (tau in 1:n_weights) {
errors[tau,1] = tau
errors[tau,2] = train_obj_func(Phi, W[tau,],label, 0) ## Get the errors, set the lambda to 0
}
return(errors)
}

 同时,我们定义了计算矩阵乘法,计算目标函数以及求误差的方程。

sgd_train <- function(train_x, train_y, lambda, eta, epsilon, max_epoch) {

    ## Prepare the traindata
## Attach the 1 for X0
Phi <- as.matrix(cbind('X0'=1, train.data)) ## Calculate the max iteration time for the SGD
train_len = dim(train_x)[1]
tau_max = max_epoch * train_len W <- matrix(,nrow=tau_max, ncol=ncol(Phi))
set.seed(1234)
## Random Generate the start parameter
W[1,] <- runif(ncol(Phi)) tau = 1 # counter
## Create a dateframe to store the value of cost function for each iteration
obj_func_val <-matrix(,nrow=tau_max, ncol=1)
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda) while (tau <= tau_max){ # check termination criteria
if (obj_func_val[tau,1]<=epsilon) {break} # shuffle data:
train_index <- sample(1:train_len, train_len, replace = FALSE) # loop over each datapoint
for (i in train_index) {
# increment the counter
tau <- tau + 1
if (tau > tau_max) {break} # make the weight update
y_pred <- predict_func(Phi[i,], W[tau-1,])
W[tau,] <- sgd_update_weight(W[tau-1,], Phi[i,], train_y[i], y_pred, lambda, eta) # keep track of the objective funtion
obj_func_val[tau,1] = train_obj_func(Phi, W[tau,],train_y, lambda)
}
}
# resulting values for the training objective function as well as the weights
return(list('vals'=obj_func_val,'W'=W))
} # updating the weight vector
sgd_update_weight <- function(W_prev, x, y_true, y_pred, lambda, eta) {
## Computer the Gradient
grad = - (y_true-y_pred) * x
## Here I just combine the regularisation term with prev w
return(W_prev*(1-eta * lambda) - eta * grad)
}

  根据上述我们写的计算更新目标函数和参数的方法,讲算法用R实现

下面是实验部分

## Load the train data and train label
train.data <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$x
train.label <- read_data('assignment1_datasets/Task1C_train.csv',TRUE)$y
## Load the testdata and test label
test.data <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$x
test.label <- read_data('assignment1_datasets/Task1C_test.csv',TRUE)$y # Set MAX EPOCH max_epoch = 18 ## Implement SGD with Ridge regression
options(warn=-1) ## Initilize
## Set the related parameters
epsilon = .001 ## Terimation threshold
eta = .01 ## Learning Rate
lambda= 0.5 ## Regularisation parmater ## Run SGD
## Cost function values
train_res2 = sgd_train(train.data, train.label, lambda, eta, epsilon, max_epoch)
## Calulate the errors
## To be mentioned here, we will only visulisation for the train error to check the converge result
errors2 = get_errors(train.data, train.label, train_res2$W)  

 接着,我们把SGD的error plot给绘制出来

## Visulastion for SGD
plot(train_res2$val, main="SGD", type="l", col="blue",
xlab="iteration", ylab="training objective function")

  

  

这里我们的方程比较简单,可以看到,目标函数很快就收敛了。

简单线性回归问题的优化(SGD)R语言的更多相关文章

  1. 【数据分析】线性回归与逻辑回归(R语言实现)

    文章来源:公众号-智能化IT系统. 回归模型有多种,一般在数据分析中用的比较常用的有线性回归和逻辑回归.其描述的是一组因变量和自变量之间的关系,通过特定的方程来模拟.这么做的目的也是为了预测,但有时也 ...

  2. 一个简单文本分类任务-EM算法-R语言

    一.问题介绍 概率分布模型中,有时只含有可观测变量,如单硬币投掷模型,对于每个测试样例,硬币最终是正面还是反面是可以观测的.而有时还含有不可观测变量,如三硬币投掷模型.问题这样描述,首先投掷硬币A,如 ...

  3. R语言

    什么是R语言编程? R语言是一种用于统计分析和为此目的创建图形的编程语言.不是数据类型,它具有用于计算的数据对象.它用于数据挖掘,回归分析,概率估计等领域,使用其中可用的许多软件包. R语言中的不同数 ...

  4. R语言:用简单的文本处理方法优化我们的读书体验

    博客总目录:http://www.cnblogs.com/weibaar/p/4507801.html 前言 延续之前的用R语言读琅琊榜小说,继续讲一下利用R语言做一些简单的文本处理.分词的事情.其实 ...

  5. R语言-简单线性回归图-方法

    目标:利用R语言统计描绘50组实验对比结果 第一步:导入.csv文件 X <- read.table("D:abc11.csv",header = TRUE, sep = & ...

  6. R 语言中的简单线性回归

    ... sessionInfo() # 查询版本及系统和库等信息 getwd() path <- "E:/RSpace/R_in_Action" setwd(path) rm ...

  7. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

  8. R语言解读多元线性回归模型

    转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...

  9. 多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

随机推荐

  1. KiB和KB的区别

    原文链接:http://blog.csdn.net/starshine/article/details/8226320 原来没太注意MB与MiB的区别,甚至没太关注还有MiB这等单位,今天认真了一下, ...

  2. 2016-2017-2 20155312 实验三敏捷开发与XP实践实验报告

    1.研究code菜单 Move Line/statement Down/Up:将某行.表达式向下.向上移动一行 suround with:用 try-catch,for,if等包裹语句 comment ...

  3. kbmmw 中虚拟文件操作入门

    kbmmw 中一直有一个功能,但是基本上都没有提过,但是在实际应用中,却非常有用,这个功能就是 虚拟文件包功能,他可以把一大堆文件保存到一个文件里面,方便后台管理. kbmmw 的虚拟文件在单元kbm ...

  4. 【转】shell expect spawn、linux expect 用法小记 看着舒服点

    使用expect实现自动登录的脚本,网上有很多,可是都没有一个明白的说明,初学者一般都是照抄.收藏.可是为什么要这么写却不知其然.本文用一个最短的例子说明脚本的原理. 脚本代码如下: ######## ...

  5. Web 开发

    Django(发音:[`dʒæŋɡəʊ]) 是一个开放源代码的Web应用框架,由Python写成.采用了MTV的框架模式,模型(Model).模板(Template)和视图(Views).

  6. Echarts的使用方法

    效果图: 1. 在echarts官网下载包,解压后,将文件Echarts\echarts-2.2.7\echarts-2.2.7\doc\example\www\js\echarts.js文件拷贝,放 ...

  7. boost--function

    1.简介 function是一个模板类,它就像一个包装了函数指针或函数对象的容器(只有一个元素).可以把它想象成一个泛化的函数指针,而且他非常适合代替函数指针,存储用于回调的函数.如下定义了一个能够容 ...

  8. 关于preg_match() / preg_replace()函数的一点小说明

    int preg_match ( string $pattern , string $subject [, array &$matches [, int $flags = 0 [, int $ ...

  9. 第28章:MongoDB-索引--过期索引(TTL)

    ①过期索引(TTL) TTL索引是让文档的某个日期时间满足条件的时候自动删除文档,这是一种特殊的索引,这种索引不是为了提高查询速度的,TTL索引类似于缓存,缓存时间到了就过期了,就要被删除了 ②范例: ...

  10. POJ 3110 Jenny's First Exam (贪心)

    题意:告诉你n 个科目的考试日期,在考试当天不能复习,每一个科目的最早复习时间不能早于考试时间的t天,每一天你可以复习完一科,也只能复习一科,求最晚的复习时间!. 析:由于题目给定的时间都在1900 ...