矩阵求导术
本文为知乎上的一篇文章 矩阵求导术 的笔记。
符号约定:以下使用小写字母如 $x$ 表示标量,粗体小写字母如 $\boldsymbol{a}$ 表示向量,大写字母如 $A$ 表示矩阵。为保持一致性,向量如不进行特殊说明,均为列向量,行向量可通过列向量的转置来表示,如 $\boldsymbol{a}^T$。
矩阵对标量的求导即逐个元素对标量求导,没什么值得讨论的。我们以下主要介绍标量对矩阵的求导,然后延伸到矩阵对矩阵的求导。
标量对矩阵的导数
定义
首先明确一下标量对矩阵求导的定义:
$$ \frac{\partial f}{\partial X} = \left[\ddots, \frac{\partial f}{\partial X_{ij}}, \ddots \right] $$
即 $f$ 对 $X$ 逐元素求导,并排列成与 $X$ 形状相同的矩阵。
该定义在形式上很容易理解,但在实际计算中却难以使用,因为它破坏了矩阵的整体性。在工程实践中(比如 Matlab、numpy 等),我们倾向于把矩阵看作一个整体,从整体上对其进行的加减乘除等运算,像普通的标量一样。因此,我们应当考虑如何将求导作用于整个矩阵上,而不是作用于矩阵的各个元素上。
我们不妨从最简单的标量函数入手,构建起一套易于掌握的矩阵求导技巧。
导数与微分之间存在着密切的联系:
$$ \mathrm{d}y = f’(x)\mathrm{d}x = \frac{\mathrm{d}y}{\mathrm{d}x}\mathrm{d}x $$
即:全微分 $\mathrm{d}y$ 是导数 $\displaystyle \frac{\mathrm{d}y}{\mathrm{d}x}$ 与微分变量 $\mathrm{d}x$ 的积。(推论1)
明确了上述重要定义之后,我们再来看看多元函数的情形。设 $f(x_1, x_2, x_3, \dots, x_n)$ 为多元函数,我们依然从全微分的定义入手:
$$ \mathrm{d}f = \sum_{i=1}^{n} \frac{\partial f}{\partial x_i}\mathrm{d}x_i $$
为了得到一个满足“整体性”的式子,我们需要去掉求和符号。令向量 $\boldsymbol{x}^T=[x_1, x_2, x_3, \dots, x_n]$,得到:
$$ \mathrm{d}f = \frac{\partial f}{\partial \boldsymbol{x}} \cdot \mathrm{d}\boldsymbol{x} $$
其中,$\displaystyle \frac{\partial f}{\partial \boldsymbol{x}}$ 代表 $f$ 对 $\boldsymbol{x}$ 的所有项的偏导数组成的向量。运算符“$\cdot$”代表点乘运算,其运算结果称为内积。
于是,我们得到了如下推论:
多元函数的全微分 $\mathrm{d}f$ 是导数向量 $\displaystyle \frac{\partial f}{\partial \boldsymbol{x}}$ 与微分变量 $\mathrm{d}\boldsymbol{x}$ 的内积。(推论2)
与推论1进行比较可以发现,二者在形式上是完全相同的,唯一的差别在于末尾的“积”和“内积”。然而,标量的积完全可以看作向量内积的一种特殊情况,也就是说,推论2可以涵盖推论1。
现在,我们已经把全微分与导数之间的关系式推广到了关于向量的函数(即多元函数),如果把向量看作一种特殊的矩阵,那么这一推论也很容易推广到关于矩阵的函数:
$$ \mathrm{d}f = \frac{\partial f}{\partial X}\cdot \mathrm{d}X \tag{1} $$
即:关于矩阵的函数的全微分 $\mathrm{d}f$ 是导数矩阵 $\displaystyle \frac{\partial f}{\partial X}$ 与微分变量 $\mathrm{d}X$ 的内积。(推论3)
注:这里将内积的定义从向量推广到了矩阵,即先对做逐元素相乘,然后将所有乘积求和。
由于标量和向量都可以看作是矩阵的特殊情况,因此推论3涵盖了推论 1、2。至此,我们得到了通用表达式 $(1)$。
运算规则
有了定义,还需要一套完整的运算规则才能实际应用。
我们依然从最简单的标量函数中汲取灵感。例如,对于函数 $\displaystyle f(x)=x^2 sin(e^x)$,我如何求导呢?我们通常不是直接从导数的定义出发,而是先建立了初等函数的导数和四则运算、复合等法则,然后运用这些法则求导。因此,我们也需要建立矩阵微分的运算规则。
矩阵的微分运算
-
加(减)法:
$$ \mathrm{d}(A\pm B)=\mathrm{d}A\pm\mathrm{d}B $$
-
乘法:
$$ \mathrm{d}(AB)=\mathrm{d}AB+A \mathrm{d}B $$
-
逆:
$$ \mathrm{d}A^{-1}=-A^{-1}\mathrm{d}A A^{-1} $$
-
转置:
$$ \mathrm{d}A^T=(\mathrm{d}A)^T $$
-
迹(trace):
$$ \mathrm{d}tr(A)=tr(\mathrm{d}A) $$
-
行列式:
$$ \mathrm{d}|A|=tr(A^*\mathrm{d}A) $$
其中 $A^*$ 表示 $A$ 的伴随矩阵。如果 $A$ 可逆,则又有:
$$ \mathrm{d}|A|=|A|tr(A^{-1}\mathrm{d}A) $$
-
逐元素(element-wise)乘法:
$$ \mathrm{d}(A\odot B)=\mathrm{d}A\odot B+A\odot \mathrm{d}B $$
-
逐元素函数:
$$ \mathrm{d}h(A)=h’(A)\odot \mathrm{d}A $$
其中,$h(X)$ 为对矩阵 $X$ 进行逐元素运算的标量函数。
矩阵的点乘运算
考虑到点乘运算是公式 $(1)$ 中的核心操作,因此我们还需要一些关于点乘的运算规则(注意:点乘要求参加运算的两个矩阵形状相同):
-
定义式:
$$ A\cdot B=tr(A^TB) $$
对于向量来说,该式可以进一步简化:$\displaystyle \boldsymbol{a}\cdot \boldsymbol{b}=\boldsymbol{a}^T \boldsymbol{b}$
-
交换律:
$$ A\cdot B = B\cdot A $$
-
加法分配律:
$$ A\cdot(B+C)=A\cdot B + A\cdot C $$
-
与逐元素乘法的结合律:
$$ A\cdot(B\odot C)=(A\odot B)\cdot C $$
-
转置的交换:
$$ A^T\cdot B = A\cdot B^T $$
矩阵的迹运算
以上公式中,多处涉及到矩阵的迹(trace)运算,这里提供一些迹技巧(注意:迹运算的对象必须为方阵):
-
标量的迹等于其自身:
$$ tr(x) = x $$
-
线性:
$$ tr(A\pm B)=tr(A)\pm tr(B) $$
-
乘法可交换性:
$$ tr(AB) = tr(BA) $$
要求 $A$ 和 $B^T$ 的形状相同,这样才能保证 $AB$ 为方阵。
-
转置迹不变:
$$ tr(A^T)=tr(A) $$
-
迹与点乘的转换:
$$ tr(AB) = A^T\cdot B $$
矩阵的其他相关运算
最后,为了方便查阅,附上矩阵的转置、逆、行列式等相关运算规则:
$$ \begin{aligned}
(A^T)^T &= A \\
(AB)^T &= B^T A^T \\
(A^{-1})^{-1} &= A \\
AA^{-1} &= A^{-1}A = I \\
(AB)^{-1} &= B^{-1}A^{-1} \\
(A^T)^{-1} &= (A^{-1})^T \\
|A^{-1}| &= |A|^{-1} \\
(kA)^{-1} &= k^{-1}A^{-1}
\end{aligned}$$
有了上述这些运算规则,只要矩阵的函数 $f$ 是由矩阵 $X$ 经过加、减、乘、转置、逆、迹、行列式、逐元素乘法、逐元素函数等运算及其复合运算构成的,我们都能利用上述运算规则求得 $\displaystyle \frac{\partial f}{\partial X}$。其基本思路为:对 $f$ 的表达式求全微分,并设法将其转化为 $\mathrm{d}f = \mathrm{expr} \cdot \mathrm{d}X$ 的形式,那么 $expr$ 即为待求导数的表达式。
算例演示
例 1
设 $\displaystyle y=\boldsymbol{a}^T X \boldsymbol{b}$,求 $\displaystyle \frac{\partial y}{\partial X}$。其中 $\boldsymbol{a}$ 为 $m\times 1$ 向量,$X$ 为 $m\times n$ 矩阵,$\boldsymbol{b}$ 为 $n\times 1$ 向量,$y$ 为标量。
解:
$$ \begin{aligned}
\mathrm{d}y &= \boldsymbol{a}^T \mathrm{d}X \boldsymbol{b} \\
&= tr(\boldsymbol{a}^T \mathrm{d}X \boldsymbol{b}) \\
&= tr(\boldsymbol{b} \boldsymbol{a}^T \mathrm{d}X) \\
&= tr((\boldsymbol{a} \boldsymbol{b}^T)^T \mathrm{d}X) \\
&= \boldsymbol{a} \boldsymbol{b}^T \cdot \mathrm{d}X
\end{aligned} $$
说明:
- 第0步:对 $y$ 求全微分;
- 第1步:标量套上迹;
- 第2步:迹运算内交换 $\boldsymbol{a}^T \mathrm{d}X$ 与 $\boldsymbol{b}$;
- 第3步:矩阵乘法的转置;
- 第4步:转化为点乘形式。
根据公式 $(1)$ 得到:
$$ \frac{\partial y}{\partial X}=\boldsymbol{a} \boldsymbol{b}^T $$
例 2:线性回归
线性回归的损失函数定义为 $\displaystyle l=||X \boldsymbol{w}- \boldsymbol{y}||_2^2$,求 $\boldsymbol{w}$ 的最小二乘估计,即 $\boldsymbol{w}$ 为何值时 $l$ 可取得最小值。其中 $\boldsymbol{y}$ 为 $m\times 1$ 列向量,$X$ 为 $m\times n$ 矩阵,$\boldsymbol{w}$ 为 $n\times 1$ 向量,$l$ 为标量。
解:要求极小值,只需找到 $\displaystyle \frac{\partial l}{\partial \boldsymbol{w}}$ 的零点。
我们的运算规则中并未定义二阶范数的微分,但根据向量范数的定义,我们可以将它表示成内积的形式:
$$ l=(X \boldsymbol{w}- \boldsymbol{y})\cdot(X \boldsymbol{w}- \boldsymbol{y})=(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y}) $$
接下来,求 $l$ 对 $\boldsymbol{w}$ 的微分:
$$ \begin{aligned}
\mathrm{d}l &= \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y}) + (X \boldsymbol{w}- \boldsymbol{y})^T \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y}) \\
&= (X \mathrm{d}\boldsymbol{w})^T(X \boldsymbol{w}- \boldsymbol{y})+(X \boldsymbol{w}- \boldsymbol{y})^T(X \mathrm{d}\boldsymbol{w}) \\
&= 2(X \boldsymbol{w}- \boldsymbol{y})^T X \mathrm{d}\boldsymbol{w} \\
&= 2X^T(X \boldsymbol{w}- \boldsymbol{y})\cdot \mathrm{d}\boldsymbol{w}
\end{aligned} $$
说明:
- 第2步:加号前后两项均为标量,所以对第一项加上转置,结果不变。
- 第3步:标量套上迹运算,然后转化为点积。
根据公式 $(1)$ 可得:
$$ \frac{\partial l}{\partial \boldsymbol{w}}=2X^T(X \boldsymbol{w}- \boldsymbol{y}) $$
令 $\displaystyle \frac{\partial l}{\partial \boldsymbol{w}}=\boldsymbol{0}$ 得(加粗的 $\boldsymbol{0}$ 代表零向量,其形状与 $\boldsymbol{w}$ 相同):
$$ \boldsymbol{w}=(X^T X)^{-1}X^T \boldsymbol{y} $$
例 3:多元逻辑回归
多元逻辑回归的损失函数为 $\displaystyle l=-\boldsymbol{y}^T \log (\mathrm{softmax}(W \boldsymbol{x}))$,求 $\displaystyle \frac{\partial l}{\partial W}$。其中,$\boldsymbol{y}$ 为 $m\times 1$ 的 one-hot 向量(即除一个元素为 1 外,其他元素均为 0),$W$ 为 $m\times n$ 矩阵,$\boldsymbol{x}$ 为 $n\times 1$ 向量,$l$ 为标量。$\displaystyle \mathrm{softmax}(a)=\frac{\mathrm{exp}(\boldsymbol{a})}{\boldsymbol{1}^T\mathrm{exp}(\boldsymbol{a})}$,其中 $\displaystyle \mathrm{exp}(\boldsymbol{\cdot})$ 表示逐元素求指数,$\boldsymbol{1}$ 代表全 1 向量。
解:首先将 $\mathrm{softmax}$ 的表达式代入:
$$ l=-\boldsymbol{y}^T(\log(\mathrm{exp}(W\boldsymbol{x}))- \boldsymbol{1}\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})))=-\boldsymbol{y}^T W \boldsymbol{x}+\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})) $$
这里用到了 2 个等式:
$$ \log(\frac{\boldsymbol{v}}{c})=\log(\boldsymbol{v})- \boldsymbol{1}\log© $$ $$ \boldsymbol{y}^T \boldsymbol{1}=1 $$
接下来求 $l$ 对 $W$ 的全微分:
$$ \begin{aligned}
\mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x}))}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})}
\end{aligned} $$
说明:注意逐元素函数 $\mathrm{exp}(\cdot)$ 的微分变换
由于
$$ \begin{aligned}
\boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) &= \boldsymbol{1}\cdot(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) \\
&= (\boldsymbol{1}\odot\mathrm{exp}(W \boldsymbol{x}))\cdot \mathrm{d}W \boldsymbol{x} \\
&= \mathrm{exp}(W \boldsymbol{x})\cdot \mathrm{d}W \boldsymbol{x} \\
&= \mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x}
\end{aligned} $$
说明:利用点乘与逐元素乘法的结合律
故
$$ \begin{aligned}
\mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x}}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})} \\
&= (-\boldsymbol{y}^T+\mathrm{softmax}(W \boldsymbol{x})^T)\mathrm{d}W \boldsymbol{x} \\
&= tr((\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W\boldsymbol{x}) \\
&= tr(\boldsymbol{x}(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W) \\
&= (\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T \cdot \mathrm{d}W
\end{aligned} $$
所以:
$$ \frac{\partial l}{\partial W}=(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T $$
矩阵对矩阵的导数
未完待续。