Loss Functions
xorl supports multiple loss functions for different training objectives. Each is selected via the loss_fn field in the forward_backward request (server training) or configured automatically for local training.
Supported Loss Functions
Section titled “Supported Loss Functions”loss_fn name | Use case | Returns |
|---|---|---|
causallm_loss | Standard next-token prediction (SFT, continued pretraining) | Scalar cross-entropy loss |
policy_loss | PPO-style policy gradient (RLHF, GRPO) | Scalar policy gradient loss |
importance_sampling | Off-policy RL training with IS correction | Scalar IS-weighted loss |
causallm_loss
Section titled “causallm_loss”Standard next-token prediction using cross-entropy. The loss is computed over all positions where labels != IGNORE_INDEX (-100). Loss is normalized by the total number of valid tokens globally across all ranks and micro-batches:
loss = CE_sum(logits, labels) / global_valid_tokensThis is the default loss for local training and the most common choice for SFT and continued pretraining.
policy_loss
Section titled “policy_loss”PPO-style policy gradient loss for online RL training. Requires logprobs_ref in loss_fn_params (log probabilities from the reference model) and per-token advantages.
The loss applies clipped importance sampling:
ratio = exp(logprobs - logprobs_ref)loss = -mean(min(ratio * advantages, clip(ratio, 1-eps, 1+eps) * advantages))where eps is controlled by the eps_clip parameter (default 0.2). Used with PPO, GRPO, and related algorithms.
importance_sampling_loss
Section titled “importance_sampling_loss”Off-policy RL training where the behavior policy differs from the current policy. Requires per-token importance weights iw passed in loss_fn_params.
loss = -mean(iw * logprobs * advantages)The eps_clip parameter bounds the importance weights to [1-eps, 1+eps] to prevent excessive policy updates from stale data.
Per-token outputs
Section titled “Per-token outputs”The causallm_loss and policy_loss functions can return per-token log probabilities and elementwise cross-entropy values alongside the scalar loss. These are used for:
- Computing KL divergence against a reference model
- Logging token-level reward signals
- Debugging training dynamics
Per-token outputs are controlled by the return_per_token parameter in the model runner, not by a separate loss function name.
Loss Function Backends
Section titled “Loss Function Backends”Compiled cross-entropy (ce_mode)
Section titled “Compiled cross-entropy (ce_mode)”For causallm_loss, xorl supports two cross-entropy backends controlled by ce_mode:
| Value | Description | When to use |
|---|---|---|
compiled (default) | torch.compile-compiled chunked cross-entropy | Production training; fuses logit computation with CE for reduced peak memory |
eager | Standard F.cross_entropy | Debugging; incompatible with torch.compile on the full model |
The compiled backend computes cross-entropy in chunks along the sequence dimension, avoiding materializing the full [B × S, vocab_size] float32 tensor at once. This is particularly important for large vocabulary models (Qwen3 has vocab_size=151,936) where the naive logit tensor can be 2–4 GB per micro-batch.
Vocabulary-parallel cross-entropy (TP training)
Section titled “Vocabulary-parallel cross-entropy (TP training)”When tensor_parallel_size > 1, the lm_head output is sharded across TP ranks: each rank holds logits for vocab_size / tp_size tokens. xorl computes cross-entropy directly on these sharded logits using a fused vocab-parallel CE kernel:
- Each TP rank computes the local log-sum-exp contribution from its vocab shard.
- An all-reduce aggregates the global log-sum-exp across TP ranks.
- Each rank computes the per-token CE using the correct global normalization.
This avoids an all-gather of the full logit tensor before CE, saving vocab_size × B × S × 4 bytes of cross-TP communication per forward pass.
Loss Computation Flow
Section titled “Loss Computation Flow”Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/ops/loss/ | Loss function implementations: causallm_loss, policy_loss, importance_sampling |
src/xorl/ops/loss/compiled_cross_entropy.py | Compiled chunked cross-entropy |
src/xorl/ops/loss/vocab_parallel_cross_entropy.py | Vocabulary-parallel cross-entropy for TP |
src/xorl/distributed/gradient_accumulate_loss.py | GradientAccumulateLoss — token-normalized loss accumulation across micro-batches |