CMU DLSys 课程笔记 2 - ML Refresher / Softmax Regression

本节 Slides | 本节课程视频

这一节课是对机器学习内容的一个复习,以 Softmax Regression 为例讲解一个典型的有监督机器学习案例的整个流程以及其中的各种概念。预期读者应当对机器学习的基本概念有一定的了解。

目录

机器学习基础

针对于手写数字识别这一问题,传统的图像识别算法可能是首先找到每个数字的特征,然后手写规则来识别每个数字。这种方式的问题在于,当我们想要识别的对象的种类很多时,我们需要手动设计的规则就会变得非常复杂,而且这些规则很难设计得很好,因为我们很难找到一个完美的特征来区分所有的对象。

而机器学习方法则是让计算机自己学习如何区分这些对象,我们只需要给计算机一些数据,让它自己学习如何区分这些数据,这样的方法就可以很好地解决这个问题。

具体到有监督机器学习方法,我们需要给计算机一些数据,这些数据包含了我们想要识别的对象的一些特征,以及这些对象的标签,计算机需要从这些数据中学习到如何区分这些对象,如下图

图里中间部分即为我们需要建立的机器学习模型,通常由以下内容组成:

  1. 模型假设:描述我们如何将输入(例如数字的图像)映射到输出(例如类别标签或不同类别标签的概率)的“程序结构”,通过一组参数进行参数化。
  2. 损失函数:指定给定假设(即参数选择)在所关注任务上的表现“好坏”的函数。
  3. 优化方法:确定一组参数(近似)最小化训练集上损失总和的过程。

Softmax Regression 案例

问题定义

让我们考虑一个 k 类分类问题,其中我们有:

  • 训练数据:\(x^{(i)} \in \R^n\), \(y^{(i)} \in {1,\dots, k}\) for \(i = 1, … , m\)
  • 其中 \(n\) 为输入数据的维度,\(m\) 为训练数据的数量,\(k\) 为分类类别的数量
  • 针对 28x28 的 MNIST 数字进行分类,\(n = 28 \cdot 28 = 784\), \(k = 10\), \(m = 60,000\)

模型假设

我们的模型假设是一个线性模型,即

\[h_\theta(x) = \theta^T x
\]

其中 \(\theta \in \R^{n\times k}\) 是我们的模型参数,\(x \in \R^n\) 是输入数据。

机器学习中,经常使用的形式是多个输入叠加在一起的形式,即

\[X \in \R^{m\times n}= \begin{bmatrix} {x^{(1)}}^T \\ \vdots \\ {x^{(m)}}^T \end{bmatrix}, \quad y = \begin{bmatrix} y^{(1)} \\ \vdots \\ y^{(m)} \end{bmatrix}
\]

然后线性模型假设可以写为

\[h_\theta(X) = \begin{bmatrix} {x^{(1)}}^T\theta \\ \vdots \\ {x^{(m)}}^T\theta \end{bmatrix} = X\theta
\]

损失函数

最简单的损失函数就是根据是否预测正确,如

\[\ell_{e r r}(h(x), y)=\left\{\begin{array}{ll}
0 & \text { if } \operatorname{argmax}_{i} h_{i}(x)=y \\
1 & \text { otherwise }
\end{array}\right.
\]

我们经常用这个函数来评价分类器的质量。但是这个函数有一个重大的缺陷是非连续,因此我们无法使用梯度下降等优化方法来优化这个函数。

取而代之,我们会用一个连续的损失函数,即交叉熵损失函数

\[z_{i}=p(\text { label }=i)=\frac{\exp \left(h_{i}(x)\right)}{\sum_{j=1}^{k} \exp \left(h_{j}(x)\right)} \Longleftrightarrow z \equiv \operatorname{softmax}(h(x))
\]
\[\ell_{ce}(h(x), y) = -\log p(\text { label }=y) = -h_y(x) + \log \sum_{j=1}^k \exp(h_j(x))
\]

这个损失函数是连续的,而且是凸的,因此我们可以使用梯度下降等优化方法来优化这个损失函数。

优化方法

我们的目标是最小化损失函数,即

\[\min_{\theta} \frac{1}{m} \sum_{i=1}^m \ell_{ce}(h_\theta(x^{(i)}), y^{(i)})
\]

我们使用梯度下降法来优化这个损失函数,针对函数\(f:\R^{n\times k} \rightarrow \R\),其梯度为

\[\nabla_\theta f(\theta) = \begin{bmatrix} \frac{\partial f}{\partial \theta_{11}} & \dots & \frac{\partial f}{\partial \theta_{1k}} \\ \vdots & \ddots & \vdots \\ \frac{\partial f}{\partial \theta_{n1}} & \dots & \frac{\partial f}{\partial \theta_{nk}} \end{bmatrix}
\]

梯度的几何含义为函数在某一点的梯度是函数在该点上升最快的方向,如下图

我们可以使用梯度下降法来优化这个损失函数,即

\[\theta \leftarrow \theta - \alpha \nabla_\theta f(\theta)
\]

其中 \(\alpha \gt 0\) 为学习率,即每次更新的步长。学习率过大会导致无法收敛,学习率过小会导致收敛速度过慢。

我们不需要针对每个样本都计算一次梯度,而是可以使用一个 batch 的样本来计算梯度,这样可以减少计算量,同时也可以减少梯度的方差,从而加快收敛速度,这种方法被称为随机梯度下降法(Stochastic Gradient Descent, SGD)。该方法的算法描述如下

\[\left.
\begin{array}{l}
\text { Repeat:} \\
\text { \quad Sample a batch of data } X \in \R^{B\times n}, y \in \{1, \dots, k\}^B \\
\text { \quad Update parameters } \theta \leftarrow \theta-\alpha \nabla_{\theta} \frac{1}{B} \sum_{i=1}^{B} \ell_{ce}\left(h_{\theta}\left(x^{(i)}\right), y^{(i)}\right)
\end{array}
\right.
\]

前面都是针对 SGD 的描述,但是损失函数的梯度还没有给出,我们一般使用链式法则进行计算,首先计算 softmax 函数本身的梯度

\[\frac{\partial \ell(h, y)}{\partial h_i} = \frac{\partial}{\partial h_i} \left( -h_y + \log \sum_{j=1}^k \exp(h_j) \right) = -e_y + \frac{\exp(h_i)}{\sum_{j=1}^k \exp(h_j)}
\]

写成矩阵形式即为

\[\nabla_h \ell(h, y) = -e_y + \operatorname{softmax}(h)
\]

然后计算损失函数对模型参数的梯度

\[\frac{\partial \ell(h, y)}{\partial \theta} = \frac{\partial \ell(\theta^T x, y)}{\partial \theta} = \frac{\partial \ell(h, y)}{\partial h} \frac{\partial h}{\partial \theta} = x(\operatorname{softmax}(h) - e_y)^T
\]

写成矩阵形式即为

\[\nabla_\theta \ell(h, y) = X^T (\operatorname{softmax}(X\theta) - \mathbb{I}_y)
\]

完整算法描述

最终算法描述为

\[\left.
\begin{array}{l}
\text { Repeat:} \\
\text { \quad Sample a batch of data } X \in \R^{B\times n}, y \in \{1, \dots, k\}^B \\
\text { \quad Update parameters } \theta \leftarrow \theta-\alpha X^T (\operatorname{softmax}(X\theta) - \mathbb{I}_y)
\end{array}
\right.
\]

以上就是完整的 Softmax Regression 的算法描述,最终在 hw0 中我们会实现这个算法,其分类错误率将低于 8 %。

CMU DLSys 课程笔记 2 - ML Refresher / Softmax Regression的更多相关文章

  1. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  2. 斯坦福CS229机器学习课程笔记 Part1:线性回归 Linear Regression

    机器学习三要素 机器学习的三要素为:模型.策略.算法. 模型:就是所要学习的条件概率分布或决策函数.线性回归模型 策略:按照什么样的准则学习或选择最优的模型.最小化均方误差,即所谓的 least-sq ...

  3. ML:吴恩达 机器学习 课程笔记(Week1~2)

    吴恩达(Andrew Ng)机器学习课程:课程主页 由于博客编辑器有些不顺手,所有的课程笔记将全部以手写照片形式上传.有机会将在之后上传课程中各个ML算法实现的Octave版本. Linear Reg ...

  4. 深度学习课程笔记(十二) Matrix Capsule

    深度学习课程笔记(十二) Matrix Capsule with EM Routing  2018-02-02  21:21:09  Paper: https://openreview.net/pdf ...

  5. 深度学习课程笔记(十一)初探 Capsule Network

    深度学习课程笔记(十一)初探 Capsule Network  2018-02-01  15:58:52 一.先列出几个不错的 reference: 1. https://medium.com/ai% ...

  6. ng-深度学习-课程笔记-0: 概述

    课程概述 这是一个专项课程(Specialization),包含5个独立的课程,学习这门课程后做了相关的笔记记录. (1) 神经网络和深度学习 (2)  改善深层神经网络:超参数调试,正则化,优化 ( ...

  7. CS231n课程笔记翻译9:卷积神经网络笔记

    译者注:本文翻译自斯坦福CS231n课程笔记ConvNet notes,由课程教师Andrej Karpathy授权进行翻译.本篇教程由杜客和猴子翻译完成,堃堃和李艺颖进行校对修改. 原文如下 内容列 ...

  8. CS231n课程笔记翻译8:神经网络笔记 part3

    译者注:本文智能单元首发,译自斯坦福CS231n课程笔记Neural Nets notes 3,课程教师Andrej Karpathy授权翻译.本篇教程由杜客翻译完成,堃堃和巩子嘉进行校对修改.译文含 ...

  9. CS231n课程笔记翻译7:神经网络笔记 part2

    译者注:本文智能单元首发,译自斯坦福CS231n课程笔记Neural Nets notes 2,课程教师Andrej Karpathy授权翻译.本篇教程由杜客翻译完成,堃堃进行校对修改.译文含公式和代 ...

  10. CS231n课程笔记翻译6:神经网络笔记 part1

    译者注:本文智能单元首发,译自斯坦福CS231n课程笔记Neural Nets notes 1,课程教师Andrej Karpathy授权翻译.本篇教程由杜客翻译完成,巩子嘉和堃堃进行校对修改.译文含 ...

随机推荐

  1. 【RocketMQ】Rebalance负载均衡总结

    消费者负载均衡,是指为消费组下的每个消费者分配订阅主题下的消费队列,分配了消费队列消费者就可以知道去消费哪个消费队列上面的消息,这里针对集群模式,因为广播模式,所有的消息队列可以被消费组下的每个消费者 ...

  2. linux的认知与基本命令

    一.linux的了解 1. 什么是Linux?       a,Linux是一种免费使用和自由传播的类UNIX操作系统,其内核由林纳斯·本纳第克特·托瓦兹于1991年10月5日首次发布.它主要受到Mi ...

  3. 【Unity3D】UI Toolkit样式选择器

    1 前言 ​ UI Toolkit简介 中介绍了样式属性,UI Toolkit容器 和 UI Toolkit元素 中介绍了容器和元素,本文将介绍样式选择器(Selector),主要包含样式类选择器(C ...

  4. 如何为你的WSL2更换最新的6.5.7kernel

    1.如果你像我一样,喜欢折腾你的 WSL2 ,这里是安装内核 6.X 的方法. 2.这是一个坏主意,可能会导致系统不稳定.数据损坏和其他问题.也可能会没事的,但不要怪我. Arch linux的wsl ...

  5. 让物体动起来,Unity的几种移动方式

    一.前言 在大部分的Unity游戏开发中,移动是极其重要的一部分,移动的手感决定着游戏的成败,一个优秀的移动手感无疑可以给游戏带来非常舒服的体验.而Unity中有多种移动方法,使用Transform, ...

  6. c#中单例模式详解

    基础介绍:   确保一个类只有一个实例,并提供一个全局访问点.   适用于需要频繁实例化然后销毁的对象,创建对象消耗资源过多,但又经常用到的对象,频繁访问数据库或文件的对象.   其本质就是保证在整个 ...

  7. 小景的Dba之路--压力测试和Oracle数据库缓存

    小景最近在做系统查询接口的压测相关的工作,其中涉及到了查询接口的数据库缓存相关的内容,在这里做一个汇总和思维发散,顺便简单说下自己的心得: 针对系统的查询接口,首次压测执行的时候TPS较低,平均响应时 ...

  8. JUC并发编程学习笔记(十七)彻底玩转单例模式

    彻底玩转单例模式 单例中最重要的思想------->构造器私有! 恶汉式.懒汉式(DCL懒汉式!) 恶汉式 package single; //饿汉式单例(问题:因为一上来就把对象加载了,所以可 ...

  9. 二叉搜索树 & 平衡树

    二叉搜索树 & 平衡树 专题 0x00 前言 我 AFO 了,但不代表不写 Code 了... CSP-S 在数据结构上吃了大亏,就差这一点就一等了,所以觉得好好整整. 本篇博客主要研究二叉搜 ...

  10. 比较Spring Security6.X 和 Spring Security 5.X的不同

    项目使用了SpringBoot3 ,因此 SpringSecurity也相应进行了升级 版本由5.4.5升级到了6.1.5 写法上发生了很大的变化,最显著的变化之一就是对 WebSecurityCon ...