Calculus on Computational Graphs: Backpropagation

Introduction

Backpropagation is the key algorithm that makes training deep models computationally tractable. For modern neural networks, it can make training with gradient descent as much as ten million times faster, relative to a naive implementation. That’s the difference between a model taking a week to train and taking 200,000 years.

Beyond its use in deep learning, backpropagation is a powerful computational tool in many other areas, ranging from weather forecasting to analyzing numerical stability – it just goes by different names. In fact, the algorithm has been reinvented at least dozens of times in different fields (seeGriewank (2010)). The general, application independent, name is “reverse-mode differentiation.”

Fundamentally, it’s a technique for calculating derivatives quickly. And it’s an essential trick to have in your bag, not only in deep learning, but in a wide variety of numerical computing situations.

Computational Graphs

Computational graphs are a nice way to think about mathematical expressions. For example, consider the expression e=(a+b)∗(b+1). There are three operations: two additions and one multiplication. To help us talk about this, let’s introduce two intermediary variables, c and d so that every function’s output has a variable. We now have:

c=a+b
d=b+1
e=c∗d

To create a computational graph, we make each of these operations, along with the input variables, into nodes. When one node’s value is the input to another node, an arrow goes from one to another.

These sorts of graphs come up all the time in computer science, especially in talking about functional programs. They are very closely related to the notions of dependency graphs and call graphs. They’re also the core abstraction behind the popular deep learning framework Theano.

We can evaluate the expression by setting the input variables to certain values and computing nodes up through the graph. For example, let’s set a=2 and b=1:

The expression evaluates to 6.

Derivatives on Computational Graphs

If one wants to understand derivatives in a computational graph, the key is to understand derivatives on the edges. If a directly affects c, then we want to know how it affects c. If achanges a little bit, how does c change? We call this the partial derivative of c with respect to a.

To evaluate the partial derivatives in this graph, we need the sum rule and the product rule:

∂∂a(a+b)=∂a∂a+∂b∂a=1
∂∂uuv=u∂v∂u+v∂u∂u=v

Below, the graph has the derivative on each edge labeled.

What if we want to understand how nodes that aren’t directly connected affect each other? Let’s consider how e is affected by a. If we change a at a speed of 1, c also changes at a speed of 1. In turn, c changing at a speed of 1 causes e to change at a speed of 2. So e changes at a rate of 1∗2with respect to a.

The general rule is to sum over all possible paths from one node to the other, multiplying the derivatives on each edge of the path together. For example, to get the derivative of e with respect to b we get:

∂e∂b=1∗2+1∗3

This accounts for how b affects e through c and also how it affects it through d.

This general “sum over paths” rule is just a different way of thinking about the multivariate chain rule.

Factoring Paths

The problem with just “summing over the paths” is that it’s very easy to get a combinatorial explosion in the number of possible paths.

In the above diagram, there are three paths from X to Y, and a further three paths from Y to Z. If we want to get the derivative ∂Z∂X by summing over all paths, we need to sum over 3∗3=9paths:

∂Z∂X=αδ+αϵ+αζ+βδ+βϵ+βζ+γδ+γϵ+γζ

The above only has nine paths, but it would be easy to have the number of paths to grow exponentially as the graph becomes more complicated.

Instead of just naively summing over the paths, it would be much better to factor them:

∂Z∂X=(α+β+γ)(δ+ϵ+ζ)

This is where “forward-mode differentiation” and “reverse-mode differentiation” come in. They’re algorithms for efficiently computing the sum by factoring the paths. Instead of summing over all of the paths explicitly, they compute the same sum more efficiently by merging paths back together at every node. In fact, both algorithms touch each edge exactly once!

Forward-mode differentiation starts at an input to the graph and moves towards the end. At every node, it sums all the paths feeding in. Each of those paths represents one way in which the input affects that node. By adding them up, we get the total way in which the node is affected by the input, it’s derivative.

Though you probably didn’t think of it in terms of graphs, forward-mode differentiation is very similar to what you implicitly learned to do if you took an introduction to calculus class.

Reverse-mode differentiation, on the other hand, starts at an output of the graph and moves towards the beginning. At each node, it merges all paths which originated at that node.

Forward-mode differentiation tracks how one input affects every node. Reverse-mode differentiation tracks how every node affects one output. That is, forward-mode differentiation applies the operator ∂∂X to every node, while reverse mode differentiation applies the operator ∂Z∂to every node.1

Computational Victories

At this point, you might wonder why anyone would care about reverse-mode differentiation. It looks like a strange way of doing the same thing as the forward-mode. Is there some advantage?

Let’s consider our original example again:

We can use forward-mode differentiation from b up. This gives us the derivative of every node with respect to b.

We’ve computed ∂e∂b, the derivative of our output with respect to one of our inputs.

What if we do reverse-mode differentiation from e down? This gives us the derivative of e with respect to every node:

When I say that reverse-mode differentiation gives us the derivative of e with respect to every node, I really do mean every node. We get both ∂e∂a and ∂e∂b, the derivatives of e with respect to both inputs. Forward-mode differentiation gave us the derivative of our output with respect to a single input, but reverse-mode differentiation gives us all of them.

For this graph, that’s only a factor of two speed up, but imagine a function with a million inputs and one output. Forward-mode differentiation would require us to go through the graph a million times to get the derivatives. Reverse-mode differentiation can get them all in one fell swoop! A speed up of a factor of a million is pretty nice!

When training neural networks, we think of the cost (a value describing how bad a neural network performs) as a function of the parameters (numbers describing how the network behaves). We want to calculate the derivatives of the cost with respect to all the parameters, for use in gradient descent. Now, there’s often millions, or even tens of millions of parameters in a neural network. So, reverse-mode differentiation, called backpropagation in the context of neural networks, gives us a massive speed up!

(Are there any cases where forward-mode differentiation makes more sense? Yes, there are! Where the reverse-mode gives the derivatives of one output with respect to all inputs, the forward-mode gives us the derivatives of all outputs with respect to one input. If one has a function with lots of outputs, forward-mode differentiation can be much, much, much faster.)

Isn’t This Trivial?

When I first understood what backpropagation was, my reaction was: “Oh, that’s just the chain rule! How did it take us so long to figure out?” I’m not the only one who’s had that reaction. It’s true that if you ask “is there a smart way to calculate derivatives in feedforward neural networks?” the answer isn’t that difficult.

But I think it was much more difficult than it might seem. You see, at the time backpropagation was invented, people weren’t very focused on the feedforward neural networks that we study. It also wasn’t obvious that derivatives were the right way to train them. Those are only obvious once you realize you can quickly calculate derivatives. There was a circular dependency.

Worse, it would be very easy to write off any piece of the circular dependency as impossible on casual thought. Training neural networks with derivatives? Surely you’d just get stuck in local minima. And obviously it would be expensive to compute all those derivatives. It’s only because we know this approach works that we don’t immediately start listing reasons it’s likely not to.

That’s the benefit of hindsight. Once you’ve framed the question, the hardest work is already done.

Conclusion

Derivatives are cheaper than you think. That’s the main lesson to take away from this post. In fact, they’re unintuitively cheap, and us silly humans have had to repeatedly rediscover this fact. That’s an important thing to understand in deep learning. It’s also a really useful thing to know in other fields, and only more so if it isn’t common knowledge.

Are there other lessons? I think there are.

Backpropagation is also a useful lens for understanding how derivatives flow through a model. This can be extremely helpful in reasoning about why some models are difficult to optimize. The classic example of this is the problem of vanishing gradients in recurrent neural networks.

Finally, I claim there is a broad algorithmic lesson to take away from these techniques. Backpropagation and forward-mode differentiation use a powerful pair of tricks (linearization and dynamic programming) to compute derivatives more efficiently than one might think possible. If you really understand these techniques, you can use them to efficiently calculate several other interesting expressions involving derivatives. We’ll explore this in a later blog post.

This post gives a very abstract treatment of backpropagation. I strongly recommend reading Michael Nielsen’s chapter on it for an excellent discussion, more concretely focused on neural networks.

Acknowledgments

Thank you to Greg CorradoJon ShlensSamy Bengio and Anelia Angelova for taking the time to proofread this post.

Thanks also to Dario AmodeiMichael Nielsen and Yoshua Bengio for discussion of approaches to explaining backpropagation. Also thanks to all those who tolerated me practicing explaining backpropagation in talks and seminar series!


  1. This might feel a bit like dynamic programming. That’s because it is!

Calculus on Computational Graphs: Backpropagation的更多相关文章

  1. (译)Calculus on Computational Graphs: Backpropagation

    Posted on August 31, 2015 Introduction Backpropagation is the key algorithm that makes training deep ...

  2. [TF] Architecture - Computational Graphs

    阅读笔记: 仅希望对底层有一定必要的感性认识,包括一些基本核心概念. Here只关注Graph相关,因为对编程有益. TF – Kernels模块部分参见:https://mp.weixin.qq.c ...

  3. 谷歌大神Jeff Dean:大规模深度学习最新进展 zz

    http://www.tuicool.com/articles/MBBbeeQ 在AlphaGo与李世石比赛期间,谷歌天才工程师Jeff Dean在Google Campus汉城校区做了一次关于智能计 ...

  4. Recurrent Neural Network系列2--利用Python,Theano实现RNN

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  5. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  6. Pytorch 之 backward

    首先看这个自动求导的参数: grad_variables:形状与variable一致,对于y.backward(),grad_variables相当于链式法则dz/dx=dz/dy × dy/dx 中 ...

  7. LSTM与Highway-LSTM算法实现的研究概述

    LSTM与Highway-LSTM算法实现的研究概述 zoerywzhou@gmail.com http://www.cnblogs.com/swje/ 作者:Zhouwan  2015-12-22 ...

  8. What are some good books/papers for learning deep learning?

    What's the most effective way to get started with deep learning?       29 Answers     Yoshua Bengio, ...

  9. 本人AI知识体系导航 - AI menu

    Relevant Readable Links Name Interesting topic Comment Edwin Chen 非参贝叶斯   徐亦达老板 Dirichlet Process 学习 ...

随机推荐

  1. 【CS231N】6、神经网络动态部分:损失函数等

    一.疑问 二.知识点 1. 损失函数可视化 ​ 损失函数一般都是定义在高维度的空间中,这样要将其可视化就很困难.然而办法还是有的,在1个维度或者2个维度的方向上对高维空间进行切片,例如,随机生成一个权 ...

  2. Windows10(UWP)下的MEF

    前言 最近在帮一家知名外企开发Universal Windows Platform的相关应用,开发过程中不由感慨:项目分为两种,一种叫做前人栽树后人乘凉,一种叫做前人挖坑后人遭殃.不多说了,多说又要变 ...

  3. 封装GetQueryString()方法来获取URL的value值(转载)

    首先测试URL:http://192.168.1.82:8020/juzhong/daojishi.html?name=xiangruding&sex=nuuu&age=90 代码如下 ...

  4. LR_问题_虚拟用户以进程和线程模式运行的区别

    进程方式和线程方式的优缺点: 如果选择按照进程方式运行, 每个用户都将启动一个mmdrv进程,多个mmdrv进程会占用大量内存及其他系统资源,这就限制了可以在任一负载生成器上运行的并发用户数的数量,因 ...

  5. [Cnbeta]BAT财报对比

    https://www.cnbeta.com/articles/tech/789123.htm 随着腾讯上周公布财报,BAT三家2018年第三季度的数据均已公布,曾经与腾讯.阿里齐名的百度正被拉开越来 ...

  6. 设置session的过期时间

    1)修改php.ini文件中的gc_maxlifetime变量就可以延长session的过期时间了 session.gc_maxlifetime = 86400 然后,重启你的web服务(一般是apa ...

  7. UFLDL学习笔记 ---- 主成分分析与白化

    主成分分析(PCA)是用来提升无监督特征学习速度的数据降维算法.看过下文大致可以知道,PCA本质是对角化协方差矩阵,目的是让维度之间的相关性最小(降噪),保留下来的维度能量最大(去冗余),PCA在图像 ...

  8. springmvc+mybatis 根据数据的id删除数据

    1. 数据库表 2. notices.jsp <form action="#" method="post"> <fieldset> &l ...

  9. 【HLSDK系列】overview(俯视图)

    温馨提示:使用PC端浏览器阅读可获得最佳体验 阅读本文时,请时不时就对照参考图看一下. 什么是overview? 如果你有使用过3D模型制作工具,例如3dsMax等等,在编辑模型时这些软件通常会展示四 ...

  10. 为什么有时候访问某些加密https网站是不需要证书的? https? ssl?

    根证书是CA颁发给自己的证书, 是信任链的起点 1.所有访问https的网站都是需要证书的. 2.对于某些网站,尤其是证书颁发机构的网站,操作系统自动添加了这些网站访问需要的证书到证书管理器中,所以就 ...