Skip to content

Pipeline Parallelism

Pipeline parallelism (PP) partitions model layers across multiple GPUs along a pipeline axis. Each GPU holds a contiguous slice of layers (a stage) and processes a stream of micro-batches in an overlapping fashion. The approach is complementary to FSDP2 data parallelism and expert parallelism: stages communicate via P2P tensors while each stage’s parameters are independently sharded by FSDP2 within the PP-local FSDP group.

1F1B Schedule (PP=3, 4 microbatches)Stage →Stage 0Stage 1Stage 2time →ForwardBackwardF0F1F2B0F3B1B2B3idleF0F1F2B0F3B1B2B3idleidleF0B0F1B1F2B2F3B3 Pipeline Parallelism: Layer DistributionStage 0 (GPU 0)embed + layers 0..L/2m1m2activations[B,S,H]Stage 1 (GPU 1)layers L/2..L + lm_headm1m2Model layers are split across PP stages. Activations pass between stages via P2P.

The full model is divided into pipeline_parallel_size contiguous slices called stages.

  • Stage 0 (first): embedding layer + the first N transformer layers.
  • Stage K (middle): a block of transformer layers with no embedding or head.
  • Stage P-1 (last): the remaining transformer layers + layer norm + lm_head.

Because embed_tokens and lm_head live on different stages, PP requires tie_word_embeddings: false in the model config. pipeline_module_split enforces this at construction time and raises ValueError otherwise.

Each optimizer step is broken into gradient_accumulation_steps micro-batches (the number of micro-batches equals n_microbatches passed to build_pipeline_schedule). The PP schedule drives all stages through these micro-batches in lock-step, passing activations from stage K to stage K+1 via NCCL P2P send/recv calls.

4 GPUs, PP=2, FSDP=2 (data_parallel_shard_size=2)
PP Stage 0 PP Stage 1
+----------------+ +----------------+
| GPU 0 | -----> | GPU 2 | FSDP shard group 0
| embed_tokens | | layers N..M |
| layers 0..N-1 | | norm + lm_head |
+----------------+ +----------------+
| GPU 1 | -----> | GPU 3 | FSDP shard group 1
| embed_tokens | | layers N..M |
| layers 0..N-1 | | norm + lm_head |
+----------------+ +----------------+

All four GPUs participate in the PP process group (for P2P activations). GPUs 1 form one FSDP shard group; GPUs 3 form another.


All n_microbatches forwards run first (filling the pipeline), then all backwards run in reverse order (draining it).

Time →
Stage 0: F0 F1 F2 F3 . . . . B3 B2 B1 B0
Stage 1: . F0 F1 F2 F3 . . B3 B2 B1 B0 .

Bubble fraction (fraction of time wasted waiting):

bubble_fraction = (P - 1) / (n_microbatches + P - 1)

where P = number of stages. With P=2, n_microbatches=4: bubble = 1/5 = 20%. The bubble shrinks as n_microbatches grows.

Memory: All micro-batch activations from the forward phase are live simultaneously before any backward runs, so peak activation memory scales with n_microbatches.

After the pipeline is filled, each stage immediately processes one backward for every forward, keeping activation memory bounded to P micro-batches regardless of n_microbatches.

Time →
Stage 0: F0 F1 F2 F3 B0 B1 B2 B3
Stage 1: . F0 F1 F2 B0 F3 B1 B2 B3

Bubble fraction: same formula as GPipe — (P-1)/(n_microbatches + P-1).

Memory: bounded to P in-flight micro-batches (activation checkpointing further reduces this).

PropertyGPipe1F1B
Bubble fraction(P-1)/(m+P-1)(P-1)/(m+P-1)
Peak activation memoryO(m)O(P)
Implementation complexitysimplemoderate
Default in xorlnoyes

For most training scenarios the 1F1B schedule is preferred because its activation memory is independent of n_microbatches.

train:
pipeline_parallel_schedule: "1F1B" # or "GPipe"

Both schedules are backed by torch.distributed.pipelining via get_schedule_class. The schedule is constructed with scale_grads=False because xorl normalizes gradients manually after the backward (see Section 7).


build_parallelize_model()
└─ generate_llm_fqn_per_model_part() # compute FQN assignment
└─ pipeline_module_split() # prune + patch each stage
└─ _recursive_prune() # remove unneeded layers
└─ _pp_forward patch # HF dataclass → raw tensor I/O
└─ (per model_part) FSDP2 wrapping

Assigns module fully-qualified names (FQNs) to each stage. The algorithm treats input modules (embedding) as input_weight effective layers and output modules (norm + lm_head) as output_weight effective layers, then distributes num_layers + input_weight + output_weight effective layers evenly across stages, with the first extra_layers stages each getting one additional layer.

Example for Qwen3-8B (36 transformer layers, PP=2):

num_effective_layers = 36 + 1 + 1 = 38
layers_per_stage = 38 // 2 = 19
extra_layers = 38 % 2 = 0
Stage 0: ["model.embed_tokens",
"model.layers.0", ..., "model.layers.17"] # 18 transformer layers
Stage 1: ["model.layers.18", ..., "model.layers.35",
"model.norm", "lm_head"] # 18 transformer layers

The model must implement get_pp_module_config() returning a dict with:

  • num_layers: total transformer layer count
  • input_fqns: list of FQNs for the embedding stage
  • layer_prefix: FQN prefix for the transformer layer list (e.g., "model.layers")
  • output_fqns: list of FQNs for the output stage
  • always_keep_fqns (optional): FQNs to keep on every stage (e.g., "model.rotary_emb")

For each rank:

  1. Deep-copies the whole model.
  2. Calls _recursive_prune to set unused submodules to None, preserving original ModuleList indices so checkpoint keys remain valid.
  3. Patches the model’s forward method to _pp_forward (raw tensor I/O).
  4. Constructs a PipelineStage from the pruned model, binding it to the PP process group for P2P communication.

The recursive pruner handles HF-style nested structures (e.g., Qwen3ForCausalLM.model.layers) by walking the module tree and nullifying entries in ModuleList/ModuleDict that are not assigned to this stage.

After splitting, build_parallelize_model iterates over model_parts and applies FSDP2 to each independently. Key behavior:

  • reshard_after_forward defaults to False for PP (auto-set in parallelize_model_fsdp2). PP’s 1F1B schedule holds micro-batch activations across forward passes before the matching backward; resharding between them would require re-gathering parameters during the backward, which is both wasteful and can interact poorly with the PP schedule’s internal bookkeeping. Setting reshard_after_forward=True via config is supported for memory-constrained runs.
  • The FSDP shard group is ps.fsdp_group (the DP-local group), not the full world group, so each stage’s parameters are sharded only within the DP dimension.
  • After FSDP wrapping, stages[i].submod is updated to point to the wrapped model so that P2P communication uses the wrapped module’s forward.

4. Per-Microbatch Metadata: _pp_batch_metadata

Section titled “4. Per-Microbatch Metadata: _pp_batch_metadata”

PipelineStage drives the model forward with raw tensors (no keyword arguments other than what the patched _pp_forward provides). Flash Attention varlen mode requires cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, and position_ids for correct packed-sequence attention. These cannot be passed via the P2P tensor stream.

Before calling pp_schedule.step(...), forward_backward_pp attaches a collections.deque to each model_part:

for model_part in model_parts:
model_part._pp_batch_metadata = deque(pp_metadata_list)

pp_metadata_list is built by extracting position_ids, cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k from each micro-batch dict.

Inside _pp_forward, every time the model’s forward is called for a microbatch, it pops one entry from the queue:

if torch.is_grad_enabled():
metadata_queue = getattr(self, "_pp_batch_metadata", None)
if metadata_queue:
metadata = metadata_queue.popleft()
position_ids = metadata.pop("position_ids", None)
extra_kwargs = {k: v.to(x.device) ... for k, v in metadata.items()}

The torch.is_grad_enabled() guard is critical: PipelineStage._shape_inference() runs a full model forward under torch.no_grad() with zero-filled tensors on the first step to determine output shapes for P2P buffer allocation. Without the guard, shape inference would consume real metadata entries from the queue, causing subsequent training forwards to pop stale or shifted metadata.

The popped extra_kwargs (containing cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) are passed directly to the model’s original forward method, which forwards them to the attention layers:

outputs = self._pp_original_forward(
inputs_embeds=x, position_ids=position_ids,
use_cache=False, output_hidden_states=False,
**extra_kwargs,
)

This ensures every stage, including middle stages that receive hidden states rather than input IDs, has the correct sequence metadata for varlen flash attention.

When pp_variable_seq_lengths is disabled or metadata is absent but cp_size > 1, _pp_forward generates a sequential position_ids spanning seq_len * cp_size tokens so that the RoPE cache is large enough for the full sequence.


PipelineStage allocates P2P send/recv buffers once from the first shape it observes and reuses them for all subsequent micro-batches. With sample packing, different training steps may produce batches of different total lengths, which would cause shape mismatches between sender and receiver.

Static approach (legacy / pp_variable_seq_lengths=False)

Section titled “Static approach (legacy / pp_variable_seq_lengths=False)”

All micro-batches are padded upfront to sample_packing_sequence_len / cp_size before any PP communication begins. A single PipelineStage (and schedule) is created once and reused for every step. This is simple but wastes compute and memory on padding tokens.

Dynamic approach (pp_variable_seq_lengths=True, default)

Section titled “Dynamic approach (pp_variable_seq_lengths=True, default)”

Step 1 — Negotiate the global max sequence length.

negotiate_pp_seq_len runs an all-reduce MAX across all PP ranks:

local_max = max(mb["input_ids"].shape[-1] for mb in micro_batches)
t = torch.tensor([local_max], dtype=torch.int64, device=device)
dist.all_reduce(t, op=dist.ReduceOp.MAX, group=pp_group)
seq_len = int(t.item())

Because different PP ranks may receive different data (each FSDP replica fetches independently), their local max lengths can differ. After the all-reduce, all ranks share the same target length.

Step 2 — Pad to that length.

pad_micro_batches_for_pp pads input_ids, labels, attention_mask, position_ids, and the flash-attention length metadata to the negotiated target.

Step 3 — Per-seq_len schedule cache.

_get_pp_schedule(seq_len) is keyed by seq_len. On a cache miss, a new PipelineStage is built from the existing (FSDP-wrapped) model_part:

stage = build_pp_stage(
model_parts[0],
pp_rank=ps.pp_rank,
num_stages=ps.pp_size,
device=get_device_type(),
pp_group=ps.pp_group,
)

build_pp_stage calls PipelineStage(model_part, ...) without deepcopy, so the stage wraps the same FSDP-wrapped parameter tensors. Only the internal P2P buffer sizes differ between cache entries. This is cheap compared to pipeline_module_split.

After the first step at a given seq_len, subsequent steps at the same length hit the cache and avoid buffer re-allocation entirely.

Server path note: ModelRunner._get_pp_schedule is keyed by (n_microbatches, seq_len) because the server may issue forward-backward calls with varying micro-batch counts across requests.

pp_variable_seq_lengths (default true) controls whether the pipeline pads sequences to a static worst-case length or negotiates a dynamic per-step maximum.

Without it (pp_variable_seq_lengths: false): every PP micro-batch is padded up front to the static sample_packing_sequence_len (divided by CP size). A single PipelineStage object and schedule are created once and reused for all steps. This is simple but wastes compute and memory whenever actual batch lengths are shorter than the static cap — which is common with variable-length packed data.

With it (default true): rank 0 issues an all-reduce MAX over the PP process group to find the actual maximum sequence length among all micro-batches at the current step, then pads only to that dynamic maximum. Steps with shorter batches pay proportionally less compute.

Each unique negotiated seq_len gets its own cached PipelineStage (keyed by seq_len in the local trainer, by (n_microbatches, seq_len) in the server runner). Because P2P send/recv buffers are allocated to match the stage’s expected input shape, the cache ensures buffer shapes always match the actual tensor shapes — no re-allocation penalty after the first occurrence of a given length.

Trade-off: the all-reduce adds a small synchronization overhead per step (one collective over the PP group, not the full world). For workloads with highly variable sequence lengths this overhead is negligible compared to the compute saved by avoiding padding to the static maximum.


target_sharded = sample_packing_sequence_len // sp_size # SP-local slice length
# round up to pad_to_multiple_of
target_sharded = ceil(target_sharded / pad_to_multiple_of) * pad_to_multiple_of
full_target = target_sharded * sp_size

When pp_variable_seq_lengths=True, sample_packing_sequence_len is set to seq_len * cp_size (the negotiated global max scaled by context-parallel factor).

input_ids is padded with 0; labels with IGNORE_INDEX (-100) so padded positions do not contribute to the loss; attention_mask with 0.

position_ids is padded with 0 (padded positions will be masked out by the labels anyway).

Flash attention metadata (cu_seq_lens) — growing the last document

Section titled “Flash attention metadata (cu_seq_lens) — growing the last document”

Naively inserting a new all-zero “padding document” at the end of cu_seq_lens would create a degenerate sequence of length 0, which causes NaN in the FA3 varlen backward because the max_length_q/max_length_k statistics become inconsistent.

Instead, xorl extends the last real document by increasing cu_seq_lens[-1] to full_target:

for key in ("cu_seq_lens_q", "cu_seq_lens_k"):
if key in mb and mb[key][-1] < full_target:
mb[key] = mb[key].clone()
mb[key][-1] = full_target

This means the padding tokens are logically appended to the last document. Their loss contribution is zero because the corresponding labels positions are IGNORE_INDEX. max_length_q / max_length_k are then recomputed from the updated cu_seq_lens diff to stay consistent.


With PP, only the last stage computes the loss. The PyTorch pipeline schedule internally calls loss_fn(logits, labels).backward() — there is no natural place to inject the / global_valid_tokens factor before the backward accumulates gradients across micro-batches.

@torch.compile
def pp_loss_fn(pred, labels):
return F.cross_entropy(
pred.flatten(0, 1).float(),
labels.flatten(0, 1),
ignore_index=IGNORE_INDEX,
reduction="sum", # CE sum, NOT mean
)

reduction="sum" returns CE_sum — the total cross-entropy over all non-ignored tokens. Because no normalization happens inside the loss function, gradients accumulating inside pp_schedule.step() are proportional to token count, not normalized.

After pp_schedule.step() returns, the trainer multiplies all gradients by 1 / global_valid_tokens:

scale = 1.0 / gvt
for model_part in self.model_parts:
for p in model_part.parameters():
if p.grad is not None:
p.grad.mul_(scale)

global_valid_tokens is pre-computed via count_valid_tokens (all-reduce SUM across fsdp_group) before the forward-backward so every PP stage uses the same denominator.

This two-step approach (CE_sum loss → manual grad.mul_(1/gvt)) is mathematically equivalent to reduction="mean" but avoids having to compute global_valid_tokens inside pp_loss_fn, which would require cross-rank communication inside the loss function itself (impractical inside the schedule).

The last stage computes total_loss = sum(losses). Middle and first stages have total_loss = -1.0. The true loss value is broadcast to all stages via MAX all-reduce over pp_group:

dist.all_reduce(loss_tensor, op=dist.ReduceOp.MAX, group=pp_group)

A value of -1.0 is safe as the sentinel because CE_sum is always non-negative.

After the PP step, _reduce_metrics in Trainer uses op="sum" over fsdp_group (the DP replica group) to aggregate the loss:

total_loss = all_reduce(total_loss, op="sum", group=self.ps.fsdp_group)

This differs from the non-PP path which uses the default mean-reduce. The reason: different DP replicas process different data, so their CE_sum values are additive (each replica holds a disjoint subset of the global batch). A mean-reduce would divide by the DP size, undoing the correct normalization by global_valid_tokens.


8. Interaction with Other Parallelism Dimensions

Section titled “8. Interaction with Other Parallelism Dimensions”

Each PP stage is independently wrapped with FSDP2. The FSDP shard group is ps.fsdp_group, which spans only the DP replicas of the same PP stage, not the full world. This means:

  • Parameters on stage 0 are sharded among GPUs {0, 1, 2, ...} that all hold stage 0; parameters on stage 1 are sharded among their own set.
  • FSDP all-gather and reduce-scatter happen independently per stage, with no cross-stage communication.
  • reshard_after_forward defaults to False when PP is active (set automatically in parallelize_model_fsdp2) to avoid redundant re-gathers during the 1F1B backward phase. Override with reshard_after_forward: true in config to trade memory for communication.

PP and EP are orthogonal: EP shards MoE expert weights within each PP stage, and the micro-batches flowing through the PP pipeline are the same batches that EP dispatches to experts. No special handling is needed in forward_backward_pp; each stage’s FSDP+EP wrapping handles expert dispatch transparently.

With routing replay (R3), the _pp_forward function correctly switches the routing stage from "replay_backward" to "record" (or "replay_forward") before the forward and restores it afterward. The no_grad guard inside _pp_forward ensures the shape-inference forward does not record stale routing decisions into the replay cache.

PP + Context Parallelism (CP: Ulysses or ring-attention)

Section titled “PP + Context Parallelism (CP: Ulysses or ring-attention)”

CP splits each sequence across multiple GPUs. The per-rank sequence length seen by the PP pipeline is full_seq_len / cp_size. Two adjustments are needed:

  1. negotiate_pp_seq_len returns a length in the CP-sharded space (i.e., after CP splits the sequence). The padding target is therefore seq_len * cp_size passed to pad_micro_batches_for_pp as sample_packing_sequence_len.
  2. When CP is active and position_ids are absent from the metadata queue, _pp_forward synthesizes position IDs covering the full seq_len * cp_size range so that the RoPE cache is initialized to sufficient size:
    if position_ids is None and ps.cp_size > 1:
    full_seq_len = x.shape[1] * ps.cp_size
    position_ids = torch.arange(full_seq_len, ...).expand(batch, -1)

ConstraintReason
gradient_accumulation_steps >= pipeline_parallel_sizeThe pipeline must have enough micro-batches to fill all stages before any stage stalls waiting for work. n_microbatches = gradient_accumulation_steps.
tie_word_embeddings: falseembed_tokens (stage 0) and lm_head (last stage) cannot share weights across different GPUs.
Single model chunk per rank (no virtual stages)The current implementation uses pipeline_module_split which assigns exactly one stage per rank. Multi-stage (virtual) schedules like ScheduleGPipe with num_microbatches > num_stages in the looped sense are not supported.
Model must implement get_pp_module_config()The stage-splitting algorithm requires structured metadata about input/layer/output FQNs.
Padding must be a multiple of pad_to_multiple_ofFA3 varlen kernels may require sequence length alignment (e.g., multiples of 8 or 64).
PP is not compatible with torch.compile on the whole modelPer-stage torch.compile of individual decoder layers (via basic_modules) is supported, but whole-model compile is not.

GPUs: 8
PP=2, FSDP shard=4, DP replicate=1
Stage 0: embed_tokens + layers 0..17 (18 layers)
Stage 1: layers 18..35 + norm + lm_head (18 layers)
Micro-batches: 4 (gradient_accumulation_steps=4)
Bubble: (2-1)/(4+2-1) = 1/5 = 20%
examples/local/dummy/configs/full/qwen3_8b_pp2.yaml
model:
model_path: Qwen/Qwen3-8B
attn_implementation: flash_attention_3
data:
datasets:
- path: dummy
type: tokenized
max_seq_len: 8000
sample_packing_method: sequential
sample_packing_sequence_len: 8000
train:
data_parallel_mode: fsdp2
pipeline_parallel_size: 2
data_parallel_replicate_size: 1
data_parallel_shard_size: 4
micro_batch_size: 1
gradient_accumulation_steps: 4
# pp_variable_seq_lengths: true (default)
enable_gradient_checkpointing: true
enable_full_shard: true
init_device: meta
load_weights_mode: broadcast

Process group layout (8 GPUs, ranks 0–7):

  • PP group 0 (stage 0): ranks 3
  • PP group 1 (stage 1): ranks 7
  • FSDP shard group for stage 0: ranks 3
  • FSDP shard group for stage 1: ranks 7

Example 2: PP=2 + EP=4 + CP=4, Qwen3-30B-A3B, 8 GPUs

Section titled “Example 2: PP=2 + EP=4 + CP=4, Qwen3-30B-A3B, 8 GPUs”
GPUs: 8
PP=2, EP=4, ring-attn CP=4, DP_shard=1, DP_replicate=1
world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8
EP=4 folded onto each PP stage's 4 ranks (ep_fsdp_size=1)
Micro-batches: 2
Bubble: (2-1)/(2+2-1) = 1/3 ≈ 33% (small n_microbatches; increase GAS to reduce)
seq_len per CP rank: 96000 / 4 = 24000 tokens
examples/local/dummy/configs/full/qwen3_30b_a3b_pp2_ep4_cp4_muon.yaml
model:
model_path: Qwen/Qwen3-30B-A3B
attn_implementation: flash_attention_3
moe_implementation: triton
data:
datasets:
- path: dummy
max_seq_len: 48001
sample_packing_sequence_len: 96000
train:
data_parallel_mode: fsdp2
pipeline_parallel_size: 2
ringattn_parallel_size: 4
expert_parallel_size: 4
data_parallel_shard_size: 1
micro_batch_size: 1
gradient_accumulation_steps: 2
reshard_after_forward: true # memory-saving for large model
optimizer: muon
enable_gradient_checkpointing: true

ParameterTypeDefaultDescription
train.pipeline_parallel_sizeint1Number of pipeline stages. Set to 1 to disable PP.
train.pipeline_parallel_schedulestr"1F1B"Schedule algorithm: "1F1B" or "GPipe".
train.pp_variable_seq_lengthsboolTrueNegotiate per-step max seq_len across PP ranks and pad dynamically, instead of padding to the static sample_packing_sequence_len.
train.gradient_accumulation_stepsintrequiredSets n_microbatches. Must be >= pipeline_parallel_size.
train.reshard_after_forwardbool or NoneNone (auto)FSDP2 reshard behavior. Auto-set to False for PP; set True to save memory at the cost of extra communication.
data.sample_packing_sequence_lenintrequired for PPMaximum sequence length used as the static padding target when pp_variable_seq_lengths=False. Also used as the upper bound for negotiate_pp_seq_len.
data.pad_to_multiple_ofint128Rounds the padded sequence length up to this multiple for GPU efficiency.

FileRole
src/xorl/distributed/pipeline_parallel.pyStage construction (generate_llm_fqn_per_model_part, pipeline_module_split), forward patching (_pp_forward), schedule builder (build_pipeline_schedule, build_pp_stage)
src/xorl/trainers/training_utils.pyShared PP utilities: forward_backward_pp, negotiate_pp_seq_len, pad_micro_batches_for_pp, pp_loss_fn
src/xorl/trainers/trainer.pyOffline trainer: _init_pp_schedule_cache, _get_pp_schedule, _forward_backward_pp, _reduce_metrics
src/xorl/server/runner/model_runner.pyServer (online) path: _get_pp_schedule (keyed by (n_microbatches, seq_len)), _forward_backward_pp
src/xorl/distributed/torch_parallelize.pyOrchestrates stage split + FSDP2 wrapping (build_parallelize_model)
src/xorl/arguments.pypipeline_parallel_size, pipeline_parallel_schedule, pp_variable_seq_lengths, reshard_after_forward
FileDescription
src/xorl/distributed/pipeline_parallel.pyStage construction, 1F1B/GPipe schedule, _pp_forward, build_pp_stage