Решения
Оптимизация памяти LLM: Flash‑Attention, checkpointing, paged‑KV
Цель. Дать чёткие рецепты, как уместить модель/батч в доступную VRAM и повысить пропускную способность без деградации качества.
TL;DR
- Flash‑Attention 2/3 снижает обращения к HBM за счёт тайлинга/фьюзинга и даёт ускорение с меньшим пиком памяти (FA3 добавляет Hopper‑специфику и FP8‑режимы). Рекомендуется как «дефолт» для внимания на Ampere/Hopper.
- Gradient (activation) checkpointing меняет память на вычисления: хранит меньше активаций и пересчитывает их на бэкварде. Даёт крупную экономию VRAM, особенно на длинных контекстах.
- Для инференса используйте PagedAttention/paged‑KV (vLLM): страничный KV‑кэш уменьшает фрагментацию и «паразитные» потери памяти; дополнительно поможет FP8‑квантизация KV и prefix caching общих префиксов.
Что именно «ест» вашу VRAM
- Параметры модели и стейты оптимизатора (решается FSDP/ZeRO/8‑bit оптимизаторами — см. отдельные страницы).
- Активации (каждый слой × длина контекста × батч): главный кандидат на экономию за счёт checkpointing и эффективного внимания.
- KV‑кэш на инференсе: растёт линейно с контекстом и числом одновременных запросов — оптимизируется через PagedAttention/квантизацию.
Flash‑Attention (FA2/FA3): меньше памяти на внимание
Зачем. Стандартное внимание держит большие временные тензоры (QKᵀ и пр.). Flash‑Attention пересчитывает их «плитками» и минимизирует HBM‑I/O, уменьшая пик памяти и ускоряя шаг.
Как включить в Transformers (рекомендуемый путь):
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-8B",
attn_implementation="flash_attention_2", # FA2
torch_dtype="bfloat16"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")
В современных версиях HF можно выбирать реализацию внимания через attn_implementation (например, sdpa, flash_attention_2).
Как контролировать бэкенд SDPA в «чистом» PyTorch:
# Контекстный выбор ядра внимания
from torch.nn.attention import sdpa_kernel
with sdpa_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None)
PyTorch SDPA поддерживает FlashAttention‑2 и Memory‑Efficient Attention; выбор ядра влияет на память и скорость.
FA3 (Hopper/H100). Новая версия использует асинхронность Tensor Cores/TMA и FP8, повышая утилизацию на H100 (до ~1.2 PFLOP/s в FP8) при сохранении точности. Для обучения/инференса на H100 рекомендуется попробовать FA3‑совместимые стеки.

Gradient (Activation) Checkpointing: RAM ↔ Compute

Идея. Не хранить все активации для бэкварда, а пересчитать часть на обратном проходе. Это резко снижает пик памяти и позволяет увеличить batch_size/seq_len.
PyTorch (базово):
import torch, torch.utils.checkpoint as cp
def block(x):
# ваш крупный подграф: attention + MLP
return model_block(x)
y = cp.checkpoint(block, x) # активации внутри block не сохраняются
Transformers (упрощённо):
model.gradient_checkpointing_enable()
model.config.use_cache = False # важно при обучении с checkpointing
use_cache=True (хранение past_key_values) конфликтует с gradient checkpointing на обучении — отключайте.
DeepSpeed‑расширения. Есть варианты partition activations, CPU‑checkpointing и offload — помогают в «крайних» случаях, когда VRAM критически мала.
Paged‑KV для инференса: длинный контекст без «раздувания»
Проблема. KV‑кэш растёт с контекстом и числом запросов, создавая фрагментацию и «мертвое» место.
Решение. PagedAttention (vLLM) разбивает KV‑кэш на блоки‑страницы и управляет ими как «виртуальной памятью», почти устраняя потери и повышая throughput при длинных контекстах и батчинге.
Практика (vLLM):
vllm serve meta-llama/Llama-3-8B \
--max-model-len 32768 \
--enable-prefix-caching \
--kv-cache-dtype fp8
- —enable-prefix-caching даёт автокэш префиксов (общие подсказки/системные промпты не пересчитываются заново).
- —kv-cache-dtype fp8 уменьшает объём KV‑кэша и увеличивает число одновременно обслуживаемых запросов; для FP8 есть нюансы форматов (e4m3/e5m2).
Flash‑Attention + «пэкинг» последовательностей: меньше паддинга — меньше памяти
Суть. Собирать несколько коротких примеров в один фиксированный блок токенов (packing) вместо классического паддинга до макс‑длины батча, соблюдая маскирование границ. Это уменьшает «пустые» токены в вычислениях/памяти и хорошо сочетается с Flash‑Attention‑ядрами.
Где включить. В SFT‑пайплайнах HF есть packing (например, через ConstantLengthDataset/packing=True), плюс коллаторы для корректных масок и varlen‑ядра FA2.


Дополнительные приёмы экономии памяти
- Mixed precision (BF16/FP16) — меньше активаций/градиентов; на Ampere/Hopper это «база». См. страницу про смешанную точность.
- 8‑бит оптимизаторы/стейты и ZeRO/FSDP — уводят память оптимизатора/градиентов с одного GPU и/или шардинг по узлам. См. страницу по FSDP/DeepSpeed.
- Аллокатор PyTorch: уменьшение фрагментации через PYTORCH_CUDA_ALLOC_CONF (например, expandable_segments:True, max_split_size_mb=…) в ряде случаев улучшает «сборку» больших блоков. Используйте осознанно, тестируйте на своей версии PyTorch.
Мини‑рецепты под задачи
- Attention‑ядра. Современные реализации (cuDNN attention / FlashAttention‑2/3) дают основной прирост; они работают в BF16/FP16, а на H100 — и в FP8.
- KV‑кэш и длина контекста. Память и время растут с контекстом; смешанная точность снижает объём активаций и требования к HBM, но контролируйте стабильность при длинных L (наблюдайте перплексию/регрессии).
- Комбинации с FSDP/ZeRO. Смешанная точность хорошо сочетается с шардированием параметров/стейтов; выбирайте BF16 как «базу», добавляйте grad‑checkpointing. См. страницу по FSDP/DeepSpeed.
A) Максимально длинный контекст на одном A/H‑GPU (обучение)
B) Массовый инференс с длинными промптами
- Сервируйте через vLLM с PagedAttention.
- Включите prefix caching и FP8 KV‑кэш; валидируйте качество/стабильность под вашу модель.
Траблшутинг
- CUDA OOM при переменной длине батча. Попробуйте PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (снизить фрагментацию), уменьшите max_split_size_mb, проверьте, что SDPA выбирает нужное ядро.
- grad_norm = NaN после включения FA2. Проверьте версии PyTorch/FA/xFormers/HF; временно переключите attn_implementation на sdpa и сравните.
- Ошибка с checkpointing и кешем. На обучении убедитесь, что model.config.use_cache = False.
Навигация по смежным страницам
Точность/скорость
Масштабирование/шардирование: /solutions/llm-training/fsdp-deepspeed/, /solutions/multi-gpu/
Оптимизаторы и обучение: /solutions/llm-training/optimizers/, /solutions/llm-training/finetune-lora/
Инференс и квантизация/наблюдаемость: /solutions/llm-inference/, /solutions/llm-inference/vllm/, /solutions/llm-inference/quantization/