Skip to content

Trainer Architecture

This page covers how Trainer is structured for local training — the training loop, data pipeline, collator chain, and how data sharding integrates with context parallelism and gradient accumulation.


Local Training: Data → Model → OptimizerDatasetpacked tokensStatefulDLCollator PipelineToTensor → Flatten→ ShiftTokens → Pack→ SP Shard (if CP)MicroBatchSplitbatch / grad_accum→ List[micro_batch]Model ForwardFSDP2 / EP / PPGradAccumLoss∇ backwardOptimizerclip + stepLR schedulertrain_step() — one complete gradient-accumulation stepvalid tokens counted globally → GradientAccumulateLoss normalizes gradients across all DP ranksR3 routing replay wraps the entire step when EP + gradient checkpointing are activeSP gradient sync (all-reduce) runs before optimizer step for Ulysses/Ring dims not folded into FSDP

src/xorl/trainers/trainer.py

Trainer.__init__() runs eight timed setup phases in order:

PhaseMethodWhat it does
1_bootstrap()Init distributed (torchrun env), set device, seed, build ParallelState
2_build_model()Load foundation model config; inject LoRA/QLoRA if enabled
3_parallelize()Apply TP plan, FSDP2/EP mesh, PP split; broadcast weights from rank 0
4_build_data()Load tokenizer, prepare dataset, build DataLoaderBuilder with collator chain
5_build_optimizer()Create AdamW/Muon optimizer; build LR scheduler
6_setup_observability()Configure structured logging, W&B
7_resume_checkpoint()Load DCP checkpoint if load_checkpoint_path is set
8_init_pp_schedule_cache()Pre-build PP schedule objects for each microbatch count

After __init__, call trainer.train() to run the full training loop.


train_step(micro_batches) is the core of the training loop. It executes one full gradient-accumulation step:

def train_step(self, micro_batches):
global_valid_tokens = self._count_valid_tokens(micro_batches)
self.optimizer.zero_grad()
# R3: set up routing replay for MoE + gradient checkpointing
if self._use_routing_replay:
set_replay_stage("replay_backward")
# Forward + backward over all micro-batches
if self.pp_enabled:
total_loss = self._forward_backward_pp(micro_batches, global_valid_tokens)
else:
total_loss = self._forward_backward(micro_batches, global_valid_tokens)
# R3: clear replay state
if self._use_routing_replay:
set_replay_stage(None)
RoutingReplay.clear_all()
self._sync_sp_gradients() # all-reduce for CP/Ulysses dims not in FSDP
grad_norm = self._clip_and_step() # gradient clip + optimizer.step() + scheduler.step()
self._maybe_merge_lora() # periodic LoRA merge (if merge_lora_interval > 0)
total_loss, grad_norm = self._reduce_metrics(total_loss, grad_norm)
return total_loss, grad_norm

xorl normalizes gradients by total valid tokens across all DP ranks, not by the number of FSDP ranks. This is critical for correct training with variable-length sequence packing where different ranks may process very different token counts.

FSDP’s automatic averaging is disabled (set_gradient_divide_factor(1.0)). Instead, GradientAccumulateLoss scales each micro-batch’s loss by local_valid_tokens / global_valid_tokens before calling .backward(). The reduce-scatter then sums raw gradients across ranks, resulting in a correctly token-normalized gradient.

See src/xorl/distributed/gradient_accumulate_loss.py.

After the backward pass, gradients for context-parallel (CP/Ulysses) dimensions that are not folded into the FSDP mesh are all-reduced via _sync_sp_gradients(). This ensures that all ranks within a CP group agree on the gradient before the optimizer step.


src/xorl/data/data_loader.py

DataLoaderBuilder assembles a collator pipeline and wraps it in a MicroBatchCollator that splits each batch into micro-batches for gradient accumulation.

Default collator chain (applied in order):

StepCollatorInput → Output
1ToTensorCollatorraw Python lists → torch.Tensor
2FlattenCollatorlist-of-lists → flat list
3ShiftTokensCollatoraligns labels = input_ids shifted by 1 (causal LM)
4PackingConcatCollatorconcatenates sequences into packed bins; pads to pad_to_multiple_of (default 128)
5TextSequenceShardCollator(only if CP enabled) shards the packed sequence across SP ranks with zigzag reordering

After the pipeline, MicroBatchCollator splits the result into gradient_accumulation_steps equal micro-batches, each of size micro_batch_size.

Custom collators can be appended or prepended:

builder = DataLoaderBuilder(dataset, micro_batch_size=1, gradient_accumulation_steps=4)
builder.add_collator(MyCustomCollator(), position="after_packing")
dataloader = builder.build()

Receives a flat batch of micro_batch_size × gradient_accumulation_steps samples and returns a List[Dict] of gradient_accumulation_steps micro-batches. Each micro-batch is independently processed in _forward_backward().


When context parallelism (Ulysses or Ring Attention) is active, TextSequenceShardCollator runs as the final collator step.

Each rank in the Ulysses group receives the full sequence but handles only a subset of attention heads. No sequence slicing is needed in the data pipeline — the attention heads are split inside the model by gather_seq_scatter_heads() and gather_heads_scatter_seq().

For Ring Attention, each rank must receive a balanced set of tokens from every document so that causal masking works correctly across the ring. A naive split (rank 0 gets the first S/N tokens) would give rank 0 only the early part of each document, breaking causal consistency.

Zigzag reordering assigns each rank two non-contiguous sub-chunks per document — one from the beginning and one from the end:

Zigzag Reordering for Ring Attention (4 ranks, 1 document)Full doc:tokens 0–S (8 equal chunks: c0 c1 c2 c3 c4 c5 c6 c7)After:Rank 0: c0 + c7Rank 1: c1 + c6Rank 2: c2 + c5Rank 3: c3 + c4Each rank holds an early chunk + a symmetrically paired late chunk, ensuring causal mask balance.Documents must be divisible by 2 × ringattn_parallel_size. Enforced before packing.Implemented in: src/xorl/data/collators/sequence_shard_collator.py → zigzag_reorder_packed_sequence()

Document padding requirement: Before packing, each document must be individually padded so its length is divisible by 2 × ringattn_parallel_size. This ensures the zigzag split yields equal-sized chunks. TextSequenceShardCollator pads the full packed sequence to a multiple of 2 × cp_size (where cp_size = ringattn_size × ulysses_size), using sequential position IDs for the pad region so Flash Attention treats padding tokens as a separate dummy sequence. The pad_to_multiple_of field in the data config should be set accordingly — or left at its default of 128, which is a safe common multiple for typical ring sizes.


xorl packs multiple short documents into a single training bin to maximize GPU utilization and avoid padding waste. Two algorithms are available, controlled by sample_packing_method:

sample_packing_method: sequential fills bins greedily in dataset order:

for each document (in order):
if it fits in the current bin → append
else → close current bin, open a new bin

Properties:

  • Preserves dataset order within each bin — documents that appear together in the dataset stay together
  • Fast: O(N) single pass, no sorting or bin search
  • Bin utilization is good for datasets with similar sequence lengths; can leave gaps for high-variance datasets

sample_packing_method: multipack processes documents in groups of sample_packing_group_size (default 100,000) and sorts each group by length descending before packing:

for each group:
sort by length descending
for each document:
find first bin with enough remaining capacity → place it there
if no bin fits → open a new bin

Properties:

  • Better bin utilization — long documents placed first leave fewer large gaps
  • Does not preserve document order within a group (cross-document attention is masked anyway)
  • Slightly slower due to sorting; parallelized across CPU cores for large datasets
  • Implemented with @numba.njit for speed

When to use each:

SequentialMultipack
Dataset order sensitivityYes — preserves orderNo
Packing efficiencyGoodBetter (~5-10% more tokens/bin)
SpeedFastSlower (sorting + numba JIT)
Best forDatasets with uniform lengths, ordered corporaMixed-length datasets, maximizing utilization
data:
sample_packing_method: sequential # or multipack
sample_packing_sequence_len: 8192 # bin capacity in tokens
sample_packing_group_size: 100000 # multipack only: group size for FFD

Packed bins are cached to disk (keyed by a hash of the packing config) so subsequent runs skip re-packing. See src/xorl/data/prepare/packing.py.


src/xorl/distributed/torch_parallelize.py

_parallelize() applies parallelism in strict order to avoid device mesh conflicts:

1. TP plan — annotate linear layers with ColwiseParallel / RowwiseParallel
2. EP sharding — slice expert tensors to [E/ep_size, K, N] on each rank
3. Expert FSDP — fully_shard(experts, mesh=ep_fsdp_mesh, shard_placement_fn=Shard(1))
4. Non-expert FSDP — fully_shard(layer, mesh=fsdp_mesh) per decoder block
5. Root FSDP — fully_shard(model, mesh=fsdp_mesh) embeddings + lm_head
6. PP split — pipeline_module_split() → build_pp_stage() per rank
7. Weight loading — broadcast from rank 0 (or all_ranks read independently)

For TP, the plan is read from the model config’s base_model_tp_plan dict (maps FQN patterns → colwise/rowwise/embedding style strings). xorl resolves these to PyTorch ParallelStyle objects and calls parallelize_module().

For FSDP2, every decoder block is wrapped independently (reshard_after_forward=True by default), keeping peak memory bounded to one layer’s parameters at a time.


FileDescription
src/xorl/trainers/trainer.pyTrainer class — full lifecycle, train_step, _forward_backward, _forward_backward_pp
src/xorl/trainers/training_utils.py_count_valid_tokens, _sync_sp_gradients, _clip_and_step, _reduce_metrics
src/xorl/data/data_loader.pyDataLoaderBuilder, MicroBatchCollator
src/xorl/data/collators/ToTensorCollator, PackingConcatCollator, TextSequenceShardCollator, etc.
src/xorl/data/collators/sequence_shard_collator.pyzigzag_reorder_packed_sequence(), TextSequenceShardCollator
src/xorl/distributed/torch_parallelize.pyparallelize_model_fsdp2(), _build_tp_plan(), FSDP2 wrapping order
src/xorl/distributed/gradient_accumulate_loss.pyGradientAccumulateLoss — token-normalized gradient scaling