Skip to content

MoE Router

The TopKRouter (src/xorl/models/layers/moe/router.py) is stateless — it owns no parameters. The gate linear (nn.Linear(hidden_size, num_experts, bias=False)) lives on MoEBlock so that the checkpoint path resolves to mlp.gate.weight, matching HuggingFace conventions.


routing_weights = softmax(gate(hidden_states), dim=-1) # [T, E]
routing_weights, selected_experts = topk(routing_weights, k=top_k, dim=-1) # [T, k]
if norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # renormalize to sum=1

Each token independently selects its top-K experts from the softmax distribution over all E experts. Gate scores are broadcast into the weighted sum after expert compute.


ParameterTypeDescription
norm_topk_probboolRenormalize the top-k gate scores to sum to 1. Qwen3-MoE uses true. With renormalization the output is a convex combination of expert outputs. Without it, scores are raw softmax values — which can be very small when E is large.
router_fp32boolUpcast gate linear and hidden states to float32 before computing logits. Recommended for numerical stability and alignment with SGLang inference.
freeze_routerboolFreeze the gate linear weights during training (server training). Useful for fine-tuning where pre-trained routing should be preserved.
train_routerboolWhen false, detaches routing weights before they flow into the expert weighted sum. The gate still trains via auxiliary losses on router_logits, but not through expert gradients. Stabilizes training when routing is already good.
model:
router_fp32: true # recommended for stability + SGLang alignment
freeze_router: true # server training: freeze gate.weight entirely
train:
no_decay_modules: [gate] # local training: exclude gate from weight decay

Routing Replay (R3) solves a correctness problem that arises when combining EP with gradient checkpointing.

The problem: Flash attention is non-deterministic on recompute — small numerical differences mean the hidden states fed to the router during backward’s recompute differ from the original forward. This causes top-K selection to pick different experts, changing the AllToAll shapes and causing tensor mismatches.

The solution: R3 records routing decisions (expert indices, optionally gate weights) during the forward pass and replays them verbatim during backward recompute.

R3: Routing Replay for Gradient Checkpointing DeterminismForward Pass (stage: record)1. Router computes top-K indices2. Indices saved to RoutingReplay (CPU pin)3. AllToAll dispatch with real indicesActivation checkpointing drops expert activationsBackward Recompute (stage: replay_backward)1. Router runs but result is discarded2. Cached indices popped from replay buffer3. AllToAll uses same indices → same shapesCorrect gradients despite non-deterministic attnreplayR3 mode: Pre-loaded Routing (RL use case)Routing decisions loaded from inference enginebefore training step (SGLang rollout data)Training exactly replicates inference routingRoutingReplay Buffertop_indices_list: List[Tensor] (CPU pinned)separate forward_index / backward_index@torch.compiler.disable — not JIT traceable
StageWhenWhat happens
recordForward passRouter runs normally; top-K indices are stored to CPU-pinned buffer
replay_forwardRecompute in backwardRouter result discarded; cached indices loaded instead
replay_backwardBackward through replayed forwardAllToAll shapes guaranteed identical to original forward

R3 Mode for RL (Pre-loaded Routing from SGLang)

Section titled “R3 Mode for RL (Pre-loaded Routing from SGLang)”

R3 also powers rollout routing replay for RL training. When SGLang generates rollouts, it records which experts each token was routed to during inference. xorl loads these pre-computed routing decisions so that training uses exactly the same expert assignments as inference, ensuring gradient-level consistency between the policy model’s training and rollout steps.

R3 is activated automatically when ep_size > 1 and gradient checkpointing is enabled. No configuration is required.


FileDescription
src/xorl/models/layers/moe/router.pyTopKRouter — softmax, top-K, optional renormalization
src/xorl/models/layers/moe/routing_replay.pyRoutingReplay — R3 record/replay for checkpoint determinism and RL routing consistency
src/xorl/models/layers/moe/moe_block.pyMoEBlock — gate linear ownership, router invocation, R3 integration