在matlab中做Regularized logistic regression

原理:

我的代码:

function [J, grad] = costFunctionReg(theta, X, y, lambda)
%COSTFUNCTIONREG Compute cost and gradient for logistic regression with regularization
% J = COSTFUNCTIONREG(theta, X, y, lambda) computes the cost of using
% theta as the parameter for regularized logistic regression and the
% gradient of the cost w.r.t. to the parameters. % Initialize some useful values
m = length(y); % number of training examples % You need to return the following variables correctly
J = 0;
grad = zeros(size(theta)); % ====================== YOUR CODE HERE ======================
% Instructions: Compute the cost of a particular choice of theta.
% You should set J to the cost.
% Compute the partial derivatives and set grad to the partial
% derivatives of the cost w.r.t. each parameter in theta h = sigmoid(X*theta);
theta2=[0;theta(2:end)]; J_partial = sum((-y).*log(h)+(y-1).*log(1-h))./m;
J_regularization= (lambda/(2*m)).*sum(theta2.^2);
J = J_partial+J_regularization; grad_partial = sum((h-y).*X)/m;
grad_regularization = lambda.*theta2./m;
grad = grad_partial+grad_regularization; % ============================================================= end

运行结果:

标黄的与下面的预期对比发现不同

尝试删去

+grad_regularization

.rtcContent { padding: 30px }
.lineNode { font-size: 10pt; font-family: Menlo, Monaco, Consolas, "Courier New", monospace; font-style: normal; font-weight: normal }

部分结果符合预期,部分不符合

尝试大佬代码

%Hypotheses
hx = sigmoid(X * theta);
%%The cost without regularization
J_partial = (-y' * log(hx) - (1 - y)' * log(1 - hx)) ./ m;
%%Regularization Cost Added
J_regularization = (lambda/(2*m)) * sum(theta(2:end).^2);
%%Cost when we add regularization
J = J_partial + J_regularization;
%Grad without regularization
grad_partial = (1/m) * (X' * (hx -y));
%%Grad Cost Added
grad_regularization = (lambda/m) .* theta(2:end);
grad_regularization = [0; grad_regularization];
grad = grad_partial + grad_regularization;

完全成功!?我不李姐……

观察大佬代码发现,我和大佬的区别在于:

最开始的theta向量和计算J(theta)和grad时候使用sum的数目

故尝试修改和大佬数目一样多的sum

    h = sigmoid(X*theta);
theta2=[0;theta(2:end)]; J_partial = (-y).*log(h)+(y-1).*log(1-h)./m;
J_regularization= (lambda/(2*m)).*sum(theta2.^2);
J = J_partial+J_regularization; grad_partial = (h-y).*X/m;
grad_regularization = lambda.*theta2./m;
grad = grad_partial+grad_regularization;

结果:incompatible不兼容

文档对该错误的解释如下

事已至此,只好向大佬更近一步!

    h = sigmoid(X*theta);

    J_partial = (-y).*log(h)+(y-1).*log(1-h)./m;
J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
J = J_partial+J_regularization; grad_partial = (h-y).*X/m;
grad_regularization = lambda.*theta(2:end)./m;
grad_regularization2=[0;grad_regularization]; grad = grad_partial+grad_regularization2;

为什么还是不兼容?

到底哪里出了问题?

最后,尝试离大佬更近一步,把grad_partial里的(h-y).*X/m变成了(1/m) * (X' * (h -y))

    h = sigmoid(X*theta);

    J_partial = (1/m).*((-y).*log(h)+(y-1).*log(1-h));
J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
J = J_partial+J_regularization; grad_partial = (1/m) * (X' * (h -y));
grad_regularization = (lambda/m).*theta(2:end);
grad_regularization = [0; grad_regularization];
grad = grad_partial+ grad_regularization;

舒服了!

但,等等,上面怎么那么多行,数值还不对?看来不能完全靠大佬,还得自己改!!!

    h = sigmoid(X*theta);

    J_partial = (1/m).*sum((-y).*log(h)+(y-1).*log(1-h));
J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
J = J_partial+J_regularization; grad_partial = (1/m) * (X' * (h -y));
grad_regularization = (lambda/m).*theta(2:end);
grad_regularization = [0; grad_regularization];
grad = grad_partial+ grad_regularization;

最终,得到了满意的答案

以及

 总结一下出现的问题

01不兼容,就像上面说明的那样,行列不匹配

(解决方法:查看有无sum、是值还是array,把系数往前放,修改两数相乘的顺序)

02加入grad_regularization后,grad(1,5)的后四项都出现了问题(很神奇地值相等),

一旦去掉又与正确值有小范围差距(缺少grad_regularization导致的)

说明grad_regularization存在问题

而如果一开始就将theta变为第一行元素是0的矩阵,很容易出现不兼容的问题

大佬的代码提示我们特殊情况可以分出来特殊处理,也就是:

在计算J(θ)不使用矩阵,而是用除0外、后面的θ直接产出需要的值

在计算grad时,由于输出也是矩阵,所以可以创建一个含0和其他θ的矩阵

这样既可以避免不兼容,也可以得出正确的结果

最终的部分code如下

    h = sigmoid(X*theta);

    J_partial = (1/m).*sum((-y).*log(h)+(y-1).*log(1-h));
J_regularization= (lambda/(2*m)).*sum(theta(2:end).^2);
J = J_partial+J_regularization; grad_partial = (1/m) * (X' * (h -y));
grad_regularization = (lambda/m).*theta(2:end);
grad_regularization = [0; grad_regularization];
grad = grad_partial+ grad_regularization;

Logistic regression中regularization失败的解决方法探索(文末附解决后code)的更多相关文章

  1. Machine Learning - 第3周(Logistic Regression、Regularization)

    Logistic regression is a method for classifying data into discrete outcomes. For example, we might u ...

  2. logistic regression中的cost function选择

    一般的线性回归使用的cost function为: 但由于logistic function: 本身非凸函数(convex function), 如果直接使用线性回归的cost function的话, ...

  3. Windows 共享无线上网 无法启动ICS服务解决方法(WIN7 ICS服务启动后停止)

    Windows 共享无线上网 无法启动ICS服务解决方法(WIN7 ICS服务启动后停止) ICS 即Internet Connection Sharing,internet连接共享,可以使局域网上其 ...

  4. 斯坦福机器学习视频笔记 Week3 逻辑回归与正则化 Logistic Regression and Regularization

    我们将讨论逻辑回归. 逻辑回归是一种将数据分类为离散结果的方法. 例如,我们可以使用逻辑回归将电子邮件分类为垃圾邮件或非垃圾邮件. 在本模块中,我们介绍分类的概念,逻辑回归的损失函数(cost fun ...

  5. Andrew Ng Machine Learning 专题【Logistic Regression & Regularization】

    此文是斯坦福大学,机器学习界 superstar - Andrew Ng 所开设的 Coursera 课程:Machine Learning 的课程笔记. 力求简洁,仅代表本人观点,不足之处希望大家探 ...

  6. week3编程作业: Logistic Regression中一些难点的解读

    %% ============ Part : Compute Cost and Gradient ============ % In this part of the exercise, you wi ...

  7. 在IE浏览器中执行OpenFlashChart的reload方法时无法刷新的解决方法

    由于项目需求,需要在网页上利用图表展示相关数据的统计信息,采用了OpenFlashChart技术.OpenFlashChart是一款开源的以Flash和Javascript为技术基础的免费图表,用它能 ...

  8. (蓝牙)网络编程中,使用InputStream read方法读取数据阻塞的解决方法

    问题如题,这个问题困扰了我好几天,今天终于解决了,感谢[1]. 首先,我要做的是android手机和电脑进行蓝牙通信,android发一句话,电脑端程序至少就要做到接受到那句话.android端发送信 ...

  9. blocked because of many connection errors; unblock with 'mysqladmin flush-hosts;MySQL在远程访问时非常慢的解决方法;MySql链接慢的解决方法

     一:服务器异常:Host 'xx.xxx.xx.xxx' is blocked because of many connection errors; unblock with 'mysqladmin ...

随机推荐

  1. C#编程基础之字符串操作

    本文来源于复习基础知识的学习笔记.自用的同时希望也能帮到其他童鞋. 什么是编程语言? 计算机可以执行的指令.这些指令成为源代码或者代码 有什么用? 以人们可读可理解的方式编写指令.人们希望计算机执行指 ...

  2. 除了增删改查你对MySQL还了解多少?

    目录 除了增删改查你对MySQL还了解多少? MySQL授权远程连接 创建用户.授权 客户端与服务器连接的过程 TCP/IP 命名管道和共享内存 Unix域套接字文件 查询优化 MySQL中走与不走索 ...

  3. winform 学习之qq邮箱正则验证及常用正则

    这段时间一直再做winform相关的项目,记录了一些东西 qq邮箱正则表达式: 第一种:字母和数字组合邮箱判断 string str = "justin1107@qq.com"; ...

  4. 5分钟了解二叉树之LeetCode里的二叉树

    有读者反馈,现在谁不是为了找工作才学的数据结构,确实很有道理,是我肤浅了.所以为了满足大家的需求,这里总结下LeetCode里的数据结构.对于我们这种职场老人来说,刷LeetCode会遇到个很尴尬的问 ...

  5. JavaScript day03 循环

    循环 while循环 循环是重复性做一件事情 没有办法控制每次循环的时间长度 循环会增大程序时间复杂度(不建议无限循环嵌套 一般情况下不会嵌套超过两次) 死循环 是不会停止的循环 会导致电脑内存溢出 ...

  6. Spring Boot 的配置文件有哪几种格式?它们有什么区别?

    .properties 和 .yml,它们的区别主要是书写格式不同.    1).properties    app.user.name = javastack    2).yml    app:   ...

  7. 四种类型的数据节点 Znode?

    1.PERSISTENT-持久节点 除非手动删除,否则节点一直存在于 Zookeeper 上 2.EPHEMERAL-临时节点 临时节点的生命周期与客户端会话绑定,一旦客户端会话失效(客户端与 zoo ...

  8. Dubbo 如何停机?

    Dubbo 是通过 JDK 的 ShutdownHook 来完成优雅停机的,所以如果使用 kill -9 PID 等强制关闭指令,是不会执行优雅停机的,只有通过 kill PID 时,才 会执行.

  9. 为什么要用Spring

    1.方便解耦,简化开发 通过Spring提供的IoC容器,我们可以将对象之间的依赖关系交由Spring进行控制,避免硬编码所造成的过度程序耦合.有了Spring,用户不必再为单实例模式类.属性文件解析 ...

  10. window10使用putty传输文件到Linux服务器

    由于Linux和Linux可以使用scp进行传输文件,而window系统无法向Linux传输文件,当然,有xshell等等类似的工具可以进行操作:putty工具就可以实现,毕竟zip压缩包也不大,启动 ...