Обучение LLM на JAX/XLA: планирование и память

ому подойдёт: тем, кто хочет тренировать/дообучать большие модели на NVIDIA GPU с минимальными накладными на фреймворк и максимальной отдачей от XLA‑компиляции. Мы используем стек JAX → XLA → PJRT (рантайм OpenXLA), а на уровне нейросетей — Flax (модули), Optax (оптимизаторы), Orbax (чекпоинты).

Мини‑обзор стека и когда JAX уместен

Стек:

  • JAX — трансформации jit/vmap/grad/scan и единая модель массивов. Запуск на GPU/TPU через XLA и PJRT.
  • Flax — лаконичные модули linen.*, готовые приёмы SPMD и аннотации шардирования.
  • Optax — набор оптимизаторов и «градиентных трансформаций» (можно свободно комбинировать).
  • Orbax — чекпоинты в мульти‑хост/мульти‑девайс окружениях, async‑сейвы.

Когда выбирать JAX:

  • Нужны агрессивные оптимизации XLA и скорость инфры при простом Python‑коде.

Хотите SPMD‑распараллеливание с декларативными аннотациями шардирования (без ручной «обвязки» коллективных операций).

Быстрый старт: JAX + Flax + Optax (+ Orbax)

Зачем. Стандартное внимание держит большие временные тензоры (QKᵀ и пр.). Flash‑Attention пересчитывает их «плитками» и минимизирует HBM‑I/O, уменьшая пик памяти и ускоряя шаг.

Как включить в Transformers (рекомендуемый путь):

import jax, jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
class MLP(nn.Module):
 features: int
 @nn.compact
 def __call__(self, x, train: bool = True):
 x = nn.Dense(self.features)(x)
 x = nn.relu(x)
 x = nn.Dense(self.features)(x)
 return x
def loss_fn(params, x, y):
 preds = MLP(1024).apply({'params': params}, x, train=True)
 return ((preds - y)**2).mean()
@jax.jit # компиляция XLA
def train_step(state, x, y):
 grads = jax.grad(loss_fn)(state.params, x, y)
 state = state.apply_gradients(grads=grads)
 return state
  • Optax подключается как tx = optax.adamw(1e-3) и state = train_state.TrainState.create(params=..., apply_fn=..., tx=tx). Идея «градиентных трансформаций» позволяет легко собирать пайплайн оптимизатора.
  • Для чекпоинтов используйте Orbax (orbax.checkpoint.Checkpointer / AsyncCheckpointer) — он покрывает разные форматы, асинхронные сейвы и мульти‑хост. От pmap/pjit к современному jit с шардированием В актуальных гайдах Flax показано, что масштабирование выполняется через jax.jit с in_shardings/out_shardings (+ аннотации на массивах), что исторически эволюционировало из pjit. Рабочая модель — SPMD: вы задаёте как шардируются входы/выходы, а XLA сам разводит остальное и вставляет коллективы.

Ключевые сущности:

  • Mesh — логическая сетка девайсов (напр. ('data','model')).
  • PartitionSpec / NamedSharding — как делить размерности массивов по осям Mesh.

Пример: объявляем Mesh и шардируем батч по оси 'data', а параметры — по оси 'model'.

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((2, 4)) # 8 GPU как (data=2, model=4)
mesh = Mesh(devices, ('data','model'))
def S(pspec): return NamedSharding(mesh, pspec)
@jax.jit(
 in_shardings=(S(P('data', None)),), # вход: [batch, dim] шардим по data
 out_shardings=S(P('data', None)) # выход тоже по data
)
def forward(x, params):
 # при необходимости можно зафиксировать шардирование промежуточных тензоров:
 from jax.lax import with_sharding_constraint as wsc
 x = wsc(x, S(P('data', None)))
 ...
 return x

Справки и тонкости: различия pmap/pjit, jit с шардированием и рекомендации по согласованию in_shardings/out_shardings для конвейеров.

Data‑parallel, model‑parallel, mixed

  • Data parallel: делим батч по 'data', параметры/оптимайзер реплицируем (или шардируем частично).
  • Model parallel: шардируем параметры/активации по 'model' (весы слоёв получают P(None,'model'), батч — P('data',None)).

Комбинации: SPMD позволяет смешивать стратегии на разных слоях. Эти приёмы опираются на GSPMD (ядро авто‑распараллеливания в OpenXLA).

Мульти‑хост тренинг (кластер из нескольких узлов)

Для обучения на нескольких серверах (каждый — с набором GPU) используйте jax.distributed.initialize() и гайды по multi‑process/multi‑host. Там описано, как процессы «находят» друг друга, делятся топологией и как организуется распределённый чекпоинтинг.

import jax
jax.distributed.initialize() # вызывать ДО любых вычислений JAX
world = jax.process_count()
rank = jax.process_index()
print(world, rank)

Полезно знать jax.process_count(), jax.process_index(), jax.local_devices(). Память: remat/gradient checkpointing, donation, микро‑батчи

Gradient checkpointing: оборачиваем тяжёлые блоки в jax.checkpoint (aka jax.remat) или используем flax.linen.remat — экономим VRAM ценой повторных вычислений на бэкварде.

from flax import linen as nn
from flax.linen import remat # lifted-версия remat
class Block(nn.Module):
 @nn.compact
 def __call__(self, x):
 x = nn.Dense(4096)(x); x = nn.gelu(x)
 return x
class BigModel(nn.Module):
 @nn.compact
 def __call__(self, x):
 f = remat(Block) # checkpoint блока
 for _ in range(12):
 x = f()(x)
 return x

Donation: указываем donate_argnums в jax.jit, чтобы XLA мог «перераспользовать» входные буферы и не делать лишних копий.

@jax.jit(donate_argnums=(0,)) # donate state
def train_step(state, batch):
 ...
 return new_state

Градиентная аккумуляция: используйте lax.scan для микро‑батчей — помогает уместить крупные модели в память при фиксированном глобальном батче.

I/O и подкачка на устройство

Чтобы не простаивать на загрузках, предвыбирайте батчи на GPU через flax.jax_utils.prefetch_to_device(iterator, size=2). Это перекрывает копирование и вычисление, что особенно полезно на GPU.

Чекпоинты, возобновление и совместимость

Рекомендуем Orbax Checkpointing (в том числе async и восстановление сложных pytrees; поддерживаются мульти‑хост сценарии). Есть готовые рецепты «сохранить/восстановить несколько объектов» (параметры, state оптимизатора, метрики и т.п.).

Профилирование, дебаг шардирования и кэш компиляции

  • Инспекция шардирования внутри jit‑функций: jax.debug.inspect_array_sharding(value, callback=print) — удобно проверять, действительно ли тензор нарезан как задумано.
  • Кэш компиляции JAX: выставьте JAX_COMPILATION_CACHE_DIR/jax_compilation_cache_dir, чтобы ускорить последующие запуски/рестарты.

Продвинутые темы

  • Кастомные GPU‑ядра через Pallas — когда нужны «фьюзы»/нестандартные операции (JAX→Triton на GPU). Подходит для узких мест в LLM (например, спец. attention‑варианты).
  • Отладка планов шардирования и производительности — читайте разделы SPMD/mesh в Flax‑гайдах и обсуждения по объединению pjit→jit.

Рекомендации по выбору GPU (кратко)

  • Дообучение/LoRA небольших/средних LLM — достаточно одной «H‑серии/А‑серии» с 24–48 GB VRAM.
  • Полное обучение / крупные контексты — планируйте multi‑GPU с Mesh (('data','model')) и распределённые чекпоинты через Orbax.
  • Для потоков данных: держите I/O конвейер (prefetch, pinned‑host) и избегайте рекомпиляций (следите за статическими аргументами и формами).

Готовые шаблоны запуска на cloudcompute.ru

  • JAX + Flax + Optax (Jupyter/SSH/Docker) — базовый образ с CUDA, cuDNN, JAX, Flax, Optax, Orbax, примеры multi‑GPU и remat. См. /solutions/templates/.
  • LLM‑фокус (SPMD) — образ с примерами Mesh/PartitionSpec, inspect_array_sharding, кэшем компиляции и чекпоинт‑пайплайном. См. /solutions/templates/.

Частые ошибки и как их избежать

  • Плавающие «статические» аргументы вызывают рекомпиляции. Помечайте их static_argnums/static_argnames и держите в узком наборе значений.
  • Несогласованное шардирование между стадиями — следите, чтобы out_shardings предыдущей функции совпадали с in_shardings следующей.
  • Версии CUDA/JAX/XLA — при несовпадении бэкенд может не подняться; используйте проверенные связки (см. релиз‑ноуты контейнеров NVIDIA/JAX).

Навигация по связанным разделам

LLM обучение (база)

Предобучение

LoRA/QLoRA

RLHF/DPO

Оптимизаторы

Смешанная точность

Оптимизация памяти

Чекпоинтинг

Распределённый I/O

Чек‑лист перед стартом

  1. Убедитесь, что видите нужные GPU: jax.devices('gpu').
  2. Включите кэш компиляции (директория на быстрый диск).
  3. Соберите TrainState + Optax, добавьте Orbax чекпоинты.
  4. Задайте Mesh и PartitionSpec, проверьте inspect_array_sharding.
  5. Включите prefetch_to_device(size=2) и при необходимости remat/donation.

Готовы запустить?

Запустить GPU-сервер