1.6 学习率

从优化动力学角度系统分析学习率的作用、热身与衰减机制,以及常见的调度策略

什么是学习率

  • 学习率(learning rate)是优化算法中控制参数更新步长的超参数

  • 在一次参数更新中,学习率决定了模型参数沿着梯度方向前进的“幅度”:

    θt+1=θtηθL(θt)\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)
    • $\theta_t$ 是当前参数

    • $\nabla_\theta \mathcal{L}$ 是损失函数对参数的梯度

    • $\eta$ 是学习率

  • 学习率过大:更新步子太猛,容易越过最优点,甚至发散

  • 学习率过小:训练稳定但收敛极慢,算力被浪费

  • 在大模型(尤其是 Transformer)中,学习率往往比模型结构本身更影响训练是否成功

热身阶段

  • 热身(warmup)指在训练初期,学习率从一个非常小的值逐步增大到设定的最大学习率

  • 一个典型的线性热身形式是:

    ηt=ηmaxtTwarmup,tTwarmup\eta_t = \eta_{\text{max}} \cdot \frac{t}{T_{\text{warmup}}}, \quad t \le T_{\text{warmup}}
  • 引入热身的原因主要有三点:

    • 参数尚未稳定:随机初始化后,模型内部表征非常不稳定,直接使用大学习率容易导致梯度爆炸

    • 自注意力结构敏感:Q/K/V 投影和 LayerNorm 在初期对权重尺度非常敏感

    • Adam 等自适应优化器在初期统计量偏差较大,需要时间“校准”

  • 经验上,Transformer 几乎总是配合 warmup 使用,否则很容易在前几百步直接 loss 崩掉

衰减阶段

为什么要衰减

  • 当学习率达到峰值后,通常会逐步降低,这一阶段称为学习率衰减

    • 精细调整:训练后期模型已进入较优区域,小学习率有助于在低损失区域内进行细粒度搜索,而不是在最优点附近来回跳跃

    • 避免震荡:高学习率在损失曲面较平坦但曲率不一致的区域,容易造成参数更新方向反复变化,表现为 loss 上下震荡

    • 提高泛化能力:较低学习率往往更容易收敛到“宽而平”的最小值,而非尖锐最小值,这通常对应更好的泛化性能

逆平方根衰减

  • Transformer 原论文中使用的经典策略:

    ηt=dmodel0.5min(t0.5, tTwarmup1.5)\eta_t = d_{\text{model}}^{-0.5} \cdot \min \left( t^{-0.5},\ t \cdot T_{\text{warmup}}^{-1.5} \right)
  • 在训练初期线性 warmup,而在训练中后期学习率按 $1/\sqrt{t}$ 缓慢下降

  • 优点:衰减速度温和,适合长时间训练,对 batch size 和训练步数变化相对鲁棒

  • 缺点:不直观,难以和训练总步数精确对齐,在有限步数训练中,后期学习率可能仍然偏大

余弦衰减 (余弦退火)

  • 余弦衰减将学习率视为一个从最大值平滑下降到最小值的半个余弦曲线:

    ηt=ηmin+12(ηmaxηmin)(1+cos(πtT))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left(1 + \cos\left(\pi \frac{t}{T}\right)\right)
    • 其中 $T$ 是衰减总步数

  • 前期下降较慢,后期下降较快,整体变化非常平滑,没有突变点

  • 实践中常见组合:linear warmup + cosine decay,在大模型预训练和微调中都非常流行

线性衰减

  • 线性衰减在 warmup 后按固定速率降低学习率:

    ηt=ηmax(1tTwarmupTTwarmup)\eta_t = \eta_{\max} \left(1 - \frac{t - T_{\text{warmup}}}{T - T_{\text{warmup}}}\right)
  • 线性衰减形式简单,行为可预测,与训练总步数强绑定

  • 适合训练步数固定、实验可控的场景

  • 缺点:衰减不够“智能”,在某些阶段可能下降过快或过慢

指数衰减

  • 指数衰减每隔一段时间按比例缩小学习率:

    ηt=η0γt,0<γ<1\eta_t = \eta_0 \cdot \gamma^t, \quad 0 < \gamma < 1
  • 或按 step size 进行:

    ηt=η0γt/k\eta_t = \eta_0 \cdot \gamma^{\lfloor t / k \rfloor}
  • 学习率下降速度快,后期可能过早进入“几乎不学习”的状态

  • 常见于传统 CNN 或训练步数不太长的任务,在大规模 Transformer 训练中相对少用

PyTorch 实现

  • 线性 warmup + 余弦衰减:

  • 逆平方根衰减(Transformer 风格):

  • 指数衰减:

  • 实际训练中,学习率调度几乎总是需要结合模型规模、batch size、优化器类型、训练总步数等进行配置,单独讨论学习率数值本身通常没有意义

Last updated

Was this helpful?