Checkpointing
xorl uses PyTorch Distributed Checkpoint (DCP) as its native checkpoint format. DCP is designed for FSDP2: each rank writes only its local parameter shards, and the resulting checkpoint is topology-independent (it can be loaded on a different number of GPUs or with a different parallelism configuration).
Checkpoint Format: DCP
Section titled “Checkpoint Format: DCP”DCP stores sharded tensors as a collection of per-rank shard files alongside a metadata.json that records the full logical tensor shapes and placement information. This is fundamentally different from a single-file checkpoint:
| Format | Structure | Load flexibility | Best for |
|---|---|---|---|
| DCP (default) | One .distcp shard per rank + metadata.json | Any topology (DCP handles redistribution) | FSDP2 distributed training |
| HuggingFace safetensors | Single or split .safetensors file | Load into any model, any framework | Inference, model sharing |
xorl can save both formats simultaneously using save_hf_weights: true.
Saving
Section titled “Saving”Configure checkpointing frequency and behavior in the train config block:
train: save_steps: 100 # save every N optimizer steps save_epochs: 1 # save every N epochs (alternative to save_steps) save_async: true # async DCP write (non-blocking, overlaps with training) save_hf_weights: false # also save a HuggingFace-compatible safetensors checkpoint output_dir: outputs/my_runCheckpoint directory layout
Section titled “Checkpoint directory layout”outputs/my_run/weights/{run_id}/step_{N}/├── metadata.json # DCP metadata: logical shapes, dtypes, placements├── __0_0.distcp # rank 0 shard├── __1_0.distcp # rank 1 shard├── ...├── training_state/ # optimizer state, scheduler, RNG state│ ├── optimizer_state.json│ └── rng_state_rank{r}.pt└── adapter_model.safetensors # LoRA weights only (if save_lora_only: true)Async saving
Section titled “Async saving”When save_async: true, DCP writes happen on a background thread. The trainer continues with the next batch while the checkpoint is written. A lock prevents two async saves from running simultaneously. Async saves are flushed before the training run exits.
Loading
Section titled “Loading”Auto-detection
Section titled “Auto-detection”xorl automatically detects the latest checkpoint in output_dir when resuming:
train: output_dir: outputs/my_run # xorl scans for the latest step_{N} directoryTo load from a specific path:
train: load_checkpoint_path: outputs/my_run/weights/run_abc123/step_500Loading from a HuggingFace model
Section titled “Loading from a HuggingFace model”On first training, xorl loads the base model weights from the HuggingFace checkpoint specified in model.model_path. The load_weights_mode field controls how this initial load is distributed.
Weight Loading Optimization
Section titled “Weight Loading Optimization”The load_weights_mode field controls how base model weights are loaded from disk and distributed across ranks:
| Value | Description | When to use |
|---|---|---|
broadcast (default) | Rank 0 reads the full checkpoint from disk, then broadcasts each parameter shard to the appropriate ranks | Default for most runs; safe on shared filesystems that cannot handle parallel reads |
all_ranks | Every rank reads its own parameter shard directly from the checkpoint | Fast parallel storage (Lustre, object storage); required when EP expert shards differ per rank |
FSDP Sharding and DCP
Section titled “FSDP Sharding and DCP”DCP is tightly integrated with FSDP2’s DTensor sharding. When saving:
- Each rank’s
FSDPModuleexposes its local parameter shards asDTensorobjects with placement annotations (Shard(0),Shard(1),Replicate). - DCP’s
saveAPI serializes each shard to__{rank}_{shard_idx}.distcpalong with the logical tensor metadata.
When loading into a different topology (e.g., resuming a 4-GPU run from an 8-GPU checkpoint):
- DCP reads
metadata.jsonto determine the logical tensor shapes and original shard placements. - Each rank requests the portion of each tensor that maps to its new shard coordinates.
- DCP computes which source shard files to read and splices the correct byte ranges.
This reshard-on-load is handled automatically — no manual conversion is required.
LoRA checkpoints
Section titled “LoRA checkpoints”With save_lora_only: true, DCP saves only the LoRA adapter parameters (lora_A, lora_B for each targeted module). The base model weights are not saved. The resulting checkpoint is compatible with the HuggingFace PEFT format.
Advanced: Checkpoint Optimization
Section titled “Advanced: Checkpoint Optimization”The sections below cover internal optimizations. Most users don’t need to read further.
xorl applies several optimizations to minimize checkpoint latency and I/O overhead for large distributed models.
Multi-stream DMA (weight loading)
Section titled “Multi-stream DMA (weight loading)”When loading model weights into GPU memory, xorl uses _MultiStreamDMA — a two-CUDA-stream manager that round-robins H2D (Host-to-Device) DMA transfers across streams to saturate the PCIe copy engine:
Stream 0: pin → DMA tensor A → sync → place into modelStream 1: pin → DMA tensor B → sync → place into model ─────────────────────────────────────────────────► Streams overlap, doubling effective H2D throughput vs single-streamHow it works:
- Each weight tensor is first pinned with
pin_memory()to lock it in pageable memory, which makes it eligible for asynchronous DMA - The transfer is issued with
non_blocking=Trueinto a GPU staging buffer on the current stream - The staging buffer is added to a
_pendinglist; the stream advances to the next tensor - Before NCCL operations,
stream.synchronize()ensures the DMA completed, then the result is placed into the model parameter or distributed via NCCL collective
The pending list is bounded to roughly one shard’s worth of GPU temporaries, keeping peak VRAM overhead small.
Async checkpoint saving
Section titled “Async checkpoint saving”When save_async: true, DCP writes happen on a background thread using a dedicated Gloo process group:
Training step N: forward → backward → optim → trigger async DCP save (non-blocking)Training step N+1: ─────────────────────────────────────── (continues immediately)Background: write rank_{r}.distcp ───────────────────────────►The Gloo process group is created lazily on the first async save and reused for all subsequent ones. Before starting a new async save, xorl waits for the previous one to complete (.result() + dist.barrier()), preventing two concurrent saves from racing.
Post-save memory cleanup: After each async DCP save, xorl runs an aggressive cleanup sequence to reclaim the GPU memory held by DCP’s internal state tensors:
gc.collect() # first GC passgc.collect() # second pass for cyclic referencestorch.cuda.empty_cache() # release CUDA cachetorch.cuda.synchronize() # ensure GPU is idleThis is especially important for large EP models where DCP briefly materializes full expert weight tensors.
DCP write config: thread_count=1 per rank (avoids a PyTorch concurrent write bug), sync_files=False (no fsync() on each write — faster but relies on OS page cache durability), single_file_per_rank=True.
Prefetch + batched broadcast (rank-0 weight loading)
Section titled “Prefetch + batched broadcast (rank-0 weight loading)”In broadcast mode, rank 0 reads from disk and broadcasts to all other ranks. xorl pipelines this with:
Background prefetch: A background thread pre-loads the next shard file while the current shard is being broadcast (prefetch_count=2). This overlaps NFS/SSD reads with NCCL communication, hiding disk latency behind the broadcast.
Batched metadata broadcast: Instead of one broadcast_object_list() call per tensor (each requiring a pickle + NCCL launch), xorl collects all tensor metadata (name, shape, dtype) into a single list and broadcasts it in one call. This reduces N broadcast launches to 1 for the entire shard, eliminating per-tensor NCCL overhead.
Non-blocking CPU→GPU transfer before broadcast:
tensor = broadcast_tensor.pin_memory().to(device, non_blocking=True)dist.broadcast(tensor, src=0)The pin_memory() + non_blocking=True initiates the H2D DMA while the NCCL library is setting up the collective, overlapping the two operations.
EP-aware filtered loading
Section titled “EP-aware filtered loading”For Expert Parallel models, xorl skips disk reads for expert weight tensors that don’t belong to the current rank. With ep_size=8, each rank needs only 1/8 of all expert shards — the other 7/8 are read by other ranks. xorl’s checkpoint handler provides a skip_key_fn predicate that filters out out-of-range expert keys before any I/O occurs:
Without EP filtering: every rank reads ALL expert shards → 8× redundant readsWith EP filtering: each rank reads only its 1/8 expert shards → ~87% I/O reductionSkipped keys are tracked in the handler’s completion counters so DCP’s progress reporting remains accurate.
Expert parameter pre-shrinking
Section titled “Expert parameter pre-shrinking”Before materializing model parameters from meta device, xorl resizes expert tensors from their full shape [num_experts, K, N] down to the EP-local shape [num_local_experts, K, N] on the meta device itself. This prevents a brief OOM spike that would otherwise occur if FSDP2 tried to allocate a full-size expert tensor on a rank that will only ever hold a fraction of the experts.
Summary
Section titled “Summary”| Optimization | Applies to | Benefit |
|---|---|---|
| Multi-stream DMA | Weight loading | ~2× H2D throughput via overlapped PCIe DMA |
| Async saving | Checkpoint save | Non-blocking writes; training continues immediately |
| Background prefetch | Rank-0 broadcast | Overlaps disk read with NCCL broadcast |
| Batched metadata broadcast | Rank-0 broadcast | 1 NCCL launch vs N per-tensor launches |
| Non-blocking H2D before broadcast | Rank-0 broadcast | Overlaps DMA with NCCL setup |
| EP-aware filtered loading | EP weight loading | ~87% I/O reduction at ep_size=8 |
| Expert pre-shrinking | EP initialization | Prevents OOM on meta device materialization |
| Aggressive post-save GC | Async save | Reclaims DCP temporary tensors promptly |
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/checkpoint/checkpointer.py | DistributedCheckpointer — DCP save/load, async save with Gloo process group, extra-state per rank |
src/xorl/models/module_utils.py | _MultiStreamDMA — 2-stream H2D DMA; rank0_load_and_broadcast_weights — prefetch + batched broadcast; EP-aware filtered loading; expert pre-shrinking |
src/xorl/arguments.py | save_steps, save_epochs, save_async, save_hf_weights, load_checkpoint_path, load_weights_mode |