Pipeline Parallelism
Pipeline parallelism (PP) partitions model layers across multiple GPUs along a pipeline axis. Each GPU holds a contiguous slice of layers (a stage) and processes a stream of micro-batches in an overlapping fashion. The approach is complementary to FSDP2 data parallelism and expert parallelism: stages communicate via P2P tensors while each stage’s parameters are independently sharded by FSDP2 within the PP-local FSDP group.
1. What Pipeline Parallelism Does
Section titled “1. What Pipeline Parallelism Does”Layer partitioning
Section titled “Layer partitioning”The full model is divided into pipeline_parallel_size contiguous slices called stages.
- Stage 0 (first): embedding layer + the first N transformer layers.
- Stage K (middle): a block of transformer layers with no embedding or head.
- Stage P-1 (last): the remaining transformer layers + layer norm +
lm_head.
Because embed_tokens and lm_head live on different stages, PP requires
tie_word_embeddings: false in the model config. pipeline_module_split enforces this
at construction time and raises ValueError otherwise.
Micro-batches
Section titled “Micro-batches”Each optimizer step is broken into gradient_accumulation_steps micro-batches (the
number of micro-batches equals n_microbatches passed to build_pipeline_schedule).
The PP schedule drives all stages through these micro-batches in lock-step, passing
activations from stage K to stage K+1 via NCCL P2P send/recv calls.
GPU layout example
Section titled “GPU layout example”4 GPUs, PP=2, FSDP=2 (data_parallel_shard_size=2)
PP Stage 0 PP Stage 1+----------------+ +----------------+| GPU 0 | -----> | GPU 2 | FSDP shard group 0| embed_tokens | | layers N..M || layers 0..N-1 | | norm + lm_head |+----------------+ +----------------+| GPU 1 | -----> | GPU 3 | FSDP shard group 1| embed_tokens | | layers N..M || layers 0..N-1 | | norm + lm_head |+----------------+ +----------------+All four GPUs participate in the PP process group (for P2P activations). GPUs 1 form one FSDP shard group; GPUs 3 form another.
2. Schedules: GPipe and 1F1B
Section titled “2. Schedules: GPipe and 1F1B”All n_microbatches forwards run first (filling the pipeline), then all backwards
run in reverse order (draining it).
Time →Stage 0: F0 F1 F2 F3 . . . . B3 B2 B1 B0Stage 1: . F0 F1 F2 F3 . . B3 B2 B1 B0 .Bubble fraction (fraction of time wasted waiting):
bubble_fraction = (P - 1) / (n_microbatches + P - 1)where P = number of stages. With P=2, n_microbatches=4:
bubble = 1/5 = 20%. The bubble shrinks as n_microbatches grows.
Memory: All micro-batch activations from the forward phase are live simultaneously
before any backward runs, so peak activation memory scales with n_microbatches.
1F1B (one-forward-one-backward, default)
Section titled “1F1B (one-forward-one-backward, default)”After the pipeline is filled, each stage immediately processes one backward for every
forward, keeping activation memory bounded to P micro-batches regardless of
n_microbatches.
Time →Stage 0: F0 F1 F2 F3 B0 B1 B2 B3Stage 1: . F0 F1 F2 B0 F3 B1 B2 B3Bubble fraction: same formula as GPipe — (P-1)/(n_microbatches + P-1).
Memory: bounded to P in-flight micro-batches (activation checkpointing further
reduces this).
Comparison
Section titled “Comparison”| Property | GPipe | 1F1B |
|---|---|---|
| Bubble fraction | (P-1)/(m+P-1) | (P-1)/(m+P-1) |
| Peak activation memory | O(m) | O(P) |
| Implementation complexity | simple | moderate |
| Default in xorl | no | yes |
For most training scenarios the 1F1B schedule is preferred because its
activation memory is independent of n_microbatches.
Configuring the schedule
Section titled “Configuring the schedule”train: pipeline_parallel_schedule: "1F1B" # or "GPipe"Both schedules are backed by torch.distributed.pipelining via get_schedule_class.
The schedule is constructed with scale_grads=False because xorl normalizes
gradients manually after the backward (see Section 7).
3. Stage Construction
Section titled “3. Stage Construction”Entry point
Section titled “Entry point”build_parallelize_model() └─ generate_llm_fqn_per_model_part() # compute FQN assignment └─ pipeline_module_split() # prune + patch each stage └─ _recursive_prune() # remove unneeded layers └─ _pp_forward patch # HF dataclass → raw tensor I/O └─ (per model_part) FSDP2 wrappinggenerate_llm_fqn_per_model_part
Section titled “generate_llm_fqn_per_model_part”Assigns module fully-qualified names (FQNs) to each stage. The algorithm treats
input modules (embedding) as input_weight effective layers and output modules
(norm + lm_head) as output_weight effective layers, then distributes
num_layers + input_weight + output_weight effective layers evenly across stages,
with the first extra_layers stages each getting one additional layer.
Example for Qwen3-8B (36 transformer layers, PP=2):
num_effective_layers = 36 + 1 + 1 = 38layers_per_stage = 38 // 2 = 19extra_layers = 38 % 2 = 0
Stage 0: ["model.embed_tokens", "model.layers.0", ..., "model.layers.17"] # 18 transformer layersStage 1: ["model.layers.18", ..., "model.layers.35", "model.norm", "lm_head"] # 18 transformer layersThe model must implement get_pp_module_config() returning a dict with:
num_layers: total transformer layer countinput_fqns: list of FQNs for the embedding stagelayer_prefix: FQN prefix for the transformer layer list (e.g.,"model.layers")output_fqns: list of FQNs for the output stagealways_keep_fqns(optional): FQNs to keep on every stage (e.g.,"model.rotary_emb")
pipeline_module_split
Section titled “pipeline_module_split”For each rank:
- Deep-copies the whole model.
- Calls
_recursive_pruneto set unused submodules toNone, preserving originalModuleListindices so checkpoint keys remain valid. - Patches the model’s
forwardmethod to_pp_forward(raw tensor I/O). - Constructs a
PipelineStagefrom the pruned model, binding it to the PP process group for P2P communication.
The recursive pruner handles HF-style nested structures (e.g., Qwen3ForCausalLM.model.layers)
by walking the module tree and nullifying entries in ModuleList/ModuleDict that
are not assigned to this stage.
FSDP2 wrapping after stage construction
Section titled “FSDP2 wrapping after stage construction”After splitting, build_parallelize_model iterates over model_parts and applies
FSDP2 to each independently. Key behavior:
reshard_after_forwarddefaults toFalsefor PP (auto-set inparallelize_model_fsdp2). PP’s 1F1B schedule holds micro-batch activations across forward passes before the matching backward; resharding between them would require re-gathering parameters during the backward, which is both wasteful and can interact poorly with the PP schedule’s internal bookkeeping. Settingreshard_after_forward=Truevia config is supported for memory-constrained runs.- The FSDP shard group is
ps.fsdp_group(the DP-local group), not the full world group, so each stage’s parameters are sharded only within the DP dimension. - After FSDP wrapping,
stages[i].submodis updated to point to the wrapped model so that P2P communication uses the wrapped module’s forward.
4. Per-Microbatch Metadata: _pp_batch_metadata
Section titled “4. Per-Microbatch Metadata: _pp_batch_metadata”The problem
Section titled “The problem”PipelineStage drives the model forward with raw tensors (no keyword arguments
other than what the patched _pp_forward provides). Flash Attention varlen mode
requires cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k, and
position_ids for correct packed-sequence attention. These cannot be passed via
the P2P tensor stream.
The solution: a per-microbatch queue
Section titled “The solution: a per-microbatch queue”Before calling pp_schedule.step(...), forward_backward_pp attaches a
collections.deque to each model_part:
for model_part in model_parts: model_part._pp_batch_metadata = deque(pp_metadata_list)pp_metadata_list is built by extracting position_ids, cu_seq_lens_q,
cu_seq_lens_k, max_length_q, and max_length_k from each micro-batch dict.
Inside _pp_forward, every time the model’s forward is called for a microbatch,
it pops one entry from the queue:
if torch.is_grad_enabled(): metadata_queue = getattr(self, "_pp_batch_metadata", None) if metadata_queue: metadata = metadata_queue.popleft() position_ids = metadata.pop("position_ids", None) extra_kwargs = {k: v.to(x.device) ... for k, v in metadata.items()}The torch.is_grad_enabled() guard is critical: PipelineStage._shape_inference()
runs a full model forward under torch.no_grad() with zero-filled tensors on the
first step to determine output shapes for P2P buffer allocation. Without the guard,
shape inference would consume real metadata entries from the queue, causing subsequent
training forwards to pop stale or shifted metadata.
Flash attention kwargs flow
Section titled “Flash attention kwargs flow”The popped extra_kwargs (containing cu_seq_lens_q, cu_seq_lens_k,
max_length_q, max_length_k) are passed directly to the model’s original
forward method, which forwards them to the attention layers:
outputs = self._pp_original_forward( inputs_embeds=x, position_ids=position_ids, use_cache=False, output_hidden_states=False, **extra_kwargs,)This ensures every stage, including middle stages that receive hidden states rather than input IDs, has the correct sequence metadata for varlen flash attention.
Fallback for CP (context parallelism)
Section titled “Fallback for CP (context parallelism)”When pp_variable_seq_lengths is disabled or metadata is absent but cp_size > 1,
_pp_forward generates a sequential position_ids spanning seq_len * cp_size
tokens so that the RoPE cache is large enough for the full sequence.
5. Variable Sequence Lengths
Section titled “5. Variable Sequence Lengths”The problem
Section titled “The problem”PipelineStage allocates P2P send/recv buffers once from the first shape it
observes and reuses them for all subsequent micro-batches. With sample packing,
different training steps may produce batches of different total lengths, which
would cause shape mismatches between sender and receiver.
Static approach (legacy / pp_variable_seq_lengths=False)
Section titled “Static approach (legacy / pp_variable_seq_lengths=False)”All micro-batches are padded upfront to sample_packing_sequence_len / cp_size
before any PP communication begins. A single PipelineStage (and schedule) is
created once and reused for every step. This is simple but wastes compute and
memory on padding tokens.
Dynamic approach (pp_variable_seq_lengths=True, default)
Section titled “Dynamic approach (pp_variable_seq_lengths=True, default)”Step 1 — Negotiate the global max sequence length.
negotiate_pp_seq_len runs an all-reduce MAX across all PP ranks:
local_max = max(mb["input_ids"].shape[-1] for mb in micro_batches)t = torch.tensor([local_max], dtype=torch.int64, device=device)dist.all_reduce(t, op=dist.ReduceOp.MAX, group=pp_group)seq_len = int(t.item())Because different PP ranks may receive different data (each FSDP replica fetches independently), their local max lengths can differ. After the all-reduce, all ranks share the same target length.
Step 2 — Pad to that length.
pad_micro_batches_for_pp pads input_ids, labels, attention_mask,
position_ids, and the flash-attention length metadata to the negotiated target.
Step 3 — Per-seq_len schedule cache.
_get_pp_schedule(seq_len) is keyed by seq_len. On a cache miss, a new
PipelineStage is built from the existing (FSDP-wrapped) model_part:
stage = build_pp_stage( model_parts[0], pp_rank=ps.pp_rank, num_stages=ps.pp_size, device=get_device_type(), pp_group=ps.pp_group,)build_pp_stage calls PipelineStage(model_part, ...) without deepcopy, so
the stage wraps the same FSDP-wrapped parameter tensors. Only the internal P2P
buffer sizes differ between cache entries. This is cheap compared to pipeline_module_split.
After the first step at a given seq_len, subsequent steps at the same length hit
the cache and avoid buffer re-allocation entirely.
Server path note: ModelRunner._get_pp_schedule is keyed by
(n_microbatches, seq_len) because the server may issue forward-backward calls
with varying micro-batch counts across requests.
pp_variable_seq_lengths
Section titled “pp_variable_seq_lengths”pp_variable_seq_lengths (default true) controls whether the pipeline pads sequences to a static worst-case length or negotiates a dynamic per-step maximum.
Without it (pp_variable_seq_lengths: false): every PP micro-batch is padded up front to the static sample_packing_sequence_len (divided by CP size). A single PipelineStage object and schedule are created once and reused for all steps. This is simple but wastes compute and memory whenever actual batch lengths are shorter than the static cap — which is common with variable-length packed data.
With it (default true): rank 0 issues an all-reduce MAX over the PP process group to find the actual maximum sequence length among all micro-batches at the current step, then pads only to that dynamic maximum. Steps with shorter batches pay proportionally less compute.
Each unique negotiated seq_len gets its own cached PipelineStage (keyed by seq_len in the local trainer, by (n_microbatches, seq_len) in the server runner). Because P2P send/recv buffers are allocated to match the stage’s expected input shape, the cache ensures buffer shapes always match the actual tensor shapes — no re-allocation penalty after the first occurrence of a given length.
Trade-off: the all-reduce adds a small synchronization overhead per step (one collective over the PP group, not the full world). For workloads with highly variable sequence lengths this overhead is negligible compared to the compute saved by avoiding padding to the static maximum.
6. Padding: pad_micro_batches_for_pp
Section titled “6. Padding: pad_micro_batches_for_pp”Target length computation
Section titled “Target length computation”target_sharded = sample_packing_sequence_len // sp_size # SP-local slice length# round up to pad_to_multiple_oftarget_sharded = ceil(target_sharded / pad_to_multiple_of) * pad_to_multiple_offull_target = target_sharded * sp_sizeWhen pp_variable_seq_lengths=True, sample_packing_sequence_len is set to
seq_len * cp_size (the negotiated global max scaled by context-parallel factor).
Token tensor padding
Section titled “Token tensor padding”input_ids is padded with 0; labels with IGNORE_INDEX (-100) so padded
positions do not contribute to the loss; attention_mask with 0.
position_ids is padded with 0 (padded positions will be masked out by the labels
anyway).
Flash attention metadata (cu_seq_lens) — growing the last document
Section titled “Flash attention metadata (cu_seq_lens) — growing the last document”Naively inserting a new all-zero “padding document” at the end of cu_seq_lens
would create a degenerate sequence of length 0, which causes NaN in the FA3 varlen
backward because the max_length_q/max_length_k statistics become inconsistent.
Instead, xorl extends the last real document by increasing cu_seq_lens[-1]
to full_target:
for key in ("cu_seq_lens_q", "cu_seq_lens_k"): if key in mb and mb[key][-1] < full_target: mb[key] = mb[key].clone() mb[key][-1] = full_targetThis means the padding tokens are logically appended to the last document.
Their loss contribution is zero because the corresponding labels positions are
IGNORE_INDEX. max_length_q / max_length_k are then recomputed from the
updated cu_seq_lens diff to stay consistent.
7. Loss Normalization
Section titled “7. Loss Normalization”The core challenge
Section titled “The core challenge”With PP, only the last stage computes the loss. The PyTorch pipeline schedule
internally calls loss_fn(logits, labels).backward() — there is no natural place
to inject the / global_valid_tokens factor before the backward accumulates
gradients across micro-batches.
pp_loss_fn: raw cross-entropy sum
Section titled “pp_loss_fn: raw cross-entropy sum”@torch.compiledef pp_loss_fn(pred, labels): return F.cross_entropy( pred.flatten(0, 1).float(), labels.flatten(0, 1), ignore_index=IGNORE_INDEX, reduction="sum", # CE sum, NOT mean )reduction="sum" returns CE_sum — the total cross-entropy over all non-ignored
tokens. Because no normalization happens inside the loss function, gradients
accumulating inside pp_schedule.step() are proportional to token count, not
normalized.
Post-backward gradient rescaling
Section titled “Post-backward gradient rescaling”After pp_schedule.step() returns, the trainer multiplies all gradients by
1 / global_valid_tokens:
scale = 1.0 / gvtfor model_part in self.model_parts: for p in model_part.parameters(): if p.grad is not None: p.grad.mul_(scale)global_valid_tokens is pre-computed via count_valid_tokens (all-reduce SUM
across fsdp_group) before the forward-backward so every PP stage uses the same
denominator.
This two-step approach (CE_sum loss → manual grad.mul_(1/gvt)) is mathematically
equivalent to reduction="mean" but avoids having to compute global_valid_tokens
inside pp_loss_fn, which would require cross-rank communication inside the
loss function itself (impractical inside the schedule).
Reporting the loss scalar
Section titled “Reporting the loss scalar”The last stage computes total_loss = sum(losses). Middle and first stages have
total_loss = -1.0. The true loss value is broadcast to all stages via MAX
all-reduce over pp_group:
dist.all_reduce(loss_tensor, op=dist.ReduceOp.MAX, group=pp_group)A value of -1.0 is safe as the sentinel because CE_sum is always non-negative.
Cross-DP-replica loss aggregation
Section titled “Cross-DP-replica loss aggregation”After the PP step, _reduce_metrics in Trainer uses op="sum" over
fsdp_group (the DP replica group) to aggregate the loss:
total_loss = all_reduce(total_loss, op="sum", group=self.ps.fsdp_group)This differs from the non-PP path which uses the default mean-reduce. The reason:
different DP replicas process different data, so their CE_sum values are additive
(each replica holds a disjoint subset of the global batch). A mean-reduce would
divide by the DP size, undoing the correct normalization by global_valid_tokens.
8. Interaction with Other Parallelism Dimensions
Section titled “8. Interaction with Other Parallelism Dimensions”PP + FSDP2
Section titled “PP + FSDP2”Each PP stage is independently wrapped with FSDP2. The FSDP shard group is
ps.fsdp_group, which spans only the DP replicas of the same PP stage, not the
full world. This means:
- Parameters on stage 0 are sharded among GPUs
{0, 1, 2, ...}that all hold stage 0; parameters on stage 1 are sharded among their own set. - FSDP all-gather and reduce-scatter happen independently per stage, with no cross-stage communication.
reshard_after_forwarddefaults toFalsewhen PP is active (set automatically inparallelize_model_fsdp2) to avoid redundant re-gathers during the 1F1B backward phase. Override withreshard_after_forward: truein config to trade memory for communication.
PP + Expert Parallelism (EP)
Section titled “PP + Expert Parallelism (EP)”PP and EP are orthogonal: EP shards MoE expert weights within each PP stage, and
the micro-batches flowing through the PP pipeline are the same batches that EP
dispatches to experts. No special handling is needed in forward_backward_pp;
each stage’s FSDP+EP wrapping handles expert dispatch transparently.
With routing replay (R3), the _pp_forward function correctly switches the routing
stage from "replay_backward" to "record" (or "replay_forward") before the
forward and restores it afterward. The no_grad guard inside _pp_forward
ensures the shape-inference forward does not record stale routing decisions into
the replay cache.
PP + Context Parallelism (CP: Ulysses or ring-attention)
Section titled “PP + Context Parallelism (CP: Ulysses or ring-attention)”CP splits each sequence across multiple GPUs. The per-rank sequence length seen
by the PP pipeline is full_seq_len / cp_size. Two adjustments are needed:
negotiate_pp_seq_lenreturns a length in the CP-sharded space (i.e., after CP splits the sequence). The padding target is thereforeseq_len * cp_sizepassed topad_micro_batches_for_ppassample_packing_sequence_len.- When CP is active and
position_idsare absent from the metadata queue,_pp_forwardsynthesizes position IDs covering the fullseq_len * cp_sizerange so that the RoPE cache is initialized to sufficient size:if position_ids is None and ps.cp_size > 1:full_seq_len = x.shape[1] * ps.cp_sizeposition_ids = torch.arange(full_seq_len, ...).expand(batch, -1)
9. Constraints
Section titled “9. Constraints”| Constraint | Reason |
|---|---|
gradient_accumulation_steps >= pipeline_parallel_size | The pipeline must have enough micro-batches to fill all stages before any stage stalls waiting for work. n_microbatches = gradient_accumulation_steps. |
tie_word_embeddings: false | embed_tokens (stage 0) and lm_head (last stage) cannot share weights across different GPUs. |
| Single model chunk per rank (no virtual stages) | The current implementation uses pipeline_module_split which assigns exactly one stage per rank. Multi-stage (virtual) schedules like ScheduleGPipe with num_microbatches > num_stages in the looped sense are not supported. |
Model must implement get_pp_module_config() | The stage-splitting algorithm requires structured metadata about input/layer/output FQNs. |
Padding must be a multiple of pad_to_multiple_of | FA3 varlen kernels may require sequence length alignment (e.g., multiples of 8 or 64). |
PP is not compatible with torch.compile on the whole model | Per-stage torch.compile of individual decoder layers (via basic_modules) is supported, but whole-model compile is not. |
10. Configuration Examples
Section titled “10. Configuration Examples”Example 1: PP=2, Qwen3-8B, 8 GPUs
Section titled “Example 1: PP=2, Qwen3-8B, 8 GPUs”GPUs: 8PP=2, FSDP shard=4, DP replicate=1Stage 0: embed_tokens + layers 0..17 (18 layers)Stage 1: layers 18..35 + norm + lm_head (18 layers)Micro-batches: 4 (gradient_accumulation_steps=4)Bubble: (2-1)/(4+2-1) = 1/5 = 20%model: model_path: Qwen/Qwen3-8B attn_implementation: flash_attention_3
data: datasets: - path: dummy type: tokenized max_seq_len: 8000 sample_packing_method: sequential sample_packing_sequence_len: 8000
train: data_parallel_mode: fsdp2 pipeline_parallel_size: 2 data_parallel_replicate_size: 1 data_parallel_shard_size: 4 micro_batch_size: 1 gradient_accumulation_steps: 4 # pp_variable_seq_lengths: true (default) enable_gradient_checkpointing: true enable_full_shard: true init_device: meta load_weights_mode: broadcastProcess group layout (8 GPUs, ranks 0–7):
- PP group 0 (stage 0): ranks 3
- PP group 1 (stage 1): ranks 7
- FSDP shard group for stage 0: ranks 3
- FSDP shard group for stage 1: ranks 7
Example 2: PP=2 + EP=4 + CP=4, Qwen3-30B-A3B, 8 GPUs
Section titled “Example 2: PP=2 + EP=4 + CP=4, Qwen3-30B-A3B, 8 GPUs”GPUs: 8PP=2, EP=4, ring-attn CP=4, DP_shard=1, DP_replicate=1world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8EP=4 folded onto each PP stage's 4 ranks (ep_fsdp_size=1)Micro-batches: 2Bubble: (2-1)/(2+2-1) = 1/3 ≈ 33% (small n_microbatches; increase GAS to reduce)seq_len per CP rank: 96000 / 4 = 24000 tokensmodel: model_path: Qwen/Qwen3-30B-A3B attn_implementation: flash_attention_3 moe_implementation: triton
data: datasets: - path: dummy max_seq_len: 48001 sample_packing_sequence_len: 96000
train: data_parallel_mode: fsdp2 pipeline_parallel_size: 2 ringattn_parallel_size: 4 expert_parallel_size: 4 data_parallel_shard_size: 1 micro_batch_size: 1 gradient_accumulation_steps: 2 reshard_after_forward: true # memory-saving for large model optimizer: muon enable_gradient_checkpointing: true11. Parameter Reference
Section titled “11. Parameter Reference”| Parameter | Type | Default | Description |
|---|---|---|---|
train.pipeline_parallel_size | int | 1 | Number of pipeline stages. Set to 1 to disable PP. |
train.pipeline_parallel_schedule | str | "1F1B" | Schedule algorithm: "1F1B" or "GPipe". |
train.pp_variable_seq_lengths | bool | True | Negotiate per-step max seq_len across PP ranks and pad dynamically, instead of padding to the static sample_packing_sequence_len. |
train.gradient_accumulation_steps | int | required | Sets n_microbatches. Must be >= pipeline_parallel_size. |
train.reshard_after_forward | bool or None | None (auto) | FSDP2 reshard behavior. Auto-set to False for PP; set True to save memory at the cost of extra communication. |
data.sample_packing_sequence_len | int | required for PP | Maximum sequence length used as the static padding target when pp_variable_seq_lengths=False. Also used as the upper bound for negotiate_pp_seq_len. |
data.pad_to_multiple_of | int | 128 | Rounds the padded sequence length up to this multiple for GPU efficiency. |
Implementation File Map
Section titled “Implementation File Map”| File | Role |
|---|---|
src/xorl/distributed/pipeline_parallel.py | Stage construction (generate_llm_fqn_per_model_part, pipeline_module_split), forward patching (_pp_forward), schedule builder (build_pipeline_schedule, build_pp_stage) |
src/xorl/trainers/training_utils.py | Shared PP utilities: forward_backward_pp, negotiate_pp_seq_len, pad_micro_batches_for_pp, pp_loss_fn |
src/xorl/trainers/trainer.py | Offline trainer: _init_pp_schedule_cache, _get_pp_schedule, _forward_backward_pp, _reduce_metrics |
src/xorl/server/runner/model_runner.py | Server (online) path: _get_pp_schedule (keyed by (n_microbatches, seq_len)), _forward_backward_pp |
src/xorl/distributed/torch_parallelize.py | Orchestrates stage split + FSDP2 wrapping (build_parallelize_model) |
src/xorl/arguments.py | pipeline_parallel_size, pipeline_parallel_schedule, pp_variable_seq_lengths, reshard_after_forward |
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/pipeline_parallel.py | Stage construction, 1F1B/GPipe schedule, _pp_forward, build_pp_stage |