Skip to content

Loss Functions

xorl supports multiple loss functions for different training objectives. Each is selected via the loss_fn field in the forward_backward request (server training) or configured automatically for local training.

loss_fn nameUse caseReturns
causallm_lossStandard next-token prediction (SFT, continued pretraining)Scalar cross-entropy loss
policy_lossPPO-style policy gradient (RLHF, GRPO)Scalar policy gradient loss
importance_samplingOff-policy RL training with IS correctionScalar IS-weighted loss

Standard next-token prediction using cross-entropy. The loss is computed over all positions where labels != IGNORE_INDEX (-100). Loss is normalized by the total number of valid tokens globally across all ranks and micro-batches:

loss = CE_sum(logits, labels) / global_valid_tokens

This is the default loss for local training and the most common choice for SFT and continued pretraining.

PPO-style policy gradient loss for online RL training. Requires logprobs_ref in loss_fn_params (log probabilities from the reference model) and per-token advantages.

The loss applies clipped importance sampling:

ratio = exp(logprobs - logprobs_ref)
loss = -mean(min(ratio * advantages, clip(ratio, 1-eps, 1+eps) * advantages))

where eps is controlled by the eps_clip parameter (default 0.2). Used with PPO, GRPO, and related algorithms.

Off-policy RL training where the behavior policy differs from the current policy. Requires per-token importance weights iw passed in loss_fn_params.

loss = -mean(iw * logprobs * advantages)

The eps_clip parameter bounds the importance weights to [1-eps, 1+eps] to prevent excessive policy updates from stale data.

The causallm_loss and policy_loss functions can return per-token log probabilities and elementwise cross-entropy values alongside the scalar loss. These are used for:

  • Computing KL divergence against a reference model
  • Logging token-level reward signals
  • Debugging training dynamics

Per-token outputs are controlled by the return_per_token parameter in the model runner, not by a separate loss function name.

For causallm_loss, xorl supports two cross-entropy backends controlled by ce_mode:

ValueDescriptionWhen to use
compiled (default)torch.compile-compiled chunked cross-entropyProduction training; fuses logit computation with CE for reduced peak memory
eagerStandard F.cross_entropyDebugging; incompatible with torch.compile on the full model

The compiled backend computes cross-entropy in chunks along the sequence dimension, avoiding materializing the full [B × S, vocab_size] float32 tensor at once. This is particularly important for large vocabulary models (Qwen3 has vocab_size=151,936) where the naive logit tensor can be 2–4 GB per micro-batch.

Vocabulary-parallel cross-entropy (TP training)

Section titled “Vocabulary-parallel cross-entropy (TP training)”

When tensor_parallel_size > 1, the lm_head output is sharded across TP ranks: each rank holds logits for vocab_size / tp_size tokens. xorl computes cross-entropy directly on these sharded logits using a fused vocab-parallel CE kernel:

  1. Each TP rank computes the local log-sum-exp contribution from its vocab shard.
  2. An all-reduce aggregates the global log-sum-exp across TP ranks.
  3. Each rank computes the per-token CE using the correct global normalization.

This avoids an all-gather of the full logit tensor before CE, saving vocab_size × B × S × 4 bytes of cross-TP communication per forward pass.

Loss Computation Flowtokens[B, S]model fwdlm_headlogits[B,S,V]cross_entropycompiled / VP-CEscalar loss/ gvtgradient ← .backward()gvt = global valid tokens (all-reduced across all ranks)
FileDescription
src/xorl/ops/loss/Loss function implementations: causallm_loss, policy_loss, importance_sampling
src/xorl/ops/loss/compiled_cross_entropy.pyCompiled chunked cross-entropy
src/xorl/ops/loss/vocab_parallel_cross_entropy.pyVocabulary-parallel cross-entropy for TP
src/xorl/distributed/gradient_accumulate_loss.pyGradientAccumulateLoss — token-normalized loss accumulation across micro-batches