Решения

FSDP и DeepSpeed: масштабирование обучения LLM

Задача страницы. Дать ясный выбор между PyTorch FSDP и DeepSpeed ZeRO, показать рабочие конфиги (с offload на CPU/NVMe), чек‑лист стабильности и ссылки на смежные темы.

Когда FSDP, а когда ZeRO (DeepSpeed)

  • FSDP (PyTorch): нативный модуль шардирования параметров/градиентов/состояний оптимизатора по ранкам. Удобная интеграция, поддержка шардированных стейт‑диктов и Distributed Checkpoint; в новых версиях есть FSDP2 (fully_shard) на базе DTensor.
  • ZeRO (DeepSpeed): трёхстадийное шардирование (Stage‑1/2/3) плюс ZeRO‑Offload (CPU) и ZeRO‑Infinity (CPU+NVMe), что позволяет тренировать очень большие модели на ограниченной VRAM. Поддерживаются pipeline/tensor параллелизм и «3D‑параллелизм».

Правило: если хватает возможностей «чистого» data‑parallel шардирования и нужен простой онбординг — начните с FSDP. Если критична агрессивная разгрузка на CPU/NVMe или требуется DP×PP×TP, выбирайте DeepSpeed (ZeRO‑3 + Offload/Infinity).

Как это работает (коротко и по делу)

  • FSDP: вне вычислений параметры хранятся шардированными; перед forward/backward нужные куски all‑gather, в backward градиенты reduce‑scatter; оптимизатор работает со шардированными состояниями. Это радикально снижает VRAM по сравнению с DDP.
  • ZeRO:
    • Stage‑1: шардирует состояния оптимизатора.
    • Stage‑2: + шардирует градиенты.
    • Stage‑3: + шардирует параметры (аналогично идее FSDP).

Offload: перенос optimizer state (Stg‑1/2/3) и/или parameters (Stg‑3) на CPU/NVMe.

Чекпоинтинг на практике

  • FSDP:
    • Для «монолитного» файла — FullStateDictConfig(offload_to_cpu=True, rank0_only=True); просто, но может упереться в RAM/время на больших моделях. 
    • Для масштаба и гибкости — шардированные чекпоинты (рекомендуется); удобно через StateDictType.SHARDED_STATE_DICT и/или Distributed Checkpoint (DCP) — сохраняет/загружает параллельно и умеет reshard под новую топологию.

DeepSpeed: чекпоинты управляются рантаймом; при Offload/Infinity убедитесь, что пути на NVMe корректны и быстры (реальные NVMe диски).

Точность, память и батчинг

  • Mixed precision: на Ampere/Hopper обычно BF16; в DeepSpeed вкл. через «bf16.enabled»: true, в FSDP — через политику смешанной точности/настройки тренера.
  • Gradient checkpointing: снижает пик активаций ценой дополнительного compute; включайте в «тяжёлых» блоках.
  • Полезные флаги FSDP: limit_all_gathers, forward_prefetch, sync_module_states, use_orig_params — влияют на стабильность и утилизацию.
  • Эффективный батч: увеличивайте grad_accumulation, удерживая micro_batch у порога VRAM.

Производительность и топологии

  • Один узел, много GPU: чаще хватает FSDP (или ZeRO‑3 без offload) при хорошей NVLink‑межсвязи.
  • Много узлов / очень большие модели: DeepSpeed 3D (DP×PP×TP) или связки Megatron‑DeepSpeed. TP размещайте внутри узла, PP — поперёк узлов, поверх — DP/ZeRO.
  • I/O и чекпоинты вынесите на быстрые диски/NVMe; для ZeRO‑Infinity — обязательно NVMe‑путь.

Быстрый старт: минимальные конфиги

A) PyTorch FSDP (один узел, несколько GPU)

				
					# train_fsdp.py
import torch, functools
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from model import MyTransformer  # ваш код модели


def main():
    torch.distributed.init_process_group("nccl")
    torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count())


    model = MyTransformer().cuda()
    auto_wrap = functools.partial(size_based_auto_wrap_policy, min_num_params=int(1e7))
    model = FSDP(model, auto_wrap_policy=auto_wrap)  # шардирование по блокам


    optim = torch.optim.AdamW(model.parameters(), lr=2e-4)


    for step, (x, y) in enumerate(loader()):
        optim.zero_grad(set_to_none=True)
        loss = model(x).loss(y)
        loss.backward()
        optim.step()


    # Полный чекпоинт (rank0 CPU)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
        ckpt = model.state_dict()
    torch.save(ckpt, "model_full.pt")


if __name__ == "__main__":
    main()

				
			

Запуск:

				
					torchrun --standalone --nproc_per_node=8 train_fsdp.py

				
			

FSDP шардирует параметры/градиенты/состояния и умеет сохранять полный или шардированный стейт‑дикт. Для крупных моделей используйте шардированные чекпоинты или Distributed Checkpoint.

Примечания по FSDP: auto‑wrap‑политики (например, size‑based или transformer‑policy) упрощают «нарезку» больших трансформеров на блоки.

B) DeepSpeed ZeRO‑3 + Offload (CPU)

				
					{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 64,
  "zero_optimization": {
    "stage": 3,
    "offload_param":     { "device": "cpu", "pin_memory": true },
    "offload_optimizer": { "device": "cpu", "pin_memory": true }
  },
  "bf16": { "enabled": true },
  "steps_per_print": 100,
  "wall_clock_breakdown": false
}

				
			

Запуск:

				
					deepspeed --num_gpus=8 train.py --deepspeed ds_config.json

				
			

ZeRO‑3 делит параметры/градиенты/оптимизатор между ранками; Offload переносит параметры/оптимизатор на CPU (а с ZeRO‑Infinity — ещё и на NVMe).

Частые проблемы и быстрые решения

  • OOM при FSDP/ZeRO‑3. Уменьшите micro_batch, увеличьте grad_accum, включите gradient checkpointing; в ZeRO — добавьте Offload.
  • «Зависания»/ошибки при загрузке чекпоинтов. Для FSDP используйте согласованные типы стейт‑диктов и Distributed Checkpoint — он поддерживает reshard между конфигурациями кластера.
  • Неравномерные all‑gather. Включите limit_all_gathers и forward_prefetch на статичных графах.
  • Нужно ещё больше памяти. ZeRO‑Infinity (параметры/оптимизатор на NVMe) — крайняя мера; учитывайте стоимость I/O.

Чек‑лист запуска (минимум)

  1. Выберите стек: FSDP или ZeRO‑3(+Offload/Infinity).
  2. Включите BF16, задайте grad accumulation и checkpointing по блокам.
  3. Для FSDP задайте auto‑wrap policy (size‑based/transformer) и тип чекпоинта (sharded/DCP).
  4. Для DeepSpeed настройте «zero_optimization.stage»: 3 и Offload‑опции при дефиците VRAM.