本文为知乎上的一篇文章 矩阵求导术 的笔记。
符号约定:以下使用小写字母如 x 表示标量,粗体小写字母如 a 表示向量,大写字母如 A 表示矩阵。为保持一致性,向量如不进行特殊说明,均为列向量,行向量可通过列向量的转置来表示,如 aT。
矩阵对标量的求导即逐个元素对标量求导,没什么值得讨论的。我们以下主要介绍标量对矩阵的求导,然后延伸到矩阵对矩阵的求导。
标量对矩阵的导数
定义
首先明确一下标量对矩阵求导的定义:
∂X∂f=[⋱,∂Xij∂f,⋱]
即 f 对 X 逐元素求导,并排列成与 X 形状相同的矩阵。
该定义在形式上很容易理解,但在实际计算中却难以使用,因为它破坏了矩阵的整体性。在工程实践中(比如 Matlab、numpy 等),我们倾向于把矩阵看作一个整体,从整体上对其进行的加减乘除等运算,像普通的标量一样。因此,我们应当考虑如何将求导作用于整个矩阵上,而不是作用于矩阵的各个元素上。
我们不妨从最简单的标量函数入手,构建起一套易于掌握的矩阵求导技巧。
导数与微分之间存在着密切的联系:
dy=f′(x)dx=dxdydx
即:全微分 dy 是导数 dxdy 与微分变量 dx 的积。(推论1)
明确了上述重要定义之后,我们再来看看多元函数的情形。设 f(x1,x2,x3,…,xn) 为多元函数,我们依然从全微分的定义入手:
df=i=1∑n∂xi∂fdxi
为了得到一个满足“整体性”的式子,我们需要去掉求和符号。令向量 xT=[x1,x2,x3,…,xn],得到:
df=∂x∂f⋅dx
其中,∂x∂f 代表 f 对 x 的所有项的偏导数组成的向量。运算符“⋅”代表点乘运算,其运算结果称为内积。
于是,我们得到了如下推论:
多元函数的全微分 df 是导数向量 ∂x∂f 与微分变量 dx 的内积。(推论2)
与推论1进行比较可以发现,二者在形式上是完全相同的,唯一的差别在于末尾的“积”和“内积”。然而,标量的积完全可以看作向量内积的一种特殊情况,也就是说,推论2可以涵盖推论1。
现在,我们已经把全微分与导数之间的关系式推广到了关于向量的函数(即多元函数),如果把向量看作一种特殊的矩阵,那么这一推论也很容易推广到关于矩阵的函数:
df=∂X∂f⋅dX(1)
即:关于矩阵的函数的全微分 df 是导数矩阵 ∂X∂f 与微分变量 dX 的内积。(推论3)
注:这里将内积的定义从向量推广到了矩阵,即先对做逐元素相乘,然后将所有乘积求和。
由于标量和向量都可以看作是矩阵的特殊情况,因此推论3涵盖了推论 1、2。至此,我们得到了通用表达式 (1)。
运算规则
有了定义,还需要一套完整的运算规则才能实际应用。
我们依然从最简单的标量函数中汲取灵感。例如,对于函数 f(x)=x2sin(ex),我如何求导呢?我们通常不是直接从导数的定义出发,而是先建立了初等函数的导数和四则运算、复合等法则,然后运用这些法则求导。因此,我们也需要建立矩阵微分的运算规则。
矩阵的微分运算
-
加(减)法:
d(A±B)=dA±dB
-
乘法:
d(AB)=dAB+AdB
-
逆:
dA−1=−A−1dAA−1
-
转置:
dAT=(dA)T
-
迹(trace):
dtr(A)=tr(dA)
-
行列式:
d∣A∣=tr(A∗dA)
其中 A∗ 表示 A 的伴随矩阵。如果 A 可逆,则又有:
d∣A∣=∣A∣tr(A−1dA)
-
逐元素(element-wise)乘法:
d(A⊙B)=dA⊙B+A⊙dB
-
逐元素函数:
dh(A)=h′(A)⊙dA
其中,h(X) 为对矩阵 X 进行逐元素运算的标量函数。
矩阵的点乘运算
考虑到点乘运算是公式 (1) 中的核心操作,因此我们还需要一些关于点乘的运算规则(注意:点乘要求参加运算的两个矩阵形状相同):
-
定义式:
A⋅B=tr(ATB)
对于向量来说,该式可以进一步简化:a⋅b=aTb
-
交换律:
A⋅B=B⋅A
-
加法分配律:
A⋅(B+C)=A⋅B+A⋅C
-
与逐元素乘法的结合律:
A⋅(B⊙C)=(A⊙B)⋅C
-
转置的交换:
AT⋅B=A⋅BT
矩阵的迹运算
以上公式中,多处涉及到矩阵的迹(trace)运算,这里提供一些迹技巧(注意:迹运算的对象必须为方阵):
-
标量的迹等于其自身:
tr(x)=x
-
线性:
tr(A±B)=tr(A)±tr(B)
-
乘法可交换性:
tr(AB)=tr(BA)
要求 A 和 BT 的形状相同,这样才能保证 AB 为方阵。
-
转置迹不变:
tr(AT)=tr(A)
-
迹与点乘的转换:
tr(AB)=AT⋅B
矩阵的其他相关运算
最后,为了方便查阅,附上矩阵的转置、逆、行列式等相关运算规则:
(AT)T(AB)T(A−1)−1AA−1(AB)−1(AT)−1∣A−1∣(kA)−1=A=BTAT=A=A−1A=I=B−1A−1=(A−1)T=∣A∣−1=k−1A−1
有了上述这些运算规则,只要矩阵的函数 f 是由矩阵 X 经过加、减、乘、转置、逆、迹、行列式、逐元素乘法、逐元素函数等运算及其复合运算构成的,我们都能利用上述运算规则求得 ∂X∂f。其基本思路为:对 f 的表达式求全微分,并设法将其转化为 df=expr⋅dX 的形式,那么 expr 即为待求导数的表达式。
算例演示
例 1
设 y=aTXb,求 ∂X∂y。其中 a 为 m×1 向量,X 为 m×n 矩阵,b 为 n×1 向量,y 为标量。
解:
dy=aTdXb=tr(aTdXb)=tr(baTdX)=tr((abT)TdX)=abT⋅dX
说明:
- 第0步:对 y 求全微分;
- 第1步:标量套上迹;
- 第2步:迹运算内交换 aTdX 与 b;
- 第3步:矩阵乘法的转置;
- 第4步:转化为点乘形式。
根据公式 (1) 得到:
∂X∂y=abT
例 2:线性回归
线性回归的损失函数定义为 l=∣∣Xw−y∣∣22,求 w 的最小二乘估计,即 w 为何值时 l 可取得最小值。其中 y 为 m×1 列向量,X 为 m×n 矩阵,w 为 n×1 向量,l 为标量。
解:要求极小值,只需找到 ∂w∂l 的零点。
我们的运算规则中并未定义二阶范数的微分,但根据向量范数的定义,我们可以将它表示成内积的形式:
l=(Xw−y)⋅(Xw−y)=(Xw−y)T(Xw−y)
接下来,求 l 对 w 的微分:
dl=d(Xw−y)T(Xw−y)+(Xw−y)Td(Xw−y)=(Xdw)T(Xw−y)+(Xw−y)T(Xdw)=2(Xw−y)TXdw=2XT(Xw−y)⋅dw
说明:
- 第2步:加号前后两项均为标量,所以对第一项加上转置,结果不变。
- 第3步:标量套上迹运算,然后转化为点积。
根据公式 (1) 可得:
∂w∂l=2XT(Xw−y)
令 ∂w∂l=0 得(加粗的 0 代表零向量,其形状与 w 相同):
w=(XTX)−1XTy
例 3:多元逻辑回归
多元逻辑回归的损失函数为 l=−yTlog(softmax(Wx)),求 ∂W∂l。其中,y 为 m×1 的 one-hot 向量(即除一个元素为 1 外,其他元素均为 0),W 为 m×n 矩阵,x 为 n×1 向量,l 为标量。softmax(a)=1Texp(a)exp(a),其中 exp(⋅) 表示逐元素求指数,1 代表全 1 向量。
解:首先将 softmax 的表达式代入:
l=−yT(log(exp(Wx))−1log(1Texp(Wx)))=−yTWx+log(1Texp(Wx))
这里用到了 2 个等式:
\log(\frac{\boldsymbol{v}}{c})=\log(\boldsymbol{v})- \boldsymbol{1}\log(c) $$ $$ \boldsymbol{y}^T \boldsymbol{1}=1
接下来求 l 对 W 的全微分:
dl=−yTdWx+1Texp(Wx)1T(exp(Wx)⊙(dWx))
说明:注意逐元素函数 exp(⋅) 的微分变换
由于
1T(exp(Wx)⊙(dWx))=1⋅(exp(Wx)⊙(dWx))=(1⊙exp(Wx))⋅dWx=exp(Wx)⋅dWx=exp(Wx)TdWx
说明:利用点乘与逐元素乘法的结合律
故
dl=−yTdWx+1Texp(Wx)exp(Wx)TdWx=(−yT+softmax(Wx)T)dWx=tr((softmax(Wx)−y)TdWx)=tr(x(softmax(Wx)−y)TdW)=(softmax(Wx)−y)xT⋅dW
所以:
∂W∂l=(softmax(Wx)−y)xT
矩阵对矩阵的导数
未完待续。