矩阵求导术

本文为知乎上的一篇文章 矩阵求导术 的笔记。

符号约定:以下使用小写字母如 $x$ 表示标量,粗体小写字母如 $\boldsymbol{a}$ 表示向量,大写字母如 $A$ 表示矩阵。为保持一致性,向量如不进行特殊说明,均为列向量,行向量可通过列向量的转置来表示,如 $\boldsymbol{a}^T$。

矩阵对标量的求导即逐个元素对标量求导,没什么值得讨论的。我们以下主要介绍标量对矩阵的求导,然后延伸到矩阵对矩阵的求导。

标量对矩阵的导数

定义

首先明确一下标量对矩阵求导的定义:

$$ \frac{\partial f}{\partial X} = \left[\ddots, \frac{\partial f}{\partial X_{ij}}, \ddots \right] $$

即 $f$ 对 $X$ 逐元素求导,并排列成与 $X$ 形状相同的矩阵。

该定义在形式上很容易理解,但在实际计算中却难以使用,因为它破坏了矩阵的整体性。在工程实践中(比如 Matlab、numpy 等),我们倾向于把矩阵看作一个整体,从整体上对其进行的加减乘除等运算,像普通的标量一样。因此,我们应当考虑如何将求导作用于整个矩阵上,而不是作用于矩阵的各个元素上。

我们不妨从最简单的标量函数入手,构建起一套易于掌握的矩阵求导技巧。

导数与微分之间存在着密切的联系:

$$ \mathrm{d}y = f’(x)\mathrm{d}x = \frac{\mathrm{d}y}{\mathrm{d}x}\mathrm{d}x $$

即:全微分 $\mathrm{d}y$ 是导数 $\displaystyle \frac{\mathrm{d}y}{\mathrm{d}x}$ 与微分变量 $\mathrm{d}x$ 的积。(推论1)

明确了上述重要定义之后,我们再来看看多元函数的情形。设 $f(x_1, x_2, x_3, \dots, x_n)$ 为多元函数,我们依然从全微分的定义入手:

$$ \mathrm{d}f = \sum_{i=1}^{n} \frac{\partial f}{\partial x_i}\mathrm{d}x_i $$

为了得到一个满足“整体性”的式子,我们需要去掉求和符号。令向量 $\boldsymbol{x}^T=[x_1, x_2, x_3, \dots, x_n]$,得到:

$$ \mathrm{d}f = \frac{\partial f}{\partial \boldsymbol{x}} \cdot \mathrm{d}\boldsymbol{x} $$

其中,$\displaystyle \frac{\partial f}{\partial \boldsymbol{x}}$ 代表 $f$ 对 $\boldsymbol{x}$ 的所有项的偏导数组成的向量。运算符“$\cdot$”代表点乘运算,其运算结果称为内积

于是,我们得到了如下推论:

多元函数的全微分 $\mathrm{d}f$ 是导数向量 $\displaystyle \frac{\partial f}{\partial \boldsymbol{x}}$ 与微分变量 $\mathrm{d}\boldsymbol{x}$ 的内积。(推论2)

与推论1进行比较可以发现,二者在形式上是完全相同的,唯一的差别在于末尾的“积”和“内积”。然而,标量的积完全可以看作向量内积的一种特殊情况,也就是说,推论2可以涵盖推论1。

现在,我们已经把全微分与导数之间的关系式推广到了关于向量的函数(即多元函数),如果把向量看作一种特殊的矩阵,那么这一推论也很容易推广到关于矩阵的函数

$$ \mathrm{d}f = \frac{\partial f}{\partial X}\cdot \mathrm{d}X \tag{1} $$

即:关于矩阵的函数的全微分 $\mathrm{d}f$ 是导数矩阵 $\displaystyle \frac{\partial f}{\partial X}$ 与微分变量 $\mathrm{d}X$ 的内积。(推论3)

注:这里将内积的定义从向量推广到了矩阵,即先对做逐元素相乘,然后将所有乘积求和。

由于标量和向量都可以看作是矩阵的特殊情况,因此推论3涵盖了推论 1、2。至此,我们得到了通用表达式 $(1)$。

运算规则

有了定义,还需要一套完整的运算规则才能实际应用。

我们依然从最简单的标量函数中汲取灵感。例如,对于函数 $\displaystyle f(x)=x^2 sin(e^x)$,我如何求导呢?我们通常不是直接从导数的定义出发,而是先建立了初等函数的导数和四则运算、复合等法则,然后运用这些法则求导。因此,我们也需要建立矩阵微分的运算规则。

矩阵的微分运算

  1. 加(减)法:

    $$ \mathrm{d}(A\pm B)=\mathrm{d}A\pm\mathrm{d}B $$

  2. 乘法:

    $$ \mathrm{d}(AB)=\mathrm{d}AB+A \mathrm{d}B $$

  3. 逆:

    $$ \mathrm{d}A^{-1}=-A^{-1}\mathrm{d}A A^{-1} $$

  4. 转置:

    $$ \mathrm{d}A^T=(\mathrm{d}A)^T $$

  5. 迹(trace):

    $$ \mathrm{d}tr(A)=tr(\mathrm{d}A) $$

  6. 行列式:

    $$ \mathrm{d}|A|=tr(A^*\mathrm{d}A) $$

    其中 $A^*$ 表示 $A$ 的伴随矩阵。如果 $A$ 可逆,则又有:

    $$ \mathrm{d}|A|=|A|tr(A^{-1}\mathrm{d}A) $$

  7. 逐元素(element-wise)乘法:

    $$ \mathrm{d}(A\odot B)=\mathrm{d}A\odot B+A\odot \mathrm{d}B $$

  8. 逐元素函数:

    $$ \mathrm{d}h(A)=h’(A)\odot \mathrm{d}A $$

    其中,$h(X)$ 为对矩阵 $X$ 进行逐元素运算的标量函数。

矩阵的点乘运算

考虑到点乘运算是公式 $(1)$ 中的核心操作,因此我们还需要一些关于点乘的运算规则(注意:点乘要求参加运算的两个矩阵形状相同):

  1. 定义式:

    $$ A\cdot B=tr(A^TB) $$

    对于向量来说,该式可以进一步简化:$\displaystyle \boldsymbol{a}\cdot \boldsymbol{b}=\boldsymbol{a}^T \boldsymbol{b}$

  2. 交换律:

    $$ A\cdot B = B\cdot A $$

  3. 加法分配律:

    $$ A\cdot(B+C)=A\cdot B + A\cdot C $$

  4. 与逐元素乘法的结合律:

    $$ A\cdot(B\odot C)=(A\odot B)\cdot C $$

  5. 转置的交换:

    $$ A^T\cdot B = A\cdot B^T $$

矩阵的迹运算

以上公式中,多处涉及到矩阵的迹(trace)运算,这里提供一些迹技巧(注意:迹运算的对象必须为方阵):

  1. 标量的迹等于其自身:

    $$ tr(x) = x $$

  2. 线性:

    $$ tr(A\pm B)=tr(A)\pm tr(B) $$

  3. 乘法可交换性:

    $$ tr(AB) = tr(BA) $$

    要求 $A$ 和 $B^T$ 的形状相同,这样才能保证 $AB$ 为方阵。

  4. 转置迹不变:

    $$ tr(A^T)=tr(A) $$

  5. 迹与点乘的转换

    $$ tr(AB) = A^T\cdot B $$

矩阵的其他相关运算

最后,为了方便查阅,附上矩阵的转置、逆、行列式等相关运算规则:

$$ \begin{aligned}
(A^T)^T &= A \\
(AB)^T &= B^T A^T \\
(A^{-1})^{-1} &= A \\
AA^{-1} &= A^{-1}A = I \\
(AB)^{-1} &= B^{-1}A^{-1} \\
(A^T)^{-1} &= (A^{-1})^T \\
|A^{-1}| &= |A|^{-1} \\
(kA)^{-1} &= k^{-1}A^{-1}
\end{aligned}$$

有了上述这些运算规则,只要矩阵的函数 $f$ 是由矩阵 $X$ 经过加、减、乘、转置、逆、迹、行列式、逐元素乘法、逐元素函数等运算及其复合运算构成的,我们都能利用上述运算规则求得 $\displaystyle \frac{\partial f}{\partial X}$。其基本思路为:对 $f$ 的表达式求全微分,并设法将其转化为 $\mathrm{d}f = \mathrm{expr} \cdot \mathrm{d}X$ 的形式,那么 $expr$ 即为待求导数的表达式

算例演示

例 1

设 $\displaystyle y=\boldsymbol{a}^T X \boldsymbol{b}$,求 $\displaystyle \frac{\partial y}{\partial X}$。其中 $\boldsymbol{a}$ 为 $m\times 1$ 向量,$X$ 为 $m\times n$ 矩阵,$\boldsymbol{b}$ 为 $n\times 1$ 向量,$y$ 为标量。

解:

$$ \begin{aligned}
\mathrm{d}y &= \boldsymbol{a}^T \mathrm{d}X \boldsymbol{b} \\
&= tr(\boldsymbol{a}^T \mathrm{d}X \boldsymbol{b}) \\
&= tr(\boldsymbol{b} \boldsymbol{a}^T \mathrm{d}X) \\
&= tr((\boldsymbol{a} \boldsymbol{b}^T)^T \mathrm{d}X) \\
&= \boldsymbol{a} \boldsymbol{b}^T \cdot \mathrm{d}X
\end{aligned} $$

说明:

  • 第0步:对 $y$ 求全微分;
  • 第1步:标量套上迹;
  • 第2步:迹运算内交换 $\boldsymbol{a}^T \mathrm{d}X$ 与 $\boldsymbol{b}$;
  • 第3步:矩阵乘法的转置;
  • 第4步:转化为点乘形式。

根据公式 $(1)$ 得到:

$$ \frac{\partial y}{\partial X}=\boldsymbol{a} \boldsymbol{b}^T $$

例 2:线性回归

线性回归的损失函数定义为 $\displaystyle l=||X \boldsymbol{w}- \boldsymbol{y}||_2^2$,求 $\boldsymbol{w}$ 的最小二乘估计,即 $\boldsymbol{w}$ 为何值时 $l$ 可取得最小值。其中 $\boldsymbol{y}$ 为 $m\times 1$ 列向量,$X$ 为 $m\times n$ 矩阵,$\boldsymbol{w}$ 为 $n\times 1$ 向量,$l$ 为标量。

解:要求极小值,只需找到 $\displaystyle \frac{\partial l}{\partial \boldsymbol{w}}$ 的零点。

我们的运算规则中并未定义二阶范数的微分,但根据向量范数的定义,我们可以将它表示成内积的形式:

$$ l=(X \boldsymbol{w}- \boldsymbol{y})\cdot(X \boldsymbol{w}- \boldsymbol{y})=(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y}) $$

接下来,求 $l$ 对 $\boldsymbol{w}$ 的微分:

$$ \begin{aligned}
\mathrm{d}l &= \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y}) + (X \boldsymbol{w}- \boldsymbol{y})^T \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y}) \\
&= (X \mathrm{d}\boldsymbol{w})^T(X \boldsymbol{w}- \boldsymbol{y})+(X \boldsymbol{w}- \boldsymbol{y})^T(X \mathrm{d}\boldsymbol{w}) \\
&= 2(X \boldsymbol{w}- \boldsymbol{y})^T X \mathrm{d}\boldsymbol{w} \\
&= 2X^T(X \boldsymbol{w}- \boldsymbol{y})\cdot \mathrm{d}\boldsymbol{w}
\end{aligned} $$

说明:

  • 第2步:加号前后两项均为标量,所以对第一项加上转置,结果不变。
  • 第3步:标量套上迹运算,然后转化为点积。

根据公式 $(1)$ 可得:

$$ \frac{\partial l}{\partial \boldsymbol{w}}=2X^T(X \boldsymbol{w}- \boldsymbol{y}) $$

令 $\displaystyle \frac{\partial l}{\partial \boldsymbol{w}}=\boldsymbol{0}$ 得(加粗的 $\boldsymbol{0}$ 代表零向量,其形状与 $\boldsymbol{w}$ 相同):

$$ \boldsymbol{w}=(X^T X)^{-1}X^T \boldsymbol{y} $$

例 3:多元逻辑回归

多元逻辑回归的损失函数为 $\displaystyle l=-\boldsymbol{y}^T \log (\mathrm{softmax}(W \boldsymbol{x}))$,求 $\displaystyle \frac{\partial l}{\partial W}$。其中,$\boldsymbol{y}$ 为 $m\times 1$ 的 one-hot 向量(即除一个元素为 1 外,其他元素均为 0),$W$ 为 $m\times n$ 矩阵,$\boldsymbol{x}$ 为 $n\times 1$ 向量,$l$ 为标量。$\displaystyle \mathrm{softmax}(a)=\frac{\mathrm{exp}(\boldsymbol{a})}{\boldsymbol{1}^T\mathrm{exp}(\boldsymbol{a})}$,其中 $\displaystyle \mathrm{exp}(\boldsymbol{\cdot})$ 表示逐元素求指数,$\boldsymbol{1}$ 代表全 1 向量。

解:首先将 $\mathrm{softmax}$ 的表达式代入:

$$ l=-\boldsymbol{y}^T(\log(\mathrm{exp}(W\boldsymbol{x}))- \boldsymbol{1}\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})))=-\boldsymbol{y}^T W \boldsymbol{x}+\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})) $$

这里用到了 2 个等式:

$$ \log(\frac{\boldsymbol{v}}{c})=\log(\boldsymbol{v})- \boldsymbol{1}\log© $$ $$ \boldsymbol{y}^T \boldsymbol{1}=1 $$

接下来求 $l$ 对 $W$ 的全微分:

$$ \begin{aligned}
\mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x}))}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})}
\end{aligned} $$

说明:注意逐元素函数 $\mathrm{exp}(\cdot)$ 的微分变换

由于

$$ \begin{aligned}
\boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) &= \boldsymbol{1}\cdot(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) \\
&= (\boldsymbol{1}\odot\mathrm{exp}(W \boldsymbol{x}))\cdot \mathrm{d}W \boldsymbol{x} \\
&= \mathrm{exp}(W \boldsymbol{x})\cdot \mathrm{d}W \boldsymbol{x} \\
&= \mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x}
\end{aligned} $$

说明:利用点乘与逐元素乘法的结合律

$$ \begin{aligned}
\mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x}}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})} \\
&= (-\boldsymbol{y}^T+\mathrm{softmax}(W \boldsymbol{x})^T)\mathrm{d}W \boldsymbol{x} \\
&= tr((\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W\boldsymbol{x}) \\
&= tr(\boldsymbol{x}(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W) \\
&= (\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T \cdot \mathrm{d}W
\end{aligned} $$

所以:

$$ \frac{\partial l}{\partial W}=(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T $$

矩阵对矩阵的导数

未完待续。