矩阵求导术

本文为知乎上的一篇文章 矩阵求导术 的笔记。

符号约定:以下使用小写字母如 xx 表示标量,粗体小写字母如 a\boldsymbol{a} 表示向量,大写字母如 AA 表示矩阵。为保持一致性,向量如不进行特殊说明,均为列向量,行向量可通过列向量的转置来表示,如 aT\boldsymbol{a}^T

矩阵对标量的求导即逐个元素对标量求导,没什么值得讨论的。我们以下主要介绍标量对矩阵的求导,然后延伸到矩阵对矩阵的求导。

标量对矩阵的导数

定义

首先明确一下标量对矩阵求导的定义:

fX=[,fXij,]\frac{\partial f}{\partial X} = \left[\ddots, \frac{\partial f}{\partial X_{ij}}, \ddots \right]

ffXX 逐元素求导,并排列成与 XX 形状相同的矩阵。

该定义在形式上很容易理解,但在实际计算中却难以使用,因为它破坏了矩阵的整体性。在工程实践中(比如 Matlab、numpy 等),我们倾向于把矩阵看作一个整体,从整体上对其进行的加减乘除等运算,像普通的标量一样。因此,我们应当考虑如何将求导作用于整个矩阵上,而不是作用于矩阵的各个元素上。

我们不妨从最简单的标量函数入手,构建起一套易于掌握的矩阵求导技巧。

导数与微分之间存在着密切的联系:

dy=f(x)dx=dydxdx\mathrm{d}y = f'(x)\mathrm{d}x = \frac{\mathrm{d}y}{\mathrm{d}x}\mathrm{d}x

即:全微分 dy\mathrm{d}y 是导数 dydx\displaystyle \frac{\mathrm{d}y}{\mathrm{d}x} 与微分变量 dx\mathrm{d}x 的积。(推论1)

明确了上述重要定义之后,我们再来看看多元函数的情形。设 f(x1,x2,x3,,xn)f(x_1, x_2, x_3, \dots, x_n) 为多元函数,我们依然从全微分的定义入手:

df=i=1nfxidxi\mathrm{d}f = \sum_{i=1}^{n} \frac{\partial f}{\partial x_i}\mathrm{d}x_i

为了得到一个满足“整体性”的式子,我们需要去掉求和符号。令向量 xT=[x1,x2,x3,,xn]\boldsymbol{x}^T=[x_1, x_2, x_3, \dots, x_n],得到:

df=fxdx\mathrm{d}f = \frac{\partial f}{\partial \boldsymbol{x}} \cdot \mathrm{d}\boldsymbol{x}

其中,fx\displaystyle \frac{\partial f}{\partial \boldsymbol{x}} 代表 ffx\boldsymbol{x} 的所有项的偏导数组成的向量。运算符“\cdot”代表点乘运算,其运算结果称为内积

于是,我们得到了如下推论:

多元函数的全微分 df\mathrm{d}f 是导数向量 fx\displaystyle \frac{\partial f}{\partial \boldsymbol{x}} 与微分变量 dx\mathrm{d}\boldsymbol{x} 的内积。(推论2)

与推论1进行比较可以发现,二者在形式上是完全相同的,唯一的差别在于末尾的“积”和“内积”。然而,标量的积完全可以看作向量内积的一种特殊情况,也就是说,推论2可以涵盖推论1。

现在,我们已经把全微分与导数之间的关系式推广到了关于向量的函数(即多元函数),如果把向量看作一种特殊的矩阵,那么这一推论也很容易推广到关于矩阵的函数

df=fXdX(1)\mathrm{d}f = \frac{\partial f}{\partial X}\cdot \mathrm{d}X \tag{1}

即:关于矩阵的函数的全微分 df\mathrm{d}f 是导数矩阵 fX\displaystyle \frac{\partial f}{\partial X} 与微分变量 dX\mathrm{d}X 的内积。(推论3)

注:这里将内积的定义从向量推广到了矩阵,即先对做逐元素相乘,然后将所有乘积求和。

由于标量和向量都可以看作是矩阵的特殊情况,因此推论3涵盖了推论 1、2。至此,我们得到了通用表达式 (1)(1)

运算规则

有了定义,还需要一套完整的运算规则才能实际应用。

我们依然从最简单的标量函数中汲取灵感。例如,对于函数 f(x)=x2sin(ex)\displaystyle f(x)=x^2 sin(e^x),我如何求导呢?我们通常不是直接从导数的定义出发,而是先建立了初等函数的导数和四则运算、复合等法则,然后运用这些法则求导。因此,我们也需要建立矩阵微分的运算规则。

矩阵的微分运算

  1. 加(减)法:

    d(A±B)=dA±dB\mathrm{d}(A\pm B)=\mathrm{d}A\pm\mathrm{d}B

  2. 乘法:

    d(AB)=dAB+AdB\mathrm{d}(AB)=\mathrm{d}AB+A \mathrm{d}B

  3. 逆:

    dA1=A1dAA1\mathrm{d}A^{-1}=-A^{-1}\mathrm{d}A A^{-1}

  4. 转置:

    dAT=(dA)T\mathrm{d}A^T=(\mathrm{d}A)^T

  5. 迹(trace):

    dtr(A)=tr(dA)\mathrm{d}tr(A)=tr(\mathrm{d}A)

  6. 行列式:

    dA=tr(AdA)\mathrm{d}|A|=tr(A^*\mathrm{d}A)

    其中 AA^* 表示 AA 的伴随矩阵。如果 AA 可逆,则又有:

    dA=Atr(A1dA)\mathrm{d}|A|=|A|tr(A^{-1}\mathrm{d}A)

  7. 逐元素(element-wise)乘法:

    d(AB)=dAB+AdB\mathrm{d}(A\odot B)=\mathrm{d}A\odot B+A\odot \mathrm{d}B

  8. 逐元素函数:

    dh(A)=h(A)dA\mathrm{d}h(A)=h'(A)\odot \mathrm{d}A

    其中,h(X)h(X) 为对矩阵 XX 进行逐元素运算的标量函数。

矩阵的点乘运算

考虑到点乘运算是公式 (1)(1) 中的核心操作,因此我们还需要一些关于点乘的运算规则(注意:点乘要求参加运算的两个矩阵形状相同):

  1. 定义式:

    AB=tr(ATB)A\cdot B=tr(A^TB)

    对于向量来说,该式可以进一步简化:ab=aTb\displaystyle \boldsymbol{a}\cdot \boldsymbol{b}=\boldsymbol{a}^T \boldsymbol{b}

  2. 交换律:

    AB=BAA\cdot B = B\cdot A

  3. 加法分配律:

    A(B+C)=AB+ACA\cdot(B+C)=A\cdot B + A\cdot C

  4. 与逐元素乘法的结合律:

    A(BC)=(AB)CA\cdot(B\odot C)=(A\odot B)\cdot C

  5. 转置的交换:

    ATB=ABTA^T\cdot B = A\cdot B^T

矩阵的迹运算

以上公式中,多处涉及到矩阵的迹(trace)运算,这里提供一些迹技巧(注意:迹运算的对象必须为方阵):

  1. 标量的迹等于其自身:

    tr(x)=xtr(x) = x

  2. 线性:

    tr(A±B)=tr(A)±tr(B)tr(A\pm B)=tr(A)\pm tr(B)

  3. 乘法可交换性:

    tr(AB)=tr(BA)tr(AB) = tr(BA)

    要求 AABTB^T 的形状相同,这样才能保证 ABAB 为方阵。

  4. 转置迹不变:

    tr(AT)=tr(A)tr(A^T)=tr(A)

  5. 迹与点乘的转换

    tr(AB)=ATBtr(AB) = A^T\cdot B

矩阵的其他相关运算

最后,为了方便查阅,附上矩阵的转置、逆、行列式等相关运算规则:

(AT)T=A(AB)T=BTAT(A1)1=AAA1=A1A=I(AB)1=B1A1(AT)1=(A1)TA1=A1(kA)1=k1A1 \begin{aligned} (A^T)^T &= A \\\\ (AB)^T &= B^T A^T \\\\ (A^{-1})^{-1} &= A \\\\ AA^{-1} &= A^{-1}A = I \\\\ (AB)^{-1} &= B^{-1}A^{-1} \\\\ (A^T)^{-1} &= (A^{-1})^T \\\\ |A^{-1}| &= |A|^{-1} \\\\ (kA)^{-1} &= k^{-1}A^{-1} \end{aligned}

有了上述这些运算规则,只要矩阵的函数 ff 是由矩阵 XX 经过加、减、乘、转置、逆、迹、行列式、逐元素乘法、逐元素函数等运算及其复合运算构成的,我们都能利用上述运算规则求得 fX\displaystyle \frac{\partial f}{\partial X}。其基本思路为:对 ff 的表达式求全微分,并设法将其转化为 df=exprdX\mathrm{d}f = \mathrm{expr} \cdot \mathrm{d}X 的形式,那么 exprexpr 即为待求导数的表达式

算例演示

例 1

y=aTXb\displaystyle y=\boldsymbol{a}^T X \boldsymbol{b},求 yX\displaystyle \frac{\partial y}{\partial X}。其中 a\boldsymbol{a}m×1m\times 1 向量,XXm×nm\times n 矩阵,b\boldsymbol{b}n×1n\times 1 向量,yy 为标量。

解:

dy=aTdXb=tr(aTdXb)=tr(baTdX)=tr((abT)TdX)=abTdX \begin{aligned} \mathrm{d}y &= \boldsymbol{a}^T \mathrm{d}X \boldsymbol{b} \\\\ &= tr(\boldsymbol{a}^T \mathrm{d}X \boldsymbol{b}) \\\\ &= tr(\boldsymbol{b} \boldsymbol{a}^T \mathrm{d}X) \\\\ &= tr((\boldsymbol{a} \boldsymbol{b}^T)^T \mathrm{d}X) \\\\ &= \boldsymbol{a} \boldsymbol{b}^T \cdot \mathrm{d}X \end{aligned}

说明:

  • 第0步:对 yy 求全微分;
  • 第1步:标量套上迹;
  • 第2步:迹运算内交换 aTdX\boldsymbol{a}^T \mathrm{d}Xb\boldsymbol{b}
  • 第3步:矩阵乘法的转置;
  • 第4步:转化为点乘形式。

根据公式 (1)(1) 得到:

yX=abT\frac{\partial y}{\partial X}=\boldsymbol{a} \boldsymbol{b}^T

例 2:线性回归

线性回归的损失函数定义为 l=Xwy22\displaystyle l=||X \boldsymbol{w}- \boldsymbol{y}||_2^2,求 w\boldsymbol{w} 的最小二乘估计,即 w\boldsymbol{w} 为何值时 ll 可取得最小值。其中 y\boldsymbol{y}m×1m\times 1 列向量,XXm×nm\times n 矩阵,w\boldsymbol{w}n×1n\times 1 向量,ll 为标量。

解:要求极小值,只需找到 lw\displaystyle \frac{\partial l}{\partial \boldsymbol{w}} 的零点。

我们的运算规则中并未定义二阶范数的微分,但根据向量范数的定义,我们可以将它表示成内积的形式:

l=(Xwy)(Xwy)=(Xwy)T(Xwy)l=(X \boldsymbol{w}- \boldsymbol{y})\cdot(X \boldsymbol{w}- \boldsymbol{y})=(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y})

接下来,求 llw\boldsymbol{w} 的微分:

dl=d(Xwy)T(Xwy)+(Xwy)Td(Xwy)=(Xdw)T(Xwy)+(Xwy)T(Xdw)=2(Xwy)TXdw=2XT(Xwy)dw \begin{aligned} \mathrm{d}l &= \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y})^T(X \boldsymbol{w}- \boldsymbol{y}) + (X \boldsymbol{w}- \boldsymbol{y})^T \mathrm{d}(X \boldsymbol{w}- \boldsymbol{y}) \\\\ &= (X \mathrm{d}\boldsymbol{w})^T(X \boldsymbol{w}- \boldsymbol{y})+(X \boldsymbol{w}- \boldsymbol{y})^T(X \mathrm{d}\boldsymbol{w}) \\\\ &= 2(X \boldsymbol{w}- \boldsymbol{y})^T X \mathrm{d}\boldsymbol{w} \\\\ &= 2X^T(X \boldsymbol{w}- \boldsymbol{y})\cdot \mathrm{d}\boldsymbol{w} \end{aligned}

说明:

  • 第2步:加号前后两项均为标量,所以对第一项加上转置,结果不变。
  • 第3步:标量套上迹运算,然后转化为点积。

根据公式 (1)(1) 可得:

lw=2XT(Xwy)\frac{\partial l}{\partial \boldsymbol{w}}=2X^T(X \boldsymbol{w}- \boldsymbol{y})

lw=0\displaystyle \frac{\partial l}{\partial \boldsymbol{w}}=\boldsymbol{0} 得(加粗的 0\boldsymbol{0} 代表零向量,其形状与 w\boldsymbol{w} 相同):

w=(XTX)1XTy\boldsymbol{w}=(X^T X)^{-1}X^T \boldsymbol{y}

例 3:多元逻辑回归

多元逻辑回归的损失函数为 l=yTlog(softmax(Wx))\displaystyle l=-\boldsymbol{y}^T \log (\mathrm{softmax}(W \boldsymbol{x})),求 lW\displaystyle \frac{\partial l}{\partial W}。其中,y\boldsymbol{y}m×1m\times 1 的 one-hot 向量(即除一个元素为 1 外,其他元素均为 0),WWm×nm\times n 矩阵,x\boldsymbol{x}n×1n\times 1 向量,ll 为标量。softmax(a)=exp(a)1Texp(a)\displaystyle \mathrm{softmax}(a)=\frac{\mathrm{exp}(\boldsymbol{a})}{\boldsymbol{1}^T\mathrm{exp}(\boldsymbol{a})},其中 exp()\displaystyle \mathrm{exp}(\boldsymbol{\cdot}) 表示逐元素求指数,1\boldsymbol{1} 代表全 1 向量。

解:首先将 softmax\mathrm{softmax} 的表达式代入:

l=yT(log(exp(Wx))1log(1Texp(Wx)))=yTWx+log(1Texp(Wx))l=-\boldsymbol{y}^T(\log(\mathrm{exp}(W\boldsymbol{x}))- \boldsymbol{1}\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})))=-\boldsymbol{y}^T W \boldsymbol{x}+\log(\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x}))

这里用到了 2 个等式:

\log(\frac{\boldsymbol{v}}{c})=\log(\boldsymbol{v})- \boldsymbol{1}\log(c) $$ $$ \boldsymbol{y}^T \boldsymbol{1}=1

接下来求 llWW 的全微分:

dl=yTdWx+1T(exp(Wx)(dWx))1Texp(Wx) \begin{aligned} \mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x}))}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})} \end{aligned}

说明:注意逐元素函数 exp()\mathrm{exp}(\cdot) 的微分变换

由于

1T(exp(Wx)(dWx))=1(exp(Wx)(dWx))=(1exp(Wx))dWx=exp(Wx)dWx=exp(Wx)TdWx \begin{aligned} \boldsymbol{1}^T(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) &= \boldsymbol{1}\cdot(\mathrm{exp}(W \boldsymbol{x})\odot(\mathrm{d}W \boldsymbol{x})) \\\\ &= (\boldsymbol{1}\odot\mathrm{exp}(W \boldsymbol{x}))\cdot \mathrm{d}W \boldsymbol{x} \\\\ &= \mathrm{exp}(W \boldsymbol{x})\cdot \mathrm{d}W \boldsymbol{x} \\\\ &= \mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x} \end{aligned}

说明:利用点乘与逐元素乘法的结合律

dl=yTdWx+exp(Wx)TdWx1Texp(Wx)=(yT+softmax(Wx)T)dWx=tr((softmax(Wx)y)TdWx)=tr(x(softmax(Wx)y)TdW)=(softmax(Wx)y)xTdW \begin{aligned} \mathrm{d}l &= -\boldsymbol{y}^T\mathrm{d}W \boldsymbol{x}+\frac{\mathrm{exp}(W \boldsymbol{x})^T \mathrm{d}W \boldsymbol{x}}{\boldsymbol{1}^T\mathrm{exp}(W \boldsymbol{x})} \\\\ &= (-\boldsymbol{y}^T+\mathrm{softmax}(W \boldsymbol{x})^T)\mathrm{d}W \boldsymbol{x} \\\\ &= tr((\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W\boldsymbol{x}) \\\\ &= tr(\boldsymbol{x}(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})^T \mathrm{d}W) \\\\ &= (\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T \cdot \mathrm{d}W \end{aligned}

所以:

lW=(softmax(Wx)y)xT\frac{\partial l}{\partial W}=(\mathrm{softmax}(W \boldsymbol{x})-\boldsymbol{y})\boldsymbol{x}^T

矩阵对矩阵的导数

未完待续。