Training

gradient checkpointing

Техника обучения: перевычисление активаций при backward pass вместо их хранения в памяти.

Что такое gradient checkpointing

При обратном распространении ошибки PyTorch по умолчанию хранит активации всех слоёв, вычисленных в forward pass, — это необходимо для расчёта градиентов. При обучении трансформеров это может занимать до 60–80% VRAM (больше, чем сами веса).

Gradient checkpointing (также activation checkpointing) — компромисс: вместо хранения всех активаций перевычисляет их при backward pass «на лету». Активации хранятся только на «контрольных точках» (обычно на входах каждого трансформерного блока), остальное пересчитывается.

Компромисс: память vs вычисления

  • Экономия памяти: 5–10× снижение пика VRAM для активаций
  • Замедление: +20–30% времени обучения за счёт пересчёта
  • Итог: позволяет обучать с большим batch_size или обучать модель, не помещающуюся без checkpointing

Включение

# Hugging Face model
model.gradient_checkpointing_enable()

# Или через TrainingArguments
from transformers import TrainingArguments
args = TrainingArguments(
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},  # рекомендуется
    ...
)

# PyTorch native
from torch.utils.checkpoint import checkpoint
# Оборачивает слой для перевычисления при backward
output = checkpoint(layer_fn, input)

Когда использовать

  • QLoRA / LoRA: почти всегда — PEFT-библиотека вызывает prepare_model_for_kbit_training, который включает gradient checkpointing
  • Full fine-tuning: при нехватке VRAM или для увеличения batch_size
  • Pretraining: стандартная практика при обучении на 100+ GPU
# QLoRA: включается автоматически
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)  # включает gradient checkpointing

Связанные термины

  • VRAM — ресурс, экономимый gradient checkpointing
  • QLoRA и LoRA — применяются совместно
  • ZeRO — другая техника экономии VRAM
  • batch size — можно увеличить благодаря checkpointing

Готовы запустить GPU-задачу?

Запустить GPU-сервер