[ch04-04] 多样本单特征值计算
系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力。
4.4 多样本单特征值计算
在前面的代码中,我们一直使用单样本计算来实现神经网络的训练过程,但是单样本计算有一些缺点:
- 很有可能前后两个相邻的样本,会对反向传播产生相反的作用而互相抵消。假设样本1造成了误差为0.5,w的梯度计算结果是0.1;紧接着样本2造成的误差为-0.5,w的梯度计算结果是-0.1,那么前后两次更新w就会产生互相抵消的作用。
- 在样本数据量大时,逐个计算会花费很长的时间。由于我们在本例中样本量不大(200个样本),所以计算速度很快,觉察不到这一点。在实际的工程实践中,动辄10万甚至100万的数据量,轮询一次要花费很长的时间。
如果使用多样本计算,就要涉及到矩阵运算了,而所有的深度学习框架,都对矩阵运算做了优化,会大幅提升运算速度。打个比方:如果200个样本,循环计算一次需要2秒的话,那么把200个样本打包成矩阵,做一次计算也许只需要0.1秒。
下面我们来看看多样本运算会对代码实现有什么影响,假设我们一次用3个样本来参与计算,每个样本只有1个特征值。
4.4.1 前向计算
由于有多个样本同时计算,所以我们使用\(x_i\)表示第 \(i\) 个样本,X是样本组成的矩阵,Z是计算结果矩阵,w和b都是标量:
\[
Z = X \cdot w + b \tag{1}
\]
把它展开成3个样本(3行,每行代表一个样本)的形式:
\[
X=\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix}
\]
\[
Z=
\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix} \cdot w + b
=
\begin{pmatrix}
x_1 \cdot w + b \\
x_2 \cdot w + b \\
x_3 \cdot w + b
\end{pmatrix}
=
\begin{pmatrix}
z_1 \\
z_2 \\
z_3
\end{pmatrix} \tag{2}
\]
\(z_1、z_2、z_3\)是三个样本的计算结果。根据公式1和公式2,我们的前向计算python代码可以写成:
def __forwardBatch(self, batch_x):
Z = np.dot(batch_x, self.w) + self.b
return Z
Python中的矩阵乘法命名有些问题,np.dot()并不是矩阵点乘,而是矩阵叉乘,请读者习惯。
4.4.2 损失函数
用传统的均方差函数,其中,z是每一次迭代的预测输出,y是样本标签数据。我们使用m个样本参与计算,因此损失函数为:
\[J(w,b) = \frac{1}{2m}\sum_{i=1}^{m}(z_i - y_i)^2\]
其中的分母中有个2,实际上是想在求导数时把这个2约掉,没有什么原则上的区别。
我们假设每次有3个样本参与计算,即m=3,则损失函数实例化后的情形是:
\[
\begin{aligned}
J(w,b) &= \frac{1}{2\times3}[(z_1-y_1)^2+(z_2-y_2)^2+(z_3-y_3)^2] \\
&=\frac{1}{2\times3}\sum_{i=1}^3[(z_i-y_i)^2]
\end{aligned} \tag{3}
\]
公式3中大写的Z和Y都是矩阵形式,用代码实现:
def __checkLoss(self, dataReader):
X,Y = dataReader.GetWholeTrainSamples()
m = X.shape[0]
Z = self.__forwardBatch(X)
LOSS = (Z - Y)**2
loss = LOSS.sum()/m/2
return loss
Python中的矩阵减法运算,不需要对矩阵中的每个对应的元素单独做减法,而是整个矩阵相减即可。做求和运算时,也不需要自己写代码做遍历每个元素,而是简单地调用求和函数即可。
4.4.3 求w的梯度
我们用 J 的值作为基准,去求 w 对它的影响,也就是 J 对 w 的偏导数,就可以得到w的梯度了。从公式3看 J 的计算过程,\(z_1、z_2、z_3\)都对它有贡献;再从公式2看\(z_1、z_2、z_3\)的生成过程,都有w的参与。所以,J 对 w 的偏导应该是这样的:
\[
\begin{aligned}
\frac{\partial{J}}{\partial{w}}&=\frac{\partial{J}}{\partial{z_1}}\frac{\partial{z_1}}{\partial{w}}+\frac{\partial{J}}{\partial{z_2}}\frac{\partial{z_2}}{\partial{w}}+\frac{\partial{J}}{\partial{z_3}}\frac{\partial{z_3}}{\partial{w}} \\
&=\frac{1}{3}[(z_1-y_1)x_1+(z_2-y_2)x_2+(z_3-y_3)x_3] \\
&=\frac{1}{3}
\begin{pmatrix}
x_1 & x_2 & x_3
\end{pmatrix}
\begin{pmatrix}
z_1-y_1 \\
z_2-y_2 \\
z_3-y_3
\end{pmatrix} \\
&=\frac{1}{m} \sum^m_{i=1} (z_i-y_i)x_i \\
&=\frac{1}{m} X^T \cdot (Z-Y) \\
\end{aligned} \tag{4}
\]
其中:
\[X =
\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix}, X^T =
\begin{pmatrix}
x_1 & x_2 & x_3
\end{pmatrix}
\]
公式4中最后两个等式其实是等价的,只不过倒数第二个公式用求和方式计算每个样本,最后一个公式用矩阵方式做一次性计算。
4.4.4 求b的梯度
\[
\begin{aligned}
\frac{\partial{J}}{\partial{b}}&=\frac{\partial{J}}{\partial{z_1}}\frac{\partial{z_1}}{\partial{b}}+\frac{\partial{J}}{\partial{z_2}}\frac{\partial{z_2}}{\partial{b}}+\frac{\partial{J}}{\partial{z_3}}\frac{\partial{z_3}}{\partial{b}} \\
&=\frac{1}{3}[(z_1-y_1)+(z_2-y_2)+(z_3-y_3)] \\
&=\frac{1}{m} \sum^m_{i=1} (z_i-y_i) \\
&=\frac{1}{m}(Z-Y)
\end{aligned} \tag{5}
\]
公式5中最后两个等式也是等价的,在python中,可以直接用最后一个公式求矩阵的和,免去了一个个计算\(z_i-y_i\)最后再求和的麻烦,速度还快。
def __backwardBatch(self, batch_x, batch_y, batch_z):
m = batch_x.shape[0]
dZ = batch_z - batch_y
dW = np.dot(batch_x.T, dZ)/m
dB = dZ.sum(axis=0, keepdims=True)/m
return dW, dB
代码位置
ch04, HelperClass/NeuralNet.py
[ch04-04] 多样本单特征值计算的更多相关文章
- bt 介绍以及 bt 种子的hash值(特征值)计算
bt种子的hansh值计算,近期忽然对bt种子感兴趣了(原因勿问) 1. bt种子(概念) bt 是一个分布式文件分发协议,每一个文件下载者在下载的同一时候向其他下载者不断的上传已经下载的数据,这样保 ...
- 样本打散后计算单特征 NDCG
单特征 NDCG 能计算模型的 NDCG,也就能计算单特征的 NDCG,用于评估单特征的有效性,跟 Group AUC 用途一样 单特征 NDCG 如何衡量好坏 如果是 AUC,越大于或小于 0.5, ...
- ubuntu14.04 下安装 gsl 科学计算库
GSL(GNU Scientific Library)作为三大科学计算库之一,除了涵盖基本的线性代数,微分方程,积分,随机数,组合数,方程求根,多项式求根,排序等,还有模拟退火,快速傅里叶变换,小波, ...
- Video Target Tracking Based on Online Learning—TLD单目标跟踪算法详解
视频目标跟踪问题分析 视频跟踪技术的主要目的是从复杂多变的的背景环境中准确提取相关的目标特征,准确地识别出跟踪目标,并且对目标的位置和姿态等信息精确地定位,为后续目标物体行为分析提供足 ...
- PCB 加投率计算实现基本原理--K最近邻算法(KNN)
PCB行业中,客户订购5000pcs,在投料时不会直接投5000pcs,因为实际在生产过程不可避免的造成PCB报废, 所以在生产前需计划多投一定比例的板板, 例:订单 量是5000pcs,加投3%,那 ...
- Python数模笔记-Sklearn(2)样本聚类分析
1.分类的分类 分类的分类?没错,分类也有不同的种类,而且在数学建模.机器学习领域常常被混淆. 首先我们谈谈有监督学习(Supervised learning)和无监督学习(Unsupervised ...
- ICCV2021 | TOOD:任务对齐的单阶段目标检测
前言 单阶段目标检测通常通过优化目标分类和定位两个子任务来实现,使用具有两个平行分支的头部,这可能会导致两个任务之间的预测出现一定程度的空间错位.本文提出了一种任务对齐的一阶段目标检测(TOOD) ...
- [zz]计算 协方差矩阵
http://www.cnblogs.com/chaosimple/p/3182157.html http://blog.csdn.net/goodshot/article/details/86111 ...
- vue 开发系列(八) 动态表单开发
概要 动态表单指的是我们的表单不是通过vue 组件一个个编写的,我们的表单是根据后端生成的vue模板,在前端通过vue构建出来的.主要的思路是,在后端生成vue的模板,前端通过ajax的方式加载后端的 ...
随机推荐
- 关于JQUery.parseJSON()函数的知识札记
JSON数据也许大家都很陌生,而对我来讲属于半成品,由于项目问题,做web虽然用的是JSON数据格式传输,但是关于解析这一块还真不知道该注意什么,更不知道它是如何解析的,由于最近要把串口通信协议与此一 ...
- Linux下基本操作
强行转Linux,开始以为会很不适应,其实还好,换汤不换药 本文只讲基本操作,足够让你愉快的打代码,想飞上天的自行百度,或找其他大神(友链) Update 6/20:由于写得太烂被学长爆踩了一顿 直接 ...
- 大数据之路week01--自学之集合_1(Collection)
经过我个人的调查,发现,在今后的大数据道路上,集合.线程.网络编程变得尤为重要,为什么? 因为大数据大数据,我们必然要对数据进行处理,而这些数据往往是以集合形式存放,掌握对集合的操作非常重要. 在学习 ...
- 2018年7月份JAVA开源软件TOP3
微信开发 Java SDK Weixin Java Tools 评分: 9.6 介绍: 信开发 Java 开发工具包(SDK),支持包括微信支付.微信开放平台.小程序.企业号/企业微信.公众号(包括服 ...
- jquery写$ document.getElementById效果
jquery写$ document.getElementById效果<pre>document.getElementById('video-canvas')和$('#video-canva ...
- PHP file_get_contents 读取js脚本的问题
PHP file_get_contents 读取js脚本的问题 如果文件中带有js脚本 会触发 比方说alert 这个时候 你不用去管他
- SpringBoot基本配置详解
SpringBoot项目有一些基本的配置,比如启动图案(banner),比如默认配置文件application.properties,以及相关的默认配置项. 示例项目代码在:https://githu ...
- 通过canvas合成图片
通过canvas合成图片 效果图 页面布局部分 两个图片以及一个canvas画布 <img src="https://qnlite.gtimg.com/qqnewslite/20190 ...
- Tsx写一个通用的button组件
一年又要到年底了,vue3.0都已经出来了,我们也不能一直还停留在过去的js中,是时候学习并且在项目中使用一下Ts了. 如果说jsx是基于js的话,那么tsx就是基于typescript的 废话也不多 ...
- RHEL7.2 安装Eclipse-oxygen Hadoop开发环境
1 Eclipse-oxygen下载地址 http://www.eclipse.org/downloads/download.php?file=/technology/epp/downloads/re ...