CS231n Convolutional Neural Networks for Visual Recognition —— optimization

1. 多类 SVM 的损失函数(Multiclass SVM loss)

在给出类别预测前的输出结果是实数值, 也即根据 score function 得到的 score(s=f(xi,W)),

Li=∑j≠yimax(0,sj−syi+Δ),Δ=1(一般情况下)
  • yi 表示真实的类别,syi 在真实类别上的得分;
  • sj,j≠yi 在其他非真实类别上的得分,也即预测错误时的得分;

则在全体训练样本上的平均损失为:

L=1N∑i=1NLi
delta = 1
scores = np.dot(W, X)
correct_scores = scores[y, np.arange(num_samples)] diff = score - correct_scores + delta
diff[y, np.arange(num_samples)] = 0 thresh = np.maximum(0, diff)
loss = np.sum(thresh)
loss /= num_samples

2. 优化(optimization):梯度计算

首先来看损失函数的定义,如下为第 i 个样本的损失值(Wc×d⋅Xd×N,d 特征向量的维度,c:输出类别的个数):

Li==∑j≠yimax(0,sj−syi+1)∑j≠yi[max(0,wTjxi−wTyixi+1)]
  • 遍历 j,就是遍历 W 每一列的每一个元素, wTjxi⇒j=1,…,c;i=1,…,N
  • wTj 表示 W 的每一行,共 c 行;

下面的额关键是如何求得损失函数关于参数 wj,wyi 的梯度:

∇wyiLi=−⎛⎝∑j≠yi1(wTjxi−wTyixi+Δ>0)⎞⎠xi∇wjLi=1(wTjxi−wTyixi+Δ>0)xij≠yi
binary = thresh
binary[thresh > 0] = 1 # 实现 indicator 函数 col_sum = np.sum(binary, axis=0)
binary[y, np.arange(num_samples)] = -col_sum dW = np.dot(binary, X.T) # binary 维度信息:c*N, X 维度信息:d*N
dW /= N dW += reg * W

多类 SVM 的损失函数及其梯度计算的更多相关文章

  1. Softmax 损失-梯度计算

    本文介绍Softmax运算.Softmax损失函数及其反向传播梯度计算, 内容上承接前两篇博文 损失函数 & 手推反向传播公式. Softmax 梯度 设有K类, 那么期望标签y形如\([0, ...

  2. 【CS231N】2、多类SVM

    一.疑问 1. assignments1 linear_svm.py文件的函数 svm_loss_naive中,使用循环的方式实现梯度计算 linear_svm.py文件的函数 svm_loss_ve ...

  3. 实现属于自己的TensorFlow(二) - 梯度计算与反向传播

    前言 上一篇中介绍了计算图以及前向传播的实现,本文中将主要介绍对于模型优化非常重要的反向传播算法以及反向传播算法中梯度计算的实现.因为在计算梯度的时候需要涉及到矩阵梯度的计算,本文针对几种常用操作的梯 ...

  4. [吴恩达机器学习笔记]12支持向量机1从逻辑回归到SVM/SVM的损失函数

    12.支持向量机 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考资料 斯坦福大学 2014 机器学习教程中文笔记 by 黄海广 12.1 SVM损失函数 从逻辑回归到支持向量机 为了描述 ...

  5. 实验四 (1):定义一个形状类(Shape)方法:计算周长,计算面积

    (1)定义一个形状类(Shape)方法:计算周长,计算面积子类:矩形类(Rectangle) :额外的方法:differ() 计算长宽差圆形类(Circle)三角形类(Triangle)正方形类(Sq ...

  6. Theano学习-梯度计算

    1. 计算梯度 创建一个函数 \(y\) ,并且计算关于其参数 \(x\) 的微分. 为了实现这一功能,将使用函数 \(T.grad\) . 例如:计算 \(x^2\) 关于参数 \(x\) 的梯度. ...

  7. 机器学习进阶-图像梯度计算-scharr算子与laplacian算子(拉普拉斯) 1.cv2.Scharr(使用scharr算子进行计算) 2.cv2.laplician(使用拉普拉斯算子进行计算)

    1. cv2.Scharr(src,ddepth, dx, dy), 使用Scharr算子进行计算 参数说明:src表示输入的图片,ddepth表示图片的深度,通常使用-1, 这里使用cv2.CV_6 ...

  8. 日期时间类:Date,Calendar,计算类:Math

    日期时间类 计算机如何表示时间? 时间戳(timestamp):距离特定时间的时间间隔. 计算机时间戳是指距离历元(1970-01-01 00:00:00:000)的时间间隔(ms). 计算机中时间2 ...

  9. pytorch 反向梯度计算问题

    计算如下\begin{array}{l}{x_{1}=w_{1} * \text { input }} \\ {x_{2}=w_{2} * x_{1}} \\ {x_{3}=w_{3} * x_{2} ...

随机推荐

  1. RvmTranslator6.0

    RvmTranslator6.0 eryar@163.com 1. Introduction RvmTranslator can translate the RVM file exported by ...

  2. javascript变量类型及作用域

    javascript变量类型及作用域 一.简介 变量类型 ECMAScript变量可能包含两种不同类型的数据值:基本类型和引用类型. 基本类型 基本类型指的是简单的数据段,5种基本数据类型:undef ...

  3. 51Nod 飞行员配对(二分图最大匹配)(匈牙利算法模板题)

    第二次世界大战时期,英国皇家空军从沦陷国征募了大量外籍飞行员.由皇家空军派出的每一架飞机都需要配备在航行技能和语言上能互相配合的2名飞行员,其中1名是英国飞行员,另1名是外籍飞行员.在众多的飞行员中, ...

  4. 什么是事件委托?jquery和js怎么去实现?

    事件委托又叫事件代理,事件委托就是利用事件冒泡,只指定一个事件处理程序,就可以管理某一类型的所有事件. js: window.onload = function(){ var oul = docume ...

  5. wget---从指定的URL下载文件

    wget命令用来从指定的URL下载文件.wget非常稳定,它在带宽很窄的情况下和不稳定网络中有很强的适应性,如果是由于网络的原因下载失败,wget会不断的尝试,直到整个文件下载完毕.如果是服务器打断下 ...

  6. 兔子--ps中的基本工具总结(ps cs5)

    矩形选框工具 椭圆选框工具 单行选框工具 单列选框工具 移动工具 套索工具柜 多边形套索工具 磁性套索工具 魔棒工具 高速选择工具 裁剪工具 切片工具 切片选择工具 吸管工具 颜色取样器工具 标尺工具 ...

  7. ReactNavtive框架教程(4)

    开头的响应码, 这些代码都很实用. 比如202 和 200表示返回一个推荐位置的列表.当完毕这个实例后.你能够尝试处理这些返回码.并将列表提供给用户选择. 保存,返回模拟器,按下Cmd+R ,然后搜索 ...

  8. local-语言切换监听事件

    今天在更改时钟的问题的时候,需要监听语言切换来刷新时钟的显示.记录下监听方法 //注册监听事件 intentFilter.addAction(Intent.ACTION_LOCALE_CHANGED) ...

  9. https://github.com/ 英文库

    https://github.com/ https://github.com/sachinchoolur

  10. Unix下后门查找{上}

    本文出自 "李晨光原创技术博客" 博客,请务必保留此出处http://chenguang.blog.51cto.com/350944/683699