Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Pretrain 的实施细节和常见坑

上一节我们更多讨论的是模型、目标函数和优化器这些偏原理的内容.这一节开始,我们进入训练的实现部分,也就是一个 pretrain 脚本到底是怎么真正跑起来的.

这部分未必总是在讲很复杂的理论,但它对训练能不能稳定进行、能不能顺利复现、出了问题能不能快速排查,其实非常重要.很多时候,真正让预训练跑不顺的,并不是模型结构本身,而是这些实现细节没有处理好.

所以这一节我还是想按照 QA 的形式,把训练里一些很关键但又很容易被忽略的问题串起来:

  1. 训练配置和模型配置分别是什么?
  2. 一个完整的 pretrain 训练脚本通常由哪些环节组成?
  3. 为什么要设置随机种子? 应该怎么设置?
  4. 分布式训练主要在解决什么问题? 梯度累积又是怎么接到训练循环里的?
  5. 什么是混合精度训练?
  6. 梯度稳定性操作:梯度缩放,梯度剪裁? 梯度累积、梯度剪裁和 GradScaler 的顺序应该怎么放?
  7. checkpoint 应该保存什么? 又该怎么做断点恢复?
  8. 可视化工具主要看什么?

Q1: 一个完整的 Pretrain 训练脚本通常由哪些环节组成?

如果只看最核心的训练过程,好像事情很简单,无非就是:

  1. forward 算出 loss.
  2. backward 算出梯度.
  3. optimizer 更新参数.

但真实能跑起来的 pretrain 脚本,远不止这三步.它通常还要处理分布式初始化、随机种子、数据加载、混合精度、梯度累积、日志记录、checkpoint 保存和断点恢复这些事情.

以 MiniMind 为例,它的 pretrain 主流程可以粗略拆成下面几个环节:

  1. 初始化分布式环境,并确定当前进程使用的设备.
  2. 设置随机种子,尽量保证实验可复现.
  3. 创建模型配置,初始化模型和 tokenizer.
  4. 读取预训练数据,构造 DatasetSamplerDataLoader.
  5. 设置优化器、学习率调度、混合精度上下文和 GradScaler.
  6. 对每个 batch 执行 forward,计算 loss,然后 backward.
  7. 按照梯度累积步数决定什么时候真正更新参数,并在更新前后处理梯度裁剪、梯度缩放和梯度清零.
  8. 定期打印日志、记录可视化指标、保存 checkpoint.

在代码里,这条主线主要集中在 src/minimind_learning/trainer/train_pretrain.py:

# src/minimind_learning/trainer/train_pretrain.py
local_rank = init_distributed_mode()
if dist.is_initialized():
    args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

model, tokenizer = init_model(lm_config, args.from_weight, device=args.device)
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

这段代码也说明了一件事: 训练脚本不是只有 model(X)loss.backward(). 真正完整的训练流程,还包括随机性控制、数据采样、混合精度、优化器状态、分布式包装和断点恢复.

下面我们逐一进入细节:

Q1.1: 训练配置和模型配置里一般有什么?

这是一个很容易被忽略,但实际上非常重要的点.一个完整的训练流程,最好要有比较清晰的配置系统,专门去管理“模型怎么定义”和“训练怎么执行”这两类信息.这样会让整个训练流程更清楚,也会让复现容易很多.

MiniMind 现在主要是用 argparse 去传训练参数.这种方式在项目比较小时很直接,但参数一多,训练脚本和训练配置本身就会慢慢混在一起.这样一来,如果你想改训练设置,经常就得去改脚本本身,配置来源也会变得比较分散,后面无论是复现实验还是维护代码,都会有点乱.

比较常见也更清晰的做法,是把这些配置单独整理到配置文件里,比如 jsonyaml,或者再进一步,封装成专门的配置类.这样做的核心好处,其实就是把“数据”和“逻辑”分开: 训练脚本负责执行流程,配置文件负责描述这次实验到底要怎么跑.这样你只需要改配置,就可以比较快地调整实验,也更容易复现之前的结果.

这里可以粗略把配置分成两类:

1. 模型配置

模型配置回答的是: 你到底要训练一个什么样的模型?

一般会包括这些内容:

模型基础信息:

  • 模型类型或者模型名称
  • 最大序列长度 max_seq_len

词表和 tokenizer 相关:

  • 词表大小 vocab_size
  • 特殊 token 的定义,比如 bos_ideos_idpad_id
  • tokenizer 通常会和模型一起考虑,所以这里先放在模型配置里一起讨论.

Embedding 相关:

  • 位置编码或者 RoPE 相关参数
  • 普通的 embedding 一般由 hidden_sizevocab_size 定义,但如果有特殊设计,也可能有额外参数

模型结构相关:

  • 隐藏层维度 hidden_size
  • Transformer 层数 num_layers
  • 注意力头数 num_heads
  • 前馈网络维度 intermediate_size
  • 如果项目把每个 attention head 的维度单独写出来,也可能会有 head_dim

其他细节:

  • dropout 相关参数
  • 是否共享输入输出 embedding
  • normalization 相关参数,比如 norm_eps
  • 激活函数类型,比如 GELUSwiGLU

2. 训练配置

一般会包括这些内容:

路径类配置:

  • 数据集路径 (验证集,测试集)
  • tokenizer 路径或者词表路径
  • 日志输出路径
  • checkpoint 保存路径
  • resume checkpoint 路径
  • 配置文件本身的保存路径

实验设置类:

  • 实验名称或者 run id
  • 设备类型,比如 cpucuda
  • 随机种子
  • 是从头训练,还是从已有权重开始训练
  • 日志打印频率、验证频率、checkpoint 保存频率

实验可视化类:

  • 是否开启可视化工具,以及对应的 project / run 配置

训练系统类(工程参数 分布式 混合精度):

  • 分布式训练相关参数,比如 world_sizeranklocal_rank
  • 混合精度类型,比如 float16bfloat16
  • DataLoader 相关参数,比如 num_workerspin_memory

训练参数:

  • 优化器类型,比如 AdamW
  • batch size、梯度累积步数、训练轮数、最大训练步数
  • 学习率、权重衰减、warmup 步数、学习率调度策略
  • 梯度裁剪阈值
  • 是否 shuffle、是否 drop_last

3. 更多细节

还有一些更细节的配置,不一定每个项目都会单独写出来,但在一些实现里也很常见.如果继续往下拆,其实也可以按前面的分类方式来理解:

模型配置里更细节的内容:

  • attention 细节: 有些项目除了 num_heads 之外,还会单独配置 num_kv_heads, 用来区分普通多头注意力、MQAGQA. 这些设置会直接影响 attention 的参数规模、KV cache 的大小,以及推理时的效率.

  • FFN 细节: 有的实现会直接写 intermediate_size, 也有的实现不会直接给这个值,而是通过某种扩展倍率,比如 ffn_mult,去间接确定它. 如果模型用了 SwiGLU 这类结构,FFN 部分的实际维度设计也可能和最基础的 MLP 写法不太一样.

  • 参数初始化: 有些项目会把参数初始化方式也放进配置里,比如初始化标准差、不同层是否采用不同的初始化策略. 这些内容平时不一定最先关注,但它们会影响训练刚开始时的稳定性.

  • norm 和激活函数的进一步细节: 除了前面提到的 norm_eps 和激活函数类型,有的实现还会进一步区分用的是哪一种 normalization. 也有的项目会在不同模块里采用不同的激活函数.

  • 模型默认 dtype: 有些项目会在模型配置里单独写模型权重默认使用什么 dtype,比如 float32bfloat16. 这个设置有时也会影响模型初始化、加载权重和后续训练流程.

训练配置里更细节的内容:

数据相关配置:

  • 训练集路径和验证集路径分别是什么.
  • 数据格式是什么,比如 jsonlbinmemmap.
  • 是否流式读取.
  • 是否 shuffle.
  • num_workers.
  • pin_memory.
  • 是否 drop_last.
  • 是否做 packing.
  • 每条样本的截断 / 拼接策略.

恢复与初始化相关配置:

  • 是从头训练,还是从已有权重开始.
  • from_pretrained 还是 from_scratch.
  • 是否加载 optimizer state.
  • 是否恢复 scheduler state.
  • 是否恢复 scaler state.
  • 恢复到哪个 epoch / step.

保存策略相关配置:

  • checkpoint 是按 step 保存还是按 epoch 保存.
  • 保存间隔是多少.
  • 最多保留多少个 checkpoint.
  • 是否额外保存一个 best checkpoint.
  • 是只保存模型权重,还是保存完整训练状态.

验证 / 评估相关配置:

  • 是否开启验证.
  • 验证频率.
  • 验证集路径.
  • 验证指标.
  • 验证时的 batch size.
  • 如果是生成任务,还可能会带上 max_new_tokenstemperature 这类推理配置.

训练目标相关配置:

  • loss 类型.
  • 是否做 label smoothing.
  • ignore_index 是多少.
  • 是否对某些 token 做 mask.
  • 是否有 auxiliary loss.

系统与复现相关配置:

  • 使用哪些 GPU.
  • cudnn / matmul 精度设置.
  • 除了随机种子之外,是否开启 deterministic.
  • 运行时的代码版本、git commit id.
  • 配置文件本身是否要保存到实验目录里.

从工程角度看,路径类配置尤其值得单独检查.至少在训练开始前,应该确认下面这些路径是否存在、是否可读、是否可写:

  • 训练集路径
  • 验证集路径
  • tokenizer 或词表路径
  • 日志目录
  • checkpoint 输出目录
  • resume checkpoint 路径
  • 配置文件本身的保存路径

这一类检查听起来很基础,但非常有必要.因为如果路径问题不提前处理,训练很可能不是一开始报错,而是跑到中间某个地方才因为找不到文件或者目录不可写而停下来,这样调试起来就很浪费时间.

Q2: Seed怎么设置?

训练里有很多随机性来源,比如参数初始化、数据 shuffle、dropout、CUDA 算子选择等。设置 seed,是保证两次运行复现实验现象的必要条件。

MiniMind 的工具函数里设置了 Python、NumPy、PyTorch 和 CUDA 的随机种子:

# src/minimind_learning/trainer/trainer_utils.py
def setup_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

这里还有一个细节:分布式训练时,不同 rank 会使用略微不同的 seed:

# src/minimind_learning/trainer/train_pretrain.py
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

这样可以避免所有进程的随机状态完全一样。对于分布式训练来说,这通常更合理。

Q3: 分布式训练怎么进行?

分布式训练的目标是把训练任务拆到多张 GPU 上执行,从而提高吞吐量,或者让更大的模型、更大的 batch 能够训练起来.

MiniMind 使用的是 PyTorch 的 DistributedDataParallel, 简称 DDP. 这里很容易有一个误解: DDP 并不是“自动把一个 batch 拆到很多张卡上”,它更像是在每个进程里的模型外面包了一层“梯度同步外壳”.

可以先把它理解成下面这件事:

  • 每张 GPU 对应一个进程.
  • 每个进程里各自放一份完整的模型副本.
  • 每个进程各自拿到自己那一份数据,独立做 forward 和 backward.
  • backward 结束之后,DDP 会自动把不同进程上的梯度同步起来.
  • 梯度同步完成后,每个进程再各自执行 optimizer.step(),于是所有模型参数仍然保持一致.

所以从职责上看:

  • DistributedSampler 负责“把数据分开”.
  • DistributedDataParallel 负责“把梯度同步回来”.

Q3.1: 分布式训练的初始化代码有哪些?

结合 MiniMind 的代码来看,PyTorch 分布式训练的主线其实没有多神秘,主要就是先把分布式环境初始化好,然后把数据集和模型分别接到分布式机制里:

# src/minimind_learning/trainer/train_pretrain.py
# 第一步 初始化分布式环境,并设置设备和随机种子
local_rank = init_distributed_mode()
if dist.is_initialized():
    args.device = f"cuda:{local_rank}"
setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

# 第二步 创建 sampler
train_ds = PretrainDataset(args.data_path, tokenizer, max_length=args.max_seq_len)
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None

# 第三步 包装模型
if dist.is_initialized():
    model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
    model = DistributedDataParallel(model, device_ids=[local_rank])

这里可以按顺序理解:

  1. 先初始化分布式环境.

init_distributed_mode() 会先检查当前是不是分布式模式.如果环境变量里没有 RANK,那就说明现在不是 DDP 训练,直接按单卡处理.如果检测到是分布式模式,它就会调用 dist.init_process_group(...) 建立通信组,然后读取 LOCAL_RANK,再通过 torch.cuda.set_device(local_rank) 把当前进程绑定到对应的 GPU 上.

这一步很关键,因为后面每个进程都必须明确知道: “我到底在用哪一张卡”.

  1. 再设置设备和随机种子.

初始化完分布式环境之后,代码里会把当前进程的设备写成:

if dist.is_initialized():
    args.device = f"cuda:{local_rank}"

然后再设置 seed:

setup_seed(42 + (dist.get_rank() if dist.is_initialized() else 0))

这里不是简单地让所有进程用完全相同的 seed,而是让不同 rank 的 seed 略微错开.这样做的目的是避免所有进程的随机状态完全一样.对于分布式训练来说,这种写法通常更合理.

  1. DistributedSampler 把数据分给不同进程.
train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None

如果不开分布式,数据集就正常顺序或 shuffle 读取.如果开了分布式,DistributedSampler 会负责把整个数据集切分给不同进程,避免每张卡都重复训练同一批数据.

  1. DistributedDataParallel 把模型包起来.
if dist.is_initialized():
    model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
    model = DistributedDataParallel(model, device_ids=[local_rank])

这一句 DistributedDataParallel(...) 就是在告诉 PyTorch: 这个模型现在要进入分布式同步模式了.

其中 device_ids=[local_rank] 的意思是,当前这个进程只负责当前这张 GPU 上的模型副本.后面在 backward 的时候,DDP 会自动帮我们把不同进程上的梯度做同步,通常可以理解成一次 all-reduce.

这里的

model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}

说明有一些 buffer 不希望被 DDP 按默认方式处理,所以提前显式忽略掉了.这也说明分布式训练虽然主线不复杂,但具体实现里还是会有一些模型相关的细节.

Q3.2: DistributedSampler 在每个 epoch 里是怎么工作的?

DistributedSampler,顾名思义,就是用来在分布式训练里给数据集配一个 sampler 的. 它的功能就是让每一个进程分布式的从总体数据集里面拿自己的部分.

先看这句:

train_sampler = DistributedSampler(train_ds) if dist.is_initialized() else None

它的意思其实很直接:

  • 如果当前是分布式训练,就给数据集配一个 DistributedSampler.
  • 如果当前不是分布式训练,那 train_sampler 就是 None.

然后在训练循环里,每个 epoch 开始前会调用:

train_sampler and train_sampler.set_epoch(epoch)

这句可以理解成:

if train_sampler is not None:
    train_sampler.set_epoch(epoch)

所以如果现在不是分布式环境,train_sampler 就是 None,这句代码什么都不会做.如果现在是分布式环境,那就会调用 set_epoch(epoch),让 sampler 在新的 epoch 使用新的 shuffle 顺序.

这一点很重要,因为 DistributedSampler 本身既负责“把数据按进程拆开”,也经常负责“在分布式场景下怎么 shuffle”.如果不在每个 epoch 里调用 set_epoch(epoch),那每个 epoch 的采样顺序可能就固定住了.

再看真正创建 DataLoader 的地方:

loader = DataLoader(
    train_ds,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    sampler=train_sampler,
    num_workers=args.num_workers,
    pin_memory=True
)

这里的逻辑是:

  • 如果 train_sampler is None,说明当前不是分布式训练,那就让 DataLoader 自己做普通的 shuffle.
  • 如果 train_sampler 不是 None,说明当前是分布式训练,那就由 DistributedSampler 来决定每个进程读哪些样本,这时 DataLoader 自己就不再额外 shuffle 了.

所以从代码上看,非分布式和分布式其实共用的是同一套训练流程,只是数据采样这一步根据 train_sampler 是否存在,自动切换成了不同模式.

Q3.3: train_epoch() 是每张卡各跑各的,还是同步后统一跑的?

答案是: 每个进程各自跑一遍 train_epoch().

如果有 4 张 GPU,通常就会有 4 个进程.这 4 个进程都会各自执行:

train_epoch(epoch, loader, ...)

所以并不是“主进程负责训练,其他进程等结果”,而是:

  • 每个进程各自有自己的 DataLoader
  • 每个进程各自拿自己的 batch
  • 每个进程各自有自己的模型副本
  • 每个进程都独立执行 forward、loss 计算和 backward

真正发生同步的地方,不是整个 train_epoch() 跑完之后,而是在 backward 阶段.

Q3.4: 梯度累积是在各自显卡上做,还是统一同步后做?

更准确地说,梯度是先在各自进程本地累积,但每次 backward() 时就已经发生了同步.

也就是说:

  • 每个进程先拿自己的 micro-batch
  • 每个进程本地做 forward
  • 每个进程本地做 backward
  • backward 的过程中,DDP 自动把梯度和其他进程同步
  • 同步后的梯度继续累积在各自本地参数的 .grad
  • 等到满足 accumulation_steps,每个进程再各自执行一次 optimizer.step()

所以它不是“先各自累积很多步,最后再统一同步”,而是“每一步 backward 都会同步,同步后的结果再继续本地累积”.

这份 MiniMind 代码里也没有用 no_sync(),所以当前这份实现里,每次 backward() 都会触发同步.这样写更直接,也更容易理解,只是通信开销会更大一些.

优化的办法是在累积的过程中,只在最后一步 backward() 时才同步,前面几步都用 with model.no_sync(): 来跳过同步. 这样可以减少通信开销,但代码会稍微复杂一些.

Q3.5: DDP 的同步细节和执行流程到底是什么?

这里最容易混淆的一点是: DDP 不是“主进程把 batch 发给其他显卡,其他显卡再把梯度传回主进程”.

更准确的流程是:

  1. 每个进程自己通过 DataLoader 读取自己的 batch.
  2. 每个进程自己执行 forward().
  3. 每个进程自己计算 loss.
  4. 每个进程自己调用 backward().
  5. backward() 阶段,DDP 自动把各个进程上的梯度做同步,通常可以理解成一次 all-reduce.
  6. 同步完成后,每个进程本地都会拿到一致的梯度.
  7. 每个进程再各自调用 optimizer.step(),完成本地参数更新.

所以真正负责“同步”的关键函数是 backward(),更准确地说,是 DDP 在 backward 过程中注册的那些梯度同步逻辑.

optimizer.step() 本身并不是用来做跨进程同步的,它做的事情主要是: 使用已经同步好的梯度,在每个进程本地更新参数.

之所以每个进程都可以各自 step(),但最后模型还能保持一致,原因就在于:

  • 初始参数是一致的
  • backward 之后梯度是一致的
  • optimizer 的更新规则也是一样的

所以更新完成后,每个进程上的参数仍然会保持一致.

你也可以把这件事压缩成一句话:

  • DistributedSampler 负责决定“每个进程读哪一份数据”.
  • DDP 负责决定“每个进程算出来的梯度怎么在 backward 阶段同步”.
  • optimizer.step() 负责用已经同步好的梯度,在每个进程本地更新参数.

Q3.6: DDP 的局限是什么? 更大的模型又该怎么训练?

到这里其实就能看出 DDP 的一个核心限制: 它解决的是“数据并行”的问题,但默认并不解决“单张显卡放不下完整模型”的问题.

因为在 DDP 里:

  • 每张卡上都要放一份完整的模型副本
  • 每个进程都会算出自己这一份梯度
  • backward() 阶段,这些梯度还要做同步

所以模型一旦变大,会同时遇到两个压力:

  1. 显存压力
    单张卡必须先能放下一整份模型,否则 DDP 连启动都很困难.

  2. 通信压力
    梯度的规模通常和模型参数规模是同一个量级.模型越大,每次 backward() 时要同步的内容也就越多.如果卡很多、网络带宽又不够强,那通信就可能成为瓶颈.

所以你可以粗略地把 DDP 理解成: 模型能放下时,它是一个很好用的数据并行方案; 模型放不下时,就不能只靠 DDP 了.

这时候更大的模型通常会用另外几类并行方案:

  • FSDP
    FullyShardedDataParallel 的核心思路是: 不再让每张卡都完整保存一份模型,而是把参数、梯度、优化器状态分片到不同设备上.
    这样做的直接好处就是显存占用显著下降,所以它适合“模型太大,单卡放不下”的场景.
    从 PyTorch 官方文档的描述来看,FSDP 本质上就是把 DDP 那种“整份复制”的方式,换成了“分片保存、按需聚合”的方式.

  • Tensor Parallel
    它不是把“数据”拆开,而是把“层内部的计算”拆开.
    比如一个很大的线性层,可以按列切到不同 GPU,或者按行切到不同 GPU.这样每张卡只负责这个层的一部分参数和计算.
    这种方式更像是在“一个算子内部做并行”,比较适合超大矩阵运算.

  • Pipeline Parallel
    它的思路是把模型按层切成几段,不同 GPU 分别放不同的 stage.
    比如前几层在 GPU0,中间几层在 GPU1,后几层在 GPU2.
    然后把一个 batch 再拆成多个 micro-batch,像流水线一样穿过这些 stage.

所以如果只做一个很粗的划分:

  • DDP: 每张卡一份完整模型,主要解决“如何并行处理更多数据”
  • FSDP: 把参数分片,主要解决“模型太大,单卡放不下”
  • Tensor Parallel: 把层内部的张量计算拆开
  • Pipeline Parallel: 把模型按层切成不同 stage

工业级大模型训练,通常不会只用其中一种,而是把这些并行策略组合起来.

一个很常见的思路是:

  • 数据并行负责把样本分到不同 worker
  • FSDP 或 ZeRO 类方案负责降低参数、梯度、优化器状态的显存占用
  • Tensor Parallel 负责把单层内部的大矩阵计算拆到多张卡
  • Pipeline Parallel 负责把整个模型沿层的方向切成多个阶段

所以工业级大模型训练,确实往往不是“只写一个普通 PyTorch 脚本就够了”,而是需要一整套更复杂的训练系统去协调这些并行策略.这也是为什么很多大模型项目最后都会发展出比较专门的训练框架或者训练基础设施.

不过从 PyTorch 现在的生态来看,也不能简单说“PyTorch 不支持这些”.更准确地说是:

  • DDP 是 PyTorch 里最成熟、最直接的数据并行方案
  • FSDP 已经是 PyTorch 官方支持的重要路线
  • Tensor Parallel 和 Pipeline Parallel 也已经有官方能力,只是相对更复杂,有些接口还比较新
  • 真正的工业级方案,往往会在 PyTorch 之上再叠一层自己的并行封装和训练框架

所以这一块如果要压缩成一句话,可以这样理解:

小模型或者中等规模模型,DDP 往往就够用了.
更大的模型,尤其是单卡放不下的时候,就必须从“只做数据并行”走向“数据并行 + 参数分片 + 模型并行”的组合方案.

Q4: 混合精度训练怎么做?

混合精度训练的目标是用更低精度的数据类型提升训练速度、降低显存占用,同时尽量保持训练稳定。

MiniMind 里通过 autocast 创建混合精度上下文:

# src/minimind_learning/trainer/train_pretrain.py
device_type = "cuda" if "cuda" in args.device else "cpu"
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)

with autocast_ctx:
    res = model(X)
    loss = loss_fct(
        res.logits.view(-1, res.logits.size(-1)),
        Y.view(-1)
    ).view(Y.size())

autocast 的作用是让 PyTorch 自动决定哪些操作可以用低精度计算,哪些操作应该保留较高精度。这样比手动把所有张量都转成 float16bfloat16 更安全。本质上可以理解成一个“自动混合精度的上下文环境”. 也就是说,只要代码写在

with autocast_ctx:
    ...

这个块里面,PyTorch 就会按照自己的规则,自动决定当前这些算子应该用什么精度来算.

所以它做的事情不是“粗暴地把所有张量都变成 float16bfloat16”,而是:

  • 对适合低精度的算子,尽量用低精度计算
  • 对数值更敏感的算子,保留更高精度

这样做的好处是:

  • 速度通常会更快
  • 显存占用通常会更低
  • 同时又比手动把所有内容都转成低精度更安全

这里也可以顺手解释一下这句:

autocast_ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast(dtype=dtype)

它的意思是:

  • 如果当前在 CPU 上,那就用 nullcontext(),相当于“什么都不做”
  • 如果当前在 GPU 上,那就真正启用 autocast

所以 autocast_ctx 可以理解成一个统一接口. 这样后面训练代码就不用专门写两套分支,而是统一写成:

with autocast_ctx:
    ...

autocast 在这里真正起作用的地方,主要是 forward 和 loss 计算. 它负责让这些计算尽量以合适的低精度进行,从而达到“混合精度训练”的效果.

Q5: GradScaler 到底在缩放什么?

GradScaler 主要用于 float16 混合精度训练。因为 float16 的数值范围较小,梯度太小时可能下溢成 0,导致训练不稳定。它和梯度剪裁不是一件事情

它的核心思路是: 先把 loss 放大,这样反向传播时得到的梯度也会一起被放大,从而尽量避免过小的梯度在 float16 下直接下溢成 0.

对应代码就是:


# src/minimind_learning/trainer/train_pretrain.py
# 这里的 `enabled=(args.dtype == 'float16')` 就是说,只有当我们使用 `float16` 时才启用 `GradScaler`.
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16')) #自动缩放 loss,防止梯度下溢

# 反向传播时,先用 `scaler.scale(loss)` 把 loss 放大,再调用 `backward()`.
scaler.scale(loss).backward()

所以更准确地说,GradScaler 表面上是在 scale loss,但它真正想保护的是 backward 过程中产生的梯度.

等真正更新参数之前,再把梯度恢复到正常尺度:

scaler.unscale_(optimizer)

这一步之后,参数上的 .grad 才重新回到“真实梯度大小”. 也正因为如此,如果后面还要做梯度裁剪,就必须放在 unscale_ 之后.

然后才能进行梯度裁剪:

torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

最后再通过 scaler.step()scaler.update() 执行更新并调整缩放因子:

scaler.step(optimizer)
scaler.update()

所以这一整套顺序其实是:

  1. scale(loss)
    先把 loss 放大
  2. backward()
    让放大后的 loss 参与反向传播
  3. unscale_(optimizer)
    在真正更新前,把梯度恢复到正常尺度
  4. clip_grad_norm_(...)
    如果要做梯度裁剪,在真实梯度上裁剪
  5. scaler.step(optimizer)
    执行参数更新
  6. scaler.update()
    动态调整缩放因子

这套顺序很重要. 如果要做梯度裁剪,应该先 unscale_,再 clip_grad_norm_. 否则裁剪的就是被放大后的梯度,阈值就失去了原本的意义.

Q6:scaler.scale(loss).backward() 应该放在 autocast 里面吗?

一般来说,autocast 包住 forward 和 loss 计算,而 backward 放在 autocast 外面。

MiniMind 里的写法是:

with autocast_ctx:
    res = model(X)
    loss = loss_fct(
        res.logits.view(-1, res.logits.size(-1)),
        Y.view(-1)
    ).view(Y.size())
    loss = (loss * loss_mask).sum() / loss_mask.sum()
    loss = loss / args.accumulation_steps

scaler.scale(loss).backward()

这样写是合理的. autocast 主要影响的是 forward 过程和 loss 计算过程中的算子精度选择. 反向传播会沿着 forward 过程中记录下来的计算图进行,它并不需要再额外包在 autocast 里面.

所以这里可以把两者的分工理解成:

  • autocast 负责决定 forward 和 loss 计算时“用什么精度算”
  • GradScaler 负责在 backward 前后处理梯度缩放,尽量避免 float16 下的梯度下溢

如果把整个过程串起来,更完整的逻辑其实是:

  1. 进入 autocast_ctx
  2. 做 forward
  3. 计算 loss
  4. 退出 autocast_ctx
  5. scaler.scale(loss).backward() 做反向传播
  6. scaler.unscale_(optimizer) 恢复真实梯度
  7. 如果需要,再做梯度裁剪
  8. 最后 scaler.step()scaler.update()

autocastGradScaler 就是配合关系,而不是二选一的关系:

  • autocast 解决“前向计算尽量安全地用低精度”
  • GradScaler 解决“反向传播时低精度梯度可能下溢”
  • grad_clip 解决“梯度可能过大导致训练不稳定”

三者的配合顺序至关重要

Q6.1 Minimind 中完整代码片段:

# src/minimind_learning/trainer/train_pretrain.py

def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
    '''
        epoch: 有 num_epochs 个 epoch   
        steps: 每个 epoch 有 steps_per_epoch 个 batch
        iters: iters = steps_per_epoch 最大迭代步数

    '''
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
        # B batch size L seq_len
        X = X.to(args.device) #[B,L-1] #去掉最后一个 token的index  
        Y = Y.to(args.device) #[B,L-1] #去掉第一个 
        loss_mask = loss_mask.to(args.device) #[B,L-1]
        # 手动修改LR
        lr = get_lr(epoch * iters + step, args.epochs * iters, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with autocast_ctx:
            res = model(X) #CausalLMOutputWithPast  #[batch_size, seq_len, vocab_size (logit/raw score)]
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)), # [batch_size * seq_len ,vocab_size]
                Y.view(-1) #[batch_size * seq_len]
            ).view(Y.size()) #[batch_size , seq_len]

            loss = (loss * loss_mask).sum() / loss_mask.sum()
            if res.aux_loss : loss += res.aux_loss
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        # accumulation_steps 原本大batch拆成小batch 但是为了保证梯度的稳定性(降低方差) 还是需要用大batch估计梯度 所以把小batch的梯度保留然后拼回去
        # 相当于一个正常batch结束了 要更新一下
        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer) #作用:把梯度从 GradScaler 的缩放状态恢复到正常大小。在混合精度训练中,scaler.scale(loss).backward() 会把梯度放大,以避免 fp16 下的数值下溢。在做梯度裁剪 或其他需要真实梯度值的操作前,必须先调用 unscale_。
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) #对梯度进行 裁剪,防止梯度爆炸。
            scaler.step(optimizer) # :执行一次参数更新。和普通的 optimizer.step() 不同,scaler.step() 会检查梯度是否为 NaN 或 Inf(数值不稳定)。
            scaler.update() #动态调整缩放因子。如果梯度稳定,GradScaler 会逐步增大缩放因子,提高精度利用率。如果出现溢出(NaN/Inf),它会减小缩放因子,保证安全。

            optimizer.zero_grad(set_to_none=True)#清空梯度,为下一次迭代做准备。set_to_none=True 会把 .grad 设为 None 而不是 0,这样更节省显存和计算开销。下次反向传播时,PyTorch 会重新分配梯度张量。
            

Q7: 为什么需要梯度裁剪?

训练深层网络时,梯度有时会突然变得非常大,这就是常说的梯度爆炸。梯度爆炸会让参数更新过猛,轻则 loss 剧烈波动,重则直接出现 NaNInf

梯度裁剪的思想是限制梯度范数。如果梯度整体范数超过阈值,就把它缩放回阈值范围内。

MiniMind 默认的裁剪阈值是:

parser.add_argument("--grad_clip", type=float, default=1.0, help="梯度裁剪阈值")

对应训练循环中的:

torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

它不会改变梯度方向,只会限制梯度大小。因此它更像是给训练过程加了一个安全阀。

Q8: 为什么要做梯度累积?

梯度累积的目标是模拟更大的 batch size。

如果显存不够,我们无法一次放入很大的 batch。但可以把一个大 batch 拆成多个小 batch,分别 forward 和 backward,让梯度先累积在参数的 .grad 里,等累积到一定步数后再更新一次参数。

这部分的实践技术细节,和理论方面的内容,我们分别在上一小节的分布式训练,和上一章节Effective Batch size也有涉及.相关的内容可以先参阅前述章节.

MiniMind 里默认:

parser.add_argument("--accumulation_steps", type=int, default=8, help="梯度累积步数")

训练时先把 loss 除以累积步数:

loss = loss / args.accumulation_steps
scaler.scale(loss).backward()

然后每隔 accumulation_steps 才执行一次 optimizer step:

if (step + 1) % args.accumulation_steps == 0:
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)
loss = loss / args.accumulation_steps

特别注意,这里需要把 loss 除以 accumulation_steps 是为了让累积后的梯度尺度接近真正的大 batch 平均梯度。如果不除,累积后的梯度会被放大,等价于改变了学习率。

Q8.1: PyTorch 里的 backward()optimizer.step()zero_grad() 分别是什么意思? 它们又是怎么配合完成梯度累积的?

这一组操作几乎是所有 PyTorch 训练循环里最核心的三步,但它们各自负责的事情其实不一样:

  • backward()
    负责反向传播,并把当前这一步算出来的梯度累加到参数的 .grad 上.
  • optimizer.step()
    负责读取当前参数上的 .grad,然后按照优化器规则去更新参数.
  • optimizer.zero_grad()
    负责把参数上旧的梯度清掉,让下一轮计算从干净状态开始.

这里一个很重要的细节是: PyTorch 默认不会在每次迭代后自动清空梯度.也就是说,如果你连续调用多次 backward(),而中间不 zero_grad(),那么新的梯度就会继续累加到旧的 .grad 上.

这也是梯度累积能成立的根本原因.

所以从机制上看,梯度累积并不是什么额外的黑盒功能,而是利用了 PyTorch 默认“梯度会累加,不会自动清空”这个特性.

可以把一个最基本的训练循环理解成:

  1. loss.backward()
  2. optimizer.step()
  3. optimizer.zero_grad()

而梯度累积做的事情,其实就是把“每次 backward 之后立刻更新参数”这件事延后:

  1. 多次调用 backward()
  2. 中间先不 optimizer.step()
  3. 也先不 optimizer.zero_grad()
  4. 让多次 backward 的结果都累积在 .grad
  5. 等累积够了,再执行一次 optimizer.step()
  6. 最后再 optimizer.zero_grad()

也就是说,梯度累积不是因为“不调用 optimizer.step() 就会自动保留梯度”,而是因为 PyTorch 本来就不会自动清空梯度,真正负责清空的是 zero_grad().

所以更准确地说:

  • backward() 负责“把梯度加上去”
  • optimizer.step() 负责“拿当前梯度去更新参数”
  • zero_grad() 负责“把旧梯度清掉”

Q8.2: 梯度累积会带来不稳定吗? 如果会,通常怎么避免?

梯度累积本身不等于梯度爆炸,但如果处理不当,它确实会让训练更容易不稳定.

最常见的风险有两个:

  1. 没有把 loss 除以 accumulation_steps

如果你想累积 8 步,但每一步都直接对原始 loss 做 backward,那最后参数上的 .grad 大约就会变成 8 个 micro-batch 梯度的和. 这样梯度尺度会明显变大,效果上很像学习率也被一起放大了,训练就更容易抖动甚至发散.

所以通常要写:

loss = loss / args.accumulation_steps

这样累积完成之后,梯度尺度会更接近真正的大 batch 平均梯度.

  1. 梯度累积后,更大的梯度会在 .grad 里持续保留

因为梯度累积的过程就是不断把多步的梯度往 .grad 里加,所以如果某一步梯度本身就特别大,它也会被继续保留下来. 在模型本来就不稳定、学习率偏大,或者混合精度下数值不太稳的时候,这种影响会更明显.

通常的避免办法有这几个:

  • 把 loss 除以 accumulation_steps,控制梯度尺度
  • 在真正更新参数之前做梯度裁剪
  • 混合精度训练时,先 unscale_() 再做梯度裁剪
  • 不要把学习率设置得过大

这里还要注意一个顺序问题: 梯度裁剪通常应该放在“累积完成之后”,而不是每个 micro-batch 的 backward 之后.

也就是说,更常见的做法是:

  1. 多次 backward,先把梯度累积起来
  2. 累积完成后,先 unscale_(optimizer)
    如果使用了混合精度
  3. 对最终要用于更新参数的总梯度做 clip_grad_norm_(...)
  4. 再执行 optimizer.step()

原因在于,我们真正想控制的,是“这一次参数更新实际会用到的总梯度”,而不是某一个 micro-batch 的局部梯度. 如果每一步都先裁一遍再去累积,那得到的就不再是原本那个大 batch 梯度的近似了.

Q9: 为什么 optimizer、learning rate 和 batch size 不只算工程参数?

优化器、学习率和数据设置决定了参数怎样更新,也决定了每次更新用什么样的数据估计梯度。它们比 checkpoint、可视化更接近训练理论,所以单独放到上一节讨论。

在实际看训练代码时,可以先把这里当作一个连接点:本节继续关注训练能不能稳定执行,而优化器、学习率、数据集和 batch size 的理论含义,会在上一节中单独展开。

Q10: 如何恢复训练, checkpoint 保存什么?

训练不是一次性跑完就结束的。尤其是 pretrain,常常需要中断、恢复、换机器、换 GPU 数量。因此 checkpoint 很重要。

先从更一般的角度看,如果我们希望训练能够真正从中断处继续下去,那通常至少要保存下面这些东西:

  1. 模型权重

这是最基本的一项. 不管是推理还是继续训练,都首先需要模型当前的参数.

  1. 优化器状态

如果只恢复模型权重,但不恢复优化器状态,那训练虽然还能继续跑,但优化器内部的动量、二阶统计量这些信息都丢了. 这样后续的训练轨迹通常就和中断前不一致了.

  1. 混合精度相关状态

如果训练里用了 GradScaler,那最好也把 scaler.state_dict() 一起保存下来. 否则恢复训练后,梯度缩放会重新从头开始,数值行为可能和中断前不一致.

  1. 训练进度

最常见的是保存 epochstep. 这样恢复训练时,至少可以知道是从哪个 epoch、哪个 step 继续.

  1. 数据进度

这一点有时容易被忽略. 因为训练恢复不只是“模型接着算”,还包括“数据读到哪里了”.

最简单的情况下,epoch + step 往往已经够用了. 但如果数据读取方式更复杂,比如 sampler 状态更复杂、数据是流式读取的、或者中间做了 packing,那严格来说还可能需要保存更完整的数据管线状态.

  1. 其他训练状态

比如分布式训练时的 world_size,可视化实验的 run id,以及本次训练的配置快照. 这些内容虽然不直接参与参数更新,但对恢复实验和复现实验都很重要.

所以也可以换个角度理解:

  • 如果只是为了推理或者后续加载模型,保存模型权重就够了
  • 如果是为了真正恢复训练,那就应该保存完整训练状态,而不只是模型参数

MiniMind 这里实际上保存了两类东西:

  • 半精度模型权重:方便推理或后续加载。
  • resume checkpoint:保存模型、优化器、训练进度、可视化 run id 等,用于继续训练。

对应代码在 lm_checkpoint 中:

# src/minimind_learning/trainer/trainer_utils.py
resume_data = {
    'model': state_dict,
    'optimizer': optimizer.state_dict(),
    'epoch': epoch,
    'step': step,
    'world_size': dist.get_world_size() if dist.is_initialized() else 1,
    'wandb_id': wandb_id
}

这里只截了一部分核心字段,但结合完整代码来看,MiniMind 的保存逻辑其实分成两层:

第一层是保存模型权重:

state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
torch.save({k: v.half() for k, v in state_dict.items()}, ckp_tmp)

这里保存的是一份半精度权重. 它更适合“模型加载”这个目的,比如推理或者后续再继续作为初始权重使用.

第二层是保存 resume checkpoint:

resume_data = {
    'model': state_dict,
    'optimizer': optimizer.state_dict(),
    'epoch': epoch,
    'step': step,
    'world_size': dist.get_world_size() if dist.is_initialized() else 1,
    'wandb_id': wandb_id
}
for key, value in kwargs.items():
    if value is not None:
        if hasattr(value, 'state_dict'):
            resume_data[key] = value.state_dict()
        else:
            resume_data[key] = value

这里比较关键的一点是: resume_data 不只保存模型和优化器,它还会继续把 kwargs 里传进来的训练状态一起保存. 在 train_pretrain.py 里,调用 lm_checkpoint(...) 时也把 scaler 传进去了,所以恢复训练时混合精度状态也能一起恢复.

所以从实际效果看,MiniMind 这里保存的内容主要包括:

  • 模型权重
  • 优化器状态
  • GradScaler 状态
  • epoch
  • step
  • world_size
  • 可视化 run id

然后再看加载逻辑:

ckp_data = lm_checkpoint(lm_config, weight=args.save_weight, save_dir='../checkpoints') if args.from_resume==1 else None

if ckp_data:
    model.load_state_dict(ckp_data['model'])
    optimizer.load_state_dict(ckp_data['optimizer'])
    scaler.load_state_dict(ckp_data['scaler'])
    start_epoch = ckp_data['epoch']
    start_step = ckp_data.get('step', 0)

这里做的事情也很清楚:

  • 恢复模型参数
  • 恢复优化器状态
  • 恢复 GradScaler 状态
  • 恢复 epochstep

也就是说,它恢复的不是“某一份权重文件”,而是一整套训练状态.

还有一个细节值得注意. MiniMind 并没有直接把 sampler 的完整状态保存下来,而是主要依靠 epoch + step 来恢复训练进度,并在恢复时通过 SkipBatchSampler 跳过前面的 batch.

这说明它对“数据进度”的恢复思路是:

  • epoch 确定恢复到第几个训练轮次
  • step 确定当前 epoch 里已经跑到了哪里
  • 再通过跳过前面 batch 的方式,近似恢复数据读取位置

这种做法在很多普通训练脚本里已经够用了,实现上也比较简单. 但如果数据管线特别复杂,那严格来说,还可能需要更完整的数据状态恢复机制.

这里还可以单独补充一下 world_size 的作用. 在 MiniMind 这份代码里,world_size 并不是训练主流程里持续依赖的状态,它主要是在恢复 checkpoint 时,用于处理“保存时的 GPU 数量”和“恢复时的 GPU 数量”不一致的情况:

saved_ws = ckp_data.get('world_size', 1)
current_ws = dist.get_world_size() if dist.is_initialized() else 1
if saved_ws != current_ws:
    ckp_data['step'] = ckp_data['step'] * saved_ws // current_ws

也就是说,MiniMind 保存 world_size,主要是为了在 GPU 数量发生变化时,对 step 做一个近似换算. 它并没有进一步恢复更完整的分布式数据状态,所以这更像是一种实用的工程补偿,而不是严格意义上的数据进度恢复.

Q10.1: MiniMind 在单卡和分布式场景下,是怎么恢复数据进度的?

先直接看训练循环里的代码:

for epoch in range(start_epoch, args.epochs):
    train_sampler and train_sampler.set_epoch(epoch)
    if epoch == start_epoch and start_step > 0:
        batch_sampler = SkipBatchSampler(
            train_sampler or range(len(train_ds)),
            args.batch_size,
            start_step + 1
        )
        loader = DataLoader(
            train_ds,
            batch_sampler=batch_sampler,
            num_workers=args.num_workers,
            pin_memory=True
        )
        train_epoch(epoch, loader, len(loader) + start_step + 1, start_step, wandb)
    else:
        loader = DataLoader(
            train_ds,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            num_workers=args.num_workers,
            pin_memory=True
        )
        train_epoch(epoch, loader, len(loader), 0, wandb)

这段代码的思路其实很清楚:

  • 先通过 start_epoch 决定从第几个 epoch 开始恢复
  • 如果是恢复训练,并且当前 epoch 里已经跑过一部分 step,那就通过 SkipBatchSampler 跳过前面的 batch
  • 如果不是恢复训练,那就正常创建 DataLoader

这里有一个写法很巧:

train_sampler or range(len(train_ds))

它的意思是:

  • 如果当前是分布式训练,那就用 train_sampler
  • 如果当前是单卡训练,train_samplerNone,那就退回到普通的顺序索引 range(len(train_ds))

所以 MiniMind 其实用同一套恢复逻辑,同时兼容了单卡和分布式两种场景:

  • 单卡时,根据 epoch + step 跳过前面已经训练过的 batch
  • 分布式时,先由 DistributedSampler 决定当前进程应该读取哪些样本,再在这个基础上继续跳过已经训练过的 batch

从工程上看,这是一个比较实用的方案. 它不需要保存完整的 sampler 内部状态,也不需要保存 DataLoader 的全部运行时信息,只要保存:

  • epoch
  • step
  • 当前训练时的 world_size

再配合 SkipBatchSampler,通常就已经能把训练大致恢复到之前的位置.

不过它的局限也很明显: 这种恢复方式更像是在“近似恢复训练位置”,而不是“精确恢复整个数据管线状态”:

MiniMind 的数据恢复不是严格意义上的 sampler 状态恢复。它主要以 epoch 为锚点,在恢复时先通过 set_epoch(epoch) 重建当前轮次的数据顺序,再结合 step 和 SkipBatchSampler 跳过前面的 batch。因此它比单纯的 epoch-level 恢复更细一些,但仍然更接近一种“基于 epoch 的 step 级近似恢复”,而不是完全精确的 batch-level 恢复。

如果数据管线更复杂,通常还可能需要额外保存下面这些信息:

  • sampler 的随机状态
  • 当前 epoch 对应的 shuffle 顺序
  • 流式数据读取时的游标位置
  • packing / buffer 中还没消费完的样本片段
  • 数据增强或随机裁剪使用的随机数状态

可以举一个更复杂的例子. 假设你的数据不是普通的 map-style dataset,而是一个流式读取的文本管线:

  • 数据来自多个 shard
  • 每个 shard 内部还在不断顺序读取
  • 中间会做 document packing
  • packing buffer 里可能还残留一些还没拼完的 token

这种情况下,如果只保存 epoch + step,恢复训练时虽然知道“大概跑了多少步”,但并不知道:

  • 当前到底读到了哪个 shard
  • shard 内部读到了哪一条样本
  • packing buffer 里还剩下哪些 token

这时更完整的做法通常是把数据管线自己的状态也做成一个 state_dict,例如:

resume_data = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scaler": scaler.state_dict(),
    "epoch": epoch,
    "step": step,
    "sampler_state": sampler.state_dict(),
    "data_cursor": datapipe.cursor_state(),
    "packing_buffer": packer.state_dict(),
}

恢复时再分别加载这些状态:

sampler.load_state_dict(ckp_data["sampler_state"])
datapipe.restore_cursor(ckp_data["data_cursor"])
packer.load_state_dict(ckp_data["packing_buffer"])

这样做的代价是实现更复杂,但好处是恢复训练时会更接近真正的“无缝续训”.

所以这一节最后可以压缩成一句话:

checkpoint 不只是“把模型存下来”,更重要的是把“继续训练所需的状态”一起存下来. 如果只保存模型权重,那更像是保存了一份模型快照; 如果连优化器、scaler、epoch、step 这些都一起恢复,那才更接近真正意义上的断点续训.

Q11: Epoch、Step、Iteration, Batch 这些概念是怎么对应的?

这一组概念在训练代码里非常常见,但不同项目、不同文章里的叫法并不总是完全一致. 有的人把 step 当成“读了一个 batch”,有的人把 step 当成“做了一次参数更新”; 有的人把 iterationstep 混用,有的人又会单独区分 global stepmicro stepupdate step.

所以这一节最好先把这些术语尽量定义清楚,再放回训练循环里看它们是怎么运转的.

1. 先定义最基本的几个概念

sample

最小的数据单位. 比如一条训练样本、一段文本、一个句子对,都可以看成一个 sample.

batch

一次 forward / backward 里一起送进模型的一组 sample.
如果 batch_size=32,那通常就是一次送 32 条 sample 进去.

不过在大模型训练里,batch 这个词有时并不够精确,因为它可能指的是:

  • 单卡上的一个 local batch
  • 梯度累积里的一个 micro-batch
  • 所有卡合起来的 global batch

所以只说 batch 时,最好结合上下文判断.

epoch

指“把整个训练集大致跑一遍”.
如果数据集大小固定,一个 epoch 通常表示所有样本都被遍历过一次.

在分布式训练里,更准确地说,是所有进程配合起来,把这一轮数据各自处理完,合起来相当于把整个数据集跑过一遍.

iteration

这个词通常指训练循环中的“一次迭代”.
很多时候它和“读一个 batch,做一次 forward/backward”是对应的.

但要注意,它的用法并不绝对统一. 有些代码里会把 iteration 直接等同于 step,有些地方又会单独区分.

step

这是最容易混的一个词.

在很多训练代码里,step 经常有两种常见用法:

  1. 表示“训练循环往前走了一次”,也就是处理了一个 batch
  2. 表示“参数真正更新了一次”,也就是执行了一次 optimizer.step()

所以看代码时,一定要结合上下文判断它到底指的是哪一种.

2. 放到 PyTorch 训练循环里看,这些概念通常怎么对应?

先看一个最基础的训练循环:

for epoch in range(num_epochs):
    for step, batch in enumerate(loader):
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

在这段最简单的代码里:

  • 外层 for epoch ... 对应 epoch
  • 内层 for step, batch ... 的每一次循环,通常可以看成一次 iteration
  • 这里的 step 也通常等于“一次 batch 迭代”
  • 因为每次循环都调用了 optimizer.step(),所以这里的 step 也同时等于“一次参数更新”

也就是说,在没有梯度累积时:

  • 一个 iteration 通常处理一个 batch
  • 一个 batch 通常对应一次 backward
  • 一次 backward 通常也就对应一次 optimizer.step()

这时候很多人把 iterationstepupdate step 混着说,问题通常也不大.

3. 一旦有了梯度累积,这些概念就不能再混着用了

假设现在有:

  • batch_size = 8
  • accumulation_steps = 4

训练循环可能会变成:

for epoch in range(num_epochs):
    for step, batch in enumerate(loader):
        loss = model(batch)
        loss = loss / accumulation_steps
        loss.backward()

        if (step + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

这时就要明确区分两类 step:

  • data step / iteration step
    指内层循环走了一次,处理了一个 micro-batch
  • update step / optimizer step
    指真正执行了一次参数更新

比如上面的例子里:

  • loader 每循环 1 次,处理 1 个 batch
  • 但每 4 次循环,才执行 1 次 optimizer.step()

所以:

  • 4 个 batch iteration
  • 4 次 backward
  • 1 次参数更新

这时候如果还只是笼统地说 “step”,就很容易混乱.

4. 再补几个大模型训练里常见的概念

micro-batch

梯度累积时,每次真正送进模型做 forward / backward 的那一小批数据,通常就叫 micro-batch.

也就是说,在做梯度累积时,内层循环里每次读到的 batch,很多时候其实更准确地说应该叫 micro-batch.

local batch size

单个进程、单张卡上,每次 forward 实际处理的数据量.

global batch size

所有卡、再乘上梯度累积之后,一次参数更新等效看到的总 batch size.这个等效于之前提到的 effective batch size, 更强调“从优化角度看,这次更新等效使用了多大的 batch”.

如果定义:

  • 记每张卡上的 batch size 为 \( B_{local} \)
  • 记 GPU 数量为 \( N_{gpu} \)
  • 记梯度累积步数为 \( N_{acc} \)

那么 global batch size 通常可以写成:

\[ B_{global} = B_{local} \times N_{gpu} \times N_{acc} \]

这里:

  • \( B_{local} \) 表示单卡每次 forward 的 batch size
  • \( N_{gpu} \) 表示并行使用的 GPU 数量
  • \( N_{acc} \) 表示梯度累积步数

这个公式有助于把“单卡 batch”“多卡 batch”“梯度累积后的等效 batch”统一起来理解.

5. 结合 MiniMind 的代码,这些词更准确地应该怎么叫?

MiniMind 的训练循环里有:

for step, (X, Y, loss_mask) in enumerate(loader, start=start_step + 1):
    ...
    scaler.scale(loss).backward()

    if (step + 1) % args.accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

这里的 step 更接近:

  • 当前 dataloader 迭代到了第几个 batch
  • 也就是第几个 micro-batch iteration

不完全等于“第几次参数更新”,因为参数更新只在满足:

(step + 1) % args.accumulation_steps == 0

时才会发生.

所以如果按更严格的术语来讲:

  • epoch: 第几个训练轮次
  • step: 当前 epoch 里处理到第几个 dataloader batch
  • iteration: 在这里基本可以近似理解成一次 dataloader 循环
  • optimizer step / update step: 真正执行参数更新的时刻

也就是说,MiniMind 这里的 step 更接近“batch step”或“iteration step”,而不是严格意义上的“update step”.

6. 所以这几个概念在训练里一般是怎么运转的?

可以把一个更完整的训练过程概括成这样:

  1. 外层按 epoch 循环
  2. 每个 epoch 内,DataLoader 不断产生 batch / micro-batch
  3. 每来一个 batch,就做一次 forward 和 backward
  4. 如果不开梯度累积,通常每个 batch 后都立即 optimizer.step()
  5. 如果开了梯度累积,那就会先积累多个 micro-batch 的梯度,再统一 optimizer.step()
  6. 所以 iteration 的次数通常大于等于 update step 的次数

如果只记一句最实用的话,可以记成:

  • epoch: 数据集跑了第几轮
  • batch / micro-batch: 一次 forward / backward 处理多少数据
  • iteration: 内层训练循环往前走了一次
  • optimizer step / update step: 参数真正更新了一次

而一旦进入分布式训练和梯度累积,最容易混乱的地方就是:
“一个 batch iteration” 不一定等于 “一次参数更新 step”. 这也是为什么很多项目后面会额外引入 global_stepupdate_step 这类更明确的命名.

Q12: 可视化怎么做?

MiniMind 训练脚本里记录了 loss、learning rate 和预计剩余时间:

if wandb:
    wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})

这里代码里变量名叫 wandb,实际导入的是 swanlab

import swanlab as wandb

不管使用 SwanLab、WandB 还是 TensorBoard,核心目的都一样:把训练过程从终端里解放出来,让我们能看到曲线。

Q12.1 Wandb完整API示例

不过如果继续往代码层面看,这一类可视化工具通常至少会涉及两类最基本的 API:

  1. 初始化一个实验 run
  2. 在训练过程中不断往这个 run 里记录数据

MiniMind 里初始化的代码是:

if args.use_wandb and is_main_process():
    import swanlab as wandb
    wandb_id = ckp_data.get('wandb_id') if ckp_data else None
    resume = 'must' if wandb_id else None
    wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
    wandb.init(project=args.wandb_project, name=wandb_run_name, id=wandb_id, resume=resume)

这里虽然变量名写的是 wandb,但实际导入的是 swanlab,所以更准确地说,这里用的是一套 WandB 风格的接口去记录实验.

这里几个参数可以顺手解释一下:

  • project
    表示这次实验属于哪个项目. 一般同一类实验会放在同一个 project 下面,方便后面统一比较.
  • name
    表示当前这个 run 的显示名称. 它通常更适合给人看,比如把 epoch、batch size、learning rate 直接拼进名字里.
  • id
    表示这个 run 的唯一标识. 如果恢复训练时还想继续往原来的实验里写数据,这个 id 就很重要.
  • resume
    表示是否继续已有的 run. 这里 resume='must' 的意思可以理解成: 如果已经有这个 run,那就强制接着它写; 不要新建一个同名但不连续的实验.

这里还有一个分布式训练里的关键点:

if args.use_wandb and is_main_process():

为什么只在主进程记录日志? 因为在分布式训练里,会有很多个进程同时跑训练循环. 如果每个进程都自己 initlog,那同一组指标就会被重复记录很多次,曲线也会变得很乱.

所以更常见的做法是:

  • 只让主进程负责可视化记录
  • 其他进程正常训练,但不单独写日志

然后在训练过程中,通过 wandb.log(...) 不断追加指标:

if wandb:
    wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})

wandb.log(...) 最常见的用法就是记录一组标量.
每调用一次,平台就会把这些值追加到当前 run 的时间序列里,最后画成曲线.

所以这段代码本质上记录的是三条标量曲线:

  • loss
  • lr
  • epoch_Time

再结合 Q10 里的 checkpoint 恢复逻辑,这里还有一个细节很值得注意: MiniMind 会把 wandb_id 一起保存进 resume checkpoint. 这样恢复训练时,可以重新拿回原来的 run id,再继续往同一个实验里写数据.

这里还有一个很自然的问题: 如果第一次训练时 wandb_id 本来就是 None,那保存 checkpoint 的时候这个 id 又是从哪里来的?

答案是: 第一次 wandb.init(...) 时,即使传进去的 idNone,平台也会自动新建一个 run,并给它分配一个真正的 run id. 之后在保存 checkpoint 时,MiniMind 再从当前 run 对象里把这个 id 取出来保存.

对应代码在 lm_checkpoint(...) 里:

wandb_id = None
if wandb:
    if hasattr(wandb, 'get_run'):
        run = wandb.get_run()
        wandb_id = getattr(run, 'id', None) if run else None
    else:
        wandb_id = getattr(wandb, 'id', None)

所以这件事的完整逻辑其实是:

  • 第一次训练时,wandb_id 可以是 None
  • wandb.init(...) 会自动创建一个新的 run
  • 新 run 创建好之后,平台内部就已经有了真正的 id
  • 保存 checkpoint 时,再把这个 id 读出来,写进 resume checkpoint
  • 下次恢复训练时,再把这个旧 id 传回 wandb.init(..., id=..., resume='must')

也就是说,可视化的断点续训并不是“重新开一个实验再接着看”,而是:

  • 先从 checkpoint 里读出原来的 wandb_id
  • wandb.init(..., id=wandb_id, resume='must')
  • 让后续日志继续追加到原来的 run 上

这样恢复训练之后,曲线才会是连续的.

除了记录标量之外,WandB 风格的接口通常还支持记录别的数据类型. 比如比较常见的还有:

wandb.log({"samples": wandb.Text("hello world")})

这个用法适合记录文本内容,比如模型生成样例、prompt 和 response、某一步的输出片段.

再比如:

wandb.log({"pred_table": wandb.Table(data=[["input", "output"]], columns=["x", "y"])})

Table 适合记录结构化结果,比如若干条样本的输入、输出、标签、分数. 后面做对比分析时会比较方便.

所以如果只记最常用的几种 API,大概可以记成:

  • wandb.init(...)
    创建或恢复一个实验 run
  • wandb.log({...})
    记录标量或其他可视化数据
  • wandb.Text(...)
    记录文本
  • wandb.Table(...)
    记录表格化结果

Q13: Tips:

Q13.1: 训练里经常因为 tensor 的 shape 或 dtype 报错,有没有办法尽量提前检查?

这是训练代码里非常常见的一类问题. 尤其是写 dataset、collate、model forward、loss 对齐这些地方的时候,一旦 tensor 的 shape 或 dtype 没对上,往往要等跑到中间某一层才会报错,而且报错位置还不一定是问题真正出现的地方.

先说结论: 在 PyTorch 里,这类问题很难做到像静态类型语言那样“彻底静态检查”. 但我们完全可以通过一些工程手段,把很多错误尽量提前暴露出来.

最常见的思路有下面几类:

  1. 给关键函数补类型注解

最基础的一层是先把普通 Python 类型关系理清楚,比如函数输入输出、可选值、配置对象、batch 结构这些.
这类检查虽然不能直接证明 tensor shape 一定正确,但至少能减少很多“参数类型传错”“返回值结构不一致”这类问题.

  1. 在关键边界加 shape / dtype 断言

这是最实用的一招.
因为训练代码里最容易出问题的地方,通常都集中在几个边界上:

  • dataset 输出
  • collate_fn 输出
  • model 输入
  • logits 和 labels 送进 loss 之前

比如可以直接写:

assert X.ndim == 2
assert Y.ndim == 2
assert X.shape == Y.shape
assert loss_mask.shape == Y.shape
assert X.dtype == torch.long
assert Y.dtype == torch.long

如果模型输出再进一步检查:

assert res.logits.shape[:2] == Y.shape
assert res.logits.size(-1) == vocab_size

这样做的价值是:
与其等到后面的 viewloss_fctmatmul 里报一个很绕的错,不如在边界上更早地把问题拦下来.

  1. 给 batch 封装成更明确的数据结构

如果训练里到处都在传 (X, Y, mask) 这种裸 tuple,就很容易把顺序写错,或者后面自己都忘了第三个张量到底是什么意思.

更稳一点的写法是把 batch 封装成 dataclass 或者命名结构,比如:

from dataclasses import dataclass
import torch

@dataclass
class LMTrainingBatch:
    input_ids: torch.Tensor
    labels: torch.Tensor
    loss_mask: torch.Tensor

这样一来,接口语义会更清楚,后面也更容易统一加 validate() 一类的检查逻辑.

  1. 对 dataset / dataloader / forward 写最小测试

很多 shape 问题其实没必要等到正式训练时才发现.
只要写几个很小的测试,比如:

  • dataset 取一条样本时 shape 对不对
  • dataloader 取一个 batch 时 shape 和 dtype 对不对
  • model forward 输出 shape 对不对
  • loss 计算能不能正常对齐

通常就已经能提前挡掉一大批错误.

  1. 如果还想更严格,可以引入带 shape 的类型注解工具

比如 jaxtypingtorchtyping 这类工具,可以把 tensor 的维度约定直接写进函数签名里.
它们更像是“类型注解 + 运行时检查”的组合,虽然不算完全静态检查,但能让接口变得清楚很多.

例如:

from jaxtyping import Int, Float
from torch import Tensor

def forward(
    x: Int[Tensor, "batch seq"],
    mask: Float[Tensor, "batch seq"]
) -> Float[Tensor, "batch seq vocab"]:
    ...

这样至少能把“这个函数期望什么 shape”明确写出来.

所以如果只总结成最实用的一套做法,我会更推荐:

  • 用普通类型注解把训练代码结构理顺
  • 在 dataset / collate / forward / loss 前加 shape 和 dtype 断言
  • 给 batch 封装成更明确的数据结构
  • 写几个最小单元测试

换句话说,在 PyTorch 里,shape / dtype 问题很难被纯静态检查彻底解决. 但只要把关键边界守住,很多训练时才会爆出来的问题,其实都能更早发现.

Q13.2: 训练里经常出现 CPU / GPU device 不一致报错,这种问题怎么尽量减少?

这一类报错也非常常见. 它和 shape / dtype 问题有点像: 真正的问题往往出在前面,但报错经常是在某个算子真正开始计算时才出现.

先说一个最核心的原则:

  • 参与同一次计算的 tensor,通常必须在同一个 device 上
  • 但日志、可视化、打印、保存这些操作,又经常需要把 tensor 转回 CPU

所以很多 device 报错,本质上不是“不会用 GPU”,而是“训练计算流”和“日志 / 可视化流”混在了一起.

最常见的几类问题有:

  1. 模型在 GPU,但输入数据还在 CPU
  2. labelsloss_mask 这种辅助 tensor 忘了 .to(device)
  3. 中间手动新建 tensor 时,默认建在 CPU
  4. tensor 还在 GPU 上,就直接拿去 numpy()、可视化或者打印
  5. 记录日志时,直接把 GPU tensor 塞给可视化工具

比较实用的做法通常有下面几条:

  1. batch 一进入训练循环,就统一搬到 device

比如:

X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)

这一点非常重要. 最好形成一个固定习惯:
所有参与 forward 和 loss 计算的 tensor,都在训练循环入口统一 .to(device).

  1. 新建 tensor 时,尽量继承已有 tensor 的 device

比如下面这种写法就很容易埋坑:

mask = torch.zeros(B, L)

因为它默认会建在 CPU 上.

更稳一点的写法是:

mask = torch.zeros(B, L, device=X.device)

或者:

mask = torch.zeros_like(X)

这样可以尽量避免“模型和输入都在 GPU,但中间新建的 tensor 突然跑到 CPU”这种问题.

  1. 做日志、可视化、保存时,尽量先转成 CPU 友好的形式

如果只是一个标量,最常见的做法是:

loss_value = loss.item()

如果是张量还要继续送去 numpy() 或 matplotlib,更常见的顺序是:

arr = tensor.detach().cpu().numpy()

这里的顺序一般是:

  • detach()
    先脱离计算图
  • cpu()
    再搬回 CPU
  • numpy()
    最后转成 NumPy

如果 tensor 还在 GPU 上就直接 numpy(),通常就会报错.

  1. 给 wandb / swanlab 记录数据时,尽量传 Python 标量或 CPU 数据

比如:

wandb.log({
    "loss": loss.item(),
    "lr": current_lr,
})

如果你直接把一个 GPU tensor 丢进去,有些时候可视化工具会帮你处理,有些时候又会引出新的问题. 所以更稳的做法通常是:

  • 标量用 .item()
  • 数组用 .detach().cpu()
  • 图片、表格、文本也尽量先转成更明确的 CPU 侧格式
  1. 尽量统一 device 的来源

不要在代码里到处手写 "cuda:0""cuda:1""cpu".
更稳一点的做法是统一从一个地方拿 device,比如:

device = args.device

或者直接从模型参数拿:

device = next(model.parameters()).device

这样更不容易出现“模型在一张卡上,数据却被送到另一张卡上”的问题.

  1. 在关键边界直接加 device 断言

调试阶段这其实非常有用. 比如:

assert X.device == next(model.parameters()).device
assert Y.device == X.device
assert loss_mask.device == X.device

这样你就不用等到后面某个矩阵乘法或者 loss 计算才看到一长串 device mismatch 报错.

如果把这些经验压缩成一句话,可以记成:

  • 算的时候,尽量保证参与同一次计算的 tensor 都在同一个 device 上
  • 记的时候,尽量尽早 .item().detach().cpu()

也就是说,训练代码里最好主动区分两条流:

  • 训练计算流
    model、input、label、mask、logits、loss 这些,应该尽量保持在同一个 device 上
  • 日志 / 可视化流
    print、wandb、numpy、matplotlib、保存展示结果这些,通常应该尽快转成 CPU 侧更安全的形式

很多 device 报错,本质上就是这两条流没有分开.