Training
gradient checkpointing
Техника обучения: перевычисление активаций при backward pass вместо их хранения в памяти.
Что такое gradient checkpointing
При обратном распространении ошибки PyTorch по умолчанию хранит активации всех слоёв, вычисленных в forward pass, — это необходимо для расчёта градиентов. При обучении трансформеров это может занимать до 60–80% VRAM (больше, чем сами веса).
Gradient checkpointing (также activation checkpointing) — компромисс: вместо хранения всех активаций перевычисляет их при backward pass «на лету». Активации хранятся только на «контрольных точках» (обычно на входах каждого трансформерного блока), остальное пересчитывается.
Компромисс: память vs вычисления
- Экономия памяти: 5–10× снижение пика VRAM для активаций
- Замедление: +20–30% времени обучения за счёт пересчёта
- Итог: позволяет обучать с большим batch_size или обучать модель, не помещающуюся без checkpointing
Включение
# Hugging Face model
model.gradient_checkpointing_enable()
# Или через TrainingArguments
from transformers import TrainingArguments
args = TrainingArguments(
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False}, # рекомендуется
...
)
# PyTorch native
from torch.utils.checkpoint import checkpoint
# Оборачивает слой для перевычисления при backward
output = checkpoint(layer_fn, input)
Когда использовать
- QLoRA / LoRA: почти всегда — PEFT-библиотека вызывает
prepare_model_for_kbit_training, который включает gradient checkpointing - Full fine-tuning: при нехватке VRAM или для увеличения batch_size
- Pretraining: стандартная практика при обучении на 100+ GPU
# QLoRA: включается автоматически
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model) # включает gradient checkpointing
Связанные термины
- VRAM — ресурс, экономимый gradient checkpointing
- QLoRA и LoRA — применяются совместно
- ZeRO — другая техника экономии VRAM
- batch size — можно увеличить благодаря checkpointing
Готовы запустить GPU-задачу?
Запустить GPU-сервер