4. 模型定义
nn.Module 的设计思想
nn.Module 是一个「可组合、可递归、可追踪状态的计算单元」,它可以是
一层(Linear、Conv)
一个 block(ResBlock、Attention)
一个完整模型
甚至一个 loss
nn.Module 的设计,核心在于同时解决这三个问题:
描述前向计算逻辑
管理参数与状态
支持自动求导与设备迁移
nn.Module 与“动态图”的关系
PyTorch 是动态图框架,意味着计算图是在 forward 执行时动态构建的,Python 控制流是真实存在的
动态图有一个天然问题:参数如何管理?怎么知道哪些 Tensor 是模型的一部分?
而 nn.Module 就是在 Python 层管理那些东西属于模型,并不控制计算图本身
所有规范的 nn.Module,都遵循两个隐含约定:
__init__:定义结构、注册参数 / 子模块
forward:定义一次前向计算如何发生
只要继承 nn.Module,正确注册参数,正确实现 forward,就可以自动进行反向传播、用优化器更新、移动至 GPU 做分布式训练
一个最小可训练的 Module 示例
import torch import torch.nn as nn # 继承 nn.Module,意味着这个对象可以被 PyTorch 识别为“模型”;拥有参数注册、状态管理能力;能被 optimizer、cuda、DDP 等系统组件接管 class SimpleModel(nn.Module): def __init__(self): # 初始化 nn.Module 内部的参数追踪系统,包括几个关键容器:参数(Parameter)、子模块(Module)、buffer super().__init__() # 参数必须在 __init__ 中定义 # 注册的关键在于 self.xxx = ... # 只要把一个 nn.Module 或 nn.Parameter 赋值给 self 的属性,PyTorch 就会自动追踪它 self.linear = nn.Linear(10, 1) # 定义一次前向计算如何发生 # forward 不是声明式的,是一次真实执行的 Python 函数,每执行一次,就构建一次计算图 def forward(self, x): out = self.linear(x) return out model = SimpleModel() x = torch.randn(4, 10) y = model(x) loss = y.mean() loss.backward() # 调用 model.__call__(x) # 内部做一些 hook / mode 处理 # 调用 forward(x) # 永远不要手动调用 forward(),所有地方都用 model(x)
参数与子模块的自动注册
PyTorch 通过 nn.Module 对象内部注册过的东西来分析参数
nn.Parameter子模块也是同样的规则
递归注册
nn.Module 不是列表,而是一棵树
根节点即所定义的模型,子节点是模型包含的子模块,叶子是 Parameter
当调用
model.parameters()时,PyTorch 会深度优先遍历整棵 Module 树,收集所有 Parameter
Module 列表
buffers 不是参数,但属于模型状态
权重共享:不是“复制参数”,而是多个计算路径 引用同一个 Parameter 对象
forward、计算图与自动求导
nn.Module 本身不参与求导,参与求导的是 Tensor
Module 做的事情只有两件:管理参数和在 forward 中使用这些参数进行计算
真正被 autograd 追踪的是 forward 中产生的 Tensor 和它们之间的运算关系
forward 每执行一次,就会“新建一次”计算图
requires_grad如何传播假设 model 的参数是 nn.Parameter(requires_grad=True),而 x 是普通 Tensor(requires_grad=False)
在 forward 中有
out = self.linear(x),发生的是 linear.weight 参与运算,输出的 Tensor 自动标记 requires_grad=True,autograd 记录计算历史
forward 的执行是动态的,因为它
依赖内部状态(参数、buffer)
行为会随 train / eval 模式变化
同样输入,输出可能不同(Dropout)
因此不能假设两次运行前向运算的结果相同
在 forward 中不能创建参数
Module 的容器化能力
nn.Module不是“一层”,而是一个可以递归地装其他 Module 的容器,它同时承担四个角色:参数容器(parameters)
子模块容器(submodules)
状态容器(buffers)
计算图节点(forward 组织)
子模块注册示例
PyTorch 提供了一组 “官方容器”,用来管理多个子模块
nn.Sequential:严格的前向顺序,无分支、无条件、无跳连nn.ModuleList:模块列表(不定义 forward),forward 要手动实现nn.ModuleDict:模块字典,模块有“语义名字”,可以根据 key 动态选择
Module 的递归行为
.parameters():遍历所有子 Module,收集所有nn.Parameter.to(device):递归执行所有子 Module,所有 Parameter,所有 Buffer.train()/.eval():递归通知 Dropout / BatchNorm,自定义 Module 中的self.training
残差结构(Residual / Skip Connection)
多分支结构(Branching)
state_dict 与加载机制
state_dict 是一个 Python dict,key 是字符串,value 是 Tensor
它包含所有 nn.Parameter、所有 register_buffer 注册的 buffer,按 Module 树结构递归展开
它不包含 forward 代码、Python 控制流、optimizer 逻辑、临时张量
state_dict 的命名规则来源于 Module 树结构与容器的 key / index,如 ModuleList 用 index、ModuleDict 用 key、Sequential 用 index
保存模型参数
加载模型参数
buffer 的加载行为
buffer 在 state_dict 里,不参与梯度,但影响推理行为
BatchNorm 的 running_mean / running_var
position embedding
mask
eval 模型不对 ≠ 权重问题,很可能是 buffer 没有正确加载
optimizer / scheduler 的 state_dict
在保存 checkpoint 时,一个完整 checkpoint 通常包括:
model.state_dict
optimizer.state_dict
scheduler.state_dict
epoch / step
rng state(可选)
以上是可恢复训练的最低标准
state_dict 与设备无关,其内部的 Tensor 可以在 CPU,也 可以在 GPU,加载时由 load_state_dict 拷贝
训练与推理模式切换
model.train()/model.eval()不是在切换“是否反向传播”,而是在切换“Module 的行为模式”model.train()会递归设置所有子模块的module.training = Truemodel.eval()会递归设置所有子模块的module.training = False实际上,切换训练和推理模式并不会自动关闭梯度的计算等,绝大多数 Module 无视
training标志,只有少数“状态相关”的模块会改变行为Dropout
训练时随机置零,scale 保持期望不变
推理时完全关闭,变成恒等映射
BatchNorm
训练时使用 batch mean / var,更新 running_mean / running_var(buffer)
推理时使用 running_mean / running_var,不再更新 buffer
如果 eval 忘记切换或训练集 batch 很小,模型表现会悄悄变差
nn.AlphaDropout
其他自定义 Module
正确组合方式
Transformer Block
典型的 TransformerBlock 实现如下
Post-LN(原始 Transformer)
Post-LN 的问题
∂x∂L→LN→(x+f(x))LayerNorm 在反向传播中会引入 scale / shift 的耦合,削弱 identity shortcut 的“直通梯度”
导致层数一深,梯度衰减,非常依赖 warmup,超参极其敏感
Post-LN 在 12 层左右还能忍受,24 层以上极难训练,100+ 层基本不可能
Pre-LN(现代标准):残差路径不经过 LN 层,梯度可以沿 identity 路径无衰减传播,而 LN 只作用在分支上,深度几乎不受限制
为什么 Pre-LN 的输出不再 normalize?
如果最后一层输出没有 LN,会不会数值爆炸?
答案是显然的,数值会缓慢漂移但可控
一般的补救措施是在最后加一个 Final LayerNorm 或者在 output 前进行归一化,如 GPT 模型
在非常深的模型中,可能会使用
x=x+α⋅f(LN(x)),α=N1其中 $\alpha$ 也可以是可学习参数,来控制早期训练的激活幅度,稳定深层网络的训练过程
常见模块
全连接与卷积层(基础构建模块)
全连接层
nn.Linear(in_features, out_features)
MLP 的基本组成
1D 卷积
nn.Conv1d(in_channels, out_channels, kernel_size)
时间序列、文本处理
2D 卷积
nn.Conv2d(in_channels, out_channels, kernel_size)
图像常用
3D 卷积
nn.Conv3d(in_channels, out_channels, kernel_size)
视频、医学图像等
转置卷积
nn.ConvTranspose2d(…)
图像上采样
空洞卷积
nn.Conv2d(…, dilation=N)
扩大感受野
激活函数(非线性激活)
ReLU
nn.ReLU()
最常用
LeakyReLU
nn.LeakyReLU(negative_slope=0.01)
防止死神经元
ELU
nn.ELU()
类似于 LeakyReLU
GELU
nn.GELU()
Transformer 默认激活
Sigmoid
nn.Sigmoid()
压缩到 (0,1),用于二分类输出
Tanh
nn.Tanh()
压缩到 (-1,1)
Softmax
nn.Softmax(dim)
多分类输出层使用(通常结合 CrossEntropy)
LogSoftmax
nn.LogSoftmax(dim)
通常用于 nn.NLLLoss()
池化层(Pooling)
最大池化
nn.MaxPool2d(kernel_size)
保留特征最大值
平均池化
nn.AvgPool2d(kernel_size)
取特征平均值
自适应池化
nn.AdaptiveAvgPool2d((H, W))
输出固定大小
GlobalAvgPool
nn.AdaptiveAvgPool2d((1,1))
通常用于分类尾部
归一化层(Normalization)
批归一化
nn.BatchNorm1d / 2d / 3d
对 batch 维度归一化
层归一化
nn.LayerNorm(normalized_shape)
对每个样本归一化,常用于 NLP
实例归一化
nn.InstanceNorm2d(num_features)
图像风格迁移常用
群归一化
nn.GroupNorm(num_groups, num_channels)
小 batch 情况下代替 BN
Dropout / 正则化
Dropout
nn.Dropout(p=0.5)
随机屏蔽神经元,防止过拟合
2D Dropout
nn.Dropout2d()
专用于 CNN
Alpha Dropout
nn.AlphaDropout()
与 SELU 激活函数配套使用
循环网络 / Transformer 模块
RNN
nn.RNN(input_size, hidden_size, …)
简单循环网络
GRU
nn.GRU(input_size, hidden_size, …)
Gated Recurrent Unit
LSTM
nn.LSTM(input_size, hidden_size, …)
最常用的循环网络
TransformerEncoder
nn.TransformerEncoder(…)
多头注意力块堆叠结构
MultiheadAttention
nn.MultiheadAttention(embed_dim, num_heads)
自注意力核心
PositionalEmbedding(自写)
自定义 nn.Embedding
添加位置信息
辅助模块 / 工具层
Flatten
nn.Flatten()
将多维输入拉平,通常用于 CNN 到 MLP 之间
Unflatten
nn.Unflatten(dim, unflattened_size)
展开张量为多维
Identity
nn.Identity()
占位符,用于动态启/停某层
Embedding
nn.Embedding(num_embeddings, embedding_dim)
文本 / 离散变量嵌入层
模块组合封装
Sequential
nn.Sequential(…)
顺序堆叠多个层
ModuleList
nn.ModuleList([…])
可迭代模块列表
ModuleDict
nn.ModuleDict({…})
模块字典(命名访问)
常见任务的模型结构定义
图像分类
Conv2d + BatchNorm2d + ReLU + MaxPool + Linear
文本分类
Embedding + LSTM/GRU + Linear
Transformer
Embedding + PositionalEncoding + MultiheadAttention + LayerNorm
GAN
Conv2d / ConvTranspose2d + BatchNorm + LeakyReLU / Tanh
小样本任务
GroupNorm + Dropout 替代 BN
CNN 实现
最常见的 CNN:Conv → ReLU → Pool → Linear,用来处理图像
代码实现
RNN 实现
从最简单、最“丑”的 vanilla RNN 开始,非常适合理解为什么 RNN 会梯度消失
ht=tanh(Wxhxt+Whhht−1+b)代码实现
这里,时间维度上的递归依赖是 Python for-loop 展开的,这也是 RNN 难并行的根本原因
LSTM 的核心不是“复杂”,而是用门控机制给梯度修路
itftotc~tctht=σ(Wi[xt,ht−1])=σ(Wf[xt,ht−1])=σ(Wo[xt,ht−1])=tanh(Wc[xt,ht−1])=ft⊙ct−1+it⊙c~t=ot⊙tanh(ct)代码实现
之后的 cuDNN LSTM、kernel fusion、gate packing,会在这个基础上做工程级优化
GRU 可以理解为“砍掉 cell state 的 LSTM”,结构更紧凑
ztrth~tht=σ(Wz[xt,ht−1])=σ(Wr[xt,ht−1])=tanh(Wh[xt,rt⊙ht−1])=(1−zt)⊙ht−1+zt⊙h~t代码实现
刚刚实现的 LSTM,在 Python 里是这样的逻辑:对每个时间步 $t$,把 $x_t$ 和 $h_{t-1}$ 拼起来,过一个线性层,得到 $4H$ 维,然后再分块为 $i,f,o,g$,这件事在数学上可以直接写成一行
itftotgt=W[xtht−1]+b这里的关键并不在于“LSTM 有四个门”,而在于这四个门在计算层面上完全共享同一次矩阵乘法
也就是说,在任意一个 time step,LSTM 的主要计算成本几乎全部集中在一次形状为 $(B, D + H) \times (D + H, 4H)$ 的 GEMM 上
其后的 sigmoid、tanh 以及逐元素乘法,即便放在 GPU 上执行,也只是一些规模很小的 kernel,无论是计算量还是访存量都可以忽略不计
cuDNN 所做的第一件关键事情,正是强制采用这种 gate packing 的视角:在 Python 层面,可以选择写四个 Linear(LSTM)或三个 Linear(GRU),但在 cuDNN 内部,权重在内存中始终被组织为一整块连续的 $(D+H, 4H)$ 或 $(D+H, 3H)$,从而严格对应一次 GEMM
真正理解了这一点,其实就已经理解了 cuDNN LSTM 能够如此之快的相当一部分原因
Last updated
Was this helpful?