反向传播(back-propagation)笔记

本文由**罗周杨stupidme.me.lzy@gmail.com**原创,转载请注明出处。 本文已经发表在原作者博客 blog.stupidme.me/2018/08/25/…

反向传播是深度学习的基石。

导数

先回顾下导数:

\frac{df(x)}{dx}=\lim_{h->0}\frac{f(x+h)-f(x)}{h}

函数在每个变量的导数就是偏导数。

对于函数f(x,y)=x+y,\frac{\partial f}{\partial x}=1,同时\frac{\partial f}{\partial y}=1

梯度就是偏导数组成的矢量。上述例子中,\Delta f=[\frac{\partial f}{\partial x},\frac{\partial f}{\partial y}]

链式法则

对于简单函数,我们可以根据公式直接计算出其导数。但是对于复杂的函数,我们就没那么容易直接写出导数。但是我们有链式法则(chain rule)

定义不多说,咱们举个例子,感受一下链式法则的魅力。

我们熟悉的sigmoid函数\sigma(x)=\frac{1}{1+e^{-x}} ,如果你记不住它的导数,我们怎么求解呢?

求解步骤如下:

  • 将函数模块化,分成多个基本的部分,对于每一个部分都可以使用简单的求导法则进行求导
  • 使用链式法则,将这些导数链接起来,计算出最终的导数

具体如下:

a=x,则 \frac{\partial a}{\partial x}=1

b=-a,则\frac{\partial b}{\partial a}=-1

c=e^{b},则 \frac{\partial c}{\partial b}=e^{b}

d=1+c,则 \frac{\partial d}{\partial c}=1

e=\frac{1}{d},则\frac{\partial e}{\partial d}=\frac{-1}{d^2}

上面的e实际上就是我们的$$\sigma(x)$$,那么根据链式法则,有:

chain_rule

sigmoid函数的导数可以直接用自身表示,这也是很奇妙的性质了。这样的求导过程是不是很简单?

反向传播代码实现

求导和链式法则我都会了,那么具体的前向传播和反向传播的代码是怎么样的呢?

这次我们使用一个更复杂一点点的例子:

f(x,y)=\frac{x+\sigma(x)}{\sigma(x)+(x+y)^2}

我们先看下它地forward pass代码:

import math

x = 3
y = -4

sigy = 1.0 / (1 + math.exp(-y)) # sigmoid function
num = x + sigy # 分子
sigx = 1.0 / (1 + math.exp(-x))
xpy = x + y
xpy_sqr = xpy**2
den = sigx + xpy_sqr # 分母
invden = 1.0 / den
f = num * invden # 函数
复制代码

上述过程很简单对不对,就是把复杂的函数拆解成一个一个简单函数。

我们看看接下来的反向传播过程:

dnum = invden
复制代码

因为

f = num * invden

所以有

\frac{\partial f}{\partial num} = invden

也就是

dnum=invden
dinvden = num # 同理

dden = (-1.0 / (den**2)) * dinvden # 链式法则
复制代码

展开来说:

\frac{\partial invden}{\partial den}=\frac{-1}{den^2}

\frac{\partial f}{\partial invden}=num

所以

dden=\frac{\partial f}{\partial  den}=\frac{partial f}{\partial invden}\cdot \frac{\partial invden}{\partial den} = \frac{-1.0}{den^2}\cdot dinvden

所以,同理,我们可以写出所有的导数:

dsigx = (1) * dden 
dxpy_sqr = (1) * dden

dxpy = (2 * xpy) * dxpy_sqr

# backprob xpy = x + y
dx = (1) * dxpy
dy = (1) * dxpy

# 这里开始,请注意使用的是"+=",而不是"=”
dx += ((1 - sigx) * sigx) * dsigx # dsigma(x) = (1 - sigma(x))*sigma(x)
dx += (1) * dnum

# backprob num = x + sigy
dsigy = (1) * dnum
# 注意“+=”
dy += ((1 - sigy) * sigy) * dsigy
复制代码

问题:

  • 上面计算过程中,为什么要用“+=”替代“=”呢?

如果变量x,y在前向传播的表达式中出现多次,那么进行反向传播的时候就要非常小心,使用+=而不是=来累计这些变量的梯度(不然就会造成覆写)。这是遵循了在微积分中的多元链式法则,该法则指出如果变量在线路中分支走向不同的部分,那么梯度在回传的时候,就应该进行累加。

联系我