Data Parallelism
Data parallelism (DP) replicates training across multiple GPUs, each processing a distinct mini-batch. xorl supports three DP modes — DDP, FSDP2, and HSDP — controlled by a single data_parallel_mode field. All DP modes compose freely with tensor parallelism (TP), pipeline parallelism (PP), expert parallelism (EP), and context parallelism (SP/CP).
The global world size constraint is:
world_size = PP × DP_shard × DP_replicate × TP × CP_ring × CP_ulyssesModes at a Glance
Section titled “Modes at a Glance”| Mode | Memory savings | Cross-node bandwidth | Best for |
|---|---|---|---|
ddp | None (full params on each GPU) | All-reduce gradients | Small models, prototyping |
fsdp2 | Params + grads + optimizer states | All-gather + reduce-scatter | Single-node or small clusters |
fsdp2 + HSDP | Same as FSDP2 within a node | All-gather + reduce-scatter within node; all-reduce across nodes | Multi-node (reduces expensive cross-node traffic) |
1. DDP — Distributed Data Parallel
Section titled “1. DDP — Distributed Data Parallel”How it works
Section titled “How it works”DDP wraps the model with PyTorch’s DistributedDataParallel. Every GPU holds a complete copy of the model in fp32 (or bf16 with mixed precision). During the backward pass, DDP fires an all-reduce over each parameter gradient bucket to average gradients across all ranks.
xorl applies DDP via torch.nn.parallel.DistributedDataParallel with device_ids=[local_rank]. When enable_mixed_precision: true, a MixedPrecision policy is applied:
MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, # gradients reduced in fp32 buffer_dtype=torch.bfloat16,)Configuration
Section titled “Configuration”train: data_parallel_mode: ddp data_parallel_shard_size: 8 # all 8 GPUs are in one DDP group data_parallel_replicate_size: 1 enable_mixed_precision: true init_device: cuda # DDP requires cuda or npu, not metaNote:
init_device: metais not supported for DDP. Useinit_device: cuda.
When to use DDP
Section titled “When to use DDP”- Model fits comfortably in GPU memory on its own (typically < 3–4B params on an 80 GB GPU with bf16 + optimizer states).
- You want simple semantics and do not need memory savings.
- Prototyping or debugging where FSDP complexity is undesirable.
Limitations
Section titled “Limitations”- Full parameter copy on every GPU: no memory reduction.
- Gradient all-reduce scales with model size, not local shard size, so communication cost grows linearly with parameters.
- Cannot be combined with
init_device: meta(no lazy parameter materialization).
2. FSDP2 — Fully Sharded Data Parallel (ZeRO-3)
Section titled “2. FSDP2 — Fully Sharded Data Parallel (ZeRO-3)”FSDP2 is xorl’s primary training strategy. It shards parameters, gradients, and optimizer states across all ranks in the DP group, providing ZeRO-3-equivalent memory savings.
How FSDP2 differs from FSDP1
Section titled “How FSDP2 differs from FSDP1”| Aspect | FSDP1 (FullyShardedDataParallel) | FSDP2 (fully_shard composable API) |
|---|---|---|
| API style | Module wrapper | Per-module composable annotation |
| DTensor integration | Limited | Native — sharded params are DTensor objects |
| Composition with TP/EP | Difficult | First-class via device mesh |
| Grad hooks | Implicit via _post_backward_hook | Clean reduce-scatter on DTensor |
| Mixed precision | MixedPrecision wrapper class | MixedPrecisionPolicy dataclass |
| Min PyTorch version | 1.13 | 2.4 |
xorl uses the composable fully_shard API exclusively and requires PyTorch >= 2.4.
Parameter sharding
Section titled “Parameter sharding”fully_shard is applied at two granularities:
- Per decoder block — each transformer layer (e.g.,
Qwen3DecoderLayer) is wrapped independently. This keeps the all-gather scope to one layer’s parameters at a time, minimizing peak memory. - Root model — the entire model is wrapped last, so FSDP2 handles the embedding and lm_head as a separate unit.
As a torchtitan-inspired optimization, the final norm + lm_head modules are grouped into a single FSDP unit with reshard_after_forward=False. This lets lm_head remain gathered after the last decoder layer’s forward, avoiding a redundant all-gather when the externally called compute_loss accesses lm_head.weight.
full_shard vs no_shard
Section titled “full_shard vs no_shard”The enable_full_shard config field maps to FSDP2’s sharding behavior:
enable_full_shard | FSDP2 behavior | Equivalent ZeRO level |
|---|---|---|
true (default) | Shard params + grads + optimizer states | ZeRO-3 |
false | No parameter sharding (only grad/optimizer state savings) | ZeRO-1/2 style |
In practice, always set
enable_full_shard: truefor large models. Setting it tofalseis primarily useful for debugging or models small enough that parameter memory is not the bottleneck.
Configuration
Section titled “Configuration”train: data_parallel_mode: fsdp2 data_parallel_shard_size: 8 # all 8 GPUs share parameters data_parallel_replicate_size: 1 enable_full_shard: true enable_mixed_precision: true init_device: meta # required for FSDP2 load_weights_mode: broadcast # rank0 loads, broadcasts to all ranksinit_device: meta is required for FSDP2. Parameters are initially created on the meta device (zero-cost), then materialized by FSDP2 after fully_shard is applied.
3. HSDP — Hybrid Sharding Data Parallel
Section titled “3. HSDP — Hybrid Sharding Data Parallel”HSDP is activated automatically when both data_parallel_shard_size > 1 and data_parallel_replicate_size > 1. It creates a 2D device mesh:
mesh[dp_replicate, dp_shard]- The
dp_sharddimension (inner) performs full ZeRO-3 sharding via all-gather + reduce-scatter. All GPUs in a shard group collectively hold one copy of the model. - The
dp_replicatedimension (outer) replicates the model across shard groups. Gradients are all-reduced across shard group replicas after the reduce-scatter.
Device mesh layout
Section titled “Device mesh layout”For 16 GPUs with data_parallel_shard_size=8, data_parallel_replicate_size=2:
Node 0: GPUs [0,1,2,3,4,5,6,7] → dp_shard group 0 (shard group A)Node 1: GPUs [8,9,10,11,12,13,14,15] → dp_shard group 1 (shard group B)
dp_replicate groups: {GPU 0, GPU 8}, {GPU 1, GPU 9}, ..., {GPU 7, GPU 15}All-gather and reduce-scatter happen within each node (high-bandwidth NVLink). The cross-node gradient synchronization (all-reduce across the replicate dimension) only carries the averaged gradients, which is much smaller than synchronizing full shards.
How the mesh is built in code
Section titled “How the mesh is built in code”init_parallel_state in parallel_state.py constructs a multi-dimensional DeviceMesh that includes both dp_replicate and dp_shard axes. The FSDP mesh is then sliced as:
# HSDP: 2D mesh (dp_replicate, dp_shard)fsdp_mesh = device_mesh["dp_replicate", "dp_shard"]
# Pure FSDP2: 1D mesh (dp_shard only)fsdp_mesh = device_mesh["dp_shard"]FSDP2’s fully_shard(module, mesh=fsdp_mesh) interprets a 2D mesh as HSDP automatically — shard along the inner dim, replicate across the outer dim.
When to use HSDP
Section titled “When to use HSDP”- Training on multiple nodes where cross-node bandwidth (InfiniBand, RoCE) is significantly slower than intra-node bandwidth (NVLink).
- Model parameters fit within a single node’s GPU memory when sharded (i.e.,
model_params / dp_shard_sizefits per GPU). - You want to minimize cross-node communication by confining all-gather and reduce-scatter to fast intra-node links.
Configuration
Section titled “Configuration”# 2-node × 8 GPU = 16 GPUs: shard within node, replicate across nodestrain: data_parallel_mode: fsdp2 data_parallel_shard_size: 8 # one full shard group per node data_parallel_replicate_size: 2 # 2 replicas (one per node) enable_full_shard: true enable_mixed_precision: true init_device: meta load_weights_mode: broadcastThe constraint that must hold: data_parallel_shard_size × data_parallel_replicate_size == data_parallel_size (where data_parallel_size = world_size / (TP × PP × CP)). xorl validates this in TrainingArguments.__post_init__.
4. Memory Breakdown
Section titled “4. Memory Breakdown”DDP memory per GPU (model with P parameters, bf16 weights)
Section titled “DDP memory per GPU (model with P parameters, bf16 weights)”| Component | Size |
|---|---|
| Parameters (bf16) | 2P bytes |
| Gradients (fp32) | 4P bytes |
| Optimizer states (Adam: m + v + master weights) | 12P bytes |
| Total | ~18P bytes |
For a 7B parameter model: ~126 GB per GPU. This exceeds a single 80 GB GPU.
FSDP2 memory per GPU (N GPUs in shard group)
Section titled “FSDP2 memory per GPU (N GPUs in shard group)”| Component | Size per GPU |
|---|---|
| Active parameter shard (bf16, all-gathered during forward/backward) | 2P bytes peak, 2P/N bytes at rest |
| Gradient shard (fp32 reduce-scatter result) | 4P/N bytes |
| Optimizer state shard (Adam) | 12P/N bytes |
| Total at rest | ~(2 + 4 + 12) × P/N = 18P/N bytes |
| Peak (one layer all-gathered) | ~2P/N + one layer’s params in bf16 |
For 7B / 8 GPUs: ~15.75 GB at rest, plus one layer peak overhead — fits on an 80 GB GPU with ample headroom for activations.
HSDP memory per GPU
Section titled “HSDP memory per GPU”Memory per GPU is identical to FSDP2 — the dp_replicate dimension only affects gradient communication, not per-GPU parameter storage.
5. Gradient Averaging
Section titled “5. Gradient Averaging”The problem with FSDP’s default averaging
Section titled “The problem with FSDP’s default averaging”By default, PyTorch’s FSDP2 divides gradients by fsdp_size (the number of ranks in the FSDP mesh) during the reduce-scatter. This produces gradients averaged over the number of ranks, not over the number of actual training tokens. For variable-length sequence packing, where different ranks may process very different numbers of valid (non-padding) tokens, a simple rank-average produces incorrect gradient scaling.
xorl’s approach: token-normalized gradients
Section titled “xorl’s approach: token-normalized gradients”xorl disables FSDP’s automatic gradient averaging by calling set_gradient_divide_factor(1.0) on every FSDP module immediately after fully_shard:
# From src/xorl/distributed/torch_parallelize.pyfrom torch.distributed._composable.fsdp.fully_shard import FSDPModulefor module in model.modules(): if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False): module.set_gradient_divide_factor(1.0)This means the reduce-scatter returns the sum of gradients across ranks, not the average.
Gradients are then normalized manually via global_valid_tokens — the total number of non-padding tokens across all ranks and all micro-batches in the current step:
# From src/xorl/trainers/training_utils.pyglobal_valid_tokens = torch.tensor(0, device=device)for mb in micro_batches: labels = mb.get("labels", mb.get("target_tokens")) if labels is not None: global_valid_tokens += (labels != IGNORE_INDEX).sum()dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.SUM, group=group)Non-PP path: gradient_accumulate_loss
Section titled “Non-PP path: gradient_accumulate_loss”For standard (non-pipeline-parallel) training, xorl uses the GradientAccumulateLoss autograd function in src/xorl/distributed/gradient_accumulate_loss.py. This custom function:
- Forward: scales each micro-batch’s mean loss by
local_valid_tokens, all-reduces across ranks to get the total loss sum, then divides byglobal_valid_tokensto produce a per-token-normalized scalar. - Backward: propagates gradients scaled by
local_valid_tokens / global_valid_tokens, so each rank’s contribution to the gradient is proportional to its token count.
ga_loss, _ = gradient_accumulate_loss(loss, local_valid_tokens, global_valid_tokens)ga_loss.backward()Since FSDP’s divide factor is 1.0, the reduce-scatter sums raw gradients across ranks. Combined with the local_valid_tokens / global_valid_tokens scaling applied in the backward pass, the net result is gradients averaged over all valid tokens globally.
PP path: explicit gradient scaling
Section titled “PP path: explicit gradient scaling”In pipeline-parallel training, the loss function divides by global_valid_tokens directly:
# loss = cross_entropy_sum / global_valid_tokensreturn F.cross_entropy(..., reduction="sum") / global_valid_tokensAfter the pipeline schedule completes its backward pass, gradients are already correctly normalized. The set_gradient_divide_factor(1.0) ensures no additional averaging is applied.
EP expert modules: separate divide factor
Section titled “EP expert modules: separate divide factor”Expert modules sharded by EP use a different divide factor. Because EP ranks each hold a disjoint subset of experts, their gradients must be averaged across EP ranks (not summed). xorl sets:
experts_mod.set_gradient_divide_factor(parallel_state.ep_size)experts_mod._is_ep_fsdp = True # prevents the global 1.0 reset from overridingThis is the only FSDP module that retains a non-unity divide factor.
6. Mixed Precision
Section titled “6. Mixed Precision”When enable_mixed_precision: true, xorl applies a MixedPrecisionPolicy to all FSDP2 modules:
mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, # parameters cast to bf16 for forward/backward reduce_dtype=torch.float32, # gradients reduced in fp32 for numerical stability)What this means in practice:
- Stored parameter shards are in bf16 (halving parameter memory vs fp32).
- During the all-gather before each forward, shards are cast from bf16 to bf16 (no cast needed since they are already bf16).
- The reduce-scatter during backward accumulates in fp32 before storing the gradient shard in fp32.
- Optimizer states (Adam m, v, master weights) remain in fp32 throughout.
Modules excluded from mixed precision
Some models define get_ignore_modules_in_mixed_precision() to return module types that must remain in fp32 during forward (e.g., MoE routing gate layers where fp32 precision affects load balancing). Those modules are wrapped with separate fully_shard calls using fsdp_kwargs_without_mp (no mp_policy).
DDP mixed precision
DDP uses the older MixedPrecision wrapper class with identical semantics:
MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.bfloat16,)7. Expert Parallelism (EP) Interaction
Section titled “7. Expert Parallelism (EP) Interaction”When EP is enabled (expert_parallel_size > 1), the device mesh gains a separate ep_fsdp dimension for expert parameters, distinct from the main dp_shard dimension used for non-expert parameters.
Separate meshes for expert vs non-expert parameters
Section titled “Separate meshes for expert vs non-expert parameters”Main FSDP mesh: device_mesh["dp_shard"] (or ["dp_replicate", "dp_shard"] for HSDP)EP-FSDP mesh: ep_fsdp_device_mesh["ep_fsdp"] (ranks within one EP group that share expert shards)The EP mesh (ep_fsdp_device_mesh) is a 2D mesh of shape (ep_size, ep_fsdp_size):
epdimension: routes tokens to different expert groups (AllToAll dispatch).ep_fsdpdimension: FSDP sharding of each EP group’s local expert tensors.
Dim-1 sharding for expert parameters
Section titled “Dim-1 sharding for expert parameters”Expert weight tensors have shape [num_local_experts, hidden_dim, intermediate_dim] after EP slicing. Standard FSDP2 shards on dim-0 (expert count), but because EP has already consumed dim-0, xorl sets a custom shard placement function to shard on dim-1 (hidden_dim):
def _experts_shard_placement_fn(param): return Shard(1)
expert_fsdp_kwargs["shard_placement_fn"] = _experts_shard_placement_fnManual prefetching
Section titled “Manual prefetching”When EP is active, the automatic FSDP2 prefetching heuristic cannot correctly schedule all-gathers across the heterogeneous set of FSDP units (attention, gate, experts) within each decoder layer. xorl configures explicit forward and backward prefetch lists per layer:
# Forward: prefetch next layer's modules in order (attn, gate, experts)current_block.set_modules_to_forward_prefetch(list(reversed(prefetch_modules)))# Backward: prefetch previous layer's modulescurrent_block.set_modules_to_backward_prefetch(list(reversed(prefetch_modules)))This ensures all-gathers for the next layer’s weights overlap with the current layer’s compute.
8. Configuration Examples
Section titled “8. Configuration Examples”Single node, 8 GPUs — Qwen3-8B dense model (pure FSDP2)
Section titled “Single node, 8 GPUs — Qwen3-8B dense model (pure FSDP2)”model: model_path: Qwen/Qwen3-8B attn_implementation: flash_attention_3
train: output_dir: outputs/Qwen3-8B data_parallel_mode: fsdp2 data_parallel_shard_size: 8 data_parallel_replicate_size: 1 tensor_parallel_size: 1 expert_parallel_size: 1 ulysses_parallel_size: 1
micro_batch_size: 1 gradient_accumulation_steps: 1
enable_mixed_precision: true enable_gradient_checkpointing: true enable_full_shard: true init_device: meta load_weights_mode: broadcastMemory profile (Qwen3-8B, ~8B params):
- Params at rest: ~2 GB / GPU (bf16, sharded over 8)
- Gradients: ~4 GB / GPU (fp32 shard)
- Optimizer states: ~12 GB / GPU (Adam fp32)
- Total at rest: ~18 GB / GPU — comfortable on 80 GB GPUs with activations
Single node, 8 GPUs — Qwen3-32B dense with Ulysses SP
Section titled “Single node, 8 GPUs — Qwen3-32B dense with Ulysses SP”model: model_path: Qwen/Qwen3-32B attn_implementation: flash_attention_3
train: output_dir: outputs/Qwen3-32B data_parallel_mode: fsdp2 data_parallel_shard_size: 2 data_parallel_replicate_size: 1 ulysses_parallel_size: 4 # 4-way context parallelism within each shard group
enable_mixed_precision: true enable_gradient_checkpointing: true enable_full_shard: true init_device: meta load_weights_mode: broadcastHere world_size = dp_shard × ulysses = 2 × 4 = 8. Each 2-GPU FSDP shard group uses 4-way Ulysses to handle long sequences that would not fit on 2 GPUs individually.
Multi-node HSDP — 2 nodes × 8 GPUs (16 GPUs total), 70B model
Section titled “Multi-node HSDP — 2 nodes × 8 GPUs (16 GPUs total), 70B model”train: data_parallel_mode: fsdp2 data_parallel_shard_size: 8 # shard within each node (NVLink) data_parallel_replicate_size: 2 # 2 identical replicas, one per node
enable_mixed_precision: true enable_gradient_checkpointing: true enable_full_shard: true init_device: meta load_weights_mode: broadcastCross-node IB traffic is limited to the all-reduce of averaged gradients (not full shard all-gathers), which is typically 4–8x less than pure FSDP2 would require across nodes.
Single node, 8 GPUs — Qwen3-30B-A3B MoE with EP=8
Section titled “Single node, 8 GPUs — Qwen3-30B-A3B MoE with EP=8”model: model_path: Qwen/Qwen3-30B-A3B attn_implementation: flash_attention_3 moe_implementation: triton
train: output_dir: outputs/Qwen3-30B-A3B-ep8 data_parallel_mode: fsdp2 data_parallel_shard_size: 8 data_parallel_replicate_size: 1 expert_parallel_size: 8 # one expert group per GPU ulysses_parallel_size: 1 ringattn_parallel_size: 1
enable_mixed_precision: true enable_gradient_checkpointing: true enable_full_shard: true init_device: meta load_weights_mode: all_ranks # each rank reads its expert shard from diskWith EP=8 on 8 GPUs, ep_fsdp_size = 1, so there is no FSDP sharding of expert params (each GPU holds its local expert slice independently). Non-expert params are sharded over the full 8-GPU FSDP group as usual.
PP=2 + EP=4 + FSDP, 8 GPUs — Qwen3-30B-A3B
Section titled “PP=2 + EP=4 + FSDP, 8 GPUs — Qwen3-30B-A3B”model: model_path: Qwen/Qwen3-30B-A3B attn_implementation: flash_attention_3 moe_implementation: triton
train: output_dir: outputs/Qwen3-30B-A3B-pp2-ep4 data_parallel_mode: fsdp2 pipeline_parallel_size: 2 pipeline_parallel_schedule: 1F1B data_parallel_shard_size: 4 data_parallel_replicate_size: 1 expert_parallel_size: 4
gradient_accumulation_steps: 4 # must be >= pipeline_parallel_size micro_batch_size: 1
enable_mixed_precision: true enable_gradient_checkpointing: true enable_full_shard: true reshard_after_forward: false # PP: keep params gathered between fwd micro-batches init_device: meta load_weights_mode: broadcastworld_size = PP × dp_shard = 2 × 4 = 8. Each PP stage has 4 GPUs running FSDP2 over non-expert params and EP=4 over expert params.
9. Key Parameters Reference
Section titled “9. Key Parameters Reference”data_parallel_mode
Section titled “data_parallel_mode”| Value | Description |
|---|---|
fsdp2 | FSDP2 (default). Required for large models. Requires init_device: meta. |
ddp | Standard DDP all-reduce. Requires init_device: cuda or npu. |
none | No data parallelism. Used for single-GPU or when TP/PP already fills the world. |
data_parallel_shard_size
Section titled “data_parallel_shard_size”Number of GPUs in each FSDP shard group. Parameters, gradients, and optimizer states are divided equally across these GPUs.
- Default:
-1(auto-computed asworld_size / (TP × PP × CP)whendata_parallel_replicate_sizeis not set). - For pure FSDP2: set to total DP size.
- For HSDP: set to the number of GPUs per node (intra-node shard group size).
data_parallel_replicate_size
Section titled “data_parallel_replicate_size”Number of FSDP shard-group replicas. Enables HSDP when > 1.
- Default:
-1(auto-computed asdp_size / dp_shard_size). - For pure FSDP2: leave at
1(or omit). - For HSDP: set to number of nodes.
Constraint: data_parallel_shard_size × data_parallel_replicate_size == data_parallel_size.
enable_full_shard
Section titled “enable_full_shard”true(default): full ZeRO-3 — shard params + grads + optimizer states.false: disable parameter sharding (ZeRO-1/2 style, less memory savings).
reshard_after_forward
Section titled “reshard_after_forward”Controls whether FSDP2 frees the all-gathered parameter copies after each layer’s forward pass.
| Value | Behavior | Use case |
|---|---|---|
None (default) | True for non-PP, False for PP | Automatic (recommended) |
true | Reshard after forward (free bf16 copies) | Memory-constrained training |
false | Keep params gathered across forward micro-batches | Pipeline parallelism (avoids re-gathering for each micro-batch) |
When PP is enabled, xorl sets reshard_after_forward=False automatically unless explicitly overridden. The 1F1B schedule runs multiple forward micro-batches before the backward, so keeping params gathered avoids repeated all-gathers.
enable_forward_prefetch
Section titled “enable_forward_prefetch”true(default): FSDP2 issues all-gathers for the next layer’s parameters while the current layer is still computing (overlaps communication with computation).false: sequential all-gathers (easier to debug, slower).
When EP is enabled, xorl replaces FSDP2’s automatic prefetching with explicit per-module prefetch lists via set_modules_to_forward_prefetch and set_modules_to_backward_prefetch.
init_device
Section titled “init_device”| Value | Description | Compatible with |
|---|---|---|
meta | Parameters on meta device; materialized lazily by FSDP2 | FSDP2 only (required) |
cuda | Parameters initialized directly on GPU | DDP, or FSDP2 debugging |
cpu | Parameters on CPU (rank 0 only for broadcast) | DDP only; not supported with EP |
npu | Ascend NPU device | DDP, FSDP2 |
load_weights_mode
Section titled “load_weights_mode”| Value | Description | When to use |
|---|---|---|
broadcast (default) | Rank 0 reads checkpoint from disk, broadcasts shards to all ranks | Default; avoids N-way disk I/O bottleneck |
all_ranks | Every rank reads the checkpoint independently | Fast parallel storage (e.g., Lustre, object storage), or when EP requires each rank to load its own expert shard |
10. Best Practices
Section titled “10. Best Practices”Always use init_device: meta with FSDP2
Section titled “Always use init_device: meta with FSDP2”Meta-device initialization creates parameter tensors with zero CPU/GPU memory cost before fully_shard is applied. Without it, every rank would allocate full-precision parameters before sharding, transiently using N times more memory than necessary.
init_device: metaUse load_weights_mode: broadcast by default
Section titled “Use load_weights_mode: broadcast by default”broadcast mode has rank 0 load the checkpoint from disk and distribute shards. This is the safest option when the storage system cannot handle parallel reads from all ranks simultaneously. Switch to all_ranks only on parallel filesystems with guaranteed per-rank I/O performance, or when loading EP shards that differ per rank.
Set reshard_after_forward: false for pipeline parallelism
Section titled “Set reshard_after_forward: false for pipeline parallelism”The 1F1B pipeline schedule runs multiple forward micro-batches before triggering the backward. Resharding after each micro-batch’s forward would cause the next micro-batch to re-gather the same parameters. Keep them gathered:
pipeline_parallel_size: 2reshard_after_forward: false # or omit; xorl sets this automatically for PPgradient_accumulation_steps: 4HSDP shard_size should match physical node boundaries
Section titled “HSDP shard_size should match physical node boundaries”For HSDP to be effective, data_parallel_shard_size must equal (or divide evenly into) the number of GPUs per node. This confines all-gather and reduce-scatter collectives to fast NVLink, leaving only the gradient all-reduce to cross slower inter-node links.
# 8 GPUs per node, 4 nodes = 32 GPUsdata_parallel_shard_size: 8 # intra-node FSDP (NVLink)data_parallel_replicate_size: 4 # 4 node replicas (IB for grad all-reduce only)Gradient accumulation and PP pipeline depth
Section titled “Gradient accumulation and PP pipeline depth”When PP is enabled, gradient_accumulation_steps determines the number of micro-batches in the pipeline. It must satisfy gradient_accumulation_steps >= pipeline_parallel_size (usually >= 2 × pp_size for good bubble rate with 1F1B):
pipeline_parallel_size: 4gradient_accumulation_steps: 8 # bubble rate = (pp-1)/(ga) = 3/8 ≈ 37%Enable gradient checkpointing to reduce activation memory
Section titled “Enable gradient checkpointing to reduce activation memory”Gradient checkpointing (activation recomputation) is orthogonal to data parallelism but critical for fitting large models or long sequences:
enable_gradient_checkpointing: trueThis recomputes activations during the backward pass instead of storing them, trading compute for memory. Combined with FSDP2’s parameter resharding, this is the recommended configuration for all models larger than ~7B parameters.
Prefer fsdp2 over ddp for any model that does not fit on a single GPU
Section titled “Prefer fsdp2 over ddp for any model that does not fit on a single GPU”Even for models that technically fit with DDP + bf16, FSDP2 provides:
- 4–8x reduction in optimizer state memory (sharded Adam states)
- Gradient sharding (no need to hold all gradients simultaneously)
- Compatibility with
init_device: metafor fast startup - First-class composition with TP, EP, SP, and PP
The overhead vs DDP is negligible for sufficiently large models where communication is already latency-bound.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/torch_parallelize.py | fully_shard application, gradient divide factor, EP prefetch setup |
src/xorl/distributed/parallel_state.py | Device mesh construction, DP/EP/TP group initialization |