Skip to content

Data Parallelism

Data parallelism (DP) replicates training across multiple GPUs, each processing a distinct mini-batch. xorl supports three DP modes — DDP, FSDP2, and HSDP — controlled by a single data_parallel_mode field. All DP modes compose freely with tensor parallelism (TP), pipeline parallelism (PP), expert parallelism (EP), and context parallelism (SP/CP).

The global world size constraint is:

world_size = PP × DP_shard × DP_replicate × TP × CP_ring × CP_ulysses

ModeMemory savingsCross-node bandwidthBest for
ddpNone (full params on each GPU)All-reduce gradientsSmall models, prototyping
fsdp2Params + grads + optimizer statesAll-gather + reduce-scatterSingle-node or small clusters
fsdp2 + HSDPSame as FSDP2 within a nodeAll-gather + reduce-scatter within node; all-reduce across nodesMulti-node (reduces expensive cross-node traffic)
DDPFull copy per GPUGPU 0full W, G, OGPU 1full W, G, Oall-reduce grads18P bytes / GPUno memory savingsFSDP2Sharded params+grads+optGPU 0W[0:N/2] G OGPU 1W[N/2:N] G Oall-gather (fwd) / reduce-scatter (bwd)18P/N bytes / GPUN = shard group sizeHSDPShard within node, replicate acrossNode 0 shard (NVLink)Node 1 shard (NVLink)IBall-reducereduce-scatter within nodeall-reduce across nodes18P/N bytes / GPUminimizes cross-node traffic Data Parallelism: Data DistributionGlobal batch BRank 0 [0, B/4)Rank 1 [B/4, B/2)Rank 2 [B/2, 3B/4)Rank 3 [3B/4, B)GPU 0GPU 1GPU 2GPU 3Each DP rank processes an exclusive data shard. Gradients are synchronized after backward.

DDP wraps the model with PyTorch’s DistributedDataParallel. Every GPU holds a complete copy of the model in fp32 (or bf16 with mixed precision). During the backward pass, DDP fires an all-reduce over each parameter gradient bucket to average gradients across all ranks.

xorl applies DDP via torch.nn.parallel.DistributedDataParallel with device_ids=[local_rank]. When enable_mixed_precision: true, a MixedPrecision policy is applied:

MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # gradients reduced in fp32
buffer_dtype=torch.bfloat16,
)
train:
data_parallel_mode: ddp
data_parallel_shard_size: 8 # all 8 GPUs are in one DDP group
data_parallel_replicate_size: 1
enable_mixed_precision: true
init_device: cuda # DDP requires cuda or npu, not meta

Note: init_device: meta is not supported for DDP. Use init_device: cuda.

  • Model fits comfortably in GPU memory on its own (typically < 3–4B params on an 80 GB GPU with bf16 + optimizer states).
  • You want simple semantics and do not need memory savings.
  • Prototyping or debugging where FSDP complexity is undesirable.
  • Full parameter copy on every GPU: no memory reduction.
  • Gradient all-reduce scales with model size, not local shard size, so communication cost grows linearly with parameters.
  • Cannot be combined with init_device: meta (no lazy parameter materialization).

2. FSDP2 — Fully Sharded Data Parallel (ZeRO-3)

Section titled “2. FSDP2 — Fully Sharded Data Parallel (ZeRO-3)”

FSDP2 is xorl’s primary training strategy. It shards parameters, gradients, and optimizer states across all ranks in the DP group, providing ZeRO-3-equivalent memory savings.

AspectFSDP1 (FullyShardedDataParallel)FSDP2 (fully_shard composable API)
API styleModule wrapperPer-module composable annotation
DTensor integrationLimitedNative — sharded params are DTensor objects
Composition with TP/EPDifficultFirst-class via device mesh
Grad hooksImplicit via _post_backward_hookClean reduce-scatter on DTensor
Mixed precisionMixedPrecision wrapper classMixedPrecisionPolicy dataclass
Min PyTorch version1.132.4

xorl uses the composable fully_shard API exclusively and requires PyTorch >= 2.4.

fully_shard is applied at two granularities:

  1. Per decoder block — each transformer layer (e.g., Qwen3DecoderLayer) is wrapped independently. This keeps the all-gather scope to one layer’s parameters at a time, minimizing peak memory.
  2. Root model — the entire model is wrapped last, so FSDP2 handles the embedding and lm_head as a separate unit.

As a torchtitan-inspired optimization, the final norm + lm_head modules are grouped into a single FSDP unit with reshard_after_forward=False. This lets lm_head remain gathered after the last decoder layer’s forward, avoiding a redundant all-gather when the externally called compute_loss accesses lm_head.weight.

The enable_full_shard config field maps to FSDP2’s sharding behavior:

enable_full_shardFSDP2 behaviorEquivalent ZeRO level
true (default)Shard params + grads + optimizer statesZeRO-3
falseNo parameter sharding (only grad/optimizer state savings)ZeRO-1/2 style

In practice, always set enable_full_shard: true for large models. Setting it to false is primarily useful for debugging or models small enough that parameter memory is not the bottleneck.

train:
data_parallel_mode: fsdp2
data_parallel_shard_size: 8 # all 8 GPUs share parameters
data_parallel_replicate_size: 1
enable_full_shard: true
enable_mixed_precision: true
init_device: meta # required for FSDP2
load_weights_mode: broadcast # rank0 loads, broadcasts to all ranks

init_device: meta is required for FSDP2. Parameters are initially created on the meta device (zero-cost), then materialized by FSDP2 after fully_shard is applied.


HSDP is activated automatically when both data_parallel_shard_size > 1 and data_parallel_replicate_size > 1. It creates a 2D device mesh:

mesh[dp_replicate, dp_shard]
  • The dp_shard dimension (inner) performs full ZeRO-3 sharding via all-gather + reduce-scatter. All GPUs in a shard group collectively hold one copy of the model.
  • The dp_replicate dimension (outer) replicates the model across shard groups. Gradients are all-reduced across shard group replicas after the reduce-scatter.

For 16 GPUs with data_parallel_shard_size=8, data_parallel_replicate_size=2:

Node 0: GPUs [0,1,2,3,4,5,6,7] → dp_shard group 0 (shard group A)
Node 1: GPUs [8,9,10,11,12,13,14,15] → dp_shard group 1 (shard group B)
dp_replicate groups: {GPU 0, GPU 8}, {GPU 1, GPU 9}, ..., {GPU 7, GPU 15}

All-gather and reduce-scatter happen within each node (high-bandwidth NVLink). The cross-node gradient synchronization (all-reduce across the replicate dimension) only carries the averaged gradients, which is much smaller than synchronizing full shards.

init_parallel_state in parallel_state.py constructs a multi-dimensional DeviceMesh that includes both dp_replicate and dp_shard axes. The FSDP mesh is then sliced as:

# HSDP: 2D mesh (dp_replicate, dp_shard)
fsdp_mesh = device_mesh["dp_replicate", "dp_shard"]
# Pure FSDP2: 1D mesh (dp_shard only)
fsdp_mesh = device_mesh["dp_shard"]

FSDP2’s fully_shard(module, mesh=fsdp_mesh) interprets a 2D mesh as HSDP automatically — shard along the inner dim, replicate across the outer dim.

  • Training on multiple nodes where cross-node bandwidth (InfiniBand, RoCE) is significantly slower than intra-node bandwidth (NVLink).
  • Model parameters fit within a single node’s GPU memory when sharded (i.e., model_params / dp_shard_size fits per GPU).
  • You want to minimize cross-node communication by confining all-gather and reduce-scatter to fast intra-node links.
# 2-node × 8 GPU = 16 GPUs: shard within node, replicate across nodes
train:
data_parallel_mode: fsdp2
data_parallel_shard_size: 8 # one full shard group per node
data_parallel_replicate_size: 2 # 2 replicas (one per node)
enable_full_shard: true
enable_mixed_precision: true
init_device: meta
load_weights_mode: broadcast

The constraint that must hold: data_parallel_shard_size × data_parallel_replicate_size == data_parallel_size (where data_parallel_size = world_size / (TP × PP × CP)). xorl validates this in TrainingArguments.__post_init__.


DDP memory per GPU (model with P parameters, bf16 weights)

Section titled “DDP memory per GPU (model with P parameters, bf16 weights)”
ComponentSize
Parameters (bf16)2P bytes
Gradients (fp32)4P bytes
Optimizer states (Adam: m + v + master weights)12P bytes
Total~18P bytes

For a 7B parameter model: ~126 GB per GPU. This exceeds a single 80 GB GPU.

FSDP2 memory per GPU (N GPUs in shard group)

Section titled “FSDP2 memory per GPU (N GPUs in shard group)”
ComponentSize per GPU
Active parameter shard (bf16, all-gathered during forward/backward)2P bytes peak, 2P/N bytes at rest
Gradient shard (fp32 reduce-scatter result)4P/N bytes
Optimizer state shard (Adam)12P/N bytes
Total at rest~(2 + 4 + 12) × P/N = 18P/N bytes
Peak (one layer all-gathered)~2P/N + one layer’s params in bf16

For 7B / 8 GPUs: ~15.75 GB at rest, plus one layer peak overhead — fits on an 80 GB GPU with ample headroom for activations.

Memory per GPU is identical to FSDP2 — the dp_replicate dimension only affects gradient communication, not per-GPU parameter storage.


The problem with FSDP’s default averaging

Section titled “The problem with FSDP’s default averaging”

By default, PyTorch’s FSDP2 divides gradients by fsdp_size (the number of ranks in the FSDP mesh) during the reduce-scatter. This produces gradients averaged over the number of ranks, not over the number of actual training tokens. For variable-length sequence packing, where different ranks may process very different numbers of valid (non-padding) tokens, a simple rank-average produces incorrect gradient scaling.

xorl’s approach: token-normalized gradients

Section titled “xorl’s approach: token-normalized gradients”

xorl disables FSDP’s automatic gradient averaging by calling set_gradient_divide_factor(1.0) on every FSDP module immediately after fully_shard:

# From src/xorl/distributed/torch_parallelize.py
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
for module in model.modules():
if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False):
module.set_gradient_divide_factor(1.0)

This means the reduce-scatter returns the sum of gradients across ranks, not the average.

Gradients are then normalized manually via global_valid_tokens — the total number of non-padding tokens across all ranks and all micro-batches in the current step:

# From src/xorl/trainers/training_utils.py
global_valid_tokens = torch.tensor(0, device=device)
for mb in micro_batches:
labels = mb.get("labels", mb.get("target_tokens"))
if labels is not None:
global_valid_tokens += (labels != IGNORE_INDEX).sum()
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.SUM, group=group)

For standard (non-pipeline-parallel) training, xorl uses the GradientAccumulateLoss autograd function in src/xorl/distributed/gradient_accumulate_loss.py. This custom function:

  1. Forward: scales each micro-batch’s mean loss by local_valid_tokens, all-reduces across ranks to get the total loss sum, then divides by global_valid_tokens to produce a per-token-normalized scalar.
  2. Backward: propagates gradients scaled by local_valid_tokens / global_valid_tokens, so each rank’s contribution to the gradient is proportional to its token count.
ga_loss, _ = gradient_accumulate_loss(loss, local_valid_tokens, global_valid_tokens)
ga_loss.backward()

Since FSDP’s divide factor is 1.0, the reduce-scatter sums raw gradients across ranks. Combined with the local_valid_tokens / global_valid_tokens scaling applied in the backward pass, the net result is gradients averaged over all valid tokens globally.

In pipeline-parallel training, the loss function divides by global_valid_tokens directly:

# loss = cross_entropy_sum / global_valid_tokens
return F.cross_entropy(..., reduction="sum") / global_valid_tokens

After the pipeline schedule completes its backward pass, gradients are already correctly normalized. The set_gradient_divide_factor(1.0) ensures no additional averaging is applied.

Expert modules sharded by EP use a different divide factor. Because EP ranks each hold a disjoint subset of experts, their gradients must be averaged across EP ranks (not summed). xorl sets:

experts_mod.set_gradient_divide_factor(parallel_state.ep_size)
experts_mod._is_ep_fsdp = True # prevents the global 1.0 reset from overriding

This is the only FSDP module that retains a non-unity divide factor.


When enable_mixed_precision: true, xorl applies a MixedPrecisionPolicy to all FSDP2 modules:

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, # parameters cast to bf16 for forward/backward
reduce_dtype=torch.float32, # gradients reduced in fp32 for numerical stability
)

What this means in practice:

  • Stored parameter shards are in bf16 (halving parameter memory vs fp32).
  • During the all-gather before each forward, shards are cast from bf16 to bf16 (no cast needed since they are already bf16).
  • The reduce-scatter during backward accumulates in fp32 before storing the gradient shard in fp32.
  • Optimizer states (Adam m, v, master weights) remain in fp32 throughout.

Modules excluded from mixed precision

Some models define get_ignore_modules_in_mixed_precision() to return module types that must remain in fp32 during forward (e.g., MoE routing gate layers where fp32 precision affects load balancing). Those modules are wrapped with separate fully_shard calls using fsdp_kwargs_without_mp (no mp_policy).

DDP mixed precision

DDP uses the older MixedPrecision wrapper class with identical semantics:

MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)

When EP is enabled (expert_parallel_size > 1), the device mesh gains a separate ep_fsdp dimension for expert parameters, distinct from the main dp_shard dimension used for non-expert parameters.

Separate meshes for expert vs non-expert parameters

Section titled “Separate meshes for expert vs non-expert parameters”
Main FSDP mesh: device_mesh["dp_shard"] (or ["dp_replicate", "dp_shard"] for HSDP)
EP-FSDP mesh: ep_fsdp_device_mesh["ep_fsdp"] (ranks within one EP group that share expert shards)

The EP mesh (ep_fsdp_device_mesh) is a 2D mesh of shape (ep_size, ep_fsdp_size):

  • ep dimension: routes tokens to different expert groups (AllToAll dispatch).
  • ep_fsdp dimension: FSDP sharding of each EP group’s local expert tensors.

Expert weight tensors have shape [num_local_experts, hidden_dim, intermediate_dim] after EP slicing. Standard FSDP2 shards on dim-0 (expert count), but because EP has already consumed dim-0, xorl sets a custom shard placement function to shard on dim-1 (hidden_dim):

def _experts_shard_placement_fn(param):
return Shard(1)
expert_fsdp_kwargs["shard_placement_fn"] = _experts_shard_placement_fn

When EP is active, the automatic FSDP2 prefetching heuristic cannot correctly schedule all-gathers across the heterogeneous set of FSDP units (attention, gate, experts) within each decoder layer. xorl configures explicit forward and backward prefetch lists per layer:

# Forward: prefetch next layer's modules in order (attn, gate, experts)
current_block.set_modules_to_forward_prefetch(list(reversed(prefetch_modules)))
# Backward: prefetch previous layer's modules
current_block.set_modules_to_backward_prefetch(list(reversed(prefetch_modules)))

This ensures all-gathers for the next layer’s weights overlap with the current layer’s compute.


Single node, 8 GPUs — Qwen3-8B dense model (pure FSDP2)

Section titled “Single node, 8 GPUs — Qwen3-8B dense model (pure FSDP2)”
model:
model_path: Qwen/Qwen3-8B
attn_implementation: flash_attention_3
train:
output_dir: outputs/Qwen3-8B
data_parallel_mode: fsdp2
data_parallel_shard_size: 8
data_parallel_replicate_size: 1
tensor_parallel_size: 1
expert_parallel_size: 1
ulysses_parallel_size: 1
micro_batch_size: 1
gradient_accumulation_steps: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_full_shard: true
init_device: meta
load_weights_mode: broadcast

Memory profile (Qwen3-8B, ~8B params):

  • Params at rest: ~2 GB / GPU (bf16, sharded over 8)
  • Gradients: ~4 GB / GPU (fp32 shard)
  • Optimizer states: ~12 GB / GPU (Adam fp32)
  • Total at rest: ~18 GB / GPU — comfortable on 80 GB GPUs with activations

Single node, 8 GPUs — Qwen3-32B dense with Ulysses SP

Section titled “Single node, 8 GPUs — Qwen3-32B dense with Ulysses SP”
model:
model_path: Qwen/Qwen3-32B
attn_implementation: flash_attention_3
train:
output_dir: outputs/Qwen3-32B
data_parallel_mode: fsdp2
data_parallel_shard_size: 2
data_parallel_replicate_size: 1
ulysses_parallel_size: 4 # 4-way context parallelism within each shard group
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_full_shard: true
init_device: meta
load_weights_mode: broadcast

Here world_size = dp_shard × ulysses = 2 × 4 = 8. Each 2-GPU FSDP shard group uses 4-way Ulysses to handle long sequences that would not fit on 2 GPUs individually.

Multi-node HSDP — 2 nodes × 8 GPUs (16 GPUs total), 70B model

Section titled “Multi-node HSDP — 2 nodes × 8 GPUs (16 GPUs total), 70B model”
train:
data_parallel_mode: fsdp2
data_parallel_shard_size: 8 # shard within each node (NVLink)
data_parallel_replicate_size: 2 # 2 identical replicas, one per node
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_full_shard: true
init_device: meta
load_weights_mode: broadcast

Cross-node IB traffic is limited to the all-reduce of averaged gradients (not full shard all-gathers), which is typically 4–8x less than pure FSDP2 would require across nodes.

Single node, 8 GPUs — Qwen3-30B-A3B MoE with EP=8

Section titled “Single node, 8 GPUs — Qwen3-30B-A3B MoE with EP=8”
model:
model_path: Qwen/Qwen3-30B-A3B
attn_implementation: flash_attention_3
moe_implementation: triton
train:
output_dir: outputs/Qwen3-30B-A3B-ep8
data_parallel_mode: fsdp2
data_parallel_shard_size: 8
data_parallel_replicate_size: 1
expert_parallel_size: 8 # one expert group per GPU
ulysses_parallel_size: 1
ringattn_parallel_size: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_full_shard: true
init_device: meta
load_weights_mode: all_ranks # each rank reads its expert shard from disk

With EP=8 on 8 GPUs, ep_fsdp_size = 1, so there is no FSDP sharding of expert params (each GPU holds its local expert slice independently). Non-expert params are sharded over the full 8-GPU FSDP group as usual.

PP=2 + EP=4 + FSDP, 8 GPUs — Qwen3-30B-A3B

Section titled “PP=2 + EP=4 + FSDP, 8 GPUs — Qwen3-30B-A3B”
model:
model_path: Qwen/Qwen3-30B-A3B
attn_implementation: flash_attention_3
moe_implementation: triton
train:
output_dir: outputs/Qwen3-30B-A3B-pp2-ep4
data_parallel_mode: fsdp2
pipeline_parallel_size: 2
pipeline_parallel_schedule: 1F1B
data_parallel_shard_size: 4
data_parallel_replicate_size: 1
expert_parallel_size: 4
gradient_accumulation_steps: 4 # must be >= pipeline_parallel_size
micro_batch_size: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_full_shard: true
reshard_after_forward: false # PP: keep params gathered between fwd micro-batches
init_device: meta
load_weights_mode: broadcast

world_size = PP × dp_shard = 2 × 4 = 8. Each PP stage has 4 GPUs running FSDP2 over non-expert params and EP=4 over expert params.


ValueDescription
fsdp2FSDP2 (default). Required for large models. Requires init_device: meta.
ddpStandard DDP all-reduce. Requires init_device: cuda or npu.
noneNo data parallelism. Used for single-GPU or when TP/PP already fills the world.

Number of GPUs in each FSDP shard group. Parameters, gradients, and optimizer states are divided equally across these GPUs.

  • Default: -1 (auto-computed as world_size / (TP × PP × CP) when data_parallel_replicate_size is not set).
  • For pure FSDP2: set to total DP size.
  • For HSDP: set to the number of GPUs per node (intra-node shard group size).

Number of FSDP shard-group replicas. Enables HSDP when > 1.

  • Default: -1 (auto-computed as dp_size / dp_shard_size).
  • For pure FSDP2: leave at 1 (or omit).
  • For HSDP: set to number of nodes.

Constraint: data_parallel_shard_size × data_parallel_replicate_size == data_parallel_size.

  • true (default): full ZeRO-3 — shard params + grads + optimizer states.
  • false: disable parameter sharding (ZeRO-1/2 style, less memory savings).

Controls whether FSDP2 frees the all-gathered parameter copies after each layer’s forward pass.

ValueBehaviorUse case
None (default)True for non-PP, False for PPAutomatic (recommended)
trueReshard after forward (free bf16 copies)Memory-constrained training
falseKeep params gathered across forward micro-batchesPipeline parallelism (avoids re-gathering for each micro-batch)

When PP is enabled, xorl sets reshard_after_forward=False automatically unless explicitly overridden. The 1F1B schedule runs multiple forward micro-batches before the backward, so keeping params gathered avoids repeated all-gathers.

  • true (default): FSDP2 issues all-gathers for the next layer’s parameters while the current layer is still computing (overlaps communication with computation).
  • false: sequential all-gathers (easier to debug, slower).

When EP is enabled, xorl replaces FSDP2’s automatic prefetching with explicit per-module prefetch lists via set_modules_to_forward_prefetch and set_modules_to_backward_prefetch.

ValueDescriptionCompatible with
metaParameters on meta device; materialized lazily by FSDP2FSDP2 only (required)
cudaParameters initialized directly on GPUDDP, or FSDP2 debugging
cpuParameters on CPU (rank 0 only for broadcast)DDP only; not supported with EP
npuAscend NPU deviceDDP, FSDP2
ValueDescriptionWhen to use
broadcast (default)Rank 0 reads checkpoint from disk, broadcasts shards to all ranksDefault; avoids N-way disk I/O bottleneck
all_ranksEvery rank reads the checkpoint independentlyFast parallel storage (e.g., Lustre, object storage), or when EP requires each rank to load its own expert shard

Meta-device initialization creates parameter tensors with zero CPU/GPU memory cost before fully_shard is applied. Without it, every rank would allocate full-precision parameters before sharding, transiently using N times more memory than necessary.

init_device: meta

Use load_weights_mode: broadcast by default

Section titled “Use load_weights_mode: broadcast by default”

broadcast mode has rank 0 load the checkpoint from disk and distribute shards. This is the safest option when the storage system cannot handle parallel reads from all ranks simultaneously. Switch to all_ranks only on parallel filesystems with guaranteed per-rank I/O performance, or when loading EP shards that differ per rank.

Set reshard_after_forward: false for pipeline parallelism

Section titled “Set reshard_after_forward: false for pipeline parallelism”

The 1F1B pipeline schedule runs multiple forward micro-batches before triggering the backward. Resharding after each micro-batch’s forward would cause the next micro-batch to re-gather the same parameters. Keep them gathered:

pipeline_parallel_size: 2
reshard_after_forward: false # or omit; xorl sets this automatically for PP
gradient_accumulation_steps: 4

HSDP shard_size should match physical node boundaries

Section titled “HSDP shard_size should match physical node boundaries”

For HSDP to be effective, data_parallel_shard_size must equal (or divide evenly into) the number of GPUs per node. This confines all-gather and reduce-scatter collectives to fast NVLink, leaving only the gradient all-reduce to cross slower inter-node links.

# 8 GPUs per node, 4 nodes = 32 GPUs
data_parallel_shard_size: 8 # intra-node FSDP (NVLink)
data_parallel_replicate_size: 4 # 4 node replicas (IB for grad all-reduce only)

Gradient accumulation and PP pipeline depth

Section titled “Gradient accumulation and PP pipeline depth”

When PP is enabled, gradient_accumulation_steps determines the number of micro-batches in the pipeline. It must satisfy gradient_accumulation_steps >= pipeline_parallel_size (usually >= 2 × pp_size for good bubble rate with 1F1B):

pipeline_parallel_size: 4
gradient_accumulation_steps: 8 # bubble rate = (pp-1)/(ga) = 3/8 ≈ 37%

Enable gradient checkpointing to reduce activation memory

Section titled “Enable gradient checkpointing to reduce activation memory”

Gradient checkpointing (activation recomputation) is orthogonal to data parallelism but critical for fitting large models or long sequences:

enable_gradient_checkpointing: true

This recomputes activations during the backward pass instead of storing them, trading compute for memory. Combined with FSDP2’s parameter resharding, this is the recommended configuration for all models larger than ~7B parameters.

Prefer fsdp2 over ddp for any model that does not fit on a single GPU

Section titled “Prefer fsdp2 over ddp for any model that does not fit on a single GPU”

Even for models that technically fit with DDP + bf16, FSDP2 provides:

  • 4–8x reduction in optimizer state memory (sharded Adam states)
  • Gradient sharding (no need to hold all gradients simultaneously)
  • Compatibility with init_device: meta for fast startup
  • First-class composition with TP, EP, SP, and PP

The overhead vs DDP is negligible for sufficiently large models where communication is already latency-bound.

FileDescription
src/xorl/distributed/torch_parallelize.pyfully_shard application, gradient divide factor, EP prefetch setup
src/xorl/distributed/parallel_state.pyDevice mesh construction, DP/EP/TP group initialization