Решения

Чекпоинты в обучении LLM: форматы и перезапуск

Задача страницы. Дать практические стратегии сохранения‑перезапуска обучения LLM на облачных GPU: что именно сохранять, как писать атомарно, какие форматы выбирать (safetensors/шардированные state‑dict), как возобновляться после прерывания (в т.ч. на interruptible‑инстансах), и как оставаться совместимыми с FSDP/DeepSpeed/LoRA.

Что сохранять (полный список)

Минимальный комплект для безболезненного рестарта:

  • Веса модели.
  • Состояние оптимизатора (в т.ч. 8‑бит/CPU‑offload, если используется).
  • Состояние LR‑планировщика.
  • Шаги обучения: global_step, epoch, samples_seen.
  • RNG‑состояния: Python/NumPy/CUDA.
  • Состояние даталоадера/самплера (позиция в эпохе, сид, порядок выборки).
  • Градиентный аккьюмулятор (если хранится отдельно).
  • Мета‑инфо: версия кода/конфигов, токенайзер, формат точности (bf16/fp16), аттестация последнего валидного чекпоинта.

Для LoRA/QLoRA часто достаточно адаптера (малый чекпоинт). Для прод‑инференса можно дополнительно сохранять «мердженный» вариант весов.

Форматы и совместимость

safetensors

  • Безопасная сериализация тензоров, ускоренная загрузка, детерминированные заголовки. Рекомендуется для финальных/прод‑весов.
  • В Transformers можно сохранять safe_serialization=True.

PyTorch state‑dict

  • Универсальный способ для обучения/ресюма. Поддерживает полные и шардированные состояния (особенно важно для FSDP/ZeRO).

Шардированные чекпоинты (мульти‑GPU)

  • Снижают время/память на сохранение и упрощают «перешардинг» при изменении числа/GPU на рестарте.
  • Для FSDP используйте «sharded state‑dict» или распределённый чекпоинт. Для DeepSpeed — встроенные шард‑чекпоинты.

LoRA/PEFT

  • Сохраняйте отдельно адаптер (малый объём) → удобно для итераций.

По необходимости «сливайте» адаптер в базовую модель для инференса на чистом рантайме.

Базовый рецепт (PyTorch, один узел)

Сохранение (атомарно на локальный диск):

				
					import os, json, torch, tempfile, shutil, random
import numpy as np

def save_checkpoint_atomic(out_dir, model, optimizer, scheduler, train_state):
    # 1) пишем во временную папку
    tmp_dir = out_dir + ".tmp"
    os.makedirs(tmp_dir, exist_ok=True)

    torch.save(model.state_dict(), os.path.join(tmp_dir, "model.pt"))
    torch.save(optimizer.state_dict(), os.path.join(tmp_dir, "optim.pt"))
    if scheduler is not None:
        torch.save(scheduler.state_dict(), os.path.join(tmp_dir, "sched.pt"))

    state = {
        "global_step": train_state["step"],
        "epoch": train_state["epoch"],
        "rng": {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "torch_cpu": torch.get_rng_state().tolist(),
            "torch_cuda": torch.cuda.get_rng_state_all()
        },
        "meta": train_state.get("meta", {})
    }
    with open(os.path.join(tmp_dir, "trainer_state.json"), "w", encoding="utf-8") as f:
        json.dump(state, f, ensure_ascii=False, indent=2)

    # 2) атомарная ротация каталогов
    bk_dir = out_dir + ".bak"
    if os.path.isdir(out_dir):
        if os.path.isdir(bk_dir):
            shutil.rmtree(bk_dir, ignore_errors=True)
        os.rename(out_dir, bk_dir)
    os.rename(tmp_dir, out_dir)
    shutil.rmtree(bk_dir, ignore_errors=True)

				
			

Загрузка:

				
					def load_checkpoint(out_dir, model, optimizer=None, scheduler=None):
    model.load_state_dict(torch.load(os.path.join(out_dir,"model.pt"), map_location="cpu"))

    if optimizer is not None and os.path.exists(os.path.join(out_dir,"optim.pt")):
        optimizer.load_state_dict(torch.load(os.path.join(out_dir,"optim.pt"), map_location="cpu"))

    if scheduler is not None and os.path.exists(os.path.join(out_dir,"sched.pt")):
        scheduler.load_state_dict(torch.load(os.path.join(out_dir,"sched.pt"), map_location="cpu"))

    with open(os.path.join(out_dir, "trainer_state.json"), "r", encoding="utf-8") as f:
        return json.load(f)

				
			

FSDP: полный vs шардированный стейт

Полный стейт (удобно для экспорта/инференса):

				
					from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig

with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT,
    FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
    full_state = model.state_dict()
torch.save(full_state, "ckpt/model_full.pt")

				
			

Шардированный стейт (рекомендуется для больших моделей):

				
					from torch.distributed.fsdp import StateDictType, ShardedStateDictConfig
with FSDP.state_dict_type(
    model,
    StateDictType.SHARDED_STATE_DICT,
    ShardedStateDictConfig(offload_to_cpu=False)
):
    sharded_state = model.state_dict()
    # Сохраните набор файлов (по ранкам) в общий каталог (один файл на ранк)
    torch.save(sharded_state, f"ckpt/shard_rank{rank:02d}.pt")

				
			

Важные примечания

  • Для рестарта на другом числе GPU выбирайте шардированные форматы; при загрузке модель может «перешардить» параметры.
  • Следите за согласованностью конфигурации FSDP (wrap‑политика, use_orig_params и т.п.) между сохранением и загрузкой.

DeepSpeed ZeRO: сохранение и загрузка

Сохранение:

				
					# engine = deepspeed.initialize(...)[0]
tag = f"global_step{global_step}"
engine.save_checkpoint("ckpt_deepspeed", tag=tag)

				
			

Загрузка:

				
					load_tag = "global_step12345"
engine.load_checkpoint("ckpt_deepspeed", load_tag)

				
			

Практика

  • При ZeRO‑3/Offload/Infinity убедитесь, что пути на NVMe существуют и хватает места.
  • Для долгих тренировок лучше включить ротацию чекпоинтов (см. политику ниже).

Hugging Face Trainer / PEFT

Сохранение «в стиле HF»:

				
					model.save_pretrained("ckpt_hf", safe_serialization=True)
tokenizer.save_pretrained("ckpt_hf")

				
			

LoRA/QLoRA (адаптер):

				
					# model = PeftModel(...)
model.save_pretrained("ckpt_lora")  # малый чекпоинт

				
			

Смерджить LoRA в базовую модель:

				
					# base = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.bfloat16)
# model = PeftModel.from_pretrained(base, "ckpt_lora")
merged = model.merge_and_unload()
merged.save_pretrained("ckpt_merged", safe_serialization=True)

				
			

Частота и политика сохранений

  • По времени: каждые N минут (надежнее против preempt).
  • По шагам: каждые K шагов.
  • Двойная защита: best по валидации + last по времени/шагам.
  • Ротация: хранить последние M чеков + по одному «опорному» на эпоху.

Каталоги:

				
					 /checkpoints/
  - last/           # всегда последний успешный
  - best/           # лучший по метрике
  - epoch-0001/
  - step-010000/
  - meta.json       # центральный реестр (пути, метрики, дата)

				
			

Атомарность и целостность (в т.ч. объектное хранилище)

  • На локальном FS используйте «папка.tmp → rename» (как в примере выше).
  • В объектном хранилище (S3‑совместимое):
    1. загрузите файлы в каталог …/staging/step-XXXXX/;
    2. загрузите COMMIT‑маркер (commit.json с хешами и датой);
    3. только после успешного коммита переключайте «указатель» last/ на новый шаг (обновите last.json).
  • Проверяйте контрольные суммы (md5/sha256) и проверочную загрузку мини‑итерации после сохранения «лучшего» чекпоинта.

Устойчивость к прерываниям (interruptible)

SIGTERM‑ловушка: сохраняйте «быстрый» чекпоинт при сигнале завершения.

				
					import signal, functools
def on_sigterm(*_):
    try: save_checkpoint_atomic("ckpt/last", model, optimizer, scheduler, train_state)
    finally: exit(0)
signal.signal(signal.SIGTERM, on_sigterm)

				
			
  • Частые сейвы + быстрый диск (NVMe): пишите локально → асинхронно выгружайте в объектное хранилище.
  • Атомарность: никакой частично записанной директории «в прод». Жёстко проверяйте наличие commit.json.

Совместимость форматов: обучение ↔ инференс

  • Для инференса предпочтительны safetensors и «мердженные» веса без адаптеров.

  • Для обучения — шардированные state‑dict + отдельные адаптеры (LoRA).
  • Для квантизации/экспорта (например, INT4/FP8/gguf) сохраняйте «чистые» fp16/bf16 веса и токенайзер как исходник. См. /solutions/llm-inference/quantization/.

Состояние даталоадера и воспроизводимость

  • Логируйте позицию в эпохе, сиды и порядок шардов.
  • Храните sampler_state.json (номер шага, текущий индекс, seed).
  • При рестарте восстанавливайте порядок чтения, иначе появится «двойной просмотр» или пропуск данных.

Типовые сценарии

A) LoRA‑SFT на 1×GPU (24–48 GB)

  • Сохраняем адаптер каждые 15–30 минут + last и best.
  • Для прод‑теста — периодически «мерджим» и сохраняем safetensors.

B) Pretrain на 8×GPU (FSDP)

  • Каждые 30–60 минут — шардированный чекпоинт, раз в N часов — полный «экспортный».
  • Восстановление допускает изменение world‑size (при корректном формате).

C) ZeRO‑3 + Offload (многоузловой)

  • Чекпоинты DeepSpeed по тегам, с ротацией.
  • Хранилище — объектное, staged‑загрузка и commit.json.

Траблшутинг

  • Чекпоинт загружается, но качество «прыгнуло». Проверьте LR‑scheduler state, RNG‑состояния, позицию даталоадера.
  • OOM при сохранении полного стейта (FSDP). Перейдите на шардированный формат; временно выгружайте на CPU.
  • Неполный чекпоинт на объектном хранилище. Включите протокол staging → commit → switch last; не читайте чек без commit.json.
  • Несовпадение ключей при загрузке. Сравните версии модели/конфига; используйте «строгую» загрузку и явный лог пропущенных/лишних ключей.
  • Долгая выгрузка. Пишите локально, затем асинхронно реплицируйте; отключите лишнюю компрессию, увеличьте параллелизм загрузки.

Чек‑лист перед запуском обучения

  • Выбран формат весов: safetensors для экспорта, шардированный state‑dict для больших тренировок.
  • Определена частота и ротация чекпоинтов (last/best/every‑N).
  • Реализована атомарность записи (tmp‑директории/commit‑маркер).
  • Сохраняются optimizer/scheduler/RNG + sampler_state.
  • Есть обработчик SIGTERM и асинхронная выгрузка в объектное хранилище.
  • Тест «холодного» рестарта: загрузка → 1 шаг вперёд/назад → метрики совпадают.
  • Зафиксированы версии токенайзера, формата точности, конфигурации параллелизма.