优化与深度学习

优化与估计

尽管优化方法可以最小化深度学习中的损失函数值,但本质上优化方法达到的目标与深度学习的目标并不相同。

  • 优化方法目标:训练集损失函数值
  • 深度学习目标:测试集损失函数值(泛化性)
 %matplotlib inline
import sys
import d2lzh1981 as d2l
from mpl_toolkits import mplot3d # 三维画图
import numpy as np
def f(x): return x * np.cos(np.pi * x)
def g(x): return f(x) + 0.2 * np.cos(5 * np.pi * x) d2l.set_figsize((5, 3))
x = np.arange(0.5, 1.5, 0.01)
fig_f, = d2l.plt.plot(x, f(x),label="train error")
fig_g, = d2l.plt.plot(x, g(x),'--', c='purple', label="test error")
fig_f.axes.annotate('empirical risk', (1.0, -1.2), (0.5, -1.1),arrowprops=dict(arrowstyle='->'))
fig_g.axes.annotate('expected risk', (1.1, -1.05), (0.95, -0.5),arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('risk')
d2l.plt.legend(loc="upper right")

优化在深度学习中的挑战

  1. 局部最小值
  2. 鞍点
  3. 梯度消失

局部最小值

 def f(x):
return x * np.cos(np.pi * x) d2l.set_figsize((4.5, 2.5))
x = np.arange(-1.0, 2.0, 0.1)
fig, = d2l.plt.plot(x, f(x))
fig.axes.annotate('local minimum', xy=(-0.3, -0.25), xytext=(-0.77, -1.0),
arrowprops=dict(arrowstyle='->'))
fig.axes.annotate('global minimum', xy=(1.1, -0.95), xytext=(0.6, 0.8),
arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');

鞍点

 x = np.arange(-2.0, 2.0, 0.1)
fig, = d2l.plt.plot(x, x**3)
fig.axes.annotate('saddle point', xy=(0, -0.2), xytext=(-0.52, -5.0),
arrowprops=dict(arrowstyle='->'))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)');

 x, y = np.mgrid[-1: 1: 31j, -1: 1: 31j]
z = x**2 - y**2 d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 2, 'cstride': 2})
ax.plot([0], [0], [0], 'ro', markersize=10)
ticks = [-1, 0, 1]
d2l.plt.xticks(ticks)
d2l.plt.yticks(ticks)
ax.set_zticks(ticks)
d2l.plt.xlabel('x')
d2l.plt.ylabel('y');

梯度消失

 x = np.arange(-2.0, 5.0, 0.01)
fig, = d2l.plt.plot(x, np.tanh(x))
d2l.plt.xlabel('x')
d2l.plt.ylabel('f(x)')
fig.axes.annotate('vanishing gradient', (4, 1), (2, 0.0) ,arrowprops=dict(arrowstyle='->'))

凸性 (Convexity)

基础

集合

函数

 def f(x):
return 0.5 * x**2 # Convex def g(x):
return np.cos(np.pi * x) # Nonconvex def h(x):
return np.exp(0.5 * x) # Convex x, segment = np.arange(-2, 2, 0.01), np.array([-1.5, 1])
d2l.use_svg_display()
_, axes = d2l.plt.subplots(1, 3, figsize=(9, 3)) for ax, func in zip(axes, [f, g, h]):
ax.plot(x, func(x))
ax.plot(segment, func(segment),'--', color="purple")
# d2l.plt.plot([x, segment], [func(x), func(segment)], axes=ax)

Jensen 不等式

性质

  1. 无局部极小值
  2. 与凸集的关系
  3. 二阶条件

无局部最小值

与凸集的关系

 x, y = np.meshgrid(np.linspace(-1, 1, 101), np.linspace(-1, 1, 101),
indexing='ij') z = x**2 + 0.5 * np.cos(2 * np.pi * y) # Plot the 3D surface
d2l.set_figsize((6, 4))
ax = d2l.plt.figure().add_subplot(111, projection='3d')
ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10})
ax.contour(x, y, z, offset=-1)
ax.set_zlim(-1, 1.5) # Adjust labels
for func in [d2l.plt.xticks, d2l.plt.yticks, ax.set_zticks]:
func([-1, 0, 1])

凸函数与二阶导数

 def f(x):
return 0.5 * x**2 x = np.arange(-2, 2, 0.01)
axb, ab = np.array([-1.5, -0.5, 1]), np.array([-1.5, 1]) d2l.set_figsize((3.5, 2.5))
fig_x, = d2l.plt.plot(x, f(x))
fig_axb, = d2l.plt.plot(axb, f(axb), '-.',color="purple")
fig_ab, = d2l.plt.plot(ab, f(ab),'g-.') fig_x.axes.annotate('a', (-1.5, f(-1.5)), (-1.5, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('b', (1, f(1)), (1, 1.5),arrowprops=dict(arrowstyle='->'))
fig_x.axes.annotate('x', (-0.5, f(-0.5)), (-1.5, f(-0.5)),arrowprops=dict(arrowstyle='->'))

限制条件

拉格朗日乘子法

惩罚项

投影

机器学习(ML)十四之凸优化的更多相关文章

  1. Stanford机器学习---第十四讲.机器学习应用举例之Photo OCR

    http://blog.csdn.net/l281865263/article/details/50278745 本栏目(Machine learning)包括单参数的线性回归.多参数的线性回归.Oc ...

  2. SIGAI机器学习第十四集 支持向量机1

    讲授线性分类器,分类间隔,线性可分的支持向量机原问题与对偶问题,线性不可分的支持向量机原问题与对偶问题,核映射与核函数,多分类问题,libsvm的使用,实际应用 大纲: 支持向量机简介线性分类器分类间 ...

  3. 【转】机器学习教程 十四-利用tensorflow做手写数字识别

    模式识别领域应用机器学习的场景非常多,手写识别就是其中一种,最简单的数字识别是一个多类分类问题,我们借这个多类分类问题来介绍一下google最新开源的tensorflow框架,后面深度学习的内容都会基 ...

  4. Redis教程(十四):内存优化介绍

    转载于:http://www.itxuexiwang.com/a/shujukujishu/redis/2016/0216/142.html 一.特殊编码: 自从Redis 2.2之后,很多数据类型都 ...

  5. SIGAI机器学习第二十四集 聚类算法1

    讲授聚类算法的基本概念,算法的分类,层次聚类,K均值算法,EM算法,DBSCAN算法,OPTICS算法,mean shift算法,谱聚类算法,实际应用. 大纲: 聚类问题简介聚类算法的分类层次聚类算法 ...

  6. JMeter学习(三十四)测试报告优化

    如果按JMeter默认设置,生成报告如下: 从上图可以看出,结果信息比较简单,对于运行成功的case,还可以将就用着.但对于跑失败的case,就只有一行assert错误信息.(信息量太少了,比较难找到 ...

  7. 机器学习(十四)— kMeans算法

    参考文献:https://www.jianshu.com/p/5314834f9f8e # -*- coding: utf-8 -*- """ Created on Mo ...

  8. 猪猪的机器学习笔记(十四)EM算法

    EM算法 作者:樱花猪   摘要: 本文为七月算法(julyedu.com)12月机器学习第十次次课在线笔记.EM算法全称为Expectation Maximization Algorithm,既最大 ...

  9. 只需十四步:从零开始掌握 Python 机器学习(附资源)

    分享一篇来自机器之心的文章.关于机器学习的起步,讲的还是很清楚的.原文链接在:只需十四步:从零开始掌握Python机器学习(附资源) Python 可以说是现在最流行的机器学习语言,而且你也能在网上找 ...

随机推荐

  1. Linux学习_菜鸟教程_3

    我是在UBANTO上运行Linux的,开机启动时按下shift或者Esc都不能进入到grub,没有百度到可靠的教程. 暂时先这样吧.免得我把系统搞坏了,先学点实用的知识~~ Next Chapter

  2. JPA或Hibernate中使用原生SQL实现分页查询、排序

    发生背景:前端展示的数据需要来自A表和D表拼接,A表和D表根据A表的主键进行关联,D表的非主键字段关联C表的主键,根据条件筛选出符合的数据,并且根据A表的主键关联B表的主键(多主键)的条件,过滤A表中 ...

  3. 【tf.keras】使用手册

    目录 0. 简介 1. 安装 1.1 安装 CUDA 和 cuDNN 2. 数据集 2.1 使用 tensorflow_datasets 导入公共数据集 2.2 数据集过大导致内存溢出 2.3 加载 ...

  4. 【转】在Eclipse下搭建Android开发环境教程

    本文将全程演示Android开发环境的搭建过程,无需配置环境变量.所有软件都是写该文章时最新版本,希望大家喜欢.   一 相关下载 三 Eclipse配置 (1)Java JDK下载 1 安装andr ...

  5. 开发一个简单的ip解析webservice接口,并用springmvc生成客户端调用

    1.创建webservice工程,这次先采用jax-ws框架,下次再尝试jax-rs(restful) 2.写个实现ip解析的类,接收传入的ip,并返回解析信息 3.Myeclipse——>Ne ...

  6. Mysql中使用mysqldump进行导入导出sql文件

    纪念工作中的第一次删库跑路的经历 今天接到一个任务,是将一个测试库数据导到另一个测试库,然而我们公司的数据库是不让直连的,所以只能通过远程连接进行导库操作. 老大布置任务的时候让用dump命令进行操作 ...

  7. Django3.0.2学习踩坑记

    配置文件settings.py相关: 新增app INSTALLED_APPS = [ 'polls.apps.PollsConfig', # 这个是新增的APP 'django.contrib.ad ...

  8. Spirng Boot2 系列教程(二十二)| 启动原理

    一个读者,也是我的好朋友投稿的一篇关于 SpringBoot 启动原理的文章,才大二就如此优秀,未来可期. 我一直想了解一下 SpirngBoot 的是如何启动的,我想就来写一篇关于 SpirngBo ...

  9. Nginx的一理解(1)

    1.请解释一下什么是Nginx? 答:Nginx是一个web服务器和反向代理服务器,用于HTTP.HTTPS.SMTP.POP3和IMAP协议. 2.请列举Nginx的一些特性? 答:Nginx服务器 ...

  10. 9.Java三大版本以及JDK,JRE,JVM简单介绍

    Write Once,Run Anywhere. JavaSE:标准版(桌面程序,控制台开发...),是Java的基础和核心. JavaME:嵌入式开发(手机,小家电...),现在基本不用已经过时. ...