Чекпоинты в обучении LLM: форматы и перезапуск

Задача страницы. Дать практические стратегии сохранения‑перезапуска обучения LLM на облачных GPU: что именно сохранять, как писать атомарно, какие форматы выбирать (safetensors/шардированные state‑dict), как возобновляться после прерывания (в т.ч. на interruptible‑инстансах), и как оставаться совместимыми с FSDP/DeepSpeed/LoRA.

Что сохранять (полный список)

Минимальный комплект для безболезненного рестарта:

  • **Веса модели.
  • Состояние оптимизатора (в т.ч. 8‑бит/CPU‑offload, если используется).
  • Состояние LR‑планировщика.
  • Шаги обучения: global_step, epoch, samples_seen.
  • RNG‑состояния: Python/NumPy/CUDA.
  • Состояние даталоадера/самплера (позиция в эпохе, сид, порядок выборки).
  • Градиентный аккьюмулятор (если хранится отдельно).
  • Мета‑инфо: версия кода/конфигов, токенайзер, формат точности (bf16/fp16), аттестация последнего валидного чекпоинта.

Для LoRA/QLoRA часто достаточно адаптера (малый чекпоинт). Для прод‑инференса можно дополнительно сохранять «мердженный» вариант весов.

Форматы и совместимость

safetensors

  • Безопасная сериализация тензоров, ускоренная загрузка, детерминированные заголовки. Рекомендуется для финальных/прод‑весов.
  • В Transformers можно сохранять safe_serialization=True.

PyTorch state‑dict

  • Универсальный способ для обучения/ресюма. Поддерживает полные и шардированные состояния (особенно важно для FSDP/ZeRO).

Шардированные чекпоинты (мульти‑GPU)

  • Снижают время/память на сохранение и упрощают «перешардинг» при изменении числа/GPU на рестарте.
  • Для FSDP используйте «sharded state‑dict» или распределённый чекпоинт. Для DeepSpeed — встроенные шард‑чекпоинты.

LoRA/PEFT

  • Сохраняйте отдельно адаптер (малый объём) → удобно для итераций.

По необходимости «сливайте» адаптер в базовую модель для инференса на чистом рантайме. ## Базовый рецепт (PyTorch, один узел) Сохранение (атомарно на локальный диск):

import os, json, torch, tempfile, shutil, random
import numpy as np
def save_checkpoint_atomic(out_dir, model, optimizer, scheduler, train_state):
 # 1) пишем во временную папку
 tmp_dir = out_dir + ".tmp"
 os.makedirs(tmp_dir, exist_ok=True)
 torch.save(model.state_dict(), os.path.join(tmp_dir, "model.pt"))
 torch.save(optimizer.state_dict(), os.path.join(tmp_dir, "optim.pt"))
 if scheduler is not None:
 torch.save(scheduler.state_dict(), os.path.join(tmp_dir, "sched.pt"))
 state = {
 "global_step": train_state["step"],
 "epoch": train_state["epoch"],
 "rng": {
 "python": random.getstate(),
 "numpy": np.random.get_state(),
 "torch_cpu": torch.get_rng_state().tolist(),
 "torch_cuda": torch.cuda.get_rng_state_all()
 },
 "meta": train_state.get("meta", {})
 }
 with open(os.path.join(tmp_dir, "trainer_state.json"), "w", encoding="utf-8") as f:
 json.dump(state, f, ensure_ascii=False, indent=2)
 # 2) атомарная ротация каталогов
 bk_dir = out_dir + ".bak"
 if os.path.isdir(out_dir):
 if os.path.isdir(bk_dir):
 shutil.rmtree(bk_dir, ignore_errors=True)
 os.rename(out_dir, bk_dir)
 os.rename(tmp_dir, out_dir)
 shutil.rmtree(bk_dir, ignore_errors=True)

Загрузка:

def load_checkpoint(out_dir, model, optimizer=None, scheduler=None):
 model.load_state_dict(torch.load(os.path.join(out_dir,"model.pt"), map_location="cpu"))
 if optimizer is not None and os.path.exists(os.path.join(out_dir,"optim.pt")):
 optimizer.load_state_dict(torch.load(os.path.join(out_dir,"optim.pt"), map_location="cpu"))
 if scheduler is not None and os.path.exists(os.path.join(out_dir,"sched.pt")):
 scheduler.load_state_dict(torch.load(os.path.join(out_dir,"sched.pt"), map_location="cpu"))
 with open(os.path.join(out_dir, "trainer_state.json"), "r", encoding="utf-8") as f:
 return json.load(f)

FSDP: полный vs шардированный стейт

Полный стейт (удобно для экспорта/инференса):

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(
 model, 
 StateDictType.FULL_STATE_DICT,
 FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
):
 full_state = model.state_dict()
torch.save(full_state, "ckpt/model_full.pt")

Шардированный стейт (рекомендуется для больших моделей):

from torch.distributed.fsdp import StateDictType, ShardedStateDictConfig
with FSDP.state_dict_type(
 model,
 StateDictType.SHARDED_STATE_DICT,
 ShardedStateDictConfig(offload_to_cpu=False)
):
 sharded_state = model.state_dict()
 # Сохраните набор файлов (по ранкам) в общий каталог (один файл на ранк)
 torch.save(sharded_state, f"ckpt/shard_rank{rank:02d}.pt")

Важные примечания

  • Для рестарта на другом числе GPU выбирайте шардированные форматы; при загрузке модель может «перешардить» параметры.
  • Следите за согласованностью конфигурации FSDP (wrap‑политика, use_orig_params и т.п.) между сохранением и загрузкой.

DeepSpeed ZeRO: сохранение и загрузка

Сохранение:

# engine = deepspeed.initialize(...)[0]
tag = f"global_step{global_step}"
engine.save_checkpoint("ckpt_deepspeed", tag=tag)

Загрузка:

load_tag = "global_step12345"
engine.load_checkpoint("ckpt_deepspeed", load_tag)

Практика

  • При ZeRO‑3/Offload/Infinity убедитесь, что пути на NVMe существуют и хватает места.
  • Для долгих тренировок лучше включить ротацию чекпоинтов (см. политику ниже). Hugging Face Trainer / PEFT

Сохранение «в стиле HF»:

model.save_pretrained("ckpt_hf", safe_serialization=True)
tokenizer.save_pretrained("ckpt_hf")

LoRA/QLoRA (адаптер):

# model = PeftModel(...)
model.save_pretrained("ckpt_lora") # малый чекпоинт

Смерджить LoRA в базовую модель:

# base = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.bfloat16)
# model = PeftModel.from_pretrained(base, "ckpt_lora")
merged = model.merge_and_unload()
merged.save_pretrained("ckpt_merged", safe_serialization=True)

Частота и политика сохранений

  • По времени: каждые N минут (надежнее против preempt).
  • По шагам: каждые K шагов.
  • Двойная защита: best по валидации + last по времени/шагам.
  • Ротация: хранить последние M чеков + по одному «опорному» на эпоху.

**Каталоги:

/checkpoints/
 - last/ # всегда последний успешный
 - best/ # лучший по метрике
 - epoch-0001/
 - step-010000/
 - meta.json # центральный реестр (пути, метрики, дата)

Атомарность и целостность (в т.ч. объектное хранилище)

  • На локальном FS используйте «папка.tmp → rename» (как в примере выше).
  • В объектном хранилище (S3‑совместимое):
  1. загрузите файлы в каталог .../staging/step-XXXXX/;
  2. загрузите COMMIT‑маркер (commit.json с хешами и датой);
  3. только после успешного коммита переключайте «указатель» last/ на новый шаг (обновите last.json).
  • Проверяйте контрольные суммы (md5/sha256) и проверочную загрузку мини‑итерации после сохранения «лучшего» чекпоинта.

Устойчивость к прерываниям (interruptible)

SIGTERM‑ловушка: сохраняйте «быстрый» чекпоинт при сигнале завершения.

import signal, functools
def on_sigterm(*_):
 try: save_checkpoint_atomic("ckpt/last", model, optimizer, scheduler, train_state)
 finally: exit(0)
signal.signal(signal.SIGTERM, on_sigterm)
  • Частые сейвы + быстрый диск (NVMe): пишите локально → асинхронно выгружайте в объектное хранилище.
  • Атомарность: никакой частично записанной директории «в прод». Жёстко проверяйте наличие commit.json.

Совместимость форматов: обучение ↔ инференс

  • Для инференса предпочтительны safetensors и «мердженные» веса без адаптеров.
  • Для обучения — шардированные state‑dict + отдельные адаптеры (LoRA).
  • Для квантизации/экспорта (например, INT4/FP8/gguf) сохраняйте «чистые» fp16/bf16 веса и токенайзер как исходник. См. /solutions/llm-inference/quantization/.

Состояние даталоадера и воспроизводимость

  • Логируйте позицию в эпохе, сиды и порядок шардов.
  • Храните sampler_state.json (номер шага, текущий индекс, seed).
  • При рестарте восстанавливайте порядок чтения, иначе появится «двойной просмотр» или пропуск данных.

Типовые сценарии

A) LoRA‑SFT на 1×GPU (24–48 GB)

  • Сохраняем адаптер каждые 15–30 минут + last и best.
  • Для прод‑теста — периодически «мерджим» и сохраняем safetensors.

B) Pretrain на 8×GPU (FSDP)

  • Каждые 30–60 минут — шардированный чекпоинт, раз в N часов — полный «экспортный».
  • Восстановление допускает изменение world‑size (при корректном формате).

C) ZeRO‑3 + Offload (многоузловой)

  • Чекпоинты DeepSpeed по тегам, с ротацией.
  • Хранилище — объектное, staged‑загрузка и commit.json.

Траблшутинг

  • Чекпоинт загружается, но качество «прыгнуло». Проверьте LR‑scheduler state, RNG‑состояния, позицию даталоадера.
  • OOM при сохранении полного стейта (FSDP). Перейдите на шардированный формат; временно выгружайте на CPU.
  • Неполный чекпоинт на объектном хранилище. Включите протокол staging → commit → switch last; не читайте чек без commit.json.
  • Несовпадение ключей при загрузке. Сравните версии модели/конфига; используйте «строгую» загрузку и явный лог пропущенных/лишних ключей.
  • Долгая выгрузка. Пишите локально, затем асинхронно реплицируйте; отключите лишнюю компрессию, увеличьте параллелизм загрузки.

Чек‑лист перед запуском обучения

  • Выбран формат весов: safetensors для экспорта, шардированный state‑dict для больших тренировок.
  • Определена частота и ротация чекпоинтов (last/best/every‑N).
  • Реализована атомарность записи (tmp‑директории/commit‑маркер).
  • Сохраняются optimizer/scheduler/RNG + sampler_state.
  • Есть обработчик SIGTERM и асинхронная выгрузка в объектное хранилище.
  • Тест «холодного» рестарта: загрузка → 1 шаг вперёд/назад → метрики совпадают.
  • Зафиксированы версии токенайзера, формата точности, конфигурации параллелизма.

Связанные страницы

База по обучению

Масштабирование: /solutions/llm-training/fsdp-deepspeed/, /solutions/multi-gpu/

Память/точность: /solutions/llm-training/memory-opt/, /solutions/llm-training/mixed-precision/

Данные/I‑O/хранилище: /solutions/llm-training/distributed-io/, /solutions/storage-data/

Стоимость/режимы

Готовы запустить?

Запустить GPU-сервер