Goal of training a model is to find a set of weights and biases that have low loss, on average, across all examples. —— Descending into ML: Training and Loss

注释:教程中的 loss ≠ 平均方差,而是指单个 labeled example 的方差(也就是误差 loss ),这里的 reducing loss 是指减小整体的误差(就是 MSE 了)

An Iterative Approach

我们的最终目的就得到一个较好的 model(假设 feature 只有一个,那么这个 model 很可能是一条直线),这个 model 可以比较准确地帮助我们推断、预测。

那么什么是比较好的 model 呢?具有总体低 loss 的 model 就是好的 model ,问题在于,如何计算总体 loss 以及怎样减小总体 loss 以逼近我们想要的 model ,在上一节已经谈到,我们经常用平均方差来判断一个 model 的好坏,平均方差大的就是总体 loss 大的,说明 model 不好,平均方差越趋近于零,则 model 越完美。

可以想象这样一个过程:

我们先随便画一条直线,然后计算它的 loss (假设已经一个这样的函数,例如 getLoss 什么的),上面 MSE算出来是 8 ,发现好大 ,这个时候我们微调直线的斜率 w 以及 y轴截距 b 得到一条新的直线,就像下面这样,MSE 是 4 ,发现更小了, 于是我们继续 ......

就这样循环往复: 微调斜率w 和截距b → 计算 loss → 微调斜率w 和截距b →  计算 loss → 微调斜率w 和截距b → ...

直到 loss 减小到几乎不再变化(术语叫做模型已经收敛),我们就成功了,整个过程可以用下面这张图描述:

Gradient Descent

这一小节讲述具体如何“微调斜率w ”。

假设我们有足够的时间和计算资源对每一个w的可能取值计算 loss ,那么我们一定会得到一个这样的图像:

收敛问题只有一个最小值,就像图像上看到的,仅有一个地方的斜率为 0 .

如果我们真有那么多的时间和计算资源,像上面那样做就可以很直观、容易地得到最恰当的 w 了,不幸地是,现实中我们可没有那么多时间,“对每个可能的 w 在整个数据集上计算 loss ”这一做法效率太低了。有一种更好地方式来找最低点,它在机器学习中非常流行,叫“梯度下降法”。

这种方式的第一步是随机取一个点(随机定下一个 w),很多算法都直接取 0 ,取哪一点都是无关紧要的。

之后,梯度下降算法计算这一点的斜率(导数),如果有多个权重 w ,那么梯度就是这一点关于各个 w的偏导数构成的向量。

记住,梯度是一个向量,因此它具有方向和大小。因此,梯度下降算法朝着负梯度迈出一步(step),以便尽快的减少 loss . 它将梯度大小的一部分加到起点处得到下一个点,并不断重复上述步骤,越来越接近最小值。

Learning Rate

As noted, the gradient vector has both a direction and a magnitude. Gradient descent algorithms multiply the gradient by a scalar known as the learning rate (also sometimes called step size) to determine the next point. For example, if the gradient magnitude is 2.5 and the learning rate is 0.01, then the gradient descent algorithm will pick the next point 0.025 away from the previous point.

很多程序员都花费大量的时间调整学习速率,学习速率太小,那么整个学习过程会非常漫长,但是,如果学习速率太大,你甚至可能永远得不到最终的结果(点总是在最低点的两端来回弹跳)。

每一个回归问题都有一个比较恰当的学习速率,它取决于函数的平缓程度。如果你知道 loss-权重 函数的梯度很小,就可以放心地用大的学习速率尝试。(因为下一点的距离是学习速率 * 梯度,梯度小的话,学习速率大一点也无妨,并不容易因为前进太多而错过最低点)

PS. The Goldilocks learning rate 代表着最佳学习速率,实践中,找到完美的学习速率并非必要的,我们只需要找到一个“足够大又不过大”的学习速率就好了。

Stochastic Gradient Descent 随机梯度下降

full-batch iteration 每次迭代都用整个数据集

Stochastic gradient descent (SGD) 每次迭代随机仅仅选择 1 个 example

Mini-batch stochastic gradient descent (mini-batch SGD) 每次迭代随机选择 10 ~ 1000 个 example

Google's Machine Learning Crash Course #03# Reducing Loss的更多相关文章

  1. Google's Machine Learning Crash Course #01# Introducing ML & Framing & Fundamental terminology

    INDEX Introducing ML Framing Fundamental machine learning terminology Introducing ML What you learn ...

  2. Google's Machine Learning Crash Course #02# Descending into ML

    INDEX How do we know if we have a good line Linear Regression Training and Loss How do we know if we ...

  3. Google's Machine Learning Crash Course #04# First Steps with TensorFlow

    1.使用 TensorFlow 的建议 Which API(s) should you use? You should use the highest level of abstraction tha ...

  4. 学习笔记之Machine Learning Crash Course | Google Developers

    Machine Learning Crash Course  |  Google Developers https://developers.google.com/machine-learning/c ...

  5. Machine Learning 学习笔记 03 最小二乘法、极大似然法、交叉熵

    损失函数. 最小二乘法. 极大似然估计. 复习一下对数. 交叉熵. 信息量. 系统熵的定义. KL散度

  6. How do I learn machine learning?

    https://www.quora.com/How-do-I-learn-machine-learning-1?redirected_qid=6578644   How Can I Learn X? ...

  7. 学习笔记之机器学习(Machine Learning)

    机器学习 - 维基百科,自由的百科全书 https://zh.wikipedia.org/wiki/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0 机器学习是人工智能的一个分 ...

  8. 基于Windows 机器学习(Machine Learning)的图像分类(Image classification)实现

    今天看到一篇文章  Google’s Image Classification Model is now Free to Learn  说是狗狗的机器学习速成课程(Machine Learning C ...

  9. machine learning----->谷歌Cloud Machine Learning平台

    1.谷歌Cloud Machine Learning平台简介: 机器学习的三要素是数据源.计算资源和模型.谷歌在这三个方面都有强大的支撑:谷歌不仅有种类丰富且数量庞大的数据资源,而且有强大的计算机群提 ...

随机推荐

  1. Java基础之理解封装,继承,多态三大特性

    目录 封装 继承 多态 封装 封装隐藏了类的内部实现机制,可以在不影响使用的情况下改变类的内部结构,同时也保护了数据.对外界而已它的内部细节是隐藏的,暴露给外界的只是它的访问方法. 代码理解 publ ...

  2. [EF]vs15+ef6+mysql code first方式

    写在前面 前面有篇文章,尝试了db first方式,但不知道是什么原因一直没有成功,到最后也没解决,今天就尝试下code first的方式. [EF]vs15+ef6+mysql这个问题,你遇到过么? ...

  3. windows乱码

    对于支持 UNICODE的应用程序,Windows 会默认使用 Unicode编码.对于不支持Unicode的应用程序Windows 会采用 ANSI编码 (也就是各个国家自己制定的标准编码方式,如对 ...

  4. 初次安装hive-2.1.0启动报错问题解决方法

    首次安装hive-2.1.0,通过bin/hive登录hive shell命令行,报错如下: [hadoop@db03 hive-2.1.0]$ bin/hive which: no hbase in ...

  5. 火币Huobi API

    本文介绍火币Huobi API REST API 简介 火币为用户提供了一套全新的API,可以帮用户快速接入火币PRO站及HADAX站的交易系统,实现程序化交易. 访问地址 适用站点 适用功能 适用交 ...

  6. Tunnel Warfare--- hdu1540 线段树求连续子区间

    题目链接 题意:有n个村庄,编号分别为1-n:由于战争会破坏村庄,但是我们也会修复: D x代表村庄x被破坏: Q x是求与x相连的有几个没有被破坏: R 是修复最后一次被破坏的村庄: 接下来有m个操 ...

  7. csv文件的读写

    # -*- coding: utf-8 -*- """ Spyder Editor This is a temporary script file. "&quo ...

  8. 如何正确的把 Java 数组 Array 转为列表 List

    最近想把 java 数组转成 List,网上普遍的答案都是 Arrays.asList: String[] a = new String[] {"hello", "wor ...

  9. Spark2.x学习笔记:Spark SQL的SQL

    Spark SQL所支持的SQL语法 select [distinct] [column names]|[wildcard] from tableName [join clause tableName ...

  10. 如何给Pycharm加上头行 # *_*coding:utf-8 *_*?

    File>Setting>Editor>Code Style>File and Code Templates>Python Script  后面加上 # *_*codin ...