Expert Parallelism
Expert Parallelism (EP) is xorl’s strategy for distributing the expert weights of
Mixture-of-Experts (MoE) layers across multiple GPUs. Each GPU owns a contiguous shard of
the full expert stack and only processes the tokens routed to its local experts. Between the
routing decision and the expert compute, tokens are redistributed across GPUs using either an
AllToAll collective (the default alltoall backend) or DeepEP’s NVLink-optimized
dispatch/combine kernels.
Table of contents
Section titled “Table of contents”- What EP does — expert weight sharding
- Token dispatch — routing and redistribution
- AllToAll dispatch backend
- DeepEP dispatch backend
- EP device mesh — 2D [EP, FSDP] for expert params
- Gradient handling
- Routing replay (R3)
- EP + FSDP2 composition
- EP + context parallelism (CP) — folded axis
- EP + pipeline parallelism (PP)
- MoE kernel selection
- Hard constraints
- Configuration examples
- Parameter reference
1. What EP does — expert weight sharding
Section titled “1. What EP does — expert weight sharding”A vanilla MoE layer stores all expert weights on every GPU, which quickly exhausts HBM
for large expert counts. EP distributes the expert stack so that each EP rank holds only
num_experts / ep_size experts.
Expert weight layout
Section titled “Expert weight layout”Expert weights for Qwen3-30B-A3B (128 experts, hidden 2048, intermediate 768) are stored as fused 3-D tensors:
gate_proj [num_experts, hidden_dim, intermediate_size] e.g. [128, 2048, 768]up_proj [num_experts, hidden_dim, intermediate_size]down_proj [num_experts, intermediate_size, hidden_dim]With ep_size=8, each EP rank holds 128 / 8 = 16 experts:
EP rank 0: gate_proj[ 0:16, :, :] experts 0-15EP rank 1: gate_proj[ 16:32, :, :] experts 16-31...EP rank 7: gate_proj[112:128, :, :] experts 112-127The sharding is applied along dimension 0 (the expert axis):
ep_plan = { "model.layers.*.mlp.experts.gate_proj": Shard(0), "model.layers.*.mlp.experts.up_proj": Shard(0), "model.layers.*.mlp.experts.down_proj": Shard(0),}How sharding is applied
Section titled “How sharding is applied”ParallelPlan.apply() (src/xorl/distributed/parallel_plan.py) iterates all model
parameters, matches each expert parameter against the EP plan, then either:
- Redistributes the full tensor into a local shard using
DTensor.redistribute()with aShard(0)placement on the EP mesh (normal path when loading from a checkpoint). - Annotates already-local tensors with a
spec_infoattribute (fast path when weights were loaded EP-aware, e.g. withload_weights_mode: all_ranks).
After sharding, each rank’s expert parameter has shape [num_local_experts, ...]. The
spec_info attribute records the ep_fsdp_mesh, the shard placement, and the FQN for
checkpoint reconstruction.
LoRA adapters injected into expert layers are sharded the same way
(gate_proj_lora_A/B, up_proj_lora_A/B, down_proj_lora_A/B all follow Shard(0)).
Shared LoRA weights with a size-1 expert dimension are replicated instead of sharded.
2. Token dispatch — routing and redistribution
Section titled “2. Token dispatch — routing and redistribution”Routing decision
Section titled “Routing decision”MoEBlock.forward() (src/xorl/models/layers/moe/moe_block.py) first computes routing
logits through the gate linear:
router_logits = gate(hidden_states) # [num_tokens, num_experts]TopKRouter applies softmax, selects the top-k experts, and optionally renormalizes
the weights so they sum to 1. The result is:
routing_weights [num_tokens, top_k]selected_experts [num_tokens, top_k]Each token therefore has top_k expert assignments, each pointing at one of the
num_experts experts on one of the EP ranks.
Token redistribution overview
Section titled “Token redistribution overview”Because expert weights are split across ranks, tokens must be sent to the GPU that owns the target expert. The dispatch step collects all tokens that belong to a given rank’s experts and forwards them; the combine step returns computed outputs back to their originating ranks.
Before dispatch: GPU 0: tokens [t0, t1, t2, t3] (all local tokens)
After dispatch (ep_size=4, 128 experts): GPU 0 receives: tokens routed to experts 0-31 GPU 1 receives: tokens routed to experts 32-63 GPU 2 receives: tokens routed to experts 64-95 GPU 3 receives: tokens routed to experts 96-127Token counts per rank are highly variable (load-imbalanced routing is common), so the dispatch collective must handle variable-length sends and receives.
3. AllToAll dispatch backend
Section titled “3. AllToAll dispatch backend”The default backend (ep_dispatch: alltoall) implements dispatch/combine using
torch.distributed.all_to_all_single with variable split sizes.
Source files
Section titled “Source files”src/xorl/distributed/moe/alltoall.py
Pre-dispatch: alltoall_pre_dispatch()
Section titled “Pre-dispatch: alltoall_pre_dispatch()”alltoall_pre_dispatch(hidden_states, routing_weights, selected_experts, num_experts, ep_group) -> (permuted_tokens, cumsum, AllToAllDispatchContext)Steps:
-
Build expert mask —
one_hot(selected_experts, num_experts)gives a[num_experts, top_k, num_tokens]indicator tensor. -
preprocess()— computes the variable send/receive split sizes:input_splits[r]= number of tokens this rank must send to EP rankr.- Gathers
num_global_tokens_per_expertacross all EP ranks viadist.all_gather_into_tensorso every rank knows exactly how many tokens each other rank will send to each expert. output_splits[r]= number of tokens this rank will receive from EP rankr.
-
token_pre_all2all()— executes the actual communication:- Permutes local tokens into grouped-by-destination order (
permute()). - Calls
all_to_all(ep_group, ...)using the computed split sizes. - Re-groups received tokens by local expert using
sort_chunks_by_idxs()so each expert’s input is contiguous.
- Permutes local tokens into grouped-by-destination order (
The returned AllToAllDispatchContext carries input_splits, output_splits,
num_tokens_per_expert, the permutation mapping, original shape, routing weights, and
selected experts — everything needed to reverse the operation.
Post-combine: alltoall_post_combine()
Section titled “Post-combine: alltoall_post_combine()”After expert computation, alltoall_post_combine() reverses the dispatch:
- Re-groups expert outputs by source rank using
sort_chunks_by_idxs()with the transposed token count matrix. - Calls
all_to_all()again (withinput_splitsandoutput_splitsswapped) to return outputs to their originating ranks. unpermute()restores the original token order and accumulates weighted expert contributions into the final hidden state.
Communication cost
Section titled “Communication cost”Each AllToAll transfers num_tokens * hidden_dim * sizeof(bfloat16) bytes per direction.
For ep_size=8, seq_len=32768, hidden_dim=2048, top_k=8:
tokens = 32768 * 8 = 262144 (token-expert pairs before deduplication per destination)bytes per direction ≈ 262144 * 2048 * 2 / 8 ≈ 134 MB per EP rank per directionThe preprocess() step requires an additional lightweight all_gather_into_tensor to
exchange token count metadata (negligible relative to the main transfer).
4. DeepEP dispatch backend
Section titled “4. DeepEP dispatch backend”When ep_dispatch: deepep is set, xorl uses
DeepEP — a NVLink-optimized communication
library that runs dispatch/combine on a dedicated communication stream, freeing compute SMs
for overlapped expert computation.
Source files
Section titled “Source files”src/xorl/distributed/moe/deepep.py
Buffer lifecycle: DeepEPBuffer
Section titled “Buffer lifecycle: DeepEPBuffer”DeepEPBuffer wraps a deep_ep.Buffer object. The buffer is lazily initialized on the
first dispatch call. Key initialization parameters:
buffer_size_gb(default 2.0 GB) — NVLink staging buffer. If DeepEP provides config hints for the givenep_size, the buffer is sized automatically viaconfig.get_nvl_buffer_size_hint(hidden_bytes, num_ranks).num_sms(default 20) — number of SMs allocated to the communication kernels viaBuffer.set_num_sms(). Fewer SMs leave more for overlapped compute.low_latency_mode— reserved for inference use cases.
A global singleton _default_buffer is shared across all MoE layers in a forward pass.
Calling destroy_default_buffer() at the end of training frees the NVLink staging memory.
Fused autograd: _FusedDispatchAndPermute
Section titled “Fused autograd: _FusedDispatchAndPermute”The forward dispatch is a single torch.autograd.Function boundary:
Forward:
buffer.get_dispatch_layout(topk_idx, num_experts)— computesnum_tokens_per_rank,num_tokens_per_rdma_rank,num_tokens_per_expert,is_token_in_rankwithout any communication.buffer.dispatch(async_finish=True, allocate_on_comm_stream=True)— sends tokens to expert-owning ranks on the NVLink comm stream. Returns ahandlefor the reverse combine and anEventOverlapto sync streams.event.current_stream_wait()— blocks the compute stream until dispatch is done.permute_for_experts()— argsort + index_select to arrange received tokens in expert-sorted order, slicing invalid entries (expert id == -1) to the end.
Backward:
index_add_— scatter gradient from expert order back to received order.buffer.combine(async_finish=True)— reverse-dispatch gradients to originating ranks.
Fused autograd: _FusedUnpermuteAndCombine
Section titled “Fused autograd: _FusedUnpermuteAndCombine”Forward:
- Weighted scatter-add:
output[perm_idx] += score * expert_output(unpermute). buffer.combine(async_finish=True, allocate_on_comm_stream=True)— sends accumulated outputs back to originating ranks.- If
async_combine=True, stores the pendingEventOverlapin a global variable_pending_combine_eventso the next layer’s computation can overlap with the transfer. The next call totoken_pre_dispatch()automatically callssync_pending_combine().
Backward:
buffer.dispatch()— reverse combine (sends grads to expert-owning ranks).- Reverse weighted scatter-add:
grad_expert[i] = score[i] * grad[perm_idx[i]].
Async combine overlap
Section titled “Async combine overlap”Setting deepep_async_combine: true enables the async combine path. The combine
communication for layer L overlaps with the attention/dense-FFN computation of layer
L+1. The pending event is synchronized at the start of the next layer’s dispatch.
Supported EP sizes
Section titled “Supported EP sizes”Buffer config hints from DeepEP are available for the following EP sizes:
2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 144, 160. Other sizes fall back to the
fixed buffer_size_gb allocation.
Profiling
Section titled “Profiling”Set the environment variable XORL_DEEPEP_PROFILE=1 to enable timing printouts for
dispatch and combine operations on rank 0.
5. EP device mesh — 2D [EP, FSDP] for expert params
Section titled “5. EP device mesh — 2D [EP, FSDP] for expert params”Expert parameters require a different device mesh from non-expert parameters because they are already sharded along the expert dimension (dim 0) and must not participate in the global FSDP all-gather along that dimension.
Mesh construction
Section titled “Mesh construction”init_parallel_state() (src/xorl/distributed/parallel_state.py) builds a separate
ep_fsdp_device_mesh whenever ep_size > 1:
ranks_per_stage = world_size // pp_sizeep_fsdp_size = ranks_per_stage // ep_sizeWithout PP, the mesh is 2-D with shape [ep_size, ep_fsdp_size]:
ep_fsdp_device_mesh = DeviceMesh( device_type=device_type, mesh=mesh, # [ep_size, ep_fsdp_size] mesh_dim_names=("ep", "ep_fsdp"),)With PP, a 3-D mesh (_pp_ep, ep, ep_fsdp) is created so that each PP stage gets its
own EP process groups. Slicing by ["ep"] automatically returns the correct per-stage
submesh for the calling rank.
ep_outside flag
Section titled “ep_outside flag”By default (ep_outside=False), the rank assignment within each [ep_size, ep_fsdp_size]
block interleaves EP and FSDP ranks so that consecutive ranks belong to the same EP group.
Setting ep_outside=True transposes the layout so consecutive ranks belong to the same
FSDP group instead.
init_ep_mesh_matrix() implements both layouts:
# ep_outside=False (default): FSDP-major, EP-minormesh = arange(ep_size * ep_fsdp_size).view(ep_fsdp_size, ep_size).T# ep_outside=True: EP-major, FSDP-minormesh = arange(ep_size * ep_fsdp_size).view(ep_size, ep_fsdp_size)Expert FSDP uses ep_fsdp dimension only
Section titled “Expert FSDP uses ep_fsdp dimension only”When applying fully_shard to the experts module, the mesh is restricted to the
ep_fsdp sub-dimension:
ep_fsdp_mesh = parallel_state.ep_fsdp_device_mesh["ep_fsdp"]expert_fsdp_kwargs["mesh"] = ep_fsdp_meshThis means expert parameters are further sharded along hidden_dim (dim 1, not dim 0)
across ep_fsdp_size ranks within the same EP group, while non-expert parameters use
the full global FSDP mesh. When ep_fsdp_size == 1, there is no hidden-dim sharding of
experts; FSDP still wraps the module but performs no all-gather.
Non-expert FSDP uses the global mesh
Section titled “Non-expert FSDP uses the global mesh”Attention and dense-FFN parameters continue to use parallel_state.fsdp_mesh (which
spans dp_shard * cp_size ranks), so their all-gather communicates globally across all
data and context parallel ranks as usual.
6. Gradient handling
Section titled “6. Gradient handling”set_gradient_divide_factor(ep_size) for expert params
Section titled “set_gradient_divide_factor(ep_size) for expert params”During the backward pass, FSDP2 performs a reduce-scatter of gradients within the
all-gather mesh. For non-expert parameters this mesh is the full FSDP mesh, so gradients
are already averaged across fsdp_size ranks.
For expert parameters, the reduce-scatter operates only within the ep_fsdp sub-mesh
(size ep_fsdp_size). This means gradients are summed across the ep_fsdp_size ranks
that hold the same expert shard, but not averaged across the ep_size EP ranks that
hold different experts.
To correct this, xorl calls:
experts_mod.set_gradient_divide_factor(parallel_state.ep_size)This instructs FSDP2 to divide the reduce-scattered gradient by ep_size in addition to
the normal averaging, so the effective gradient across all EP ranks is properly averaged
as if the experts were not partitioned.
_is_ep_fsdp flag
Section titled “_is_ep_fsdp flag”After calling set_gradient_divide_factor, the module is marked with _is_ep_fsdp = True.
This flag prevents the global gradient divide-factor reset loop from overriding the
expert-specific factor:
for module in model.modules(): if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False): module.set_gradient_divide_factor(1.0) # disable auto-averaging for non-expertNon-expert modules have their divide factor set to 1.0 because xorl manages gradient
normalization manually (via gradient_accumulate_loss or explicit grad.mul_(1/gvt) in
the PP path).
Dim-1 sharding for expert FSDP
Section titled “Dim-1 sharding for expert FSDP”Expert FSDP shards along dim 1 (hidden dimension) rather than dim 0:
def _experts_shard_placement_fn(param): return Shard(1)expert_fsdp_kwargs["shard_placement_fn"] = _experts_shard_placement_fnDim 0 is already used by EP to shard the expert axis. FSDP2 must shard orthogonally along dim 1 to avoid double-sharding the same dimension.
7. Routing replay (R3)
Section titled “7. Routing replay (R3)”Problem: non-deterministic routing breaks gradient checkpointing
Section titled “Problem: non-deterministic routing breaks gradient checkpointing”Gradient checkpointing reruns the forward pass during backward to reconstruct activations.
In MoE+EP training, hidden_states depend on flash attention which is
non-deterministic (random tiling order). If the recomputed hidden_states differ
slightly from the original forward’s values, the gate’s softmax output changes, top-k
expert selection may change, and the AllToAll split sizes become inconsistent — causing
a runtime error.
Solution: record expert indices; replay on recompute
Section titled “Solution: record expert indices; replay on recompute”RoutingReplay (src/xorl/models/layers/moe/routing_replay.py) is a per-layer buffer
that records expert routing decisions and replays them during checkpoint recomputation.
Each MoEBlock that has routing replay enabled (_routing_replay is set) operates in
one of four stages controlled by a global _replay_stage flag:
| Stage | Behavior |
|---|---|
None | Normal routing: gate → softmax → top-k → experts |
"record" | top-k runs under torch.no_grad(); result saved to CPU pinned memory |
"replay_forward" | Pops pre-recorded indices; recomputes weights via _regather_routing() |
"replay_backward" | Pops pre-recorded indices for checkpoint recompute |
_regather_routing() always recomputes routing weights from a fresh softmax over the
gate output and gathers with cached expert indices. This ensures the autograd graph
structure is identical between the original forward and the recomputed forward
(softmax → gather → normalize), which is required by non-reentrant checkpointing.
Routing replay is only activated when both conditions are true:
ep_size > 1(EP is enabled)moe_recomputedis True (MoE layers are inside gradient checkpointed regions)
Dual read pointer for PP
Section titled “Dual read pointer for PP”RoutingReplay maintains two independent read pointers: forward_index and
backward_index. In PP mode, each microbatch calls forward multiple times (for
different stages), requiring separate advancement of the forward replay pointer and the
backward (recompute) replay pointer.
R3: Routing Replay from Rollout (inference-to-training)
Section titled “R3: Routing Replay from Rollout (inference-to-training)”When the xorl server receives training batches from an inference rollout, it can also
supply pre-computed routing decisions (routed_experts) and optionally pre-computed
routing weights (routed_expert_logits).
RoutingReplayHandler.fill_routing_replay() (src/xorl/server/runner/utils/routing_replay_handler.py)
pre-populates all RoutingReplay instances from this data before calling
set_r3_mode(True) and set_replay_stage("replay_forward"). MoE blocks then pop
pre-populated routing instead of computing it from the gate.
When routed_expert_logits (softmax weights from inference) is also provided, the
weights are replayed directly from the pre-populated buffer (cached_weights), bypassing
the _regather_routing() call entirely. This matches the exact numerical values produced
during inference for GRPO-style training.
SP alignment fix for R3
Section titled “SP alignment fix for R3”When context parallelism is active, sequences are padded before being sharded across CP
ranks. fill_routing_replay() must replicate this padding when slicing the routing
tensor to match the CP rank’s subsequence boundary. The effective padding multiple is:
pad_to_multiple_of = math.lcm(128, cp_size)The value 128 matches the hardware-level alignment used by SequentialPacker; cp_size
ensures the padded length is divisible by the number of SP ranks.
8. EP + FSDP2 composition
Section titled “8. EP + FSDP2 composition”The two strategies compose cleanly because they operate on different parameter subsets and different mesh dimensions.
Initialization order
Section titled “Initialization order”parallelize_model_fsdp2() (src/xorl/distributed/torch_parallelize.py) applies EP
then FSDP2 in strict order:
- EP sharding (
parallel_plan.apply()) — slice expert tensors from[num_experts, H, I]to[num_local_experts, H, I]on each rank. - Expert FSDP (
fully_shard(experts_mod, mesh=ep_fsdp_mesh, shard_placement_fn=Shard(1))) — further shard local expert tensors along hidden dim acrossep_fsdp_sizeranks. - Non-expert FSDP (
fully_shard(layer_mod, mesh=fsdp_mesh)) — standard dim-0 sharding over the global FSDP mesh for attention and dense-FFN layers. - Root FSDP (
fully_shard(model, mesh=fsdp_mesh)) — covers embeddings and the LM head.
Manual prefetching
Section titled “Manual prefetching”Because the experts module is wrapped by a separate fully_shard call inside each decoder
layer, FSDP2’s automatic prefetching does not know the correct prefetch order. xorl
configures manual prefetching:
# Forward: prefetch next block's attention first, then gate, then expertscurrent_block.set_modules_to_forward_prefetch(list(reversed(next_block._fsdp_modules)))# Backward: prefetch previous block in reversecurrent_block.set_modules_to_backward_prefetch(...)Separate optimizers/schedulers
Section titled “Separate optimizers/schedulers”Expert and non-expert parameters have different FSDP meshes and therefore different
all-gather/reduce-scatter collectives. Some optimizers (e.g. Muon) maintain separate
parameter groups for expert and non-expert weights, each with its own learning rate
schedule. The optimizer state is sharded consistently with the FSDP2 shard placement,
meaning optimizer states for expert parameters are only synchronized within ep_fsdp
groups, not across EP ranks.
9. EP + context parallelism (CP) — folded axis
Section titled “9. EP + context parallelism (CP) — folded axis”Context parallelism (ring attention or Ulysses) splits sequences across multiple GPUs.
When EP and CP are combined on the same node, xorl can fold the CP axis into the FSDP
shard dimension (cp_fsdp_mode), which reduces the number of distinct process groups
needed.
How the axes interact
Section titled “How the axes interact”In a typical EP+CP setup (e.g. ep_size=4, ringattn_size=4, dp_shard_size=1):
- Each EP rank communicates expert tokens with its
ep_sizepeers via AllToAll. - Ring attention communicates hidden states with its
ringattn_sizepeers via P2P. - The two collectives are independent and can run on different GPUs without conflict.
The ep_fsdp_device_mesh is constructed to span only the ranks within a single PP stage,
so EP groups never cross PP boundaries.
cp_fsdp_mode effect on expert params
Section titled “cp_fsdp_mode effect on expert params”Expert FSDP sharding uses the ep_fsdp mesh dimension. The ep_fsdp_size is:
ep_fsdp_size = ranks_per_stage / ep_sizeThis value absorbs whatever CP+DP ranks are co-located with each EP rank. With
ep_size=4, ringattn_size=4, dp_shard_size=1 on 16 GPUs per PP stage:
ep_fsdp_size = 16 / 4 = 4 (the 4 ring-attention peers within one EP group)Expert FSDP all-gathers hidden-dim shards across the 4 ring-attention peers only. Ring attention communication and expert FSDP all-gather thus operate on the same set of physical ranks, maximizing NVLink locality.
Example config (EP=4, ring-attn=4, PP=2)
Section titled “Example config (EP=4, ring-attn=4, PP=2)”See examples/local/dummy/configs/full/qwen3_30b_a3b_pp2_ep4_cp4_muon.yaml:
pipeline_parallel_size: 2ringattn_parallel_size: 4data_parallel_shard_size: 1expert_parallel_size: 4With 8 GPUs: world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8. EP=4 is folded onto each PP stage’s 4 ranks (ep_fsdp_size = 4 / 4 = 1).
10. EP + pipeline parallelism (PP)
Section titled “10. EP + pipeline parallelism (PP)”EP groups are per-PP-stage
Section titled “EP groups are per-PP-stage”An all-to-all across stages would deadlock during asynchronous pipeline execution because ranks in different stages are at different points in the forward/backward schedule. xorl enforces that EP groups are strictly within a single PP stage:
assert ranks_per_stage % ep_size == 0The ep_fsdp_device_mesh is created with a leading _pp_ep dimension so PyTorch’s
DeviceMesh can automatically select the correct per-stage submesh for each rank when
sliced by ["ep"] or ["ep", "ep_fsdp"].
Same microbatch across stages
Section titled “Same microbatch across stages”PP schedules (e.g. 1F1B) pipeline microbatches through stages. Each stage sees the same sequence of microbatch token tensors — EP dispatch is called independently within each stage on its local tokens. Routing decisions are consistent across stages because each stage runs its own gate and router on its local hidden states.
Routing replay with PP
Section titled “Routing replay with PP”In the PP path, routing replay requires more care. The PP scheduler calls forward()
multiple times before calling backward(). The global replay stage is set to
"replay_backward" at the start, and _pp_forward temporarily switches to "record"
during each forward call then restores "replay_backward". Checkpoint recomputation
reads from the backward_index pointer, while repeated PP forwards use forward_index.
11. MoE kernel selection
Section titled “11. MoE kernel selection”MoEBlock dispatches expert compute to one of four backends controlled by the
moe_implementation argument:
| Value | Description | Best for |
|---|---|---|
"eager" | Per-expert Python loop with index_add_ | Debugging, small models |
"triton" | Triton group GEMM kernels | General training (default) |
"native" | torch._grouped_mm | PyTorch-native, no Triton dependency |
"quack" | Quack kernels | High-throughput, works well with DeepEP |
The "eager" mode does not use EP dispatch at all — it iterates over all local experts
sequentially, which is only correct when ep_size == 1.
For the "triton", "native", and "quack" backends, MoEExperts.forward() receives
the permuted_tokens tensor (already in expert-sorted order after dispatch) and the
cumsum of tokens per expert. This interface is identical for both the AllToAll and
DeepEP dispatch backends.
The moe_implementation field in the model config block controls the kernel selection.
For Qwen3 MoE models, "triton" and "quack" are the recommended production choices;
DeepEP dispatch pairs well with "quack".
12. Hard constraints
Section titled “12. Hard constraints”| Constraint | Enforcement |
|---|---|
num_experts % ep_size == 0 | Asserted in ParallelPlan.apply() for each expert parameter |
ranks_per_stage % ep_size == 0 | Asserted in init_parallel_state() |
EP with init_device: cpu is not supported | Checked in TrainingArgs.__post_init__() |
EP + eager moe_implementation is undefined | No error raised, but produces wrong results with ep_size > 1; do not combine |
DeepEP requires deep_ep package installed | check_deepep_available() raises ImportError with install instructions |
Routing replay requires both ep_size > 1 AND moe_recomputed=True | Checked in trainer and direct train CLI |
| EP groups must not span PP stage boundaries | Asserted via ranks_per_stage % ep_size == 0 and per-stage mesh slicing |
13. Configuration examples
Section titled “13. Configuration examples”EP only — Qwen3-30B-A3B on 8 GPUs
Section titled “EP only — Qwen3-30B-A3B on 8 GPUs”model: model_path: Qwen/Qwen3-30B-A3B moe_implementation: triton
train: data_parallel_mode: fsdp2 data_parallel_shard_size: 8 expert_parallel_size: 8 load_weights_mode: all_ranksResult: 8 EP ranks, each holding 16 of 128 experts. No separate FSDP within EP groups
(ep_fsdp_size = 8 / 8 = 1).
EP + DeepEP + Ulysses + Quack — Qwen3-30B-A3B on 8 GPUs
Section titled “EP + DeepEP + Ulysses + Quack — Qwen3-30B-A3B on 8 GPUs”model: moe_implementation: quack ep_dispatch: deepep deepep_buffer_size_gb: 2.0 deepep_num_sms: 48
train: ulysses_parallel_size: 8 expert_parallel_size: 8 data_parallel_shard_size: 1Result: EP=8 and Ulysses-SP=8 on the same 8 GPUs. DeepEP handles dispatch with 48 SMs allocated to communication, leaving the remaining SMs for Quack expert kernels.
EP + PP — Qwen3-30B-A3B on 16 GPUs
Section titled “EP + PP — Qwen3-30B-A3B on 16 GPUs”model: moe_implementation: triton
train: pipeline_parallel_size: 2 data_parallel_shard_size: 4 expert_parallel_size: 4world_size = PP × DP_shard = 2 × 4 = 8 per stage (16 GPUs total, 8 ranks per stage).
EP=4 is folded onto each stage’s 8 ranks: ep_fsdp_size = 8 / 4 = 2. Each EP group’s
experts are further FSDP-sharded across 2 ranks within the same stage.
EP + PP + ring-attn — Qwen3-30B-A3B on 8 GPUs
Section titled “EP + PP + ring-attn — Qwen3-30B-A3B on 8 GPUs”model: moe_implementation: triton
train: pipeline_parallel_size: 2 ringattn_parallel_size: 4 data_parallel_shard_size: 1 expert_parallel_size: 4Result: world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8 GPUs.
Each PP stage has 4 ranks. EP=4 is folded onto those 4 ranks (ep_fsdp_size = 4 / 4 = 1).
sample_packing_sequence_len: 96000 allows long sequences distributed over 4 ring-attn
ranks.
14. Parameter reference
Section titled “14. Parameter reference”Model config block
Section titled “Model config block”| Parameter | Type | Default | Description |
|---|---|---|---|
moe_implementation | str | None | None | MoE compute kernel: "eager", "triton", "native", "quack". None inherits from model class default. |
ep_dispatch | str | "alltoall" | EP token dispatch strategy: "alltoall" or "deepep". |
deepep_buffer_size_gb | float | 2.0 | NVLink staging buffer size in GB for DeepEP. |
deepep_num_sms | int | 20 | SMs allocated to DeepEP communication kernels (must be even). Lower values leave more for expert compute. |
deepep_async_combine | bool | False | Overlap DeepEP combine with next layer’s compute. |
Train config block
Section titled “Train config block”| Parameter | Type | Default | Description |
|---|---|---|---|
expert_parallel_size | int | 1 | Number of EP ranks. Must divide num_experts and ranks_per_pp_stage. |
ep_outside | bool | False | When True, consecutive ranks are in the same EP-FSDP group rather than the same EP group. |
data_parallel_shard_size | int | -1 | FSDP shard size. Together with expert_parallel_size, determines ep_fsdp_size = ranks_per_stage / ep_size. |
pipeline_parallel_size | int | 1 | PP degree. EP groups are confined within each PP stage. |
ringattn_parallel_size | int | 1 | Ring attention CP size. Can be folded into EP-FSDP axis via ep_fsdp_size. |
load_weights_mode | str | "broadcast" | "all_ranks" lets every rank read its local expert shard directly without a broadcast. Preferred for EP. |
Environment variables
Section titled “Environment variables”| Variable | Default | Description |
|---|---|---|
XORL_DEEPEP_PROFILE | 0 | Set to 1 to print dispatch/combine timing on rank 0. |
Key internal identifiers
Section titled “Key internal identifiers”| Symbol | Location | Description |
|---|---|---|
ParallelState.ep_mesh | parallel_state.py | Sub-mesh along the EP dimension (for EP collectives). |
ParallelState.ep_fsdp_mesh | parallel_state.py | 2-D sub-mesh [ep, ep_fsdp] for expert module FSDP. |
ParallelState.ep_group | parallel_state.py | Process group for AllToAll dispatch. |
ParallelState.ep_rank | parallel_state.py | This rank’s position within the EP group (0-based). |
experts_mod._is_ep_fsdp | torch_parallelize.py | Marks expert FSDP units so the global divide-factor reset skips them. |
experts_mod.set_gradient_divide_factor(ep_size) | torch_parallelize.py | Averages expert gradients across EP ranks during backward. |
RoutingReplay | moe/routing_replay.py | Per-layer buffer for deterministic routing replay under gradient checkpointing. |
AllToAllDispatchContext | moe/alltoall.py | State carrier between alltoall_pre_dispatch() and alltoall_post_combine(). |
DispatchContext | moe/deepep.py | State carrier between token_pre_dispatch() and tokens_post_combine() for DeepEP. |
DeepEPBuffer | moe/deepep.py | Wrapper around deep_ep.Buffer; manages lifecycle and SM count. |
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/parallel_state.py | EP mesh construction (ep_mesh, ep_fsdp_mesh, ep_group) |
src/xorl/models/layers/moe/ | MoEBlock, TopKRouter, expert weight tensors, AllToAll and DeepEP dispatch |