Skip to content

RL Training

xorl is designed from the ground up for online reinforcement learning on large language models. The server training API exposes every step of the training loop — forward pass, backward pass, optimizer step, weight sync — as individual HTTP endpoints, letting an external RL orchestrator drive training precisely when and how it needs to.


Algorithmloss_fnKey paramsWhen to use
SFT / continued pretrainingcausallm_lossStandard next-token prediction
PPOpolicy_losseps_clip=0.2Clipped policy gradient, most stable for RL
GRPO (simpler RL)importance_samplingNo clipping, simpler but less stable
PPO + stale data correctionpolicy_lossuse_tis=True, tis_clip_low=0.1Multiple epochs over same rollout

A typical xorl RL setup runs three components in parallel:

xorl RL Training ArchitectureRL Orchestratorxorl-client SDK · HTTPxorl Training Serverforward_backwardoptim_stepFSDP2 · EP · PP · LoRAPPO · GRPO · IcePop · TISxorl-sglangInference servergenerate rolloutsR3 routing data · logprobsweight sync via NCCLforward_backwardoptim_stepgenerate(prompts)logprobs · routingsync_inference_weights (NCCL)Reward model, advantage computation, data packing — handled by the orchestrator

Components:

ComponentRole
RL OrchestratorDrives the RL loop: prompts → rollouts → rewards → advantages → training steps → weight sync. Built with xorl-client.
xorl Training ServerExposes training as HTTP API; handles FSDP, EP, PP, all parallelism internally
xorl-sglangModified SGLang inference server; generates rollouts, exports per-token logprobs and MoE routing decisions

Install both client libraries:

Terminal window
pip install git+https://github.com/togethercomputer/xorl-client.git

Start the training server:

Terminal window
python -m xorl.server.launcher \
--config server_config.yaml \
--api-port 5555

Start xorl-sglang (see the xorl-sglang page for details):

Terminal window
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--port 30000

A minimal PPO training loop with xorl_client:

import xorl_client
service = xorl_client.ServiceClient(base_url="http://localhost:5555")
training = xorl_client.TrainingClient(
holder=service.holder,
model_id="default",
base_model="Qwen/Qwen3-8B",
)
adam = xorl_client.AdamParams(learning_rate=1e-6, beta1=0.9, beta2=0.95, eps=1e-8)
for step in range(num_steps):
# 1. Generate rollouts from inference
completions = sglang_client.generate(prompts, max_new_tokens=512, return_logprobs=True)
# 2. Score with reward model
rewards = reward_model.score(completions)
# 3. Compute per-token advantages (GAE or simple returns)
advantages = compute_advantages(rewards, completions)
# 4. Pack into Datum objects with RL fields
data = [
xorl_client.Datum(
model_input=xorl_client.ModelInput.from_ints(c.token_ids),
loss_fn_inputs={
"labels": c.token_ids,
"logprobs": c.logprobs, # old policy logprobs from rollout
"advantages": advantages[i],
},
)
for i, c in enumerate(completions)
]
# 5. Training step (PPO loss)
fwd = training.forward_backward(
data,
loss_fn="policy_loss",
loss_fn_params={"eps_clip": 0.2, "compute_kl_stats": True},
)
opt = training.optim_step(adam)
result = fwd.result()
opt.result()
# 6. Sync weights to inference every N steps
if step % sync_every == 0:
service.sync_inference_weights(
master_address=TRAINING_HOST,
master_port=29600,
)

Full PPO-style clipped policy gradient loss (src/xorl/ops/loss/policy_loss.py):

ratio = exp(new_logprobs - old_logprobs)
pg_loss = max(ratio × A, clip(ratio, 1-ε, 1+ε_high) × A)

Loss function parameters:

ParameterDefaultDescription
eps_clip0.2Lower clip ratio for PPO
eps_clip_high0.2Upper clip ratio (can differ from lower for asymmetric clipping)
eps_clip_cnullDual-clip ratio for negative advantages (set to e.g. 3.0 to prevent large negative updates)
compute_kl_statsfalseReturn KL statistics in metrics (K3 estimator, entropy, ratio stats)

Metrics returned:

MetricDescription
pg_clipfracFraction of tokens where gradient clipping was applied
kl_sample_train_k3Schulman’s K3 KL estimator: mean(exp(log_ratio) - log_ratio - 1) — non-negative, unbiased
entropy_sampleMean entropy of old policy: -mean(old_logprobs)
ratio_mean/min/maxImportance sampling ratio statistics
fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={
"eps_clip": 0.2,
"eps_clip_high": 0.2,
"eps_clip_c": 3.0, # dual-clip negative advantages
"compute_kl_stats": True,
})

Simpler importance-sampling loss for GRPO-style training (src/xorl/ops/loss/importance_sampling_loss.py):

ratio = exp(new_logprobs - old_logprobs)
loss = -(ratio × advantages).mean()

No clipping — relies on advantages being bounded. Suitable when the policy doesn’t drift far from the rollout policy.

fwd = training.forward_backward(data, loss_fn="importance_sampling", loss_fn_params={
"compute_kl_stats": True,
})

IcePop (from GLM-5, arXiv:2602.15763) is a hard masking technique that zeros gradients for tokens where the importance sampling ratio falls outside the band [1/β, β]. This prevents large policy updates on tokens where the current and old policies have already diverged significantly — complementary to PPO’s soft clipping.

IcePop: Hard Masking by Importance Sampling Ratioratio < 1/βgradient ZEROED1/β ≤ ratio ≤ βgradient ACTIVE (PPO applies)ratio > βgradient ZEROEDratio = 1/βratio = βicepop_maskfrac metric: fraction of valid tokens that were zeroed
fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={
"eps_clip": 0.2,
"icepop_beta": 5.0, # zero gradients when ratio < 0.2 or ratio > 5.0
})

IcePop and PPO clipping are complementary: PPO softly clips the loss value, IcePop hard-zeros the gradient for extreme ratios.


When multiple training steps run on the same rollout batch (e.g. multiple epochs over data collected at step T), the policy drifts away from the rollout policy. TIS corrects for this by weighting each token’s gradient by exp(train_logprobs - rollout_logprobs):

tis_weight = clip(exp(train_logprobs - rollout_logprobs), tis_clip_low, tis_clip_high)
loss = (tis_weight × pg_loss).mean()

Requires passing rollout_logprobs separately from logprobs (old policy at training time):

fwd = training.forward_backward(data, loss_fn="policy_loss", loss_fn_params={
"use_tis": True,
"tis_clip_low": 0.1,
"tis_clip_high": 2.0,
})

Pass rollout logprobs as a separate field in loss_fn_inputs:

datum = xorl_client.Datum(
model_input=xorl_client.ModelInput.from_ints(token_ids),
loss_fn_inputs={
"labels": token_ids,
"logprobs": logprobs_at_last_train_step, # old policy: updated each train step
"rollout_logprobs": logprobs_at_rollout, # reference: fixed from inference
"advantages": advantages,
},
)

For MoE models, gradient checkpointing causes a problem: flash attention is non-deterministic on recompute, which produces different hidden states → different top-K routing → AllToAll shape mismatches during backward. R3 (Rollout Routing Replay) solves this by recording expert routing decisions and replaying them verbatim during backward recompute.

In RL training, R3 goes further: routing decisions from the inference rollout (xorl-sglang) can be pre-loaded into the training server so that the training policy routes tokens to exactly the same experts as inference did. This gives gradient-level consistency between rollout and training steps.

Pass routing data from xorl-sglang in the forward_backward request:

datum = xorl_client.Datum(
model_input=xorl_client.ModelInput.from_ints(token_ids),
loss_fn_inputs={
"labels": token_ids,
"logprobs": logprobs,
"advantages": advantages,
},
)
fwd = training.forward_backward(
[datum],
loss_fn="policy_loss",
routed_experts=rollout_routing_indices, # [T, L, K] from sglang
routed_expert_logits=rollout_routing_weights, # [T, L, K] softmax weights
)

See the Router page for full R3 details.


PageWhat it covers
Server TrainingREST API reference, xorl-client SDK usage, multi-node launch
Inference: xorl-sglangModified SGLang backend, what was changed, integration with xorl
Loss Functionscausallm_loss, policy_loss, importance_sampling, per_token_ce