MoE Router
The TopKRouter (src/xorl/models/layers/moe/router.py) is stateless — it owns no parameters. The gate linear (nn.Linear(hidden_size, num_experts, bias=False)) lives on MoEBlock so that the checkpoint path resolves to mlp.gate.weight, matching HuggingFace conventions.
Routing Algorithm
Section titled “Routing Algorithm”routing_weights = softmax(gate(hidden_states), dim=-1) # [T, E]routing_weights, selected_experts = topk(routing_weights, k=top_k, dim=-1) # [T, k]if norm_topk_prob: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # renormalize to sum=1Each token independently selects its top-K experts from the softmax distribution over all E experts. Gate scores are broadcast into the weighted sum after expert compute.
Router Parameters
Section titled “Router Parameters”| Parameter | Type | Description |
|---|---|---|
norm_topk_prob | bool | Renormalize the top-k gate scores to sum to 1. Qwen3-MoE uses true. With renormalization the output is a convex combination of expert outputs. Without it, scores are raw softmax values — which can be very small when E is large. |
router_fp32 | bool | Upcast gate linear and hidden states to float32 before computing logits. Recommended for numerical stability and alignment with SGLang inference. |
freeze_router | bool | Freeze the gate linear weights during training (server training). Useful for fine-tuning where pre-trained routing should be preserved. |
train_router | bool | When false, detaches routing weights before they flow into the expert weighted sum. The gate still trains via auxiliary losses on router_logits, but not through expert gradients. Stabilizes training when routing is already good. |
model: router_fp32: true # recommended for stability + SGLang alignment
freeze_router: true # server training: freeze gate.weight entirely
train: no_decay_modules: [gate] # local training: exclude gate from weight decayRouting Replay (R3)
Section titled “Routing Replay (R3)”Routing Replay (R3) solves a correctness problem that arises when combining EP with gradient checkpointing.
The problem: Flash attention is non-deterministic on recompute — small numerical differences mean the hidden states fed to the router during backward’s recompute differ from the original forward. This causes top-K selection to pick different experts, changing the AllToAll shapes and causing tensor mismatches.
The solution: R3 records routing decisions (expert indices, optionally gate weights) during the forward pass and replays them verbatim during backward recompute.
Lifecycle Stages
Section titled “Lifecycle Stages”| Stage | When | What happens |
|---|---|---|
record | Forward pass | Router runs normally; top-K indices are stored to CPU-pinned buffer |
replay_forward | Recompute in backward | Router result discarded; cached indices loaded instead |
replay_backward | Backward through replayed forward | AllToAll shapes guaranteed identical to original forward |
R3 Mode for RL (Pre-loaded Routing from SGLang)
Section titled “R3 Mode for RL (Pre-loaded Routing from SGLang)”R3 also powers rollout routing replay for RL training. When SGLang generates rollouts, it records which experts each token was routed to during inference. xorl loads these pre-computed routing decisions so that training uses exactly the same expert assignments as inference, ensuring gradient-level consistency between the policy model’s training and rollout steps.
R3 is activated automatically when ep_size > 1 and gradient checkpointing is enabled. No configuration is required.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/models/layers/moe/router.py | TopKRouter — softmax, top-K, optional renormalization |
src/xorl/models/layers/moe/routing_replay.py | RoutingReplay — R3 record/replay for checkpoint determinism and RL routing consistency |
src/xorl/models/layers/moe/moe_block.py | MoEBlock — gate linear ownership, router invocation, R3 integration |