Skip to content

Inference: xorl-sglang

xorl-sglang is Together’s fork of SGLang, the primary supported inference backend for xorl RL training. It extends SGLang with the APIs and data export capabilities needed to integrate tightly with the xorl training server.


The core RL training loop requires capabilities that upstream SGLang does not provide:

  1. Weight updates from training — the training server must push new weights to the inference server after each policy update. Upstream SGLang does not expose a distributed weight-update API.

  2. Per-token logprob export — RL loss functions require logprobs from the rollout (old policy). SGLang must return per-token log probabilities alongside completions.

  3. MoE routing data export (R3) — for MoE models, the inference server’s routing decisions must be exported so training can replay them exactly, ensuring gradient consistency. Upstream SGLang does not export expert routing indices.

  4. Numerical alignment flags — training and inference must produce identical logits for the same input. xorl-sglang supports the same numerical alignment flags as the xorl training server (router_fp32, lm_head_fp32, rmsnorm_mode, attention_cast_bf16, etc.).


xorl-sglang adds ~2,500 lines across 32 files on top of upstream SGLang. The changes fall into five categories. The xorl-rl-target branch in the repo contains all xorl-specific commits.

The largest change (~1,700 lines in model_runner.py, scheduler_update_weights_mixin.py, tokenizer_communicator_mixin.py, http_server.py). xorl-sglang adds several HTTP endpoints for NCCL-based weight synchronization:

EndpointDescription
POST /init_weights_update_groupJoin an NCCL process group for weight sync. Forces eager NCCL communicator creation with device_id and NCCL_CUMEM_ENABLE=0 to match xorl’s training side.
POST /update_weights_from_distributedReceive broadcasted weight tensors via NCCL dist.broadcast. Called per weight bucket.
POST /prepare_weights_updatePhase 1 of two-phase protocol: starts background recv threads ready to receive NCCL broadcasts.
POST /complete_weights_updatePhase 2: waits for recv threads to complete and applies the received weights.
POST /receive_weightsSingle-call weight receive via NCCL broadcast (pause/broadcast/resume protocol).
POST /receive_weights_ep_scatterReceive expert weights from multiple EP training ranks via NCCL P2P scatter. Receives directly into param.data slices (zero-copy).
POST /list_weightsList all weight names/shapes — used to verify weight name mapping between training and inference.
POST /destroy_weights_update_groupTear down the NCCL group after sync completes.

Key implementation details:

  • Eager NCCL init: Both sides must use device_id to force eager communicator creation. Without this, sglang uses lazy init and xorl’s rank 0 hangs waiting for peers.
  • Two-phase protocol: prepare_weights_update starts background recv threads, then complete_weights_update applies them after the training-side broadcast finishes. This avoids blocking the scheduler.
  • EP scatter: For MoE models with expert parallelism, each EP training rank sends its local experts directly into the corresponding slice of param.data via P2P ops — no intermediate buffers.
  • Health check bypass: During weight updates, the /health endpoint returns 200 immediately instead of running a test generation, avoiding timeouts during NCCL operations.

The training server’s nccl_broadcast backend drives these calls — see Backend: nccl_broadcast for the full protocol.

Modified files: routed_experts_capturer.py, qwen3_moe.py, qwen2_moe.py, io_struct.py, detokenizer_manager.py, scheduler_output_processor_mixin.py

For MoE models, xorl-sglang records which experts each token was routed to during generation and returns this alongside completions:

# Returned in meta_info
routed_experts = completions[i].meta_info["routed_experts"]
# Shape: [num_tokens, num_layers, top_k], dtype: int32, base64-encoded

The training server decodes and replays these routing decisions via R3, ensuring training uses the same expert assignments as inference did. This is critical for gradient consistency in GRPO/PPO training with MoE models.

Expert routing weights export (--enable-return-expert-logits): In addition to expert indices, xorl-sglang can also capture and return topk_weights (the softmax routing probabilities). The RoutedExpertsCapturer allocates a separate topk_weights_buffer in float32 to record these per layer.

Format: xorl-sglang encodes routing data as base64 int32 arrays with shape metadata to minimize transfer overhead:

{
"routed_experts": "<base64_encoded_int32_array>",
"routed_expert_meta": {"shape": [512, 28, 2], "dtype": "int32"}
}

The RoutingReplayHandler in xorl decodes and distributes this data across context-parallel and packing dimensions automatically.

Modified files: qwen3_moe.py, qwen3.py, qwen2_moe.py, qwen2.py, layernorm.py, server_args.py

xorl and xorl-sglang must produce identical logits for the same input. Mismatched logits lead to incorrect importance sampling ratios and degraded RL performance. xorl-sglang adds:

  • --enable-fp32-router: MoE router gate computation in FP32. In Qwen3MoeSparseMoeBlock, when rl_on_policy_target is set, the router uses F.linear in float32 with explicit F.softmax + torch.topk instead of the fused TopK kernel.
  • --enable-fp32-lm-head: LM head logits in FP32 (already upstream, but used together with the router flag).
  • RMSNorm alignment: When rl_on_policy_target is set, Q/K norms use cast_x_before_out_mul=True and fp32_residual=False to match xorl’s training-side norm behavior.
  • Fused kernel bypass: Disables fused_qk_norm_rope and fused_kv_buffer when rl_on_policy_target is set, since fused kernels can produce slightly different numerics.

Configure both sides identically:

# xorl server config (server_config.yaml)
router_fp32: true
lm_head_fp32: true
rmsnorm_mode: compile
activation_native: false
rope_native: false
attention_cast_bf16: false
Terminal window
# xorl-sglang launch
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--enable-fp32-router \
--enable-fp32-lm-head \
--rl-on-policy-target xorl \
--port 30000

Modified files: model_runner.py, batch_invariant_ops.py

When --rl-on-policy-target xorl-batch-invariant is set, xorl-sglang enables batch-invariant mode. This ensures that the model produces identical outputs regardless of how requests are batched together — a requirement for correct importance sampling in RL. Without this, different batching of the same request can produce different logits due to padding interactions in attention and MoE kernels.

Two additional fixes on top of the upstream merge:

  • req_to_token_pool slot leak (schedule_batch.py, scheduler.py): When a max_new_tokens=0 (prefill-only) request arrives during an idle window, its ScheduleBatch gets is_prefill_only=True. If normal generation requests are later merged in, merge_batch() never cleared this flag, so get_next_batch_to_run() skipped the decode path. Requests allocated pool slots during prefill but never decoded, never finished, and never freed their slots — exhausting the pool. Fixed by clearing is_prefill_only on merge and recomputing it from actual request state.

After a weight sync, xorl-sglang flushes its KV cache. The training server sends flush_cache=True on the last weight bucket to trigger this. Without flushing, cached key/value tensors from the old policy weights would be reused with new weights, producing incorrect logits.


xorl-sglang is included as a git submodule under submodules/xorl-sglang. If you cloned with --recurse-submodules, it’s already checked out.

Terminal window
pip install -e "submodules/xorl-sglang/python[all]"

Or use pyproject.sglang.toml to install xorl, xorl-client, and xorl-sglang together (pins PyTorch to 2.9.1):

Terminal window
cp pyproject.sglang.toml pyproject.toml
uv sync # or: pip install -e .

See the installation guide for full details.

Terminal window
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B-FP8 \
--port 30000 \
--rl-on-policy-target xorl \
--enable-fp32-router \
--enable-fp32-lm-head \
--mem-fraction-static 0.88
Terminal window
CUDA_VISIBLE_DEVICES=4,5 python -m sglang.launch_server \
--model-path Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8 \
--port 30000 \
--tp-size 2 \
--rl-on-policy-target xorl \
--enable-fp32-router \
--enable-fp32-lm-head \
--mem-fraction-static 0.88

Tensor Parallel (Qwen3-235B FP8, 4 GPUs, remote node)

Section titled “Tensor Parallel (Qwen3-235B FP8, 4 GPUs, remote node)”
Terminal window
python -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 \
--port 30000 \
--host 0.0.0.0 \
--tp-size 4 \
--rl-on-policy-target xorl \
--enable-fp32-router \
--enable-fp32-lm-head \
--mem-fraction-static 0.88
FlagDescription
--rl-on-policy-target xorlEnable xorl weight sync endpoints and routing capture. Use xorl-batch-invariant for batch-invariant mode.
--enable-fp32-routerMoE router gate in FP32 (must match training config router_fp32: true)
--enable-fp32-lm-headLM head logits in FP32 (must match training config lm_head_fp32: true)
--enable-return-routed-expertsReturn expert routing indices in response metadata (for R3)
--enable-return-expert-logitsReturn expert routing weights alongside indices
--tp-size NTensor parallelism across N GPUs
--mem-fraction-static 0.88Fraction of GPU memory for KV cache (leave headroom for weight sync buffers)
import requests, time
while True:
try:
r = requests.get("http://localhost:30000/health")
if r.status_code == 200:
break
except:
pass
time.sleep(2)

Before weight sync can happen, register the xorl-sglang instance with the training server:

import requests
requests.post("http://training-server:6000/add_inference_endpoint", json={
"host": "inference-node-01",
"port": 30000,
"worker_port": 30000,
"world_size": 4, # match --tp-size
})

Multiple replicas can be registered — weight sync broadcasts to all of them in parallel:

for host, port, tp_size in inference_replicas:
requests.post("http://training-server:6000/add_inference_endpoint", json={
"host": host,
"port": port,
"worker_port": port,
"world_size": tp_size,
})

On nodes where training and inference share GPUs, use sleep/wake to hand off GPU memory:

# Before a large training step: free inference GPU memory
requests.post("http://inference-node:30000/sleep")
# Run training steps
for _ in range(n_train_steps):
training.forward_backward(...)
training.optim_step(...)
# Sync new weights and resume inference
training.sync_inference_weights(master_address=TRAIN_HOST, master_port=29600).result()
requests.post("http://inference-node:30000/wake_up")

xorl-sglang tracks upstream SGLang closely and aims to rebase regularly. The modifications touch 32 files but are concentrated in a few areas:

AreaFiles modifiedNature of change
Weight sync protocolmodel_runner.py, scheduler_update_weights_mixin.py, tokenizer_communicator_mixin.py, http_server.pyNew endpoints and NCCL recv logic (~1,700 lines added)
Routing data exportrouted_experts_capturer.py, io_struct.py, detokenizer_manager.py, scheduler_output_processor_mixin.pyAdditions to existing generation path
Numerical alignmentqwen3_moe.py, qwen3.py, qwen2_moe.py, qwen2.py, layernorm.pyConditional paths gated on rl_on_policy_target
CLI argsserver_args.pyNew flags: --enable-fp32-router, --enable-return-expert-logits, --enable-rdma-weight-updates, --rl-on-policy-target xorl
Bug fixesschedule_batch.py, scheduler.pyPrefill-only slot leak fix

The main branch is periodically rebased onto upstream SGLang. The xorl-rl-target branch contains the xorl-specific commits on top. If you need a feature from a newer upstream SGLang release, file an issue at the xorl-sglang repo.

RepoDescription
togethercomputer/xorl-sglangxorl’s SGLang fork — weight update API, R3 export, numerical alignment
src/xorl/server/weight_sync/backends/nccl_broadcast.pyTraining-side NCCL broadcast implementation that drives the SGLang weight update endpoints
src/xorl/server/runner/utils/routing_replay_handler.pyDecodes R3 routing data from SGLang and distributes it across SP/packing dimensions