Решения

Обучение LLM на JAX/XLA: планирование и память

ому подойдёт: тем, кто хочет тренировать/дообучать большие модели на NVIDIA GPU с минимальными накладными на фреймворк и максимальной отдачей от XLA‑компиляции. Мы используем стек JAX → XLA → PJRT (рантайм OpenXLA), а на уровне нейросетей — Flax (модули), Optax (оптимизаторы), Orbax (чекпоинты).

Мини‑обзор стека и когда JAX уместен

Стек:

  • JAX — трансформации jit/vmap/grad/scan и единая модель массивов. Запуск на GPU/TPU через XLA и PJRT.
  • Flax — лаконичные модули linen.*, готовые приёмы SPMD и аннотации шардирования.
  • Optax — набор оптимизаторов и «градиентных трансформаций» (можно свободно комбинировать).
  • Orbax — чекпоинты в мульти‑хост/мульти‑девайс окружениях, async‑сейвы.

Когда выбирать JAX:

  • Нужны агрессивные оптимизации XLA и скорость инфры при простом Python‑коде.

Хотите SPMD‑распараллеливание с декларативными аннотациями шардирования (без ручной «обвязки» коллективных операций).

Быстрый старт: JAX + Flax + Optax (+ Orbax)

Зачем. Стандартное внимание держит большие временные тензоры (QKᵀ и пр.). Flash‑Attention пересчитывает их «плитками» и минимизирует HBM‑I/O, уменьшая пик памяти и ускоряя шаг.

Как включить в Transformers (рекомендуемый путь):

				
					import jax, jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

class MLP(nn.Module):
    features: int
    @nn.compact
    def __call__(self, x, train: bool = True):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        x = nn.Dense(self.features)(x)
        return x

def loss_fn(params, x, y):
    preds = MLP(1024).apply({'params': params}, x, train=True)
    return ((preds - y)**2).mean()

@jax.jit  # компиляция XLA
def train_step(state, x, y):
    grads = jax.grad(loss_fn)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state

				
			
  • Optax подключается как tx = optax.adamw(1e-3) и state = train_state.TrainState.create(params=…, apply_fn=…, tx=tx). Идея «градиентных трансформаций» позволяет легко собирать пайплайн оптимизатора.
  • Для чекпоинтов используйте Orbax (orbax.checkpoint.Checkpointer / AsyncCheckpointer) — он покрывает разные форматы, асинхронные сейвы и мульти‑хост.

От pmap/pjit к современному jit с шардированием

В актуальных гайдах Flax показано, что масштабирование выполняется через jax.jit с in_shardings/out_shardings (+ аннотации на массивах), что исторически эволюционировало из pjit. Рабочая модель — SPMD: вы задаёте как шардируются входы/выходы, а XLA сам разводит остальное и вставляет коллективы.

Ключевые сущности:

  • Mesh — логическая сетка девайсов (напр. (‘data’,’model’)).
  • PartitionSpec / NamedSharding — как делить размерности массивов по осям Mesh.

Пример: объявляем Mesh и шардируем батч по оси ‘data’, а параметры — по оси ‘model’.

				
					from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((2, 4))  # 8 GPU как (data=2, model=4)
mesh = Mesh(devices, ('data','model'))

def S(pspec): return NamedSharding(mesh, pspec)

@jax.jit(
  in_shardings=(S(P('data', None)),),              # вход: [batch, dim] шардим по data
  out_shardings=S(P('data', None))                 # выход тоже по data
)
def forward(x, params):
  # при необходимости можно зафиксировать шардирование промежуточных тензоров:
  from jax.lax import with_sharding_constraint as wsc
  x = wsc(x, S(P('data', None)))
  ...
  return x

				
			

Справки и тонкости: различия pmap/pjit, jit с шардированием и рекомендации по согласованию in_shardings/out_shardings для конвейеров.

Data‑parallel, model‑parallel, mixed

  • Data parallel: делим батч по ‘data’, параметры/оптимайзер реплицируем (или шардируем частично).
  • Model parallel: шардируем параметры/активации по ‘model’ (весы слоёв получают P(None,’model’), батч — P(‘data’,None)).

Комбинации: SPMD позволяет смешивать стратегии на разных слоях. Эти приёмы опираются на GSPMD (ядро авто‑распараллеливания в OpenXLA).

Мульти‑хост тренинг (кластер из нескольких узлов)

Для обучения на нескольких серверах (каждый — с набором GPU) используйте jax.distributed.initialize() и гайды по multi‑process/multi‑host. Там описано, как процессы «находят» друг друга, делятся топологией и как организуется распределённый чекпоинтинг.

				
					import jax
jax.distributed.initialize()  # вызывать ДО любых вычислений JAX
world = jax.process_count()
rank  = jax.process_index()
print(world, rank)

				
			

Полезно знать jax.process_count(), jax.process_index(), jax.local_devices().

Память: remat/gradient checkpointing, donation, микро‑батчи

Gradient checkpointing: оборачиваем тяжёлые блоки в jax.checkpoint (aka jax.remat) или используем flax.linen.remat — экономим VRAM ценой повторных вычислений на бэкварде.

				
					from flax import linen as nn
from flax.linen import remat  # lifted-версия remat

class Block(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(4096)(x); x = nn.gelu(x)
        return x

class BigModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        f = remat(Block)  # checkpoint блока
        for _ in range(12):
            x = f()(x)
        return x

				
			

Donation: указываем donate_argnums в jax.jit, чтобы XLA мог «перераспользовать» входные буферы и не делать лишних копий.

				
					@jax.jit(donate_argnums=(0,))   # donate state
def train_step(state, batch):
    ...
    return new_state

				
			

Градиентная аккумуляция: используйте lax.scan для микро‑батчей — помогает уместить крупные модели в память при фиксированном глобальном батче.

I/O и подкачка на устройство

Чтобы не простаивать на загрузках, предвыбирайте батчи на GPU через flax.jax_utils.prefetch_to_device(iterator, size=2). Это перекрывает копирование и вычисление, что особенно полезно на GPU.

Чекпоинты, возобновление и совместимость

Рекомендуем Orbax Checkpointing (в том числе async и восстановление сложных pytrees; поддерживаются мульти‑хост сценарии). Есть готовые рецепты «сохранить/восстановить несколько объектов» (параметры, state оптимизатора, метрики и т.п.).

Профилирование, дебаг шардирования и кэш компиляции

  • Инспекция шардирования внутри jit‑функций: jax.debug.inspect_array_sharding(value, callback=print) — удобно проверять, действительно ли тензор нарезан как задумано.
  • Кэш компиляции JAX: выставьте JAX_COMPILATION_CACHE_DIR/jax_compilation_cache_dir, чтобы ускорить последующие запуски/рестарты.

Продвинутые темы

  • Кастомные GPU‑ядра через Pallas — когда нужны «фьюзы»/нестандартные операции (JAX→Triton на GPU). Подходит для узких мест в LLM (например, спец. attention‑варианты).
  • Отладка планов шардирования и производительности — читайте разделы SPMD/mesh в Flax‑гайдах и обсуждения по объединению pjit→jit.

Рекомендации по выбору GPU (кратко)

  • Дообучение/LoRA небольших/средних LLM — достаточно одной «H‑серии/А‑серии» с 24–48 GB VRAM.
  • Полное обучение / крупные контексты — планируйте multi‑GPU с Mesh ((‘data’,’model’)) и распределённые чекпоинты через Orbax.
  • Для потоков данных: держите I/O конвейер (prefetch, pinned‑host) и избегайте рекомпиляций (следите за статическими аргументами и формами).

Готовые шаблоны запуска на cloudcompute.ru

  • JAX + Flax + Optax (Jupyter/SSH/Docker) — базовый образ с CUDA, cuDNN, JAX, Flax, Optax, Orbax, примеры multi‑GPU и remat. См. /solutions/templates/.
  • LLM‑фокус (SPMD) — образ с примерами Mesh/PartitionSpec, inspect_array_sharding, кэшем компиляции и чекпоинт‑пайплайном. См. /solutions/templates/.

Частые ошибки и как их избежать

  • Плавающие «статические» аргументы вызывают рекомпиляции. Помечайте их static_argnums/static_argnames и держите в узком наборе значений.
  • Несогласованное шардирование между стадиями — следите, чтобы out_shardings предыдущей функции совпадали с in_shardings следующей.
  • Версии CUDA/JAX/XLA — при несовпадении бэкенд может не подняться; используйте проверенные связки (см. релиз‑ноуты контейнеров NVIDIA/JAX).

Чек‑лист перед стартом

  1. Убедитесь, что видите нужные GPU: jax.devices(‘gpu’).
  2. Включите кэш компиляции (директория на быстрый диск).
  3. Соберите TrainState + Optax, добавьте Orbax чекпоинты.
  4. Задайте Mesh и PartitionSpec, проверьте inspect_array_sharding.
  5. Включите prefetch_to_device(size=2) и при необходимости remat/donation.