Решения

Предобучение LLM: VRAM, I/O и чекпоинты

Задача страницы. Помочь быстро собрать устойчивый и производительный пайплайн предпритрейна: как «влезть» в память, чем кормить GPU без простоев и как не терять прогресс при прерываниях.

Картина целиком: что ограничивает скорость и бюджет

  • Память и параллелизм. Когда модель не помещается на один GPU, используйте шардирование модели/состояний (FSDP/ZeRO) вместо чистого DDP. Это снижает VRAM‑пик за счёт распределения параметров, градиентов и состояний оптимизатора между участниками.
  • Вычисления в слое. Если после шардирования «узкое место» остаётся в матемике крупных слоёв (Attention/FFN), добавляйте tensor/pipeline parallel (как в Megatron‑подходах) и комбинируйте с DDP/FSDP.
  • Подача данных. Для больших корпусов — последовательный I/O и шардинг (tar‑шарды/WebDataset) вместо миллионов мелких файлов; это упрощает стриминг в DataLoader и снижает накладные расходы.

План VRAM: на что уходит память и как её экономить

Компоненты VRAM: параметры модели, активации (forward), градиенты (backward), состояния оптимизатора. Для больших LLM это кратно превышает объём самих параметров — поэтому DDP быстро упирается в пределы. Решение — FSDP/ZeRO (шардирование), offload на CPU/NVMe и приёмы снижения активаций.

Приёмы, которые работают практически всегда:

  • Смешанная точность. На Ampere+ используйте BF16/TF32 для обучения, на Hopper возможны FP8‑режимы через соответствующий стек (валидируйте качество).
  • Gradient checkpointing и grad accumulation — уменьшают пик активаций и VRAM‑требования.
  • FSDP / ZeRO‑3 — шардирование параметров/градиентов/состояний; при необходимости включайте offload на CPU/NVMe.
  • torch.compile (PyTorch 2.x) — уменьшает накладные расходы Python и ускоряет вычисления при минимальных правках кода.

I/O и данные: чтобы GPU не простаивал

Правила большого датасета:

Шардируйте в tar по 0.5–2 ГБ и подавайте последовательным I/O (WebDataset):

				
					 import webdataset as wds
ds = (wds.WebDataset("/data/corpus/{000000..000999}.tar").to_tuple("txt"))
loader = wds.WebLoader(ds, batch_size=..., num_workers=8)

				
			
  •  Такой формат изначально рассчитан на стриминг и параллельное чтение.
  • Префетч и pinned memory в загрузчике, NVMe‑кэш для «горячих» шардов, фоновая докачка (обучение начинается до полной синхронизации).
  • Распределяйте шард‑диапазоны между узлами, чтобы не тянуть один и тот же файл на все машины.

Подготовка корпусов: токенизация, фильтрация, дедупликация, шардирование, индекс/манифест (см. /solutions/llm-training/datasets/ и /solutions/llm-training/distributed-io/).

Стратегии параллелизма (когда и что включать)

  1. FSDP / ZeRO‑3 — если модель не влезает:
    Шардирование параметров/градиентов/состояний, при необходимости offload на CPU/NVMe (ZeRO‑Infinity).
  2. Tensor / Pipeline parallel — если узкое место в вычислениях слоя или глубине модели:
    внутри‑слойное деление (TP) и конвейер по стадиям (PP), как в Megatron‑подходах; можно комбинировать с FSDP/ZeRO.
  3. DDP — если модель помещается на 1 GPU и нужен рост throughput по данным; для предпритрейна часто используется 3D‑комбинация DP×TP×PP.

Чекпоинтинг и перезапуск: никаких потерь прогресса

  • Частота: по времени (каждые 15–30 мин) и по шагам; для инферинга/батч‑ETL — по чанкам.
  • Атомарная запись и дублирование: сохраняйте во временный файл и атомарно переименовывайте в целевое имя, затем немедленно реплицируйте во внешнее хранилище (S3‑совместимое/объектное). Такой паттерн устойчив к прерыванию/крашу на середине записи.
  • Interruptible‑стойкость: при спотовых/прерываемых инстансах чекпоинт — обязательное условие выживания пайплайна.

Мини‑рецепты запуска

A) Torchrun + FSDP (один узел, 8 GPU):

				
					torchrun --standalone --nproc_per_node=8 train.py \
  --fsdp "full_shard auto_wrap" --mixed_precision bf16 \
  --use_checkpointing true

				
			

FSDP шардирует параметры/градиенты/состояния; BF16 снижает VRAM‑пик и ускоряет математику на Ampere+.

B) DeepSpeed ZeRO‑3 + Offload (4–8 GPU):

				
					{
  "train_micro_batch_size_per_gpu": 1,
  "gradient_accumulation_steps": 64,
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": { "device": "cpu", "pin_memory": true },
    "offload_param":     { "device": "cpu", "pin_memory": true }
  },
  "bf16": { "enabled": true }
}

				
			

ZeRO‑3 распределяет параметры/градиенты/состояния; offload переносит их на CPU/NVMe при дефиците VRAM.

C) Megatron‑подход (TP/PP) на многоузловом кластере:

				
					python pretrain_gpt.py \
  --tensor-model-parallel-size 4 \
  --pipeline-model-parallel-size 8 \
  --sequence-parallel ...

				
			

TP режет слои, PP — глубину; параметры подбираются по топологии/сети.

Точность и оптимизаторы

  • BF16/TF32 (Ampere+) — типовой выбор для стабильного обучения; FP8 (Hopper) через Transformer Engine/совместимые стеки — требует тщательной валидации качества.
  • Оптимизаторы: AdamW/Lion/AdaFactor; для ZeRO‑Offload — CPU‑варианты (например, DeepSpeedCPUAdam).
  • torch.compile — пробуйте как «бесплатный» буст производительности (особенно в инференсе/валидации).

Считаем пропускную способность и время

Обозначим:
G — число GPU (world size), μ — микро‑батч на GPU, acc — grad accumulation, L — длина контекста (токенов).
Тогда токенов за шагG × μ × acc × L.
Чтобы оценить время до заданного числа токенов, измерьте tokens/sec одного шага (по логам) и масштабируйте на целевой объём корпуса.

Проверка качества и регрессий (минимум)

  • Perplexity (на отложенном корпусе) как быстрый маркер; держите стабильный набор валидационных шардов.
  • Логируйте: loss/perplexity, через сколько часов/токенов достигнуты контрольные значения, версию датасета, гит‑хэш кода/конфига.
  • Для поведенческих свойств/инструкций — переносите оценку на этап finetune/RLHF. См. /solutions/llm-training/eval/.

Частые проблемы и быстрые решения

  • OOM / не влезает батч. Включите BF16/FP16, gradient checkpointing/accumulation; переходите на FSDP/ZeRO‑3 с offload.
  • GPU простаивает. Повышайте num_workers/prefetch_factor, используйте WebDataset‑шарды и локальный NVMe‑кэш.
  • Рассыпались чекпоинты при прерывании. Перейдите на атомарную запись + немедленную выгрузку во внешнее хранилище.
  • Медленная эпоха. Включите torch.compile, проверьте узкие места I/O vs compute.

Навигация по смежным страницам