Inference: xorl-sglang
xorl-sglang is Together’s fork of SGLang, the primary supported inference backend for xorl RL training. It extends SGLang with the APIs and data export capabilities needed to integrate tightly with the xorl training server.
Why a Fork?
Section titled “Why a Fork?”The core RL training loop requires capabilities that upstream SGLang does not provide:
-
Weight updates from training — the training server must push new weights to the inference server after each policy update. Upstream SGLang does not expose a distributed weight-update API.
-
Per-token logprob export — RL loss functions require
logprobsfrom the rollout (old policy). SGLang must return per-token log probabilities alongside completions. -
MoE routing data export (R3) — for MoE models, the inference server’s routing decisions must be exported so training can replay them exactly, ensuring gradient consistency. Upstream SGLang does not export expert routing indices.
-
Numerical alignment flags — training and inference must produce identical logits for the same input. xorl-sglang supports the same numerical alignment flags as the xorl training server (
router_fp32,lm_head_fp32,rmsnorm_mode,attention_cast_bf16, etc.).
What Was Modified
Section titled “What Was Modified”xorl-sglang adds ~2,500 lines across 32 files on top of upstream SGLang. The changes fall into five categories. The xorl-rl-target branch in the repo contains all xorl-specific commits.
1. Weight Update Endpoints and Protocol
Section titled “1. Weight Update Endpoints and Protocol”The largest change (~1,700 lines in model_runner.py, scheduler_update_weights_mixin.py, tokenizer_communicator_mixin.py, http_server.py). xorl-sglang adds several HTTP endpoints for NCCL-based weight synchronization:
| Endpoint | Description |
|---|---|
POST /init_weights_update_group | Join an NCCL process group for weight sync. Forces eager NCCL communicator creation with device_id and NCCL_CUMEM_ENABLE=0 to match xorl’s training side. |
POST /update_weights_from_distributed | Receive broadcasted weight tensors via NCCL dist.broadcast. Called per weight bucket. |
POST /prepare_weights_update | Phase 1 of two-phase protocol: starts background recv threads ready to receive NCCL broadcasts. |
POST /complete_weights_update | Phase 2: waits for recv threads to complete and applies the received weights. |
POST /receive_weights | Single-call weight receive via NCCL broadcast (pause/broadcast/resume protocol). |
POST /receive_weights_ep_scatter | Receive expert weights from multiple EP training ranks via NCCL P2P scatter. Receives directly into param.data slices (zero-copy). |
POST /list_weights | List all weight names/shapes — used to verify weight name mapping between training and inference. |
POST /destroy_weights_update_group | Tear down the NCCL group after sync completes. |
Key implementation details:
- Eager NCCL init: Both sides must use
device_idto force eager communicator creation. Without this, sglang uses lazy init and xorl’s rank 0 hangs waiting for peers. - Two-phase protocol:
prepare_weights_updatestarts background recv threads, thencomplete_weights_updateapplies them after the training-side broadcast finishes. This avoids blocking the scheduler. - EP scatter: For MoE models with expert parallelism, each EP training rank sends its local experts directly into the corresponding slice of
param.datavia P2P ops — no intermediate buffers. - Health check bypass: During weight updates, the
/healthendpoint returns 200 immediately instead of running a test generation, avoiding timeouts during NCCL operations.
The training server’s nccl_broadcast backend drives these calls — see Backend: nccl_broadcast for the full protocol.
2. MoE Routing Data Export (R3)
Section titled “2. MoE Routing Data Export (R3)”Modified files: routed_experts_capturer.py, qwen3_moe.py, qwen2_moe.py, io_struct.py, detokenizer_manager.py, scheduler_output_processor_mixin.py
For MoE models, xorl-sglang records which experts each token was routed to during generation and returns this alongside completions:
# Returned in meta_inforouted_experts = completions[i].meta_info["routed_experts"]# Shape: [num_tokens, num_layers, top_k], dtype: int32, base64-encodedThe training server decodes and replays these routing decisions via R3, ensuring training uses the same expert assignments as inference did. This is critical for gradient consistency in GRPO/PPO training with MoE models.
Expert routing weights export (--enable-return-expert-logits): In addition to expert indices, xorl-sglang can also capture and return topk_weights (the softmax routing probabilities). The RoutedExpertsCapturer allocates a separate topk_weights_buffer in float32 to record these per layer.
Format: xorl-sglang encodes routing data as base64 int32 arrays with shape metadata to minimize transfer overhead:
{ "routed_experts": "<base64_encoded_int32_array>", "routed_expert_meta": {"shape": [512, 28, 2], "dtype": "int32"}}The RoutingReplayHandler in xorl decodes and distributes this data across context-parallel and packing dimensions automatically.
3. Numerical Alignment
Section titled “3. Numerical Alignment”Modified files: qwen3_moe.py, qwen3.py, qwen2_moe.py, qwen2.py, layernorm.py, server_args.py
xorl and xorl-sglang must produce identical logits for the same input. Mismatched logits lead to incorrect importance sampling ratios and degraded RL performance. xorl-sglang adds:
--enable-fp32-router: MoE router gate computation in FP32. InQwen3MoeSparseMoeBlock, whenrl_on_policy_targetis set, the router usesF.linearin float32 with explicitF.softmax+torch.topkinstead of the fused TopK kernel.--enable-fp32-lm-head: LM head logits in FP32 (already upstream, but used together with the router flag).- RMSNorm alignment: When
rl_on_policy_targetis set, Q/K norms usecast_x_before_out_mul=Trueandfp32_residual=Falseto match xorl’s training-side norm behavior. - Fused kernel bypass: Disables
fused_qk_norm_ropeandfused_kv_bufferwhenrl_on_policy_targetis set, since fused kernels can produce slightly different numerics.
Configure both sides identically:
# xorl server config (server_config.yaml)router_fp32: truelm_head_fp32: truermsnorm_mode: compileactivation_native: falserope_native: falseattention_cast_bf16: false# xorl-sglang launchpython -m sglang.launch_server \ --model-path Qwen/Qwen3-8B \ --enable-fp32-router \ --enable-fp32-lm-head \ --rl-on-policy-target xorl \ --port 300004. Batch-Invariant Mode
Section titled “4. Batch-Invariant Mode”Modified files: model_runner.py, batch_invariant_ops.py
When --rl-on-policy-target xorl-batch-invariant is set, xorl-sglang enables batch-invariant mode. This ensures that the model produces identical outputs regardless of how requests are batched together — a requirement for correct importance sampling in RL. Without this, different batching of the same request can produce different logits due to padding interactions in attention and MoE kernels.
5. Bug Fixes
Section titled “5. Bug Fixes”Two additional fixes on top of the upstream merge:
req_to_token_poolslot leak (schedule_batch.py,scheduler.py): When amax_new_tokens=0(prefill-only) request arrives during an idle window, itsScheduleBatchgetsis_prefill_only=True. If normal generation requests are later merged in,merge_batch()never cleared this flag, soget_next_batch_to_run()skipped the decode path. Requests allocated pool slots during prefill but never decoded, never finished, and never freed their slots — exhausting the pool. Fixed by clearingis_prefill_onlyon merge and recomputing it from actual request state.
KV Cache Flush on Weight Update
Section titled “KV Cache Flush on Weight Update”After a weight sync, xorl-sglang flushes its KV cache. The training server sends flush_cache=True on the last weight bucket to trigger this. Without flushing, cached key/value tensors from the old policy weights would be reused with new weights, producing incorrect logits.
Installation
Section titled “Installation”xorl-sglang is included as a git submodule under submodules/xorl-sglang. If you cloned with --recurse-submodules, it’s already checked out.
pip install -e "submodules/xorl-sglang/python[all]"Or use pyproject.sglang.toml to install xorl, xorl-client, and xorl-sglang together (pins PyTorch to 2.9.1):
cp pyproject.sglang.toml pyproject.tomluv sync # or: pip install -e .See the installation guide for full details.
Launching xorl-sglang
Section titled “Launching xorl-sglang”Single GPU (Qwen3-8B FP8)
Section titled “Single GPU (Qwen3-8B FP8)”python -m sglang.launch_server \ --model-path Qwen/Qwen3-8B-FP8 \ --port 30000 \ --rl-on-policy-target xorl \ --enable-fp32-router \ --enable-fp32-lm-head \ --mem-fraction-static 0.88Tensor Parallel (Qwen3-30B FP8, 2 GPUs)
Section titled “Tensor Parallel (Qwen3-30B FP8, 2 GPUs)”CUDA_VISIBLE_DEVICES=4,5 python -m sglang.launch_server \ --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8 \ --port 30000 \ --tp-size 2 \ --rl-on-policy-target xorl \ --enable-fp32-router \ --enable-fp32-lm-head \ --mem-fraction-static 0.88Tensor Parallel (Qwen3-235B FP8, 4 GPUs, remote node)
Section titled “Tensor Parallel (Qwen3-235B FP8, 4 GPUs, remote node)”python -m sglang.launch_server \ --model-path Qwen/Qwen3-235B-A22B-Instruct-2507-FP8 \ --port 30000 \ --host 0.0.0.0 \ --tp-size 4 \ --rl-on-policy-target xorl \ --enable-fp32-router \ --enable-fp32-lm-head \ --mem-fraction-static 0.88Key launch flags
Section titled “Key launch flags”| Flag | Description |
|---|---|
--rl-on-policy-target xorl | Enable xorl weight sync endpoints and routing capture. Use xorl-batch-invariant for batch-invariant mode. |
--enable-fp32-router | MoE router gate in FP32 (must match training config router_fp32: true) |
--enable-fp32-lm-head | LM head logits in FP32 (must match training config lm_head_fp32: true) |
--enable-return-routed-experts | Return expert routing indices in response metadata (for R3) |
--enable-return-expert-logits | Return expert routing weights alongside indices |
--tp-size N | Tensor parallelism across N GPUs |
--mem-fraction-static 0.88 | Fraction of GPU memory for KV cache (leave headroom for weight sync buffers) |
Wait for ready
Section titled “Wait for ready”import requests, timewhile True: try: r = requests.get("http://localhost:30000/health") if r.status_code == 200: break except: pass time.sleep(2)Registering with the Training Server
Section titled “Registering with the Training Server”Before weight sync can happen, register the xorl-sglang instance with the training server:
import requests
requests.post("http://training-server:6000/add_inference_endpoint", json={ "host": "inference-node-01", "port": 30000, "worker_port": 30000, "world_size": 4, # match --tp-size})Multiple replicas can be registered — weight sync broadcasts to all of them in parallel:
for host, port, tp_size in inference_replicas: requests.post("http://training-server:6000/add_inference_endpoint", json={ "host": host, "port": port, "worker_port": port, "world_size": tp_size, })Sleep / Wake for Memory Sharing
Section titled “Sleep / Wake for Memory Sharing”On nodes where training and inference share GPUs, use sleep/wake to hand off GPU memory:
# Before a large training step: free inference GPU memoryrequests.post("http://inference-node:30000/sleep")
# Run training stepsfor _ in range(n_train_steps): training.forward_backward(...) training.optim_step(...)
# Sync new weights and resume inferencetraining.sync_inference_weights(master_address=TRAIN_HOST, master_port=29600).result()requests.post("http://inference-node:30000/wake_up")Upstream Compatibility
Section titled “Upstream Compatibility”xorl-sglang tracks upstream SGLang closely and aims to rebase regularly. The modifications touch 32 files but are concentrated in a few areas:
| Area | Files modified | Nature of change |
|---|---|---|
| Weight sync protocol | model_runner.py, scheduler_update_weights_mixin.py, tokenizer_communicator_mixin.py, http_server.py | New endpoints and NCCL recv logic (~1,700 lines added) |
| Routing data export | routed_experts_capturer.py, io_struct.py, detokenizer_manager.py, scheduler_output_processor_mixin.py | Additions to existing generation path |
| Numerical alignment | qwen3_moe.py, qwen3.py, qwen2_moe.py, qwen2.py, layernorm.py | Conditional paths gated on rl_on_policy_target |
| CLI args | server_args.py | New flags: --enable-fp32-router, --enable-return-expert-logits, --enable-rdma-weight-updates, --rl-on-policy-target xorl |
| Bug fixes | schedule_batch.py, scheduler.py | Prefill-only slot leak fix |
The main branch is periodically rebased onto upstream SGLang. The xorl-rl-target branch contains the xorl-specific commits on top. If you need a feature from a newer upstream SGLang release, file an issue at the xorl-sglang repo.
Source
Section titled “Source”| Repo | Description |
|---|---|
togethercomputer/xorl-sglang | xorl’s SGLang fork — weight update API, R3 export, numerical alignment |
src/xorl/server/weight_sync/backends/nccl_broadcast.py | Training-side NCCL broadcast implementation that drives the SGLang weight update endpoints |
src/xorl/server/runner/utils/routing_replay_handler.py | Decodes R3 routing data from SGLang and distributes it across SP/packing dimensions |