Skip to content

Checkpointing

xorl uses PyTorch Distributed Checkpoint (DCP) as its native checkpoint format. DCP is designed for FSDP2: each rank writes only its local parameter shards, and the resulting checkpoint is topology-independent (it can be loaded on a different number of GPUs or with a different parallelism configuration).

DCP stores sharded tensors as a collection of per-rank shard files alongside a metadata.json that records the full logical tensor shapes and placement information. This is fundamentally different from a single-file checkpoint:

FormatStructureLoad flexibilityBest for
DCP (default)One .distcp shard per rank + metadata.jsonAny topology (DCP handles redistribution)FSDP2 distributed training
HuggingFace safetensorsSingle or split .safetensors fileLoad into any model, any frameworkInference, model sharing

xorl can save both formats simultaneously using save_hf_weights: true.

Configure checkpointing frequency and behavior in the train config block:

train:
save_steps: 100 # save every N optimizer steps
save_epochs: 1 # save every N epochs (alternative to save_steps)
save_async: true # async DCP write (non-blocking, overlaps with training)
save_hf_weights: false # also save a HuggingFace-compatible safetensors checkpoint
output_dir: outputs/my_run
outputs/my_run/weights/{run_id}/step_{N}/
├── metadata.json # DCP metadata: logical shapes, dtypes, placements
├── __0_0.distcp # rank 0 shard
├── __1_0.distcp # rank 1 shard
├── ...
├── training_state/ # optimizer state, scheduler, RNG state
│ ├── optimizer_state.json
│ └── rng_state_rank{r}.pt
└── adapter_model.safetensors # LoRA weights only (if save_lora_only: true)

When save_async: true, DCP writes happen on a background thread. The trainer continues with the next batch while the checkpoint is written. A lock prevents two async saves from running simultaneously. Async saves are flushed before the training run exits.

xorl automatically detects the latest checkpoint in output_dir when resuming:

train:
output_dir: outputs/my_run # xorl scans for the latest step_{N} directory

To load from a specific path:

train:
load_checkpoint_path: outputs/my_run/weights/run_abc123/step_500

On first training, xorl loads the base model weights from the HuggingFace checkpoint specified in model.model_path. The load_weights_mode field controls how this initial load is distributed.

The load_weights_mode field controls how base model weights are loaded from disk and distributed across ranks:

Weight Loading: broadcast vs all_ranksbroadcast (default)Disk1 readRank 0reads onlyNCCL broadcastRank 1Rank 2Rank N1 disk read on rank 0 · NCCL broadcasts to all✓ Safe on shared filesystems · defaultall_ranksDiskparallelRank 0Rank 1Rank NN concurrent reads · each rank gets its own shard✓ Fast on Lustre / object storage · required for EP
ValueDescriptionWhen to use
broadcast (default)Rank 0 reads the full checkpoint from disk, then broadcasts each parameter shard to the appropriate ranksDefault for most runs; safe on shared filesystems that cannot handle parallel reads
all_ranksEvery rank reads its own parameter shard directly from the checkpointFast parallel storage (Lustre, object storage); required when EP expert shards differ per rank

DCP is tightly integrated with FSDP2’s DTensor sharding. When saving:

  1. Each rank’s FSDPModule exposes its local parameter shards as DTensor objects with placement annotations (Shard(0), Shard(1), Replicate).
  2. DCP’s save API serializes each shard to __{rank}_{shard_idx}.distcp along with the logical tensor metadata.

When loading into a different topology (e.g., resuming a 4-GPU run from an 8-GPU checkpoint):

  1. DCP reads metadata.json to determine the logical tensor shapes and original shard placements.
  2. Each rank requests the portion of each tensor that maps to its new shard coordinates.
  3. DCP computes which source shard files to read and splices the correct byte ranges.

This reshard-on-load is handled automatically — no manual conversion is required.

With save_lora_only: true, DCP saves only the LoRA adapter parameters (lora_A, lora_B for each targeted module). The base model weights are not saved. The resulting checkpoint is compatible with the HuggingFace PEFT format.


The sections below cover internal optimizations. Most users don’t need to read further.

xorl applies several optimizations to minimize checkpoint latency and I/O overhead for large distributed models.

When loading model weights into GPU memory, xorl uses _MultiStreamDMA — a two-CUDA-stream manager that round-robins H2D (Host-to-Device) DMA transfers across streams to saturate the PCIe copy engine:

Stream 0: pin → DMA tensor A → sync → place into model
Stream 1: pin → DMA tensor B → sync → place into model
─────────────────────────────────────────────────►
Streams overlap, doubling effective H2D throughput vs single-stream

How it works:

  1. Each weight tensor is first pinned with pin_memory() to lock it in pageable memory, which makes it eligible for asynchronous DMA
  2. The transfer is issued with non_blocking=True into a GPU staging buffer on the current stream
  3. The staging buffer is added to a _pending list; the stream advances to the next tensor
  4. Before NCCL operations, stream.synchronize() ensures the DMA completed, then the result is placed into the model parameter or distributed via NCCL collective

The pending list is bounded to roughly one shard’s worth of GPU temporaries, keeping peak VRAM overhead small.


When save_async: true, DCP writes happen on a background thread using a dedicated Gloo process group:

Training step N: forward → backward → optim → trigger async DCP save (non-blocking)
Training step N+1: ─────────────────────────────────────── (continues immediately)
Background: write rank_{r}.distcp ───────────────────────────►

The Gloo process group is created lazily on the first async save and reused for all subsequent ones. Before starting a new async save, xorl waits for the previous one to complete (.result() + dist.barrier()), preventing two concurrent saves from racing.

Post-save memory cleanup: After each async DCP save, xorl runs an aggressive cleanup sequence to reclaim the GPU memory held by DCP’s internal state tensors:

gc.collect() # first GC pass
gc.collect() # second pass for cyclic references
torch.cuda.empty_cache() # release CUDA cache
torch.cuda.synchronize() # ensure GPU is idle

This is especially important for large EP models where DCP briefly materializes full expert weight tensors.

DCP write config: thread_count=1 per rank (avoids a PyTorch concurrent write bug), sync_files=False (no fsync() on each write — faster but relies on OS page cache durability), single_file_per_rank=True.


Prefetch + batched broadcast (rank-0 weight loading)

Section titled “Prefetch + batched broadcast (rank-0 weight loading)”

In broadcast mode, rank 0 reads from disk and broadcasts to all other ranks. xorl pipelines this with:

Background prefetch: A background thread pre-loads the next shard file while the current shard is being broadcast (prefetch_count=2). This overlaps NFS/SSD reads with NCCL communication, hiding disk latency behind the broadcast.

Batched metadata broadcast: Instead of one broadcast_object_list() call per tensor (each requiring a pickle + NCCL launch), xorl collects all tensor metadata (name, shape, dtype) into a single list and broadcasts it in one call. This reduces N broadcast launches to 1 for the entire shard, eliminating per-tensor NCCL overhead.

Non-blocking CPU→GPU transfer before broadcast:

tensor = broadcast_tensor.pin_memory().to(device, non_blocking=True)
dist.broadcast(tensor, src=0)

The pin_memory() + non_blocking=True initiates the H2D DMA while the NCCL library is setting up the collective, overlapping the two operations.


For Expert Parallel models, xorl skips disk reads for expert weight tensors that don’t belong to the current rank. With ep_size=8, each rank needs only 1/8 of all expert shards — the other 7/8 are read by other ranks. xorl’s checkpoint handler provides a skip_key_fn predicate that filters out out-of-range expert keys before any I/O occurs:

Without EP filtering: every rank reads ALL expert shards → 8× redundant reads
With EP filtering: each rank reads only its 1/8 expert shards → ~87% I/O reduction

Skipped keys are tracked in the handler’s completion counters so DCP’s progress reporting remains accurate.


Before materializing model parameters from meta device, xorl resizes expert tensors from their full shape [num_experts, K, N] down to the EP-local shape [num_local_experts, K, N] on the meta device itself. This prevents a brief OOM spike that would otherwise occur if FSDP2 tried to allocate a full-size expert tensor on a rank that will only ever hold a fraction of the experts.


OptimizationApplies toBenefit
Multi-stream DMAWeight loading~2× H2D throughput via overlapped PCIe DMA
Async savingCheckpoint saveNon-blocking writes; training continues immediately
Background prefetchRank-0 broadcastOverlaps disk read with NCCL broadcast
Batched metadata broadcastRank-0 broadcast1 NCCL launch vs N per-tensor launches
Non-blocking H2D before broadcastRank-0 broadcastOverlaps DMA with NCCL setup
EP-aware filtered loadingEP weight loading~87% I/O reduction at ep_size=8
Expert pre-shrinkingEP initializationPrevents OOM on meta device materialization
Aggressive post-save GCAsync saveReclaims DCP temporary tensors promptly
FileDescription
src/xorl/checkpoint/checkpointer.pyDistributedCheckpointer — DCP save/load, async save with Gloo process group, extra-state per rank
src/xorl/models/module_utils.py_MultiStreamDMA — 2-stream H2D DMA; rank0_load_and_broadcast_weights — prefetch + batched broadcast; EP-aware filtered loading; expert pre-shrinking
src/xorl/arguments.pysave_steps, save_epochs, save_async, save_hf_weights, load_checkpoint_path, load_weights_mode