4. 模型定义

nn.Module 的设计思想

  • nn.Module 是一个「可组合、可递归、可追踪状态的计算单元」,它可以是

    • 一层(Linear、Conv)

    • 一个 block(ResBlock、Attention)

    • 一个完整模型

    • 甚至一个 loss

  • nn.Module 的设计,核心在于同时解决这三个问题:

    • 描述前向计算逻辑

    • 管理参数与状态

    • 支持自动求导与设备迁移

  • nn.Module 与“动态图”的关系

    • PyTorch 是动态图框架,意味着计算图是在 forward 执行时动态构建的,Python 控制流是真实存在的

    • 动态图有一个天然问题:参数如何管理?怎么知道哪些 Tensor 是模型的一部分?

    • 而 nn.Module 就是在 Python 层管理那些东西属于模型,并不控制计算图本身

  • 所有规范的 nn.Module,都遵循两个隐含约定:

    • __init__:定义结构、注册参数 / 子模块

    • forward:定义一次前向计算如何发生

  • 只要继承 nn.Module,正确注册参数,正确实现 forward,就可以自动进行反向传播、用优化器更新、移动至 GPU 做分布式训练

  • 一个最小可训练的 Module 示例

    import torch
    import torch.nn as nn
    
    # 继承 nn.Module,意味着这个对象可以被 PyTorch 识别为“模型”;拥有参数注册、状态管理能力;能被 optimizer、cuda、DDP 等系统组件接管
    class SimpleModel(nn.Module):
        def __init__(self):
            # 初始化 nn.Module 内部的参数追踪系统,包括几个关键容器:参数(Parameter)、子模块(Module)、buffer
            super().__init__()
            # 参数必须在 __init__ 中定义
            # 注册的关键在于 self.xxx = ...
            # 只要把一个 nn.Module 或 nn.Parameter 赋值给 self 的属性,PyTorch 就会自动追踪它
            self.linear = nn.Linear(10, 1)
    
        # 定义一次前向计算如何发生
        # forward 不是声明式的,是一次真实执行的 Python 函数,每执行一次,就构建一次计算图
        def forward(self, x):
            out = self.linear(x)
            return out
    
    model = SimpleModel()
    x = torch.randn(4, 10)
    y = model(x)
    loss = y.mean()
    loss.backward()
    # 调用 model.__call__(x)
    # 内部做一些 hook / mode 处理
    # 调用 forward(x)
    # 永远不要手动调用 forward(),所有地方都用 model(x)

参数与子模块的自动注册

  • PyTorch 通过 nn.Module 对象内部注册过的东西来分析参数

  • nn.Parameter

  • 子模块也是同样的规则

  • 递归注册

    • nn.Module 不是列表,而是一棵树

    • 根节点即所定义的模型,子节点是模型包含的子模块,叶子是 Parameter

    • 当调用 model.parameters() 时,PyTorch 会深度优先遍历整棵 Module 树,收集所有 Parameter

  • Module 列表

  • buffers 不是参数,但属于模型状态

  • 权重共享:不是“复制参数”,而是多个计算路径 引用同一个 Parameter 对象

forward、计算图与自动求导

  • nn.Module 本身不参与求导,参与求导的是 Tensor

  • Module 做的事情只有两件:管理参数和在 forward 中使用这些参数进行计算

  • 真正被 autograd 追踪的是 forward 中产生的 Tensor 和它们之间的运算关系

  • forward 每执行一次,就会“新建一次”计算图

  • requires_grad 如何传播

    • 假设 model 的参数是 nn.Parameter(requires_grad=True),而 x 是普通 Tensor(requires_grad=False)

    • 在 forward 中有 out = self.linear(x),发生的是 linear.weight 参与运算,输出的 Tensor 自动标记 requires_grad=True,autograd 记录计算历史

  • forward 的执行是动态的,因为它

    • 依赖内部状态(参数、buffer)

    • 行为会随 train / eval 模式变化

    • 同样输入,输出可能不同(Dropout)

    • 因此不能假设两次运行前向运算的结果相同

  • 在 forward 中不能创建参数

Module 的容器化能力

  • nn.Module 不是“一层”,而是一个可以递归地装其他 Module 的容器,它同时承担四个角色:

    • 参数容器(parameters)

    • 子模块容器(submodules)

    • 状态容器(buffers)

    • 计算图节点(forward 组织)

  • 子模块注册示例

  • PyTorch 提供了一组 “官方容器”,用来管理多个子模块

    • nn.Sequential:严格的前向顺序,无分支、无条件、无跳连

    • nn.ModuleList:模块列表(不定义 forward),forward 要手动实现

    • nn.ModuleDict:模块字典,模块有“语义名字”,可以根据 key 动态选择

  • Module 的递归行为

    • .parameters():遍历所有子 Module,收集所有 nn.Parameter

    • .to(device):递归执行所有子 Module,所有 Parameter,所有 Buffer

    • .train() / .eval():递归通知 Dropout / BatchNorm,自定义 Module 中的 self.training

  • 残差结构(Residual / Skip Connection)

  • 多分支结构(Branching)

state_dict 与加载机制

  • state_dict 是一个 Python dict,key 是字符串,value 是 Tensor

  • 它包含所有 nn.Parameter、所有 register_buffer 注册的 buffer,按 Module 树结构递归展开

  • 它不包含 forward 代码、Python 控制流、optimizer 逻辑、临时张量

  • state_dict 的命名规则来源于 Module 树结构与容器的 key / index,如 ModuleList 用 index、ModuleDict 用 key、Sequential 用 index

  • 保存模型参数

  • 加载模型参数

  • buffer 的加载行为

    • buffer 在 state_dict 里,不参与梯度,但影响推理行为

      • BatchNorm 的 running_mean / running_var

      • position embedding

      • mask

    • eval 模型不对 ≠ 权重问题,很可能是 buffer 没有正确加载

  • optimizer / scheduler 的 state_dict

  • 在保存 checkpoint 时,一个完整 checkpoint 通常包括:

    • model.state_dict

    • optimizer.state_dict

    • scheduler.state_dict

    • epoch / step

    • rng state(可选)

    • 以上是可恢复训练的最低标准

  • state_dict 与设备无关,其内部的 Tensor 可以在 CPU,也 可以在 GPU,加载时由 load_state_dict 拷贝

训练与推理模式切换

  • model.train() / model.eval() 不是在切换“是否反向传播”,而是在切换“Module 的行为模式”

  • model.train() 会递归设置所有子模块的 module.training = True

  • model.eval() 会递归设置所有子模块的 module.training = False

  • 实际上,切换训练和推理模式并不会自动关闭梯度的计算等,绝大多数 Module 无视 training 标志,只有少数“状态相关”的模块会改变行为

    • Dropout

      • 训练时随机置零,scale 保持期望不变

      • 推理时完全关闭,变成恒等映射

    • BatchNorm

      • 训练时使用 batch mean / var,更新 running_mean / running_var(buffer)

      • 推理时使用 running_mean / running_var,不再更新 buffer

      • 如果 eval 忘记切换或训练集 batch 很小,模型表现会悄悄变差

    • nn.AlphaDropout

    • 其他自定义 Module

  • 正确组合方式

Transformer Block

  • 典型的 TransformerBlock 实现如下

  • Post-LN(原始 Transformer)

  • Post-LN 的问题

    LxLN(x+f(x))\frac{\partial \mathcal{L}}{\partial x} \rightarrow \text{LN} \rightarrow (x + f(x))
    • LayerNorm 在反向传播中会引入 scale / shift 的耦合,削弱 identity shortcut 的“直通梯度”

    • 导致层数一深,梯度衰减,非常依赖 warmup,超参极其敏感

    • Post-LN 在 12 层左右还能忍受,24 层以上极难训练,100+ 层基本不可能

  • Pre-LN(现代标准):残差路径不经过 LN 层,梯度可以沿 identity 路径无衰减传播,而 LN 只作用在分支上,深度几乎不受限制

  • 为什么 Pre-LN 的输出不再 normalize?

    • 如果最后一层输出没有 LN,会不会数值爆炸?

    • 答案是显然的,数值会缓慢漂移但可控

    • 一般的补救措施是在最后加一个 Final LayerNorm 或者在 output 前进行归一化,如 GPT 模型

  • 在非常深的模型中,可能会使用

    x=x+αf(LN(x)),α=1Nx = x + \alpha \cdot f(\text{LN}(x)),\quad \alpha = \frac{1}{\sqrt{N}}
    • 其中 $\alpha$ 也可以是可学习参数,来控制早期训练的激活幅度,稳定深层网络的训练过程

常见模块

全连接与卷积层(基础构建模块)

名称
类名
说明

全连接层

nn.Linear(in_features, out_features)

MLP 的基本组成

1D 卷积

nn.Conv1d(in_channels, out_channels, kernel_size)

时间序列、文本处理

2D 卷积

nn.Conv2d(in_channels, out_channels, kernel_size)

图像常用

3D 卷积

nn.Conv3d(in_channels, out_channels, kernel_size)

视频、医学图像等

转置卷积

nn.ConvTranspose2d(…)

图像上采样

空洞卷积

nn.Conv2d(…, dilation=N)

扩大感受野

激活函数(非线性激活)

名称
类名
特性

ReLU

nn.ReLU()

最常用

LeakyReLU

nn.LeakyReLU(negative_slope=0.01)

防止死神经元

ELU

nn.ELU()

类似于 LeakyReLU

GELU

nn.GELU()

Transformer 默认激活

Sigmoid

nn.Sigmoid()

压缩到 (0,1),用于二分类输出

Tanh

nn.Tanh()

压缩到 (-1,1)

Softmax

nn.Softmax(dim)

多分类输出层使用(通常结合 CrossEntropy)

LogSoftmax

nn.LogSoftmax(dim)

通常用于 nn.NLLLoss()

池化层(Pooling)

名称
类名
用途

最大池化

nn.MaxPool2d(kernel_size)

保留特征最大值

平均池化

nn.AvgPool2d(kernel_size)

取特征平均值

自适应池化

nn.AdaptiveAvgPool2d((H, W))

输出固定大小

GlobalAvgPool

nn.AdaptiveAvgPool2d((1,1))

通常用于分类尾部

归一化层(Normalization)

名称
类名
用途

批归一化

nn.BatchNorm1d / 2d / 3d

对 batch 维度归一化

层归一化

nn.LayerNorm(normalized_shape)

对每个样本归一化,常用于 NLP

实例归一化

nn.InstanceNorm2d(num_features)

图像风格迁移常用

群归一化

nn.GroupNorm(num_groups, num_channels)

小 batch 情况下代替 BN

Dropout / 正则化

名称
类名
用途

Dropout

nn.Dropout(p=0.5)

随机屏蔽神经元,防止过拟合

2D Dropout

nn.Dropout2d()

专用于 CNN

Alpha Dropout

nn.AlphaDropout()

与 SELU 激活函数配套使用

循环网络 / Transformer 模块

名称
类名
用途

RNN

nn.RNN(input_size, hidden_size, …)

简单循环网络

GRU

nn.GRU(input_size, hidden_size, …)

Gated Recurrent Unit

LSTM

nn.LSTM(input_size, hidden_size, …)

最常用的循环网络

TransformerEncoder

nn.TransformerEncoder(…)

多头注意力块堆叠结构

MultiheadAttention

nn.MultiheadAttention(embed_dim, num_heads)

自注意力核心

PositionalEmbedding(自写)

自定义 nn.Embedding

添加位置信息

辅助模块 / 工具层

名称
类名
用途

Flatten

nn.Flatten()

将多维输入拉平,通常用于 CNN 到 MLP 之间

Unflatten

nn.Unflatten(dim, unflattened_size)

展开张量为多维

Identity

nn.Identity()

占位符,用于动态启/停某层

Embedding

nn.Embedding(num_embeddings, embedding_dim)

文本 / 离散变量嵌入层

模块组合封装

名称
类名
说明

Sequential

nn.Sequential(…)

顺序堆叠多个层

ModuleList

nn.ModuleList([…])

可迭代模块列表

ModuleDict

nn.ModuleDict({…})

模块字典(命名访问)

常见任务的模型结构定义

任务类型
常用模块组合

图像分类

Conv2d + BatchNorm2d + ReLU + MaxPool + Linear

文本分类

Embedding + LSTM/GRU + Linear

Transformer

Embedding + PositionalEncoding + MultiheadAttention + LayerNorm

GAN

Conv2d / ConvTranspose2d + BatchNorm + LeakyReLU / Tanh

小样本任务

GroupNorm + Dropout 替代 BN

CNN 实现

  • 最常见的 CNN:Conv → ReLU → Pool → Linear,用来处理图像

  • 代码实现

RNN 实现

  • 从最简单、最“丑”的 vanilla RNN 开始,非常适合理解为什么 RNN 会梯度消失

    ht=tanh(Wxhxt+Whhht1+b)h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b)
  • 代码实现

  • 这里,时间维度上的递归依赖是 Python for-loop 展开的,这也是 RNN 难并行的根本原因

  • LSTM 的核心不是“复杂”,而是用门控机制给梯度修路

    it=σ(Wi[xt,ht1])ft=σ(Wf[xt,ht1])ot=σ(Wo[xt,ht1])c~t=tanh(Wc[xt,ht1])ct=ftct1+itc~tht=ottanh(ct)\begin{align} i_t &= \sigma(W_i [x_t, h_{t-1}]) \\ f_t &= \sigma(W_f [x_t, h_{t-1}]) \\ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\ \tilde{c}_t &= \tanh(W_c [x_t, h_{t-1}]) \\ c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ h_t &= o_t \odot \tanh(c_t) \end{align}
  • 代码实现

  • 之后的 cuDNN LSTM、kernel fusion、gate packing,会在这个基础上做工程级优化

  • GRU 可以理解为“砍掉 cell state 的 LSTM”,结构更紧凑

    zt=σ(Wz[xt,ht1])rt=σ(Wr[xt,ht1])h~t=tanh(Wh[xt,rtht1])ht=(1zt)ht1+zth~t\begin{align} z_t &= \sigma(W_z [x_t, h_{t-1}]) \\ r_t &= \sigma(W_r [x_t, h_{t-1}]) \\ \tilde{h}_t &= \tanh(W_h [x_t, r_t \odot h_{t-1}]) \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{align}
  • 代码实现

  • 刚刚实现的 LSTM,在 Python 里是这样的逻辑:对每个时间步 $t$,把 $x_t$ 和 $h_{t-1}$ 拼起来,过一个线性层,得到 $4H$ 维,然后再分块为 $i,f,o,g$​,这件事在数学上可以直接写成一行

    [itftotgt]=W[xtht1]+b\begin{bmatrix} i_t \\ f_t \\ o_t \\ g_t \end{bmatrix} =W \begin{bmatrix} x_t \\ h_{t-1} \end{bmatrix} +b
  • 这里的关键并不在于“LSTM 有四个门”,而在于这四个门在计算层面上完全共享同一次矩阵乘法

  • 也就是说,在任意一个 time step,LSTM 的主要计算成本几乎全部集中在一次形状为 $(B, D + H) \times (D + H, 4H)$ 的 GEMM 上

  • 其后的 sigmoid、tanh 以及逐元素乘法,即便放在 GPU 上执行,也只是一些规模很小的 kernel,无论是计算量还是访存量都可以忽略不计

  • cuDNN 所做的第一件关键事情,正是强制采用这种 gate packing 的视角:在 Python 层面,可以选择写四个 Linear(LSTM)或三个 Linear(GRU),但在 cuDNN 内部,权重在内存中始终被组织为一整块连续的 $(D+H, 4H)$ 或 $(D+H, 3H)$,从而严格对应一次 GEMM

  • 真正理解了这一点,其实就已经理解了 cuDNN LSTM 能够如此之快的相当一部分原因

Last updated

Was this helpful?