计算图是现代深度学习框架如 Tensorflow、PyTorch 等的核心概念,其中涉及的所有计算几乎都依赖于计算图提供的自动求导功能,因此研究计算图对深入理解反向传播等深度学习的底层算法大有帮助。
手工求导
求导在数学上非常容易实现,例如以下函数:
$$ f(x) = \sin(e^{x^2}) $$
我们能够轻易地求得其导函数为:
$$ f’(x) = \cos(e^{x^2}) \cdot e^{x^2} \cdot (2x) $$
那么能否通过编程语言实现该函数及其导函数?答案是可以,而且非常容易,只需要把式子逐项翻译即可:
1 2 3 4 5 6 7 8 import numpy as npdef f (x ): return np.sin(np.exp(np.power(x, 2 ))) def f_prime (x ): t = np.exp(np.power(x, 2 )) return np.cos(t) * t * (2 *x)
到目前为止,求导对于编程语言来说似乎没什么难的,但是不要忘记,我们这里研究的函数是一个具体的实例。在实际应用中,我们用到的函数将会非常丰富,它们的组合方式更是千变万化,光是上面这个简单的例子就会有无数种变形!例如:
$$ \begin{gathered}
f_2(x) = \cos(e^{x^2}) \\
f_3(x) = \sin(2^{x^2}) \\
f_4(x) = \sin(e^{x^3}) \\
\cdots
\end{gathered} $$
更不要提各式各样的其他复杂函数:
$$ \begin{gathered}
g(x) = -y \ln(\frac{1}{1 + e^{-wx}}) - (1-y)\ln(1 - \frac{1}{1 + e^{-wx}}) \\
h(x) = (w_n\cdot(\cdots(\mathrm{relu}(w_2\cdot(\mathrm{relu}(w_1\cdot x))))) - y) ^ 2 \\
\cdots
\end{gathered} $$
如果坚持手工求导的话,我们不仅需要无数次地推导公式,而且对于某些复杂的函数,求导公式并不简单,显然不可能完成。
链式法则
因此,我们需要一套抽象的求导规则,使得无论函数的具体形式如何,都能自动对其求导。也就是实现如下抽象函数的求导法则:
$$ f(x) = g(h(k(\cdots(x))) $$
尽管这个问题听上去要比具体函数的求导困难得多,但它依然有章可循。回想我们求导的一般过程,不过是运用了以下两点技术而已:
基本函数的求导法则 。包括三角函数、指数函数、幂函数等。
链式法则 。
链式法则使得我们可以对复合函数进行求导。针对上面的例子,为了显式地调用链式法则,我们可以引入如下中间变量:
$$ \begin{aligned}
u &= x^2 \\
v &= e^u \\
w &= sin(v)
\end{aligned} $$
使用链式法则描述的求导过程如下:
$$ \frac{\mathrm{d}y}{\mathrm{d}x} = \frac{\mathrm{d}y}{\mathrm{d}w} \cdot \frac{\mathrm{d}w}{\mathrm{d}v} \cdot \frac{\mathrm{d}v}{\mathrm{d}u} \cdot \frac{\mathrm{d}u}{\mathrm{d}x} $$
有了链式法则,我们就能够“机械”地搬运任意基本函数的导函数,从而对非常复杂的复合函数求导。
计算图
由上述分析可知,一旦我们实现了(1)基本函数的求导法则以及(2)链式法则,就能够让程序模仿我们手工求导的过程,从而做到“以不变应万变”。计算图非常适合用来描述这两个法则。
计算图在数据结构上属于有向图 (Directed Graph),图的每个节点对应一个“基本函数”,而节点之间的有向边则可用于描述链式法则。
上面的例子使用计算图描述如下:
$$ x \to (\cdot)^2 \to e^{(\cdot)} \to \sin(\cdot) \to y $$
计算图能够非常清晰地展现数据的流动过程。从输入 $x$ 开始,中间依次经过平方、自然指数、正弦函数三个基本运算依次作用,最终得到输出 $y$。
注意:这个例子并非典型的计算图,因为其中所涉及的运算都是一元运算 ,导致图结构是线性的,没有分支,更像是链表 。
这种线性结构的计算图无法描述加法、乘法等多元运算 ,例如 $x + sin(x)$、$x\sin(x)$。但它的好处是非常简单,便于理解和实现,因此我们将继续使用这种线性结构完成演示。
计算图的每一个节点都包含一个基本函数,并且其导函数是已知的。节点在进行一次“前向计算”时,除了要根据输入值计算输出值之外,还要调用导函数计算梯度值,并缓存在节点中。最终,我们将所有节点的梯度值相乘(链式法则)即可得到整个计算流程的总梯度。
代码实现
在实现代码之前,我们首先要明确接口的设计,即假想用户将会如何调用计算图,这是一个非常重要的工程原则。
我们期望用户以如下方式调用计算图:
1 2 3 4 5 6 7 8 9 10 11 12 13 >>> import compute_graph as cg>>> inp = cg.Input()>>> out = cg.power(inp, 2 )>>> out = cg.exp(out)>>> out = cg.sin(out)>>> graph = ComputeGraph(inp, out)>>> >>> import numpy as np>>> x = np.linspace(0 , 1 , 5 )>>> graph.forward(x)array([0.84147098 , 0.87454388 , 0.95916224 , 0.98307241 , 0.41078129 ]) >>> graph.gradarray([ 0. , 0.25811137 , 0.36319491 , -0.48233501 , -4.95669947 ])
这种 API 风格与 Keras 非常接近,符合一般用户的使用习惯。
下面我们开始着手实施我们的想法。我们计划为计算图、计算图节点分别设计一个类。
图节点类
首先定义所有图节点的基类,代表节点的通用结构。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 class Node (object ): """Node of compute graph""" def __init__ (self, x, *args, **kw ): if not isinstance (x, Node): raise ValueError('the input should be a compute graph Node object' ) x.next = self self.next = None self.grad = None self.init(*args, **kw) def init (self, *args, **kw ): pass def fun (self, x ): """节点中保存的基本函数""" raise NotImplementedError() def fun_grad (self, x, out ): """基本函数的导函数,用于计算梯度。 x, out 分别是 self.fun 的输入和输出。 理论上只需要 x 即可计算出梯度,但很多函数的导函数会引用自身,例如指数函数。 引入 out 作为参数可避免计算梯度时重复计算自身。 """ raise NotImplementedError() def forward (self, x ): """计算输出,同时缓存梯度""" out = self.fun(x) self.grad = self.fun_grad(x, out) return out def __str__ (self ): return self.__class__.__name__ def __repr__ (self ): return '<"{}" node of compute graph>' .format (str (self))
一般的计算节点只需要继承节点基类,并实现 fun
和 fun_grad
两个方法即可。
正弦函数节点
1 2 3 4 5 6 7 class sin (Node ): """Node of sin function""" def fun (self, x ): return np.sin(x) def fun_grad (self, x, out ): return np.cos(x)
指数函数节点
1 2 3 4 5 6 7 class exp (Node ): """Node of exp function""" def fun (self, x ): return np.exp(x) def fun_grad (self, x, out ): return out
幂函数节点
注意,幂函数需要在初始化时传入额外的参数,即幂指数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 class power (Node ): """Node of power function""" def init (self, p ): self.p = p def fun (self, x ): return np.power(x, self.p) def fun_grad (self, x, out ): return self.p * np.power(x, self.p - 1 ) def __str__ (self ): return '{}(., {})' .format (self.__class__.__name__, self.p)
输入节点
与普通节点不同,输入节点没有前驱节点,也不需要对数据进行加工和求导,因此需要单独进行定义。
1 2 3 4 5 6 7 8 9 10 class Input (Node ): """Input Node""" def __init__ (self ): self.next = None def fun (self, x ): return x def fun_grad (self, x, out ): return 1
计算图类
我们已经把主要的计算过程定义在了图节点类中,因此计算图类的任务就非常轻松了,只需要整合图节点的计算结果即可。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class ComputeGraph (object ): """Compute Graph""" def __init__ (self, inp, out ): self.head = inp self.tail = out self.grad = None def forward (self, x ): if self.head is None : raise ValueError('the graph is empty' ) out = x grad = 1.0 node = self.head while node: out = node.forward(out) grad *= node.grad node = node.next self.grad = grad return out def __str__ (self ): node = self.head desc = [] while node: desc.append(str (node)) node = node.next return ' --> ' .join(desc)
到此为止,我们的代码已经全部完成了,是不是简单地出乎意料?
验证代码
在进行接口设计时,我们给出了一段样板代码,现在我们可以用它来验证我们的程序。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 >>> import compute_graph as cg>>> inp = cg.Input()>>> out = cg.power(inp, 2 )>>> out<"power(., 2)" node of compute graph> >>> out = cg.exp(out)>>> out = cg.sin(out)>>> graph = ComputeGraph(inp, out)>>> print(graph)Input --> power(., 2) --> exp --> sin >>> >>> import numpy as np>>> x = np.linspace(0 , 1 , 5 )>>> graph.forward(x)array([0.84147098 , 0.87454388 , 0.95916224 , 0.98307241 , 0.41078129 ]) >>> graph.gradarray([ 0. , 0.25811137 , 0.36319491 , -0.48233501 , -4.95669947 ])
代码无误,且输出完全符合预期。
但我们还未考察计算结果是否正确无误,毕竟这才是最重要的。我们可以通过之前手动推导的公式对计算结果加以验证,函数的定义如下:
1 2 3 4 5 6 def f (x ): return np.sin(np.exp(np.power(x, 2 ))) def f_prime (x ): t = np.exp(np.power(x, 2 )) return 2 * x * np.cos(t) * t
我们进行如下验证:
1 2 3 4 5 6 7 8 >>> f(x)array([0.84147098 , 0.87454388 , 0.95916224 , 0.98307241 , 0.41078129 ]) >>> f_prime(x)array([ 0. , 0.25811137 , 0.36319491 , -0.48233501 , -4.95669947 ]) >>> np.all (f(x) == graph.forward(x))True >>> np.all (f_prime(x) == graph.grad)True
说明计算图的计算结果和梯度值均准确无误。
上述代码在一些细节问题上可能有所欠缺,但足以从宏观上理解计算图的实现原理。