Skip to content

Local Training Config

Local training uses a nested YAML with model, data, train, and lora sections, passed to:

Terminal window
torchrun --nproc_per_node=8 -m xorl.cli.train config.yaml

Any field can be overridden on the command line with --section.field value:

Terminal window
torchrun --nproc_per_node=8 -m xorl.cli.train config.yaml \
--train.lr 2e-5 \
--train.output_dir outputs/my_run \
--train.pipeline_parallel_size 4 \
--model.attn_implementation flash_attention_4 \
--data.sample_packing_sequence_len 16384 \
--lora.enable_lora true \
--lora.lora_rank 32

FieldDefaultDescription
model_pathnullHF Hub ID or local path to pre-trained weights. If null, model is randomly initialized.
config_pathsame as model_pathPath to model config. Useful when config and weights are in separate locations.
tokenizer_pathsame as config_pathPath to tokenizer.
attn_implementationflash_attention_3Attention backend: eager, sdpa, native (PyTorch SDPA+cuDNN, no deps, Hopper+Blackwell), flash_attention_3 (FA3, Hopper), flash_attention_4 (FA4 CUTE, Hopper+Blackwell).
moe_implementationnullMoE kernel: null (auto), eager, triton (Triton group GEMM), native (torch._grouped_mm), quack.
ep_dispatchalltoallExpert-parallel dispatch: alltoall or deepep (NVLink-optimized).
deepep_buffer_size_gb2.0DeepEP NVLink buffer size per GPU in GB. Only active when ep_dispatch: deepep.
deepep_num_sms20SMs assigned to DeepEP communication kernels. Must be even. Lower values leave more SMs for overlapped compute.
deepep_async_combinefalseOverlap DeepEP combine with the next layer’s compute (experimental).
merge_qkvtrueKeep Q/K/V projections fused as qkv_proj. Set false for tensor parallelism or per-projection LoRA.
basic_modules[]Additional module names (beyond _no_split_modules) to shard as separate FSDP units.
foundation{}Extra foundation model config (dict).
encoders{}Multimodal encoder configs, keyed by type (image, video, audio). Each value must have model_path and optionally config_path.

FieldDefaultDescription
datasetsrequiredList of dataset configs (see Dataset entry fields).
test_datasets[]Optional list of evaluation dataset configs. Same format as datasets.
dataset_prepared_pathlast_prepared_datasetPath where prepared/cached datasets are stored.
select_columnsnull (all columns)Columns to keep from each dataset (e.g., [input_ids, labels]).
sample_packing_methodsequentialPacking strategy: sequential (fast, good packing) or multipack (FFD-based, maximizes bin utilization).
sample_packing_sequence_len32000Target packed bin length in tokens.
sample_packing_group_size100000Number of samples packed together in one group. Larger values improve packing slightly.
sample_packing_sequentiallynullForce sequential packing regardless of method.
sample_packing_mp_start_methodnullMultiprocessing start method for packing: fork, spawn, or forkserver.
eval_sample_packingnullSet to false to disable packing during evaluation if errors occur.
dataloader_num_workers8DataLoader worker processes.
dataloader_prefetch_factor2Batches to prefetch per worker. Set to null when num_workers=0.
dataloader_pin_memorytruePin CPU memory for faster GPU transfer.
dataloader_drop_lasttrueDrop the last incomplete batch.
pad_to_multiple_of128Pad packed sequences to a multiple of this value for GPU efficiency.
val_set_sizenullValidation split size. Integer = number of samples, float = fraction (e.g., 0.05).
shuffle_merged_datasetstrueShuffle the merged dataset before training.
shuffle_before_merging_datasetstrueShuffle each dataset individually before merging.
dataset_num_procCPU countProcesses for dataset preprocessing. Defaults to XORL_DATASET_NUM_PROC env var or CPU count.
dataset_shard_numnullNumber of shards to split the dataset into (for parallel preprocessing).
dataset_shard_idxnullWhich shard to use (used with dataset_shard_num).
num_dataset_shards_to_savenullNumber of shards to save the prepared dataset to. Default: single file.
skip_prepare_datasetfalseSkip preparation and load directly from dataset_prepared_path.
push_dataset_to_hubnullPush prepared dataset to HF Hub (org/repo-name). Requires hf_use_auth_token: true.
hf_use_auth_tokennullUse HF auth token for private datasets or Hub pushes.

Each entry in datasets (or test_datasets) is a dict:

FieldDefaultDescription
pathrequiredHF Hub ID (org/name), s3://, gs://, abfs://, https://, or local path. Use dummy for synthetic data.
typetokenizedDataset type. Only tokenized is currently supported.
namenullHF dataset config name (subset).
splitnullHF dataset split (e.g., train, validation).
revisionnullHF Hub commit hash or tag.
trust_remote_codefalseAllow remote code execution for custom HF datasets.
data_filesnullSpecific files to load (string or list). Requires ds_type when set.
ds_typenullFile format when using data_files: json, csv, parquet, arrow, text.
max_seq_lennullTruncate and filter samples longer than this.
shardsnullSplit dataset into N pieces (use with shards_idx).
shards_idxnullIndex of the shard to use (0-based).
preprocess_shardsnullProcess dataset in N sequential chunks for memory efficiency. Mutually exclusive with shards.

FieldDefaultDescription
data_parallel_modefsdp2Data parallelism: none, ddp, fsdp2 (ZeRO-3). FSDP2 requires init_device: meta.
data_parallel_shard_size-1 (world_size)Number of GPUs per FSDP shard group. -1 = full world.
data_parallel_replicate_size-1 (1)Number of data replicas for HSDP (Hybrid Sharded DP). -1 = auto. dp_size = replicate × shard.
tensor_parallel_size1TP degree. Shards weight matrices column/row-wise across GPUs. Requires merge_qkv: false.
pipeline_parallel_size1PP stages. Splits model layers across GPUs.
pipeline_parallel_schedule1F1BPP schedule: 1F1B (interleaved, lower memory) or GPipe (simpler).
pp_variable_seq_lengthstrueDynamically negotiate max seq length per PP step via all-reduce, avoiding padding to static max.
expert_parallel_size1EP degree for MoE models. Distributes experts across GPUs.
ulysses_parallel_size1Ulysses context parallelism degree.
ringattn_parallel_size1Ring Attention degree.
cp_fsdp_modeallHow context parallelism interacts with FSDP: all (both Ulysses+Ring), ulysses_only, ring_only, none.
reshard_after_forwardnullFSDP2 reshard after forward. true = save memory, false = save communication (used for PP by default). null = auto.
ep_outsidefalsePlace EP outside the EP-FSDP mesh.
FieldDefaultDescription
optimizeradamwOptimizer: adamw, anyprecision_adamw, sgd, muon.
optimizer_dtypebf16Dtype for optimizer states in anyprecision_adamw and muon: fp32 or bf16. BF16 halves optimizer memory.
lr5e-5Peak learning rate.
lr_min1e-7Minimum learning rate at the end of decay.
lr_start0.0Initial learning rate at the start of warmup.
lr_warmup_ratio0.0Fraction of total steps used for linear LR warmup.
lr_decay_styleconstantLR schedule after warmup: constant, linear, cosine.
lr_decay_ratio1.0Fraction of total steps to apply LR decay over.
weight_decay0.0L2 regularization (AdamW weight decay).
no_decay_modules[]Module name substrings to exclude from weight decay (e.g., [norm]).
no_decay_params[]Parameter name substrings to exclude from weight decay (e.g., [bias]).
max_grad_norm1.0Gradient clipping threshold.
muon_lr0.02Learning rate for Muon matrix parameter groups. Only used when optimizer: muon.
muon_momentum0.95Muon momentum coefficient.
muon_nesterovtrueUse Nesterov momentum in Muon.
muon_ns_steps5Newton-Schulz iterations for Muon orthogonalization.
muon_adjust_lr_fnnullMuon LR scaling: original (scale by sqrt(max(1,A/B))), match_rms_adamw (lets Muon reuse AdamW LR/WD).
FieldDefaultDescription
micro_batch_size1Per-GPU batch size per step.
gradient_accumulation_steps1Steps before optimizer update. Effective per-device batch = micro_batch_size × gradient_accumulation_steps.
num_train_epochs1Number of passes over the dataset.
max_stepsnullMaximum total training steps. Overrides epoch-based stopping and caps LR scheduler length.
FieldDefaultDescription
enable_mixed_precisiontrueBF16 mixed-precision training.
enable_gradient_checkpointingtrueActivation recomputation to reduce memory.
enable_reentrantfalseUse reentrant gradient checkpointing. Default (non-reentrant) is generally preferred.
recompute_modulesnullSelective checkpointing by submodule: [self_attn], [mlp], or [self_attn, mlp]. null = whole-layer recompute.
moe_checkpoint_methodnullMoE-specific checkpoint: null (full recompute including EP communication), moe_act (recompute only gate/up activations, skip EP communication recompute — faster).
enable_full_shardtrueFSDP2 full parameter sharding (ZeRO-3). Set false for ZeRO-2.
enable_forward_prefetchtruePrefetch next FSDP unit’s parameters during forward pass.
enable_activation_offloadfalseOffload activations to CPU during forward pass.
activation_gpu_limit0.0GB of activations to keep on GPU when offloading. 0.0 = offload all.
enable_compilefalsetorch.compile for model forward pass.
init_devicecudaDevice for weight initialization: cpu (rank 0 only), cuda, meta (required for FSDP2), npu.
load_weights_modebroadcastbroadcast: rank 0 reads weights, broadcasts to other ranks (reduces disk I/O). all_ranks: every rank reads from disk.
enable_full_determinismfalseFull determinism mode. Requires allow_cuda_launch_blocking: true. Degrades performance.
allow_cuda_launch_blockingfalseAllow CUDA_LAUNCH_BLOCKING=1. Off by default to prevent accidental performance degradation.
empty_cache_steps500Call torch.cuda.empty_cache() every N steps.
gc_steps500Call gc.collect() every N steps. Python GC is disabled between calls.
FieldDefaultDescription
output_dirrequiredBase directory for checkpoints, logs, and model assets. Must be on a shared filesystem for multi-node training.
ckpt_managerdcpCheckpoint format: dcp (PyTorch Distributed Checkpoint Protocol).
save_steps0Save a checkpoint every N global steps. 0 = disabled.
save_epochs1Save every N epochs (fractional OK: 0.25 saves 4× per epoch).
save_asyncfalseWrite checkpoints asynchronously (non-blocking training).
save_hf_weightstrueAlso save HF-format weights (.safetensors) to the last checkpoint directory.
load_checkpoint_pathnullPath to checkpoint to resume from. Set to auto to auto-detect the latest checkpoint in output_dir.
FieldDefaultDescription
log_formatprogress_barprogress_bar (tqdm), structured (key=value lines for parsing).
use_wandbtrueEnable Weights & Biases logging.
wandb_projectXorlW&B project name.
wandb_namenullW&B run name.
wandb_tagsnullW&B run tags (list of strings).
wandb_log_interval1Log metrics to W&B every N steps.
FieldDefaultDescription
enable_profilingfalseEnable PyTorch profiler.
profile_start_step1Step to start profiling at.
profile_end_step2Step to stop profiling at.
profile_trace_dir./traceDirectory to write profiler trace files.
profile_record_shapestrueRecord input tensor shapes in the trace.
profile_profile_memorytrueRecord memory usage in the trace.
profile_with_stacktrueRecord Python stack traces.
profile_rank0_onlytrueOnly profile rank 0. Set false to profile all ranks (produces many large files).
FieldDefaultDescription
seed42Global random seed for reproducibility.

FieldDefaultDescription
enable_lorafalseEnable LoRA fine-tuning.
lora_rank16LoRA rank (r).
lora_alpha16LoRA scaling factor (alpha). Effective scale = alpha / rank.
lora_target_modulesnullModule names to inject LoRA into. null = default linear projections for the architecture.
save_lora_onlyfalseOnly save LoRA adapter weights in HF checkpoints (not the full model).
enable_qlorafalseQuantize base weights and train LoRA on top. Implies enable_lora: true.
quant_formatnvfp4Quantization format: nvfp4 (4-bit, Hopper+), block_fp8 (8-bit blocks), nf4 (4-bit normal float).
quant_group_size16Quantization group size. Default is 16. Recommended values: 16 for nvfp4, 128 for block_fp8, 64 for nf4.
exclude_modulesnullModule names to exclude from QLoRA quantization (kept as BF16). null = auto-detect from checkpoint config.
merge_lora_interval0Merge LoRA delta into base weights every N steps. For QLoRA this also re-quantizes. 0 = disabled.
reset_optimizer_on_mergefalseReLoRA-style optimizer state reset after each merge. Requires merge_lora_interval > 0.
enable_aqnfalseAdaptive Quantization Noise: adds calibrated noise to quantized weights during forward to reduce quantization bias.
aqn_alpha1.0Noise magnitude scale for AQN.