Context Parallelism
Context parallelism (CP) distributes long sequences across multiple GPUs during training, allowing context lengths that would not fit on a single device. There are two independent SP mechanisms — Ulysses SP and Ring Attention — which can be used alone or composed into a Hybrid CP configuration.
The combined context-parallel (CP) group size is:
cp_size = ulysses_size × ringattn_sizeControlled by two config fields in TrainArguments:
| Config field | Default | Description |
|---|---|---|
ulysses_parallel_size | 1 | Number of ranks in the Ulysses SP group |
ringattn_parallel_size | 1 | Number of ranks in the ring attention group |
cp_fsdp_mode | "all" | How SP dimensions fold into FSDP sharding |
Ulysses Context Parallelism
Section titled “Ulysses Context Parallelism”How It Works
Section titled “How It Works”Ulysses SP (originally from DeepSpeed) distributes attention heads across ranks while keeping each rank responsible for a slice of the sequence. The key insight is that attention is separable across heads, so gathering the full sequence per rank and computing a subset of heads is equivalent to full attention.
Data layout entering the attention layer:
- Each Ulysses rank holds: shape
[B, S/ulysses_size, num_heads, head_dim] - All ranks hold the same sequence slice but different positions in the batch dimension (within the CP group, they hold the same sequence data)
Pre-attention all-to-all (phase 1, project_qkv):
Input: [B, S/P, H, D] (S/P tokens, H heads, per rank) scatter dim=seq, gather dim=headOutput: [B, S, H/P, D] (S tokens, H/P heads, per rank)The all-to-all transposes the sharding axis from sequence to heads. After this
operation, each rank holds all tokens but only H/ulysses_size heads, enabling
standard local flash attention.
Post-attention all-to-all (phase 3, project_output):
Input: [B, S, H/P, D] (output of attention, full sequence, local heads) scatter dim=head, gather dim=seqOutput: [B, S/P, H, D] (back to sequence-sharded)This restores the original sequence-sharded layout for the subsequent MLP and layer-norm computations, which are sequence-independent and run locally.
Communication Cost
Section titled “Communication Cost”Each Ulysses attention layer requires two all-to-all collectives. The data volume transferred per all-to-all is:
bytes per all-to-all = B × S × H × D × dtype_bytesThe all-to-all sends 1/P of the data to each of the P ranks, for a total
communication volume proportional to S × H × D. Critically, Ulysses communication
cost does not grow with sequence length relative to compute — both compute (O(S²H))
and communication (O(S×H)) scale, but communication is lower order.
For GQA models where ulysses_size > num_kv_heads, xorl expands KV heads
with repeat_kv before the all-to-all so that each rank receives at least one KV head
after scattering.
Constraints
Section titled “Constraints”-
Head divisibility:
num_attention_heads % ulysses_size == 0For GQA models with few KV heads, the sync strategy handles the case whereulysses_size > num_kv_headsby expanding KV withrepeat_kvbefore the a2a:assert ulysses_size % num_kv_heads == 0 -
Sequence divisibility: The sequence length must be divisible by
ulysses_sizeafter padding. The collator pads tocp_size(or2*cp_sizefor ring+zigzag).
Sync vs. Async Ulysses
Section titled “Sync vs. Async Ulysses”The strategy resolver (get_cp_strategy in strategy.py) automatically selects
between two Ulysses implementations:
-
UlyssesAsyncStrategy(used whenulysses_size <= num_kv_heads): Overlaps each linear projection (Q, K, V) with its corresponding all-to-all on a separate NCCL stream. This is implemented inAsyncUlyssesQKVProjection.forward(), which launches Q projection, then Q all-to-all asynchronously, then K projection, K all-to-all, and so on — each communication overlaps the next compute step. -
UlyssesSyncStrategy(used whenulysses_size > num_kv_heads, i.e., GQA with few KV heads): Performs projections first, then synchronous all-to-all. Also uses a fused K/V all-to-all: K and V are interleaved into a single[S, 2*H_kv, D]tensor to halve the number of collectives.
For QLoRA, both strategies fall back to sync-style (weights are packed in quantized
buffers, so the manual matmul path in AsyncUlyssesQKVProjection is unavailable).
Gradient Synchronization
Section titled “Gradient Synchronization”Ulysses all-to-all collectives have correct gradients baked in via autograd. The
_SeqAllToAll autograd function swaps scatter and gather dims in the backward pass,
so parameter gradients from the QKV and output projections are correct locally.
However, if Ulysses is not folded into the FSDP mesh (controlled by
cp_fsdp_mode), gradients from different Ulysses ranks must be explicitly summed.
This is handled by sync_sp_gradients in training_utils.py:
def sync_sp_gradients(model, sp_grad_sync_group): if sp_grad_sync_group is not None: for p in model.parameters(): if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=sp_grad_sync_group)The key point: SP ranks compute gradients for complementary, non-overlapping
sequence shards, so the correct operation is SUM (not AVG). FSDP’s reduce-scatter
already averages gradients within its group; sync_sp_gradients adds the contributions
from ranks not in that group.
Ring Attention
Section titled “Ring Attention”How It Works
Section titled “How It Works”Ring attention distributes attention computation by assigning each rank a contiguous shard of the sequence. Unlike Ulysses (which uses all-to-all to gather the full sequence per rank), ring attention keeps each rank’s Q, K, V local and rotates the KV blocks around a logical ring of GPUs.
Forward pass (RingAttentionP2PFunc.forward):
- Pack local K and V into a contiguous buffer:
kv_buf = cat([k.reshape(-1), v.reshape(-1)]) - Allocate two ping-pong buffers for double-buffered P2P communication.
- Loop
ringattn_sizetimes:- Wait for the P2P recv of the current KV buffer.
- Immediately post the next P2P send/recv on a separate CUDA stream (overlaps with the attention compute below).
- Extract
k_step,v_stepfrom the current buffer. - Run flash attention between local Q and
(k_step, v_step). - Merge partial outputs using online log-sum-exp (numerically stable via softplus).
- At the end of the loop, each rank holds the correct attention output for its local Q tokens, having attended to all K/V tokens across all ring ranks.
Backward pass:
The backward does not use P2P rings. Instead, it re-gathers all KV via all_gather_kv
(stacks K and V into a single buffer, performs one all_gather_into_tensor, then
splits). This is memory-efficient because only local K, V are saved in the forward
context. The dk and dv gradients accumulated across all steps are returned to their
owning ranks via reduce_scatter_grads.
P2P deadlock avoidance: Even-rank GPUs send first then receive; odd-rank GPUs receive first then send. This avoids circular deadlocks in NCCL’s point-to-point operations.
Zigzag Load Balancing
Section titled “Zigzag Load Balancing”Causal attention creates an asymmetric workload: tokens near the end of the sequence attend to many more KV tokens than tokens near the start. Without any reordering, ring ranks holding late sequence chunks perform far more FLOPs than ranks holding early chunks.
Zigzag reordering solves this by assigning each rank a pair of sub-chunks: one from the early part of each document and one from the late part. After zigzag, every rank performs the same total amount of causal attention work.
Chunk assignment: Split each document into 2 * ringattn_size equal-sized
sub-chunks numbered 0 to 2N-1. Rank r receives sub-chunks r (early) and
2N-1-r (late). This means rank 0 gets chunks [0, 2N-1], rank 1 gets [1, 2N-2],
etc. — the symmetry ensures each rank’s early+late pair spans the same total causal
attention workload.
Ring step sections: During the ring loop, each step is classified into one of three sections based on the relative positions of the query rank and the current KV source rank:
"diagonal": The rank is computing attention with its own KV (step 0). Full causal attention applies within the combined chunk."lower": Q’s early sub-chunk is later than KV’s early sub-chunk. All Q tokens can attend to the early half of the KV (no mask needed)."upper": Q’s early sub-chunk is earlier. Only Q’s late half attends to all KV (no mask needed).
This three-case logic eliminates causal masking overhead for all non-diagonal steps while maintaining correctness.
Zigzag Reordering in the Data Collator
Section titled “Zigzag Reordering in the Data Collator”The reordering happens in zigzag_reorder_packed_sequence (called from
TextSequenceShardCollator.__call__). The full algorithm:
- Detect document boundaries by finding positions where
position_id == 0. - For each document of length
L, split into2 * ringattn_sizeequal chunks. - For rank
r: collect chunkr(early) and chunk2N-1-r(late) from every document. - Concatenate all rank 0 data first, then rank 1, etc., into a single reordered sequence.
- Each CP rank then takes its contiguous slice via
sp_slice.
Constraint: Every document must have length divisible by 2 * ringattn_size.
If any document fails this check, zigzag_reorder_packed_sequence raises a
ValueError. The SP padding step ensures the total sequence length satisfies this,
but individual documents in a packed sequence must already be individually aligned.
If documents are not naturally aligned, the upstream packing step must enforce this.
SP Padding
Section titled “SP Padding”Before slicing, the collator pads the sequence to the required multiple:
pad_multiple = 2 * cp_size if ringattn_size > 1 else cp_sizeThis ensures:
- For pure ring attention: divisible by
2 * ringattn_sizefor zigzag sub-chunks. - For hybrid CP: divisible by
2 * ringattn_size * ulysses_size = 2 * cp_size. - For pure Ulysses: divisible by
ulysses_size = cp_size.
After padding, sp_slice takes rank cp_rank’s contiguous chunk of size
padded_len / cp_size:
def sp_slice(self, tensor, dim=-1): seq_length = tensor.size(dim) cp_chunk_size = (seq_length + self.cp_size - 1) // self.cp_size return tensor.narrow(dim, self.cp_rank * cp_chunk_size, cp_chunk_size)cu_seq_lens Handling for Ring Attention
Section titled “cu_seq_lens Handling for Ring Attention”Flash attention’s varlen API requires cumulative sequence lengths (cu_seqlens) to
locate document boundaries within packed sequences. After zigzag reordering, the
position_ids tensor has resets (position_id == 0) at every sub-chunk boundary —
not just at true document boundaries. Using zigzag-reordered position IDs to compute
cu_seqlens would create hundreds of spurious tiny “documents” and produce wrong
cu_seqlens, leading to NaN in flash attention backward.
Solution: The collator stores the original (pre-zigzag) position IDs in
_original_position_ids before reordering, then computes cu_seqlens from those:
# In TextSequenceShardCollator.__call__:if "_original_position_ids" not in batch: batch["_original_position_ids"] = position_ids.clone()
# ... zigzag reorder happens here ...
if self.ringattn_size > 1 and "_original_position_ids" in batch: orig_pos = batch["_original_position_ids"] (cu_q, cu_k), (max_q, max_k) = prepare_fa_kwargs_from_position_ids(orig_pos) batch["cu_seq_lens_q"] = cu_q batch["cu_seq_lens_k"] = cu_k batch["max_length_q"] = max_q batch["max_length_k"] = max_kThese cu_seqlens reflect the full (unsharded) document structure. The ring attention
strategy (_scale_cu_seqlens_for_ringattn in strategy.py) then divides them by
ringattn_size to convert to per-rank local offsets:
cu_seqlens_q = (cu_seqlens_q // ringattn_size).to(torch.int32)max_seqlen_q = max_seqlen_q // ringattn_sizeThis works because zigzag guarantees each rank holds exactly L / ringattn_size
tokens from each document of length L.
Note: position_ids itself is not sliced — the full padded tensor is forwarded to
the model. RoPE embedding uses position_ids to index the pre-computed (cos, sin)
tables, and ring attention needs each rank to use the correct positional indices for
its local tokens, which the zigzag assignment already ensures.
Hybrid CP: Ulysses + Ring
Section titled “Hybrid CP: Ulysses + Ring”When both ulysses_size > 1 and ringattn_size > 1, the two mechanisms compose into
a hybrid strategy (HybridUlyssesRingStrategy).
Total CP Group Size
Section titled “Total CP Group Size”cp_size = ulysses_size × ringattn_sizeEach batch is sharded across cp_size ranks. After SP padding, each rank holds
S / cp_size tokens.
Data Flow in Hybrid Mode
Section titled “Data Flow in Hybrid Mode”The three-phase strategy maps to:
Input shard: [B, S/cp_size, H, D]
Phase 1 (project_qkv) — Ulysses all-to-all: scatter seq (within Ulysses group), gather heads → [B, S/ringattn_size, H/ulysses_size, D]
Phase 2 (compute_attention) — Ring attention: KV rotated across ring group (ringattn_size ranks) Each rank sees all S/ringattn_size * ulysses_size tokens for its H/ulysses_size heads Wait — ring gives each rank full-sequence attention: → ring loop over ringattn_size steps, attending to all S tokens → output: [B, S/ringattn_size, H/ulysses_size, D]
Phase 3 (project_output) — Ulysses all-to-all (reverse): scatter heads, gather seq → [B, S/cp_size, H, D]After Ulysses’ pre-attention a2a, each rank holds S / ringattn_size tokens with
H / ulysses_size heads. The ring attention then rotates KV across the
ringattn_size-rank ring group, giving each rank attention over the full sequence
for its local head subset.
Device Mesh Layout
Section titled “Device Mesh Layout”The device mesh is constructed in init_parallel_state with dimension order:
[pp, dp_replicate, dp_shard, ringattn, ulysses, tp]Only dimensions with size > 1 (plus dp_shard) are materialized. Flat mesh aliases
are created for compound groups:
| Alias | Covers |
|---|---|
dp | dp_replicate × dp_shard |
sp | ringattn × ulysses (unified SP group) |
dp_shard_sp | dp_shard × [ring] × [ulysses] depending on cp_fsdp_mode |
dp_sp | dp_replicate × dp_shard × ringattn × ulysses (loss group) |
The FSDP2 mesh (fsdp_mesh) is the dp_shard_sp submesh, which determines which
ranks shard model parameters together. Its size (fsdp_size) is used by FSDP2’s
reduce-scatter to compute the averaging factor for gradient reduction:
@propertydef fsdp_size(self) -> int: size = self.dp_size if self.cp_fsdp_mode == "all": size *= self.ringattn_size * self.ulysses_size elif self.cp_fsdp_mode == "ulysses_only": size *= self.ulysses_size elif self.cp_fsdp_mode == "ring_only": size *= self.ringattn_size return sizecp_fsdp_mode: SP–FSDP Interaction
Section titled “cp_fsdp_mode: SP–FSDP Interaction”cp_fsdp_mode controls which SP dimensions are folded into the FSDP sharding mesh.
When an SP dimension is folded into FSDP, its reduce-scatter automatically handles
gradient averaging for that dimension, so no separate sync_sp_gradients is needed.
| Value | FSDP mesh includes | Extra grad sync needed |
|---|---|---|
"all" (default) | dp_shard + ring + ulysses | None |
"ulysses_only" | dp_shard + ulysses | Ring group (ringattn_group) |
"ring_only" | dp_shard + ring | Ulysses group (ulysses_group) |
"none" | dp_shard only | Unified SP group (ring × ulysses) |
The sp_grad_sync_group property in ParallelState returns the appropriate group:
@propertydef sp_grad_sync_group(self): if self.cp_fsdp_mode == "all": return None # no-op if self.cp_fsdp_mode == "ulysses_only": return ringattn_group if self.cp_fsdp_mode == "ring_only": return ulysses_group if self.cp_fsdp_mode == "none": return sp_group # full unified SPWhen to use non-"all" modes:
"ulysses_only"can improve memory efficiency when ring attention is used without FSDP parameter sharding across ring ranks (e.g., ring provides long-context support but model fits in memory per ring-group of ranks)."none"decouples SP entirely from FSDP, at the cost of requiring a manual all-reduce across the full SP group after every backward.
Interaction with Other Parallelism Dimensions
Section titled “Interaction with Other Parallelism Dimensions”CP + DP (FSDP2)
Section titled “CP + DP (FSDP2)”DP and CP coexist naturally. The full FSDP sharding mesh is:
dp_shard × [ringattn] × [ulysses] (depends on cp_fsdp_mode)The data loader shards batches across dp_size ranks (the batch_mesh). Each
DP rank feeds a distinct sample to its CP group. Within a CP group, all ranks process
the same batch item but different sequence shards.
The loss is reduced across dp_replicate × dp_shard × ringattn × ulysses (the
loss_mesh / dp_sp group), because all of these ranks compute partial losses on
different parts of the same logical training batch.
CP + PP (Pipeline Parallelism)
Section titled “CP + PP (Pipeline Parallelism)”PP and CP are orthogonal: PP stages split model layers; CP splits the sequence within each stage. Each PP stage runs the same CP topology independently.
Key consideration: PP requires fixed-shape P2P buffers across microbatches. Since CP
pads sequences to 2 * cp_size multiples, the padded sequence length passed through
the PP pipeline is deterministic. pad_micro_batches_for_pp in training_utils.py
pads each microbatch to sample_packing_sequence_len / sp_size (rounded up to
pad_to_multiple_of), ensuring buffer shape consistency.
cu_seqlens in PP are handled by extending the last real document to cover the
padded tokens (rather than adding a zero-length padding document), which avoids
degenerate varlen inputs in flash attention backward.
CP + EP (Expert Parallelism)
Section titled “CP + EP (Expert Parallelism)”Expert parallelism (EP) and SP share the same physical GPUs. They are managed with
separate device meshes: the main device_mesh covers PP/DP/SP/TP dimensions, while
ep_fsdp_device_mesh covers EP and its internal FSDP group.
For the Qwen3-30B-A3B MoE model configuration with PP+EP+CP, both EP and CP are folded onto the same physical ranks:
pipeline_parallel_size: 2ringattn_parallel_size: 4expert_parallel_size: 4With pp=2, ringattn=4, ep=4 on 32 GPUs: each PP stage has 16 GPUs; within a
stage, the 16 GPUs handle both ep=4 expert routing and ringattn=4 sequence shards.
EP operates on the MoE expert layers while CP operates on the attention layers within
the same forward pass. They do not interfere because their collectives act on
different tensors at different points in the model.
For the pure Ulysses + EP configuration (qwen3_30b_a3b_cp1_sp8.yaml):
ulysses_parallel_size: 8expert_parallel_size: 8On 8 GPUs, every GPU is in both the Ulysses group (for attention) and the EP group (for MoE layers). This is the most memory-efficient configuration for MoE models with many KV heads.
TextSequenceShardCollator
Section titled “TextSequenceShardCollator”TextSequenceShardCollator in src/xorl/data/collators/sequence_shard_collator.py
is the entry point for preparing batches for CP training. It is activated whenever
cp_size > 1.
Processing Pipeline
Section titled “Processing Pipeline”For each batch, the collator performs the following steps in order:
-
Normalize inputs to 2D tensors
[1, S](handle list inputs, ensurendim==2). -
Store original position IDs before any modification:
batch["_original_position_ids"] = position_ids.clone() -
Compute padding amount:
pad_multiple = 2 * cp_sizeif ring attention is active, elsecp_sizepad_length = ceil(S / pad_multiple) * pad_multiple - S
-
Pad tensors:
input_ids: padded withpad_token_idlabels: padded withIGNORE_INDEX(-100)attention_mask: padded with 1position_ids: padded with chunked sequential arange (1024-token chunks per fake document, to avoid creating one enormous padding document in cu_seqlens)
-
Zigzag reorder (only when
ringattn_size > 1):- Reorder
input_ids,labels,attention_mask,position_ids, and any RL fields (target_tokens,logprobs,advantages,rollout_logprobs) usingzigzag_reorder_packed_sequencewith the pre-pad position IDs for boundary detection.
- Reorder
-
Slice per rank (
sp_slice):input_idsandlabelsare sliced: rankrgets tokens[r*chunk, (r+1)*chunk)position_idsis NOT sliced — the full padded tensor is forwarded
-
Compute cu_seqlens:
- Ring attention: use
_original_position_ids(pre-zigzag, pre-pad boundaries) to produce cu_seqlens reflecting true document structure - Ulysses only: use standard
add_flash_attention_kwargs_from_position_ids
- Ring attention: use
Why position_ids Is Not Sliced
Section titled “Why position_ids Is Not Sliced”The ring attention _scale_cu_seqlens_for_ringattn function divides cu_seqlens by
ringattn_size, producing correct per-rank boundaries. The model’s RoPE layer uses
position_ids for indexing into (cos, sin) tables; the zigzag assignment already
places the correct tokens at each rank, so the full position_ids tensor provides
the right indices. Slicing position_ids would cause incorrect RoPE embedding
positions for zigzag-reordered tokens.
For Ulysses-only mode (no ring), position_ids is also kept full because the Ulysses
all-to-all gathers the full sequence before attention, and the strategy’s
prepare_position_embeddings slices the RoPE (cos, sin) tensors at that point
using slice_position_embedding.
gather_outputs and Loss Computation
Section titled “gather_outputs and Loss Computation”After the model forward pass, the logits from each SP rank cover only its local sequence shard. Before computing the cross-entropy loss, the full sequence logits must be gathered.
gather_outputs in src/xorl/distributed/sequence_parallel/data.py gathers along
the sequence dimension across the unified SP group:
def gather_outputs(x, gather_dim, padding_dim=None, unpad_dim_size=None, scale_grad=True, group=None): group = get_unified_sequence_parallel_group() if group is None else group x = _Gather.apply(group, x, gather_dim, scale_grad) if padding_dim is not None: x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group) return xThe _Gather autograd function uses all_gather in the forward pass and splits
the gradient back to the local shard in the backward pass (optionally scaling by
seq_world_size for gradient normalization).
The loss is then reduced across the full dp_sp group (DP ranks × SP ranks), since
all of these ranks compute partial losses on different shards of the training batch.
Configuration Reference
Section titled “Configuration Reference”Parameter Table
Section titled “Parameter Table”| Parameter | Type | Default | Description |
|---|---|---|---|
ulysses_parallel_size | int | 1 | Ulysses SP group size. Constraint: num_heads % ulysses_size == 0 |
ringattn_parallel_size | int | 1 | Ring attention group size. Constraint: doc lengths divisible by 2 * ringattn_size |
cp_fsdp_mode | str | "all" | SP-FSDP overlap mode: "all", "ulysses_only", "ring_only", "none" |
The total CP size and data parallel size are derived automatically:
cp_size = ulysses_size × ringattn_sizedp_size = world_size / (pp_size × cp_size × tp_size)Example Configurations
Section titled “Example Configurations”Pure Ulysses SP (Qwen3-8B, 8 GPUs, 64K context):
train: ulysses_parallel_size: 8 ringattn_parallel_size: 1 data_parallel_replicate_size: 1 data_parallel_shard_size: 1 sample_packing_sequence_len: 64000All 8 GPUs form one Ulysses group. No DP. Each attention layer does two all-to-all
collectives per layer. Requires 64 heads % 8 == 0 (Qwen3-8B has 64 Q heads, 8 KV
heads; async strategy is used since 8 <= 8).
Pure Ring Attention (Qwen3-8B, 8 GPUs, 64K context):
train: ulysses_parallel_size: 1 ringattn_parallel_size: 8 data_parallel_replicate_size: 1 data_parallel_shard_size: 1 sample_packing_sequence_len: 64000All 8 GPUs form one ring. Each ring step sends S/8 × H_kv × D × 2 bytes (K+V).
Documents must be divisible by 2*8=16 sub-chunks. No head divisibility constraint.
Hybrid CP (Qwen3-8B, 8 GPUs: ring=2, ulysses=4):
train: ulysses_parallel_size: 4 ringattn_parallel_size: 2 data_parallel_replicate_size: 1 data_parallel_shard_size: 1 sample_packing_sequence_len: 64000cp_size = 8. Each of the 8 GPUs is in one 4-rank Ulysses group and one 2-rank ring
group. Padding is to multiples of 2*8=16. Documents must be divisible by 2*2=4.
After Ulysses a2a each rank holds S/2 tokens and H/4 heads; ring then provides
full attention coverage.
Pure Ulysses, 128K context (Qwen3-8B, 4 GPUs):
train: ulysses_parallel_size: 4 sample_packing_sequence_len: 1280004 GPUs, no ring. 128K tokens per sample. Each GPU processes 32K tokens locally; Ulysses a2a provides full attention at 128K length.
MoE + Ulysses + EP (Qwen3-30B-A3B, 8 GPUs: sp=8, ep=8):
train: ulysses_parallel_size: 8 ringattn_parallel_size: 1 expert_parallel_size: 8 sample_packing_sequence_len: 128000All 8 GPUs serve double duty: Ulysses group for attention layers, EP group for MoE
routing. Requires num_attention_heads % 8 == 0 (Qwen3-30B-A3B has 16 Q heads, 8 KV
heads; 8 <= 8 so async strategy applies).
PP + Ring + EP (Qwen3-30B-A3B, 32 GPUs: pp=2, ring=4, ep=4):
train: pipeline_parallel_size: 2 ringattn_parallel_size: 4 expert_parallel_size: 4 sample_packing_sequence_len: 9600032 GPUs total. Each PP stage has 16 GPUs forming a 4-rank ring attention group and
4-rank EP group. Documents must be divisible by 2*4=8 for zigzag. PP buffers are
padded to 96000 / 4 = 24000 tokens per rank.
Common Pitfalls and Debugging
Section titled “Common Pitfalls and Debugging”“Document length not divisible by 2*ringattn_size”
The zigzag reorder requires each packed document’s length to be divisible by
2 * ringattn_size. This is enforced in zigzag_reorder_packed_sequence. Check
that:
- Individual source documents, after tokenization and packing, have lengths that
align to
2 * ringattn_size. - SP padding is applied before zigzag reorder (the collator handles this correctly).
- If using custom collators or data pipelines, ensure documents are pre-aligned.
NaN in flash attention backward with ring attention
Usually caused by incorrect cu_seqlens. Verify that _original_position_ids is
present in the batch and that cu_seq_lens_q/cu_seq_lens_k are computed from it
(not from the zigzag-reordered position_ids).
FSDP size mismatch / gradient scale errors
Ensure cp_fsdp_mode matches the intended FSDP mesh. If ring ranks are not in the
FSDP mesh (cp_fsdp_mode="ulysses_only" or "none"), sync_sp_gradients must be
called explicitly after backward and before the optimizer step.
head count assertion failures
Ulysses requires num_heads % ulysses_size == 0. For GQA models, the sync strategy
additionally requires ulysses_size % num_kv_heads == 0 when
ulysses_size > num_kv_heads (for KV head expansion). Check model architecture
docs for head counts before setting ulysses_parallel_size.
PP + CP sequence length alignment
With PP enabled, all microbatches in a gradient accumulation step are padded to
sample_packing_sequence_len / cp_size (per-rank). This target length is rounded
up to align with pad_to_multiple_of (which is set to 2 * cp_size when ring is
active). Ensure sample_packing_sequence_len is divisible by 2 * cp_size for
deterministic PP buffer shapes.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/sequence_parallel/ | Ulysses all-to-all strategy, ring attention P2P, collator, zigzag reorder |