RL Training
xorl is designed from the ground up for online reinforcement learning on large language models. The server training API exposes every step of the training loop — forward pass, backward pass, optimizer step, weight sync — as individual HTTP endpoints, letting an external RL orchestrator drive training precisely when and how it needs to.
Choosing a Loss Function
Section titled “Choosing a Loss Function”| Algorithm | loss_fn | Key params | When to use |
|---|---|---|---|
| SFT / continued pretraining | causallm_loss | — | Standard next-token prediction |
| PPO | policy_loss | eps_clip=0.2 | Clipped policy gradient, most stable for RL |
| GRPO (simpler RL) | importance_sampling | — | No clipping, simpler but less stable |
| PPO + stale data correction | policy_loss | use_tis=True, tis_clip_low=0.1 | Multiple epochs over same rollout |
Architecture
Section titled “Architecture”A typical xorl RL setup runs three components in parallel:
Components:
| Component | Role |
|---|---|
| RL Orchestrator | Drives the RL loop: prompts → rollouts → rewards → advantages → training steps → weight sync. Built with xorl-client. |
| xorl Training Server | Exposes training as HTTP API; handles FSDP, EP, PP, all parallelism internally |
| xorl-sglang | Modified SGLang inference server; generates rollouts, exports per-token logprobs and MoE routing decisions |
Getting Started
Section titled “Getting Started”Install both client libraries:
pip install git+https://github.com/togethercomputer/xorl-client.gitStart the training server:
python -m xorl.server.launcher \ --config server_config.yaml \ --api-port 5555Start xorl-sglang (see the xorl-sglang page for details):
python -m sglang.launch_server \ --model-path Qwen/Qwen3-8B \ --port 30000A minimal PPO training loop with xorl_client:
import xorl_client
service = xorl_client.ServiceClient(base_url="http://localhost:5555")training = xorl_client.TrainingClient( holder=service.holder, model_id="default", base_model="Qwen/Qwen3-8B",)adam = xorl_client.AdamParams(learning_rate=1e-6, beta1=0.9, beta2=0.95, eps=1e-8)
for step in range(num_steps): # 1. Generate rollouts from inference completions = sglang_client.generate(prompts, max_new_tokens=512, return_logprobs=True)
# 2. Score with reward model rewards = reward_model.score(completions)
# 3. Compute per-token advantages (GAE or simple returns) advantages = compute_advantages(rewards, completions)
# 4. Pack into Datum objects with RL fields data = [ xorl_client.Datum( model_input=xorl_client.ModelInput.from_ints(c.token_ids), loss_fn_inputs={ "labels": c.token_ids, "logprobs": c.logprobs, # old policy logprobs from rollout "advantages": advantages[i], }, ) for i, c in enumerate(completions) ]
# 5. Training step (PPO loss) fwd = training.forward_backward( data, loss_fn="policy_loss", loss_fn_params={"eps_clip": 0.2, "compute_kl_stats": True}, ) opt = training.optim_step(adam) result = fwd.result() opt.result()
# 6. Sync weights to inference every N steps if step % sync_every == 0: service.sync_inference_weights( master_address=TRAINING_HOST, master_port=29600, )RL Features Implemented
Section titled “RL Features Implemented”PPO Policy Loss
Section titled “PPO Policy Loss”Full PPO-style clipped policy gradient loss (src/xorl/ops/loss/policy_loss.py):
ratio = exp(new_logprobs - old_logprobs)pg_loss = max(ratio × A, clip(ratio, 1-ε, 1+ε_high) × A)Loss function parameters:
| Parameter | Default | Description |
|---|---|---|
eps_clip | 0.2 | Lower clip ratio for PPO |
eps_clip_high | 0.2 | Upper clip ratio (can differ from lower for asymmetric clipping) |
eps_clip_c | null | Dual-clip ratio for negative advantages (set to e.g. 3.0 to prevent large negative updates) |
compute_kl_stats | false | Return KL statistics in metrics (K3 estimator, entropy, ratio stats) |
Metrics returned:
| Metric | Description |
|---|---|
pg_clipfrac | Fraction of tokens where gradient clipping was applied |
kl_sample_train_k3 | Schulman’s K3 KL estimator: mean(exp(log_ratio) - log_ratio - 1) — non-negative, unbiased |
entropy_sample | Mean entropy of old policy: -mean(old_logprobs) |
ratio_mean/min/max | Importance sampling ratio statistics |
fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={ "eps_clip": 0.2, "eps_clip_high": 0.2, "eps_clip_c": 3.0, # dual-clip negative advantages "compute_kl_stats": True,})GRPO / Importance Sampling Loss
Section titled “GRPO / Importance Sampling Loss”Simpler importance-sampling loss for GRPO-style training (src/xorl/ops/loss/importance_sampling_loss.py):
ratio = exp(new_logprobs - old_logprobs)loss = -(ratio × advantages).mean()No clipping — relies on advantages being bounded. Suitable when the policy doesn’t drift far from the rollout policy.
fwd = training.forward_backward(data, loss_fn="importance_sampling", loss_fn_params={ "compute_kl_stats": True,})IcePop
Section titled “IcePop”IcePop (from GLM-5, arXiv:2602.15763) is a hard masking technique that zeros gradients for tokens where the importance sampling ratio falls outside the band [1/β, β]. This prevents large policy updates on tokens where the current and old policies have already diverged significantly — complementary to PPO’s soft clipping.
fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={ "eps_clip": 0.2, "icepop_beta": 5.0, # zero gradients when ratio < 0.2 or ratio > 5.0})IcePop and PPO clipping are complementary: PPO softly clips the loss value, IcePop hard-zeros the gradient for extreme ratios.
TIS — Temporal Importance Sampling
Section titled “TIS — Temporal Importance Sampling”When multiple training steps run on the same rollout batch (e.g. multiple epochs over data collected at step T), the policy drifts away from the rollout policy. TIS corrects for this by weighting each token’s gradient by exp(train_logprobs - rollout_logprobs):
tis_weight = clip(exp(train_logprobs - rollout_logprobs), tis_clip_low, tis_clip_high)loss = (tis_weight × pg_loss).mean()Requires passing rollout_logprobs separately from logprobs (old policy at training time):
fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={ "use_tis": True, "tis_clip_low": 0.1, "tis_clip_high": 2.0,})Pass rollout logprobs as a separate field in loss_fn_inputs:
datum = xorl_client.Datum( model_input=xorl_client.ModelInput.from_ints(token_ids), loss_fn_inputs={ "labels": token_ids, "logprobs": logprobs_at_last_train_step, # old policy: updated each train step "rollout_logprobs": logprobs_at_rollout, # reference: fixed from inference "advantages": advantages, },)R3 — Routing Replay for MoE
Section titled “R3 — Routing Replay for MoE”For MoE models, gradient checkpointing causes a problem: flash attention is non-deterministic on recompute, which produces different hidden states → different top-K routing → AllToAll shape mismatches during backward. R3 (Rollout Routing Replay) solves this by recording expert routing decisions and replaying them verbatim during backward recompute.
In RL training, R3 goes further: routing decisions from the inference rollout (xorl-sglang) can be pre-loaded into the training server so that the training policy routes tokens to exactly the same experts as inference did. This gives gradient-level consistency between rollout and training steps.
Pass routing data from xorl-sglang in the forward_backward request:
datum = xorl_client.Datum( model_input=xorl_client.ModelInput.from_ints(token_ids), loss_fn_inputs={ "labels": token_ids, "logprobs": logprobs, "advantages": advantages, },)
fwd = training.forward_backward( [datum], loss_fn="policy_loss", routed_experts=rollout_routing_indices, # [T, L, K] from sglang routed_expert_logits=rollout_routing_weights, # [T, L, K] softmax weights)See the Router page for full R3 details.
In This Section
Section titled “In This Section”| Page | What it covers |
|---|---|
| Server Training | REST API reference, xorl-client SDK usage, multi-node launch |
| Inference: xorl-sglang | Modified SGLang backend, what was changed, integration with xorl |
| Loss Functions | causallm_loss, policy_loss, importance_sampling, per_token_ce |