Решения
Обучение LLM на JAX/XLA: планирование и память
ому подойдёт: тем, кто хочет тренировать/дообучать большие модели на NVIDIA GPU с минимальными накладными на фреймворк и максимальной отдачей от XLA‑компиляции. Мы используем стек JAX → XLA → PJRT (рантайм OpenXLA), а на уровне нейросетей — Flax (модули), Optax (оптимизаторы), Orbax (чекпоинты).
Мини‑обзор стека и когда JAX уместен
Стек:
- JAX — трансформации jit/vmap/grad/scan и единая модель массивов. Запуск на GPU/TPU через XLA и PJRT.
- Flax — лаконичные модули linen.*, готовые приёмы SPMD и аннотации шардирования.
- Optax — набор оптимизаторов и «градиентных трансформаций» (можно свободно комбинировать).
- Orbax — чекпоинты в мульти‑хост/мульти‑девайс окружениях, async‑сейвы.
Когда выбирать JAX:
- Нужны агрессивные оптимизации XLA и скорость инфры при простом Python‑коде.
Хотите SPMD‑распараллеливание с декларативными аннотациями шардирования (без ручной «обвязки» коллективных операций).
Быстрый старт: JAX + Flax + Optax (+ Orbax)
Зачем. Стандартное внимание держит большие временные тензоры (QKᵀ и пр.). Flash‑Attention пересчитывает их «плитками» и минимизирует HBM‑I/O, уменьшая пик памяти и ускоряя шаг.
Как включить в Transformers (рекомендуемый путь):
import jax, jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
class MLP(nn.Module):
features: int
@nn.compact
def __call__(self, x, train: bool = True):
x = nn.Dense(self.features)(x)
x = nn.relu(x)
x = nn.Dense(self.features)(x)
return x
def loss_fn(params, x, y):
preds = MLP(1024).apply({'params': params}, x, train=True)
return ((preds - y)**2).mean()
@jax.jit # компиляция XLA
def train_step(state, x, y):
grads = jax.grad(loss_fn)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state
- Optax подключается как tx = optax.adamw(1e-3) и state = train_state.TrainState.create(params=…, apply_fn=…, tx=tx). Идея «градиентных трансформаций» позволяет легко собирать пайплайн оптимизатора.
- Для чекпоинтов используйте Orbax (orbax.checkpoint.Checkpointer / AsyncCheckpointer) — он покрывает разные форматы, асинхронные сейвы и мульти‑хост.

От pmap/pjit к современному jit с шардированием

В актуальных гайдах Flax показано, что масштабирование выполняется через jax.jit с in_shardings/out_shardings (+ аннотации на массивах), что исторически эволюционировало из pjit. Рабочая модель — SPMD: вы задаёте как шардируются входы/выходы, а XLA сам разводит остальное и вставляет коллективы.
Ключевые сущности:
- Mesh — логическая сетка девайсов (напр. (‘data’,’model’)).
- PartitionSpec / NamedSharding — как делить размерности массивов по осям Mesh.
Пример: объявляем Mesh и шардируем батч по оси ‘data’, а параметры — по оси ‘model’.
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((2, 4)) # 8 GPU как (data=2, model=4)
mesh = Mesh(devices, ('data','model'))
def S(pspec): return NamedSharding(mesh, pspec)
@jax.jit(
in_shardings=(S(P('data', None)),), # вход: [batch, dim] шардим по data
out_shardings=S(P('data', None)) # выход тоже по data
)
def forward(x, params):
# при необходимости можно зафиксировать шардирование промежуточных тензоров:
from jax.lax import with_sharding_constraint as wsc
x = wsc(x, S(P('data', None)))
...
return x
Справки и тонкости: различия pmap/pjit, jit с шардированием и рекомендации по согласованию in_shardings/out_shardings для конвейеров.
Data‑parallel, model‑parallel, mixed
- Data parallel: делим батч по ‘data’, параметры/оптимайзер реплицируем (или шардируем частично).
- Model parallel: шардируем параметры/активации по ‘model’ (весы слоёв получают P(None,’model’), батч — P(‘data’,None)).
Комбинации: SPMD позволяет смешивать стратегии на разных слоях. Эти приёмы опираются на GSPMD (ядро авто‑распараллеливания в OpenXLA).
Мульти‑хост тренинг (кластер из нескольких узлов)
Для обучения на нескольких серверах (каждый — с набором GPU) используйте jax.distributed.initialize() и гайды по multi‑process/multi‑host. Там описано, как процессы «находят» друг друга, делятся топологией и как организуется распределённый чекпоинтинг.
import jax
jax.distributed.initialize() # вызывать ДО любых вычислений JAX
world = jax.process_count()
rank = jax.process_index()
print(world, rank)
Полезно знать jax.process_count(), jax.process_index(), jax.local_devices().


Память: remat/gradient checkpointing, donation, микро‑батчи
Gradient checkpointing: оборачиваем тяжёлые блоки в jax.checkpoint (aka jax.remat) или используем flax.linen.remat — экономим VRAM ценой повторных вычислений на бэкварде.
from flax import linen as nn
from flax.linen import remat # lifted-версия remat
class Block(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(4096)(x); x = nn.gelu(x)
return x
class BigModel(nn.Module):
@nn.compact
def __call__(self, x):
f = remat(Block) # checkpoint блока
for _ in range(12):
x = f()(x)
return x
Donation: указываем donate_argnums в jax.jit, чтобы XLA мог «перераспользовать» входные буферы и не делать лишних копий.
@jax.jit(donate_argnums=(0,)) # donate state
def train_step(state, batch):
...
return new_state
Градиентная аккумуляция: используйте lax.scan для микро‑батчей — помогает уместить крупные модели в память при фиксированном глобальном батче.
I/O и подкачка на устройство
Чтобы не простаивать на загрузках, предвыбирайте батчи на GPU через flax.jax_utils.prefetch_to_device(iterator, size=2). Это перекрывает копирование и вычисление, что особенно полезно на GPU.
Чекпоинты, возобновление и совместимость
Рекомендуем Orbax Checkpointing (в том числе async и восстановление сложных pytrees; поддерживаются мульти‑хост сценарии). Есть готовые рецепты «сохранить/восстановить несколько объектов» (параметры, state оптимизатора, метрики и т.п.).
Профилирование, дебаг шардирования и кэш компиляции
- Инспекция шардирования внутри jit‑функций: jax.debug.inspect_array_sharding(value, callback=print) — удобно проверять, действительно ли тензор нарезан как задумано.
- Кэш компиляции JAX: выставьте JAX_COMPILATION_CACHE_DIR/jax_compilation_cache_dir, чтобы ускорить последующие запуски/рестарты.
Продвинутые темы
- Кастомные GPU‑ядра через Pallas — когда нужны «фьюзы»/нестандартные операции (JAX→Triton на GPU). Подходит для узких мест в LLM (например, спец. attention‑варианты).
- Отладка планов шардирования и производительности — читайте разделы SPMD/mesh в Flax‑гайдах и обсуждения по объединению pjit→jit.
Рекомендации по выбору GPU (кратко)
- Дообучение/LoRA небольших/средних LLM — достаточно одной «H‑серии/А‑серии» с 24–48 GB VRAM.
- Полное обучение / крупные контексты — планируйте multi‑GPU с Mesh ((‘data’,’model’)) и распределённые чекпоинты через Orbax.
- Для потоков данных: держите I/O конвейер (prefetch, pinned‑host) и избегайте рекомпиляций (следите за статическими аргументами и формами).
Готовые шаблоны запуска на cloudcompute.ru
- JAX + Flax + Optax (Jupyter/SSH/Docker) — базовый образ с CUDA, cuDNN, JAX, Flax, Optax, Orbax, примеры multi‑GPU и remat. См. /solutions/templates/.
- LLM‑фокус (SPMD) — образ с примерами Mesh/PartitionSpec, inspect_array_sharding, кэшем компиляции и чекпоинт‑пайплайном. См. /solutions/templates/.
Частые ошибки и как их избежать
- Плавающие «статические» аргументы вызывают рекомпиляции. Помечайте их static_argnums/static_argnames и держите в узком наборе значений.
- Несогласованное шардирование между стадиями — следите, чтобы out_shardings предыдущей функции совпадали с in_shardings следующей.
- Версии CUDA/JAX/XLA — при несовпадении бэкенд может не подняться; используйте проверенные связки (см. релиз‑ноуты контейнеров NVIDIA/JAX).
Навигация по связанным разделам
Чек‑лист перед стартом