Решения

Оптимизация памяти LLM: Flash‑Attention, checkpointing, paged‑KV

Цель. Дать чёткие рецепты, как уместить модель/батч в доступную VRAM и повысить пропускную способность без деградации качества.

TL;DR

  • Flash‑Attention 2/3 снижает обращения к HBM за счёт тайлинга/фьюзинга и даёт ускорение с меньшим пиком памяти (FA3 добавляет Hopper‑специфику и FP8‑режимы). Рекомендуется как «дефолт» для внимания на Ampere/Hopper.
  • Gradient (activation) checkpointing меняет память на вычисления: хранит меньше активаций и пересчитывает их на бэкварде. Даёт крупную экономию VRAM, особенно на длинных контекстах.
  • Для инференса используйте PagedAttention/paged‑KV (vLLM): страничный KV‑кэш уменьшает фрагментацию и «паразитные» потери памяти; дополнительно поможет FP8‑квантизация KV и prefix caching общих префиксов.

Что именно «ест» вашу VRAM

  1. Параметры модели и стейты оптимизатора (решается FSDP/ZeRO/8‑bit оптимизаторами — см. отдельные страницы).
  2. Активации (каждый слой × длина контекста × батч): главный кандидат на экономию за счёт checkpointing и эффективного внимания.
  3. KV‑кэш на инференсе: растёт линейно с контекстом и числом одновременных запросов — оптимизируется через PagedAttention/квантизацию.

Flash‑Attention (FA2/FA3): меньше памяти на внимание

Зачем. Стандартное внимание держит большие временные тензоры (QKᵀ и пр.). Flash‑Attention пересчитывает их «плитками» и минимизирует HBM‑I/O, уменьшая пик памяти и ускоряя шаг.

Как включить в Transformers (рекомендуемый путь):

				
					from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3-8B",
    attn_implementation="flash_attention_2",  # FA2
    torch_dtype="bfloat16"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3-8B")

				
			

В современных версиях HF можно выбирать реализацию внимания через attn_implementation (например, sdpa, flash_attention_2).

Как контролировать бэкенд SDPA в «чистом» PyTorch:

				
					# Контекстный выбор ядра внимания
from torch.nn.attention import sdpa_kernel
with sdpa_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None)

				
			

PyTorch SDPA поддерживает FlashAttention‑2 и Memory‑Efficient Attention; выбор ядра влияет на память и скорость.

FA3 (Hopper/H100). Новая версия использует асинхронность Tensor Cores/TMA и FP8, повышая утилизацию на H100 (до ~1.2 PFLOP/s в FP8) при сохранении точности. Для обучения/инференса на H100 рекомендуется попробовать FA3‑совместимые стеки.

Gradient (Activation) Checkpointing: RAM ↔ Compute

Идея. Не хранить все активации для бэкварда, а пересчитать часть на обратном проходе. Это резко снижает пик памяти и позволяет увеличить batch_size/seq_len.

PyTorch (базово):

				
					import torch, torch.utils.checkpoint as cp

def block(x):
    # ваш крупный подграф: attention + MLP
    return model_block(x)

y = cp.checkpoint(block, x)  # активации внутри block не сохраняются

				
			

Transformers (упрощённо):

				
					model.gradient_checkpointing_enable()
model.config.use_cache = False  # важно при обучении с checkpointing

				
			

use_cache=True (хранение past_key_values) конфликтует с gradient checkpointing на обучении — отключайте.

DeepSpeed‑расширения. Есть варианты partition activations, CPU‑checkpointing и offload — помогают в «крайних» случаях, когда VRAM критически мала.

Paged‑KV для инференса: длинный контекст без «раздувания»

Проблема. KV‑кэш растёт с контекстом и числом запросов, создавая фрагментацию и «мертвое» место.
Решение. PagedAttention (vLLM) разбивает KV‑кэш на блоки‑страницы и управляет ими как «виртуальной памятью», почти устраняя потери и повышая throughput при длинных контекстах и батчинге.

Практика (vLLM):

				
					vllm serve meta-llama/Llama-3-8B \
  --max-model-len 32768 \
  --enable-prefix-caching \
  --kv-cache-dtype fp8

				
			
  • —enable-prefix-caching даёт автокэш префиксов (общие подсказки/системные промпты не пересчитываются заново).
  • —kv-cache-dtype fp8 уменьшает объём KV‑кэша и увеличивает число одновременно обслуживаемых запросов; для FP8 есть нюансы форматов (e4m3/e5m2).

Flash‑Attention + «пэкинг» последовательностей: меньше паддинга — меньше памяти

Суть. Собирать несколько коротких примеров в один фиксированный блок токенов (packing) вместо классического паддинга до макс‑длины батча, соблюдая маскирование границ. Это уменьшает «пустые» токены в вычислениях/памяти и хорошо сочетается с Flash‑Attention‑ядрами.

Где включить. В SFT‑пайплайнах HF есть packing (например, через ConstantLengthDataset/packing=True), плюс коллаторы для корректных масок и varlen‑ядра FA2.

Дополнительные приёмы экономии памяти

  • Mixed precision (BF16/FP16) — меньше активаций/градиентов; на Ampere/Hopper это «база». См. страницу про смешанную точность.
  • 8‑бит оптимизаторы/стейты и ZeRO/FSDP — уводят память оптимизатора/градиентов с одного GPU и/или шардинг по узлам. См. страницу по FSDP/DeepSpeed.
  • Аллокатор PyTorch: уменьшение фрагментации через PYTORCH_CUDA_ALLOC_CONF (например, expandable_segments:True, max_split_size_mb=…) в ряде случаев улучшает «сборку» больших блоков. Используйте осознанно, тестируйте на своей версии PyTorch.

Мини‑рецепты под задачи

  • Attention‑ядра. Современные реализации (cuDNN attention / FlashAttention‑2/3) дают основной прирост; они работают в BF16/FP16, а на H100 — и в FP8.
  • KV‑кэш и длина контекста. Память и время растут с контекстом; смешанная точность снижает объём активаций и требования к HBM, но контролируйте стабильность при длинных L (наблюдайте перплексию/регрессии).
  • Комбинации с FSDP/ZeRO. Смешанная точность хорошо сочетается с шардированием параметров/стейтов; выбирайте BF16 как «базу», добавляйте grad‑checkpointing. См. страницу по FSDP/DeepSpeed.

A) Максимально длинный контекст на одном A/H‑GPU (обучение)

  1. Включите BF16 + Flash‑Attention‑2 (или FA3 на H100).
  2. Включите gradient checkpointing и отключите use_cache.
  3. Примените packing или сортировку по длинам (bucket by length), чтобы уменьшить паддинг.

B) Массовый инференс с длинными промптами

  1. Сервируйте через vLLM с PagedAttention.
  2. Включите prefix caching и FP8 KV‑кэш; валидируйте качество/стабильность под вашу модель.

Траблшутинг

  • CUDA OOM при переменной длине батча. Попробуйте PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (снизить фрагментацию), уменьшите max_split_size_mb, проверьте, что SDPA выбирает нужное ядро.
  • grad_norm = NaN после включения FA2. Проверьте версии PyTorch/FA/xFormers/HF; временно переключите attn_implementation на sdpa и сравните.
  • Ошибка с checkpointing и кешем. На обучении убедитесь, что model.config.use_cache = False.

Навигация по смежным страницам

Точность/скорость

Масштабирование/шардирование: /solutions/llm-training/fsdp-deepspeed/, /solutions/multi-gpu/

Оптимизаторы и обучение: /solutions/llm-training/optimizers/, /solutions/llm-training/finetune-lora/

Инференс и квантизация/наблюдаемость: /solutions/llm-inference/, /solutions/llm-inference/vllm/, /solutions/llm-inference/quantization/