Skip to content

Inference: xorl-sglang

xorl-sglang is Together’s fork of SGLang, the primary supported inference backend for xorl RL training. It extends SGLang with the APIs and data export capabilities needed to integrate tightly with the xorl training server.


The core RL training loop requires capabilities that upstream SGLang does not provide:

  1. Weight updates from training — the training server must push new weights to the inference server after each policy update. Upstream SGLang does not expose a distributed weight-update API.

  2. Per-token logprob export — RL loss functions require logprobs from the rollout (old policy). SGLang must return per-token log probabilities alongside completions.

  3. MoE routing data export (R3) — for MoE models, the inference server’s routing decisions must be exported so training can replay them exactly, ensuring gradient consistency. Upstream SGLang does not export expert routing indices.

  4. Numerical alignment flags — training and inference must produce identical logits for the same input. xorl-sglang supports the same numerical alignment flags as the xorl training server (router_fp32, lm_head_fp32, rmsnorm_native, attention_cast_bf16, etc.).


xorl-sglang adds three HTTP endpoints that implement the NCCL-based weight synchronization protocol:

EndpointDescription
POST /init_weights_update_groupJoin an NCCL process group for weight sync. Called by the xorl training server before each sync.
POST /update_weights_from_distributedReceive broadcasted weight tensors via NCCL dist.broadcast. Called per weight bucket.
POST /destroy_weights_update_groupTear down the NCCL group after sync completes.

The training server’s nccl_broadcast backend drives these calls — see Backend: nccl_broadcast for the full protocol.

xorl-sglang returns per-token log probabilities for every generated token when return_logprobs=True. These are passed directly to the xorl training server as logprobs in the loss_fn_inputs field of each Datum.

completions = sglang_client.generate(
prompts,
sampling_params={"max_new_tokens": 512, "return_logprobs": True},
)
logprobs = completions[i].meta_info["logprobs"] # list of per-token log probs

For MoE models, xorl-sglang records which experts each token was routed to during generation and returns this alongside completions:

# Returned in meta_info
routed_experts = completions[i].meta_info["routed_experts"]
# Shape: [num_tokens, num_layers, top_k], dtype: int32, base64-encoded

The training server decodes and replays these routing decisions via R3, ensuring training uses the same expert assignments as inference did. This is critical for gradient consistency in GRPO/PPO training with MoE models.

Format: xorl-sglang encodes routing data as base64 int32 arrays with shape metadata to minimize transfer overhead:

{
"routed_experts": "<base64_encoded_int32_array>",
"routed_expert_meta": {"shape": [512, 28, 2], "dtype": "int32"}
}

The RoutingReplayHandler in xorl decodes and distributes this data across context-parallel and packing dimensions automatically.

xorl and xorl-sglang use the same numerical alignment flags to ensure training and inference produce identical logits for the same input. Mismatched logits lead to incorrect importance sampling ratios and degraded RL performance.

Configure both sides identically:

# xorl server config (server_config.yaml)
router_fp32: true
lm_head_fp32: true
rmsnorm_native: false
activation_native: false
rope_native: false
attention_cast_bf16: false
# xorl-sglang launch
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--router-fp32 \
--lm-head-fp32 \
--port 30000

After a weight sync, xorl-sglang flushes its KV cache. The training server sends flush_cache=True on the last weight bucket to trigger this. Without flushing, cached key/value tensors from the old policy weights would be reused with new weights, producing incorrect logits.


Terminal window
# Clone and install
git clone https://github.com/togethercomputer/xorl-sglang
cd xorl-sglang
pip install -e ".[all]"
# Launch (single GPU)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--port 30000 \
--router-fp32 \
--lm-head-fp32
# Launch (tensor parallel, 4 GPUs)
python -m sglang.launch_server \
--model-path Qwen/Qwen3-8B \
--port 30000 \
--tp-size 4 \
--router-fp32 \
--lm-head-fp32

Wait for the server to be ready:

import requests, time
while True:
try:
r = requests.get("http://localhost:30000/health")
if r.status_code == 200:
break
except:
pass
time.sleep(2)

Before weight sync can happen, register the xorl-sglang instance with the training server:

import requests
requests.post("http://training-server:5555/add_inference_endpoint", json={
"host": "inference-node-01",
"port": 30000,
"world_size": 4, # match --tp-size
})

Multiple replicas can be registered — weight sync broadcasts to all of them in parallel:

for host, port, tp_size in inference_replicas:
requests.post("http://training-server:5555/add_inference_endpoint", json={
"host": host, "port": port, "world_size": tp_size,
})

On nodes where training and inference share GPUs, use sleep/wake to hand off GPU memory:

# Before a large training step: free inference GPU memory
requests.post("http://inference-node:30000/sleep")
# Run training steps
for _ in range(n_train_steps):
training.forward_backward(...)
training.optim_step(...)
# Sync new weights and resume inference
service.sync_inference_weights(master_address=TRAIN_HOST, master_port=29600)
requests.post("http://inference-node:30000/wake_up")

xorl-sglang tracks upstream SGLang closely and aims to rebase regularly. The modifications are minimal and isolated:

  • Weight update endpoints (new files)
  • Routing data export (additions to existing generation path)
  • Numerical alignment CLI flags (additions to launch args)
  • Per-token logprob format compatibility with xorl’s LossFnOutput

If you need a feature from a newer upstream SGLang release, file an issue at the xorl-sglang repo.

RepoDescription
togethercomputer/xorl-sglangxorl’s SGLang fork — weight update API, R3 export, numerical alignment
src/xorl/server/weight_sync/backends/nccl_broadcast.pyTraining-side NCCL broadcast implementation that drives the SGLang weight update endpoints
src/xorl/server/runner/utils/routing_replay_handler.pyDecodes R3 routing data from SGLang and distributes it across SP/packing dimensions