Reivew Adam with me
Published:
Adam作为一种自适应的优化算法, 结合了Momentum以及RMSprop算法, 一方面参考动量作为参数更新方向, 一方面计算梯度的指数加权平方. Adam在深度学习领域有广泛的实用性, 同时也是过去五年来被cite数量最多的scientific paper, 根据Nature Index and Google Scholar, 被戏称为AI Paper计数器.
Fundamental
基于$\beta_1$和$\beta_2$这两个移动平均的衰减率, 我们可以得到一阶矩$M_t$, 二阶矩$G_t$, 可以分别看作梯度的平均值和方差.
\[M_t = \beta_1 M_{t-1} + (1-\beta_1) \boldsymbol{g}_{t} \\ G_t = \beta_2 G_{t-1} + (1-\beta_2) \boldsymbol{g}_{t} \odot \boldsymbol{g}_{t}\]通常我们会需要对$M_t$以及$G_t$进行偏差修正(bias correction):
\[\hat{M}_{t} = \frac{M_t}{1-\beta^t_1} \\ \hat{G}_{t} = \frac{G_t}{1-\beta^t_2}\]从而计算参数更新差值为:
\[\begin{align*} \Delta\boldsymbol{\theta}_{t} &= - \alpha \cdot \frac{\hat{M}_{t}} {\sqrt{\hat{G}_{t}} + \epsilon} \\ &= - \alpha \cdot \frac{M_t}{1-\beta^t_1} \frac{1-\beta^t_2}{G_t} \\ &= - \alpha \cdot \frac{\sqrt{1-\beta^t_2}}{1-\beta^t_1} \frac{M_t}{\sqrt{G_t} + \epsilon} \end{align*}\]其中$\alpha$是学习率, 也就是我们常说的learning rate. 但是在实际implement中, 在代码实现层面通常会将bias correction加入到$\alpha$的计算中, 即:
\[\hat{\alpha} = \alpha \frac{\sqrt{1-\beta^t_2}}{1-\beta^t_1}\]实际上得到的参数更新差值为:
\[\Delta\hat{\boldsymbol{\theta}}_{t} = - \hat{\alpha} \cdot \frac{M_t}{\sqrt{G_t}+ \epsilon}\]上述步骤也在Adam原文中的算法流程中得到体现:
Bias Correction
那么为什么我们要进行偏差修正呢, Bias Correction在训练过程中扮演了什么样的作用呢, 我们先来看一下论文原文的解释:
In case of sparse gradients, for a reliable estimate of the second moment one needs to average over many gradients by chosing a small value of β2; however it is exactly this case of small β2 where a lack of initialisation bias correction would lead to initial steps that are much larger.
以公式来具体论证的话, 在时间步$t$, 我们以指数移动平均来举例, 假设$G_0$有初始化$ G_0=0$, 根据$G_t = \beta_2 G_{t-1} + (1-\beta_1) \boldsymbol{g}_{t}^2 $, 我们可以得到$G_t$在任意时间步的累加状态:
\[G_{t}=\left(1-\beta_{2}\right) \sum_{i=1}^{t} \beta_{2}^{t-i} \cdot g_{i}^{2}\]我们想要知道的是Adam中模拟的指数移动平均$G_t$的期望 $\mathbb{E}\left[G_{t}\right]$ 能否对应真实二阶矩 $\mathbb{E}\left[g_{t}^2\right]$ , 这里有没有存在一定偏差呢, 我们不妨来计算一下 $\mathbb{E}\left[G_{t}\right]$:
\[\begin{aligned} \mathbb{E}\left[G_{t}\right] &=\mathbb{E}\left[\left(1-\beta_{2}\right) \sum_{i=1}^{t} \beta_{2}^{t-i} \cdot g_{i}^{2}\right] \\ &=\mathbb{E}\left[g_{t}^{2}\right] \cdot\left(1-\beta_{2}\right) \sum_{i=1}^{t} \beta_{2}^{t-i}+\zeta \\ &=\mathbb{E}\left[g_{t}^{2}\right] \cdot\left(1-\beta_{2}\right) \frac{1-\beta_2^t}{1-\beta_2}+\zeta \\ &=\mathbb{E}\left[g_{t}^{2}\right] \cdot\left(1-\beta_{2}^{t}\right)+\zeta \end{aligned}\]计算一个简单的等比数列求和完成推导后, 我们可以看到偏移系数$\left(1-\beta_{2}^{t}\right)$已经出现在我们面前了, 对于一阶矩$M_t$同样可以推导出偏移系数$\left(1-\beta_{1}^{t}\right)$, 这是由零值初始化导致的, 所以我们需要计算参数更新差值时进行修正.
举例来说, 在实际训练中通常我们会采用$\beta_1 =0.9$, $\beta_2 =0.999$作为超参数, 并将$M_0,G_0$初始化$M_0=0, G_0=0$. 将$\beta_1, \beta_2$ 带入上述公式中, 我们不妨来看一下训练初期会发生什么:
\[\begin{align*} M_1 &= 0.1 \boldsymbol{g}_{t} \\ G_1 &= 0.001 \boldsymbol{g}_{t}^2 \\ \Delta\boldsymbol{\theta}_{t} &= - \alpha \frac{M_1}{\sqrt{G_1 + \epsilon}} \\ &= - \alpha \frac{0.1 \boldsymbol{g}_{t}}{\sqrt{0.001 \boldsymbol{g}_{t}^2}+ \epsilon} \end{align*}\]看到矩估计值(moment estimates)实际上会很大, 很容易往初始值0的方向偏移, 特别是在模型训练初期或者$1-\beta_1, 1-\beta_2$很小的时候(这往往也就是我们通常训练会出现的场景).
下图为bias correction的比例$\frac{\sqrt{1-\beta^t_2}}{1-\beta^t_1}$在前10000个steps下的曲率图, 调整比例先减小再变大最后趋近于1.
Weight Decay
权重衰减是一种比较常用的正则化方法, 会在参数更新时, 引入一个衰减系数$\lambda$.
\[\begin{align*} \boldsymbol{\theta}_{t+1} & =(1-\lambda) \boldsymbol{\theta}_{t}-\alpha \nabla f_{t}\left(\boldsymbol{\theta}_{t}\right) \\ & =\boldsymbol{\theta}_{t}-\alpha \nabla f_{t}\left(\boldsymbol{\theta}_{t}\right) -\color{red}{\lambda\boldsymbol{\theta}_{t}} \end{align*}\]其中在标准SGD优化算法下, L2 regularization 和 Weight decay时等价的. L2 regularization在参数更新时加上一个L2惩罚项$f_t(\boldsymbol{\theta_t)}$:
\[\begin{align*} f_{t}^{reg}(\boldsymbol{\theta})&=f_{t}(\boldsymbol{\theta})+\frac{\lambda^{\prime}}{2}\|\boldsymbol{\theta}\|_{2}^{2} \\ \boldsymbol\theta_{t+1} &= \boldsymbol\theta_{t}-\alpha \nabla f_{t}^{\mathrm{reg}}\left(\boldsymbol{\theta}_{t}\right) \\ &=\boldsymbol{\theta}_{t}-\alpha \nabla f_{t}\left(\boldsymbol{\theta}_{t}\right)-\color{red}{\alpha \lambda^{\prime} \boldsymbol\theta_{t}} \end{align*}\]上述公式中的$\lambda^{\prime}/2$为惩罚系数, 可以看到当$\lambda^{\prime} = \frac{\lambda}{\alpha}$时, L2 regularization等价于Weight decay.
也正是因为二者在标准SGD优化算法下的等价性, 二者也经常同日而语, 在很多深度学习库中也会以Weight decay的形式来实现L2 regularization, 即不直接改变loss function以及训练目标函数, 直接在梯度更新值上加上$\lambda^{\prime} \boldsymbol\theta_{t}$. 这样的做法固然方便合理没有增加额外的计算开支, 但是[ICLR 2019] Decoupled Weight Decay Regularization指出二者在引入动量概念或者Adam这样的自适应梯度下降算法下其实并不是等价的.
考虑最简单的Momentum的方法来证明这种不等价性, $M_0=0$, $\mu$为动量因子:
\[\begin{align*} M_t &= \mu M_{t-1} + \boldsymbol{g}_{t} \\ \Delta \boldsymbol{\theta_{t}} &= - \alpha \cdot M_t \\ &= - \alpha \cdot \mu M_{t-1} - \alpha \cdot \boldsymbol{g}_{t}\\ &= - \alpha \sum^t_{i=1} \mu^{t-i} \boldsymbol{g}_{i} \end{align*}\]为需要优化的函数$f_t(\boldsymbol\theta)$加入L2 regularization $f_{t}^{reg}(\boldsymbol{\theta})=f_{t}(\boldsymbol{\theta})+\frac{\lambda^{\prime}}{2}|\boldsymbol{\theta}|_{2}^{2}$, 得到正则化后的动量$M_t^{reg}$
\[\begin{align*} \boldsymbol{g}_{t}^{reg} &= \boldsymbol{g}_{t} + \color{red}{\lambda^{\prime}\boldsymbol{\theta_t}} \\ M_t^{reg} &= \mu M_{t-1}^{reg} + \boldsymbol{g}_{t}^{reg} \\ &= \mu M_{t-1}^{reg} + \boldsymbol{g}_{t} + \color{red}{\lambda^{\prime}\boldsymbol{\theta_t}} \end{align*}\]以及在时间步$t$下的参数更新值$\Delta \boldsymbol{\theta_{t}^{reg}}$
\[\begin{align*} \Delta \boldsymbol{\theta_{t}^{reg}} &= - \alpha \cdot M_t^{reg} \\ &= - \alpha \sum^t_{i=1} \mu^{t-i} (\boldsymbol{g}_{i} + \color{red}{\lambda^{\prime}\boldsymbol{\theta_i}}) \end{align*}\]对于Weight decay, 我们直接在权重更新进行衰减:
\[\begin{align*} \boldsymbol{\theta}_{t+1}^{wd} & = (1- \lambda)\boldsymbol{\theta}_{t}-\alpha M_t\\ \Delta \boldsymbol{\theta}_{t}^{wd} &= - \alpha M_t -\color{red}{\lambda\boldsymbol{\theta}_{t}} \\ & = - \alpha \sum^t_{i=1} \mu^{t-i} \boldsymbol{g}_{i} - \color{red}{\lambda\boldsymbol{\theta}_{t}} \end{align*}\]而一般情况下不存在超参数$\alpha, \lambda, \lambda^{\prime}$ 对于所有时间步下的$\lambda\boldsymbol{\theta}{t}$ 满足 $\alpha \sum^t{i=1} \mu^{t-i} \lambda^{\prime}\boldsymbol{\theta_i} = \lambda\boldsymbol{\theta}_{t}$, 从而证伪一致性. 而Adam这种更加复杂的梯度下降算法只会让这种差异性更大.
经过一番整理与讨论, 我们来归纳一下Adam optimizer中二者的差异:
- L2 regularization: 在计算梯度$g_t$时直接改变梯度值, 损失函数的梯度和L2惩罚项都被调整了
- Weight Decay: 在权值$\Delta \boldsymbol{\theta_{t}}$更新时进行衰减, 只有损失函数的梯度被调整了
尽管上述两种正则化机制都以相同的速率迫使权重接近于零. 但是后者会导致惩罚系数$\lambda^\prime$与学习率$\alpha$之间存在耦合, 为了解耦这两个超参数的影响, 进而孵化出了Adam with decoupled weight decay (AdamW)算法.
AdamW
用公式的形式来翻译一下上述伪代码, 理解一下最后呈现出来的$\Delta \boldsymbol{\theta_t}$
\[\begin{align*} \Delta \boldsymbol{\theta_t} & = -\eta_{t} \cdot \alpha \frac{\hat{M}_{t}}{\sqrt{\hat{G}_{t}} + \epsilon} \\ & \overset{\text{line 9}} = -\eta_{t} \cdot \alpha \frac{M_{t}}{(1-\beta^t_1)} \cdot \frac{1}{\sqrt{\hat{G}_{t}} + \epsilon} \\ & \overset{\text{line 7}}= -\eta_{t} \cdot \alpha \frac{\beta_1 M_{t-1} + (1-\beta_1) \boldsymbol{g}_{t}}{(1-\beta^t_1)(\sqrt{\hat{G}_{t}} + \epsilon)} \\ & \overset{\text{line 6}} = - \eta_{t} \cdot \alpha \frac{\beta_1 M_{t-1} + (1-\beta_1) \color{red}{(\nabla f_{t}\left(\boldsymbol{\theta}_{t-1}\right)} + \lambda \boldsymbol{\theta}_{t-1})}{(1-\beta^t_1)(\sqrt{\hat{G}_{t}} + \epsilon)} \\ \end{align*}\]L2 regularization中, 我们通过时直接改变梯度值$g_t$, 损失函数的梯度($\nabla f_{t}\left(\boldsymbol{\theta}{t-1}\right)$)和L2惩罚项的梯度($\lambda \boldsymbol{\theta}{t-1}$)都被调整了, 然后调整系数为$g_t=$ $\nabla f_{t}\left(\boldsymbol{\theta}{t-1}\right) + \lambda \boldsymbol{\theta}{t-1}$ 而非 $\nabla f_{t}\left(\boldsymbol{\theta}{t-1}\right)$本身. 对应的$\sqrt{\hat{G}{t}} + \epsilon$会有更大的值, 从而导致在梯度变大比较快速的方向上, $\frac{\lambda \boldsymbol{\theta}{t-1}}{\sqrt{\hat{G}{t}} + \epsilon}$偏小, 从而降低了正则化的有效性.
乾坤大挪移搬运一下原论文的解读:
with L2 regularization, the sums of the gradient of the loss function and the gradient of the regularizer (i.e., the L2 norm of the weights) are adapted, whereas with decoupled weight decay, only the gradients of the loss function are adapted (with the weight decay step separated from the adaptive gradient mechanism). With L2 regularization both types of gradients are normalized by their typical (summed) magnitudes, and therefore weights x with large typical gradient magnitude s are regularized by a smaller relative amount than other weights. In contrast, decoupled weight decay regularizes all weights with the same rate λ, effectively regularizing weights x with large s more than standard L2 regularization
Implement
最后我们从代码层面来回顾一下adam梯度下降算法的各种变体:
- 带bias correction的adam算法可以简单的表示为(implemented by mxnet)
rescaled_grad = clip(grad * rescale_grad, clip_gradient) + wd * weight
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
w = w - lr * m / (sqrt(v) + epsilon)
- 对应的AdamW算法:
grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
lr = learning_rate * sqrt(1 - beta2**t) / (1 - beta1**t)
w = w - lr * (m / (sqrt(v) + epsilon) + wd * w)
- 不带bias correction的AdamW算法表示为:
grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad**2)
lr = learning_rate
w = w - lr * (m / (sqrt(v) + epsilon) + wd * w)
BERT等相关预训练模型的tensorflow官方实现)当中应用的就是不带bias correction的AdamW算法
参考资料:
都9102年了,别再用Adam + L2 regularization了 - paperplanet的文章 - 知乎 https://zhuanlan.zhihu.com/p/63982470
AdamW and Super-convergence is now the fastest way to train neural nets
邱锡鹏,神经网络与深度学习,机械工业出版社,https://nndl.github.io/, 2020.