Trainer Architecture
This page covers how Trainer is structured for local training — the training loop, data pipeline, collator chain, and how data sharding integrates with context parallelism and gradient accumulation.
Overview
Section titled “Overview”Trainer Lifecycle
Section titled “Trainer Lifecycle”Trainer.__init__() runs eight timed setup phases in order:
| Phase | Method | What it does |
|---|---|---|
| 1 | _bootstrap() | Init distributed (torchrun env), set device, seed, build ParallelState |
| 2 | _build_model() | Load foundation model config; inject LoRA/QLoRA if enabled |
| 3 | _parallelize() | Apply TP plan, FSDP2/EP mesh, PP split; broadcast weights from rank 0 |
| 4 | _build_data() | Load tokenizer, prepare dataset, build DataLoaderBuilder with collator chain |
| 5 | _build_optimizer() | Create AdamW/Muon optimizer; build LR scheduler |
| 6 | _setup_observability() | Configure structured logging, W&B |
| 7 | _resume_checkpoint() | Load DCP checkpoint if load_checkpoint_path is set |
| 8 | _init_pp_schedule_cache() | Pre-build PP schedule objects for each microbatch count |
After __init__, call trainer.train() to run the full training loop.
train_step()
Section titled “train_step()”train_step(micro_batches) is the core of the training loop. It executes one full gradient-accumulation step:
def train_step(self, micro_batches): global_valid_tokens = self._count_valid_tokens(micro_batches) self.optimizer.zero_grad()
# R3: set up routing replay for MoE + gradient checkpointing if self._use_routing_replay: set_replay_stage("replay_backward")
# Forward + backward over all micro-batches if self.pp_enabled: total_loss = self._forward_backward_pp(micro_batches, global_valid_tokens) else: total_loss = self._forward_backward(micro_batches, global_valid_tokens)
# R3: clear replay state if self._use_routing_replay: set_replay_stage(None) RoutingReplay.clear_all()
self._sync_sp_gradients() # all-reduce for CP/Ulysses dims not in FSDP grad_norm = self._clip_and_step() # gradient clip + optimizer.step() + scheduler.step() self._maybe_merge_lora() # periodic LoRA merge (if merge_lora_interval > 0)
total_loss, grad_norm = self._reduce_metrics(total_loss, grad_norm) return total_loss, grad_normGradient normalization
Section titled “Gradient normalization”xorl normalizes gradients by total valid tokens across all DP ranks, not by the number of FSDP ranks. This is critical for correct training with variable-length sequence packing where different ranks may process very different token counts.
FSDP’s automatic averaging is disabled (set_gradient_divide_factor(1.0)). Instead, GradientAccumulateLoss scales each micro-batch’s loss by local_valid_tokens / global_valid_tokens before calling .backward(). The reduce-scatter then sums raw gradients across ranks, resulting in a correctly token-normalized gradient.
See src/xorl/distributed/gradient_accumulate_loss.py.
SP gradient sync
Section titled “SP gradient sync”After the backward pass, gradients for context-parallel (CP/Ulysses) dimensions that are not folded into the FSDP mesh are all-reduced via _sync_sp_gradients(). This ensures that all ranks within a CP group agree on the gradient before the optimizer step.
Data Pipeline
Section titled “Data Pipeline”DataLoaderBuilder
Section titled “DataLoaderBuilder”DataLoaderBuilder assembles a collator pipeline and wraps it in a MicroBatchCollator that splits each batch into micro-batches for gradient accumulation.
Default collator chain (applied in order):
| Step | Collator | Input → Output |
|---|---|---|
| 1 | ToTensorCollator | raw Python lists → torch.Tensor |
| 2 | FlattenCollator | list-of-lists → flat list |
| 3 | ShiftTokensCollator | aligns labels = input_ids shifted by 1 (causal LM) |
| 4 | PackingConcatCollator | concatenates sequences into packed bins; pads to pad_to_multiple_of (default 128) |
| 5 | TextSequenceShardCollator | (only if CP enabled) shards the packed sequence across SP ranks with zigzag reordering |
After the pipeline, MicroBatchCollator splits the result into gradient_accumulation_steps equal micro-batches, each of size micro_batch_size.
Custom collators can be appended or prepended:
builder = DataLoaderBuilder(dataset, micro_batch_size=1, gradient_accumulation_steps=4)builder.add_collator(MyCustomCollator(), position="after_packing")dataloader = builder.build()MicroBatchCollator
Section titled “MicroBatchCollator”Receives a flat batch of micro_batch_size × gradient_accumulation_steps samples and returns a List[Dict] of gradient_accumulation_steps micro-batches. Each micro-batch is independently processed in _forward_backward().
Context Parallel Data Sharding
Section titled “Context Parallel Data Sharding”When context parallelism (Ulysses or Ring Attention) is active, TextSequenceShardCollator runs as the final collator step.
Ulysses sharding
Section titled “Ulysses sharding”Each rank in the Ulysses group receives the full sequence but handles only a subset of attention heads. No sequence slicing is needed in the data pipeline — the attention heads are split inside the model by gather_seq_scatter_heads() and gather_heads_scatter_seq().
Ring Attention sharding (zigzag)
Section titled “Ring Attention sharding (zigzag)”For Ring Attention, each rank must receive a balanced set of tokens from every document so that causal masking works correctly across the ring. A naive split (rank 0 gets the first S/N tokens) would give rank 0 only the early part of each document, breaking causal consistency.
Zigzag reordering assigns each rank two non-contiguous sub-chunks per document — one from the beginning and one from the end:
Document padding requirement: Before packing, each document must be individually padded so its length is divisible by 2 × ringattn_parallel_size. This ensures the zigzag split yields equal-sized chunks. TextSequenceShardCollator pads the full packed sequence to a multiple of 2 × cp_size (where cp_size = ringattn_size × ulysses_size), using sequential position IDs for the pad region so Flash Attention treats padding tokens as a separate dummy sequence. The pad_to_multiple_of field in the data config should be set accordingly — or left at its default of 128, which is a safe common multiple for typical ring sizes.
Sequence Packing Algorithms
Section titled “Sequence Packing Algorithms”xorl packs multiple short documents into a single training bin to maximize GPU utilization and avoid padding waste. Two algorithms are available, controlled by sample_packing_method:
Sequential (default)
Section titled “Sequential (default)”sample_packing_method: sequential fills bins greedily in dataset order:
for each document (in order): if it fits in the current bin → append else → close current bin, open a new binProperties:
- Preserves dataset order within each bin — documents that appear together in the dataset stay together
- Fast: O(N) single pass, no sorting or bin search
- Bin utilization is good for datasets with similar sequence lengths; can leave gaps for high-variance datasets
Multipack / First-Fit Decreasing (FFD)
Section titled “Multipack / First-Fit Decreasing (FFD)”sample_packing_method: multipack processes documents in groups of sample_packing_group_size (default 100,000) and sorts each group by length descending before packing:
for each group: sort by length descending for each document: find first bin with enough remaining capacity → place it there if no bin fits → open a new binProperties:
- Better bin utilization — long documents placed first leave fewer large gaps
- Does not preserve document order within a group (cross-document attention is masked anyway)
- Slightly slower due to sorting; parallelized across CPU cores for large datasets
- Implemented with
@numba.njitfor speed
When to use each:
| Sequential | Multipack | |
|---|---|---|
| Dataset order sensitivity | Yes — preserves order | No |
| Packing efficiency | Good | Better (~5-10% more tokens/bin) |
| Speed | Fast | Slower (sorting + numba JIT) |
| Best for | Datasets with uniform lengths, ordered corpora | Mixed-length datasets, maximizing utilization |
data: sample_packing_method: sequential # or multipack sample_packing_sequence_len: 8192 # bin capacity in tokens sample_packing_group_size: 100000 # multipack only: group size for FFDPacked bins are cached to disk (keyed by a hash of the packing config) so subsequent runs skip re-packing. See src/xorl/data/prepare/packing.py.
Model Parallelization
Section titled “Model Parallelization”src/xorl/distributed/torch_parallelize.py
_parallelize() applies parallelism in strict order to avoid device mesh conflicts:
1. TP plan — annotate linear layers with ColwiseParallel / RowwiseParallel2. EP sharding — slice expert tensors to [E/ep_size, K, N] on each rank3. Expert FSDP — fully_shard(experts, mesh=ep_fsdp_mesh, shard_placement_fn=Shard(1))4. Non-expert FSDP — fully_shard(layer, mesh=fsdp_mesh) per decoder block5. Root FSDP — fully_shard(model, mesh=fsdp_mesh) embeddings + lm_head6. PP split — pipeline_module_split() → build_pp_stage() per rank7. Weight loading — broadcast from rank 0 (or all_ranks read independently)For TP, the plan is read from the model config’s base_model_tp_plan dict (maps FQN patterns → colwise/rowwise/embedding style strings). xorl resolves these to PyTorch ParallelStyle objects and calls parallelize_module().
For FSDP2, every decoder block is wrapped independently (reshard_after_forward=True by default), keeping peak memory bounded to one layer’s parameters at a time.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/trainers/trainer.py | Trainer class — full lifecycle, train_step, _forward_backward, _forward_backward_pp |
src/xorl/trainers/training_utils.py | _count_valid_tokens, _sync_sp_gradients, _clip_and_step, _reduce_metrics |
src/xorl/data/data_loader.py | DataLoaderBuilder, MicroBatchCollator |
src/xorl/data/collators/ | ToTensorCollator, PackingConcatCollator, TextSequenceShardCollator, etc. |
src/xorl/data/collators/sequence_shard_collator.py | zigzag_reorder_packed_sequence(), TextSequenceShardCollator |
src/xorl/distributed/torch_parallelize.py | parallelize_model_fsdp2(), _build_tp_plan(), FSDP2 wrapping order |
src/xorl/distributed/gradient_accumulate_loss.py | GradientAccumulateLoss — token-normalized gradient scaling |