[ch04-04] 多样本单特征值计算
系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力。
4.4 多样本单特征值计算
在前面的代码中,我们一直使用单样本计算来实现神经网络的训练过程,但是单样本计算有一些缺点:
- 很有可能前后两个相邻的样本,会对反向传播产生相反的作用而互相抵消。假设样本1造成了误差为0.5,w的梯度计算结果是0.1;紧接着样本2造成的误差为-0.5,w的梯度计算结果是-0.1,那么前后两次更新w就会产生互相抵消的作用。
- 在样本数据量大时,逐个计算会花费很长的时间。由于我们在本例中样本量不大(200个样本),所以计算速度很快,觉察不到这一点。在实际的工程实践中,动辄10万甚至100万的数据量,轮询一次要花费很长的时间。
如果使用多样本计算,就要涉及到矩阵运算了,而所有的深度学习框架,都对矩阵运算做了优化,会大幅提升运算速度。打个比方:如果200个样本,循环计算一次需要2秒的话,那么把200个样本打包成矩阵,做一次计算也许只需要0.1秒。
下面我们来看看多样本运算会对代码实现有什么影响,假设我们一次用3个样本来参与计算,每个样本只有1个特征值。
4.4.1 前向计算
由于有多个样本同时计算,所以我们使用\(x_i\)表示第 \(i\) 个样本,X是样本组成的矩阵,Z是计算结果矩阵,w和b都是标量:
\[
Z = X \cdot w + b \tag{1}
\]
把它展开成3个样本(3行,每行代表一个样本)的形式:
\[
X=\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix}
\]
\[
Z=
\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix} \cdot w + b
=
\begin{pmatrix}
x_1 \cdot w + b \\
x_2 \cdot w + b \\
x_3 \cdot w + b
\end{pmatrix}
=
\begin{pmatrix}
z_1 \\
z_2 \\
z_3
\end{pmatrix} \tag{2}
\]
\(z_1、z_2、z_3\)是三个样本的计算结果。根据公式1和公式2,我们的前向计算python代码可以写成:
def __forwardBatch(self, batch_x):
Z = np.dot(batch_x, self.w) + self.b
return Z
Python中的矩阵乘法命名有些问题,np.dot()并不是矩阵点乘,而是矩阵叉乘,请读者习惯。
4.4.2 损失函数
用传统的均方差函数,其中,z是每一次迭代的预测输出,y是样本标签数据。我们使用m个样本参与计算,因此损失函数为:
\[J(w,b) = \frac{1}{2m}\sum_{i=1}^{m}(z_i - y_i)^2\]
其中的分母中有个2,实际上是想在求导数时把这个2约掉,没有什么原则上的区别。
我们假设每次有3个样本参与计算,即m=3,则损失函数实例化后的情形是:
\[
\begin{aligned}
J(w,b) &= \frac{1}{2\times3}[(z_1-y_1)^2+(z_2-y_2)^2+(z_3-y_3)^2] \\
&=\frac{1}{2\times3}\sum_{i=1}^3[(z_i-y_i)^2]
\end{aligned} \tag{3}
\]
公式3中大写的Z和Y都是矩阵形式,用代码实现:
def __checkLoss(self, dataReader):
X,Y = dataReader.GetWholeTrainSamples()
m = X.shape[0]
Z = self.__forwardBatch(X)
LOSS = (Z - Y)**2
loss = LOSS.sum()/m/2
return loss
Python中的矩阵减法运算,不需要对矩阵中的每个对应的元素单独做减法,而是整个矩阵相减即可。做求和运算时,也不需要自己写代码做遍历每个元素,而是简单地调用求和函数即可。
4.4.3 求w的梯度
我们用 J 的值作为基准,去求 w 对它的影响,也就是 J 对 w 的偏导数,就可以得到w的梯度了。从公式3看 J 的计算过程,\(z_1、z_2、z_3\)都对它有贡献;再从公式2看\(z_1、z_2、z_3\)的生成过程,都有w的参与。所以,J 对 w 的偏导应该是这样的:
\[
\begin{aligned}
\frac{\partial{J}}{\partial{w}}&=\frac{\partial{J}}{\partial{z_1}}\frac{\partial{z_1}}{\partial{w}}+\frac{\partial{J}}{\partial{z_2}}\frac{\partial{z_2}}{\partial{w}}+\frac{\partial{J}}{\partial{z_3}}\frac{\partial{z_3}}{\partial{w}} \\
&=\frac{1}{3}[(z_1-y_1)x_1+(z_2-y_2)x_2+(z_3-y_3)x_3] \\
&=\frac{1}{3}
\begin{pmatrix}
x_1 & x_2 & x_3
\end{pmatrix}
\begin{pmatrix}
z_1-y_1 \\
z_2-y_2 \\
z_3-y_3
\end{pmatrix} \\
&=\frac{1}{m} \sum^m_{i=1} (z_i-y_i)x_i \\
&=\frac{1}{m} X^T \cdot (Z-Y) \\
\end{aligned} \tag{4}
\]
其中:
\[X =
\begin{pmatrix}
x_1 \\
x_2 \\
x_3
\end{pmatrix}, X^T =
\begin{pmatrix}
x_1 & x_2 & x_3
\end{pmatrix}
\]
公式4中最后两个等式其实是等价的,只不过倒数第二个公式用求和方式计算每个样本,最后一个公式用矩阵方式做一次性计算。
4.4.4 求b的梯度
\[
\begin{aligned}
\frac{\partial{J}}{\partial{b}}&=\frac{\partial{J}}{\partial{z_1}}\frac{\partial{z_1}}{\partial{b}}+\frac{\partial{J}}{\partial{z_2}}\frac{\partial{z_2}}{\partial{b}}+\frac{\partial{J}}{\partial{z_3}}\frac{\partial{z_3}}{\partial{b}} \\
&=\frac{1}{3}[(z_1-y_1)+(z_2-y_2)+(z_3-y_3)] \\
&=\frac{1}{m} \sum^m_{i=1} (z_i-y_i) \\
&=\frac{1}{m}(Z-Y)
\end{aligned} \tag{5}
\]
公式5中最后两个等式也是等价的,在python中,可以直接用最后一个公式求矩阵的和,免去了一个个计算\(z_i-y_i\)最后再求和的麻烦,速度还快。
def __backwardBatch(self, batch_x, batch_y, batch_z):
m = batch_x.shape[0]
dZ = batch_z - batch_y
dW = np.dot(batch_x.T, dZ)/m
dB = dZ.sum(axis=0, keepdims=True)/m
return dW, dB
代码位置
ch04, HelperClass/NeuralNet.py
[ch04-04] 多样本单特征值计算的更多相关文章
- bt 介绍以及 bt 种子的hash值(特征值)计算
bt种子的hansh值计算,近期忽然对bt种子感兴趣了(原因勿问) 1. bt种子(概念) bt 是一个分布式文件分发协议,每一个文件下载者在下载的同一时候向其他下载者不断的上传已经下载的数据,这样保 ...
- 样本打散后计算单特征 NDCG
单特征 NDCG 能计算模型的 NDCG,也就能计算单特征的 NDCG,用于评估单特征的有效性,跟 Group AUC 用途一样 单特征 NDCG 如何衡量好坏 如果是 AUC,越大于或小于 0.5, ...
- ubuntu14.04 下安装 gsl 科学计算库
GSL(GNU Scientific Library)作为三大科学计算库之一,除了涵盖基本的线性代数,微分方程,积分,随机数,组合数,方程求根,多项式求根,排序等,还有模拟退火,快速傅里叶变换,小波, ...
- Video Target Tracking Based on Online Learning—TLD单目标跟踪算法详解
视频目标跟踪问题分析 视频跟踪技术的主要目的是从复杂多变的的背景环境中准确提取相关的目标特征,准确地识别出跟踪目标,并且对目标的位置和姿态等信息精确地定位,为后续目标物体行为分析提供足 ...
- PCB 加投率计算实现基本原理--K最近邻算法(KNN)
PCB行业中,客户订购5000pcs,在投料时不会直接投5000pcs,因为实际在生产过程不可避免的造成PCB报废, 所以在生产前需计划多投一定比例的板板, 例:订单 量是5000pcs,加投3%,那 ...
- Python数模笔记-Sklearn(2)样本聚类分析
1.分类的分类 分类的分类?没错,分类也有不同的种类,而且在数学建模.机器学习领域常常被混淆. 首先我们谈谈有监督学习(Supervised learning)和无监督学习(Unsupervised ...
- ICCV2021 | TOOD:任务对齐的单阶段目标检测
前言 单阶段目标检测通常通过优化目标分类和定位两个子任务来实现,使用具有两个平行分支的头部,这可能会导致两个任务之间的预测出现一定程度的空间错位.本文提出了一种任务对齐的一阶段目标检测(TOOD) ...
- [zz]计算 协方差矩阵
http://www.cnblogs.com/chaosimple/p/3182157.html http://blog.csdn.net/goodshot/article/details/86111 ...
- vue 开发系列(八) 动态表单开发
概要 动态表单指的是我们的表单不是通过vue 组件一个个编写的,我们的表单是根据后端生成的vue模板,在前端通过vue构建出来的.主要的思路是,在后端生成vue的模板,前端通过ajax的方式加载后端的 ...
随机推荐
- Head First设计模式——单例模式
单例模式是所有设计模式中最简单的模式,也是我们平常经常用到的,单例模式通常被我们应用于线程池.缓存操作.队列操作等等. 单例模式旨在创建一个类的实例,创建一个类的实例我们用全局静态变量或者约定也能办到 ...
- Can't connect to MySQL server on 'localhost' (10061),连接Navicat报错问题解决
今天,装了Mysql 1.1.7后,连接Navicat 时报错,后来找了一阵,发现问题所在. 原因是我在安装时把默认端口号3306修改成了3303, 连接时,把默认端口也修改下就好啦.
- Apache服务及个人用户主页功能和密码验证
Apache服务程序中有个默认未开启的个人用户主页功能,能够为所有系统内的用户生成个人网站,确实很实用哦 第1步:开启个人用户主页功能: 1.vim /etc/httpd/conf.d/userdir ...
- JAVA Rest High Level Client如何取聚合后得数据
对于刚刚学习es的童鞋来说,很容易不清楚怎么获取客户端对es文档的聚合结果,下面就演示一下模仿DSL写聚合,然后获取到聚合对结果. 一, 对于下面这个简单的聚合,目的是对于文档全文匹配,聚合颜色字段. ...
- 使用springcloud开发测试问题总结
使用springcloud开发测试 如下描述的问题,没有指明是linux部署的,都是在windows开发环境上部署验证发现的. Issue1配置客户端不使用配置中心 问题描述: 配置客户端使用配置中心 ...
- Java升级那么快,多个版本如何灵活切换和管理?
前言 近两年,Java 版本升级频繁,感觉刚刚掌握 Java8,写本文时,已听到 java14 的消息,无论是尝鲜新特性(Java12 中 Collectors.teeing 超强功能使用),还是由于 ...
- C# IV: 数据库基础操作2
需上一篇C# III:数据库基础操作 另外一个经常碰到的数据库操作是,单次执行多个SQL语句,譬如,一次性插入多条数据. 方法一,拼凑长SQL语句 拼凑长SQL语句实际上是String的操作.如下示例 ...
- tomcat 部署springboot 项目
Springboot项目默认jar包,且内置Tomcat.现需要将项目打成war包,并部署到服务器tomcat中. 1.修改pom.xml文件.将jar修改为war. <packaging> ...
- nyoj 209 + poj 2492 A Bug's Life (并查集)
A Bug's Life 时间限制:1000 ms | 内存限制:65535 KB 难度:4 描述 Background Professor Hopper is researching th ...
- nyoj 46-最少乘法次数 (递推)
46-最少乘法次数 内存限制:64MB 时间限制:1000ms Special Judge: No accepted:5 submit:18 题目描述: 给你一个非零整数,让你求这个数的n次方,每次相 ...