Skip to content

Expert Parallelism

Expert Parallelism (EP) is xorl’s strategy for distributing the expert weights of Mixture-of-Experts (MoE) layers across multiple GPUs. Each GPU owns a contiguous shard of the full expert stack and only processes the tokens routed to its local experts. Between the routing decision and the expert compute, tokens are redistributed across GPUs using either an AllToAll collective (the default alltoall backend) or DeepEP’s NVLink-optimized dispatch/combine kernels.

Tokens[T, H]Routertop-K gateper-tokenroutingAllToAlldispatch tokensto expert GPUsGPU 0experts 0-15GPU 1experts 16-31GPU Nexperts N*k..AllToAllcombine &weight sumEach GPU processes only its local expert shard; tokens cross GPU boundaries twice (dispatch + combine) Expert Parallelism: Expert Weight DistributionAll E experts (fused [E, K, N]) — router AllToAll → dispatch to owner rankRank 0experts 0..E/4Rank 1experts E/4..E/2Rank 2experts E/2..3E/4Rank 3experts 3E/4..EEP shards the expert dimension. Each rank stores and computes E/ep_size experts.
  1. What EP does — expert weight sharding
  2. Token dispatch — routing and redistribution
  3. AllToAll dispatch backend
  4. DeepEP dispatch backend
  5. EP device mesh — 2D [EP, FSDP] for expert params
  6. Gradient handling
  7. Routing replay (R3)
  8. EP + FSDP2 composition
  9. EP + context parallelism (CP) — folded axis
  10. EP + pipeline parallelism (PP)
  11. MoE kernel selection
  12. Hard constraints
  13. Configuration examples
  14. Parameter reference

1. What EP does — expert weight sharding

Section titled “1. What EP does — expert weight sharding”

A vanilla MoE layer stores all expert weights on every GPU, which quickly exhausts HBM for large expert counts. EP distributes the expert stack so that each EP rank holds only num_experts / ep_size experts.

Expert weights for Qwen3-30B-A3B (128 experts, hidden 2048, intermediate 768) are stored as fused 3-D tensors:

gate_proj [num_experts, hidden_dim, intermediate_size] e.g. [128, 2048, 768]
up_proj [num_experts, hidden_dim, intermediate_size]
down_proj [num_experts, intermediate_size, hidden_dim]

With ep_size=8, each EP rank holds 128 / 8 = 16 experts:

EP rank 0: gate_proj[ 0:16, :, :] experts 0-15
EP rank 1: gate_proj[ 16:32, :, :] experts 16-31
...
EP rank 7: gate_proj[112:128, :, :] experts 112-127

The sharding is applied along dimension 0 (the expert axis):

src/xorl/models/transformers/qwen3_moe/parallelize.py
ep_plan = {
"model.layers.*.mlp.experts.gate_proj": Shard(0),
"model.layers.*.mlp.experts.up_proj": Shard(0),
"model.layers.*.mlp.experts.down_proj": Shard(0),
}

ParallelPlan.apply() (src/xorl/distributed/parallel_plan.py) iterates all model parameters, matches each expert parameter against the EP plan, then either:

  • Redistributes the full tensor into a local shard using DTensor.redistribute() with a Shard(0) placement on the EP mesh (normal path when loading from a checkpoint).
  • Annotates already-local tensors with a spec_info attribute (fast path when weights were loaded EP-aware, e.g. with load_weights_mode: all_ranks).

After sharding, each rank’s expert parameter has shape [num_local_experts, ...]. The spec_info attribute records the ep_fsdp_mesh, the shard placement, and the FQN for checkpoint reconstruction.

LoRA adapters injected into expert layers are sharded the same way (gate_proj_lora_A/B, up_proj_lora_A/B, down_proj_lora_A/B all follow Shard(0)). Shared LoRA weights with a size-1 expert dimension are replicated instead of sharded.


2. Token dispatch — routing and redistribution

Section titled “2. Token dispatch — routing and redistribution”

MoEBlock.forward() (src/xorl/models/layers/moe/moe_block.py) first computes routing logits through the gate linear:

router_logits = gate(hidden_states) # [num_tokens, num_experts]

TopKRouter applies softmax, selects the top-k experts, and optionally renormalizes the weights so they sum to 1. The result is:

routing_weights [num_tokens, top_k]
selected_experts [num_tokens, top_k]

Each token therefore has top_k expert assignments, each pointing at one of the num_experts experts on one of the EP ranks.

Because expert weights are split across ranks, tokens must be sent to the GPU that owns the target expert. The dispatch step collects all tokens that belong to a given rank’s experts and forwards them; the combine step returns computed outputs back to their originating ranks.

Before dispatch:
GPU 0: tokens [t0, t1, t2, t3] (all local tokens)
After dispatch (ep_size=4, 128 experts):
GPU 0 receives: tokens routed to experts 0-31
GPU 1 receives: tokens routed to experts 32-63
GPU 2 receives: tokens routed to experts 64-95
GPU 3 receives: tokens routed to experts 96-127

Token counts per rank are highly variable (load-imbalanced routing is common), so the dispatch collective must handle variable-length sends and receives.


The default backend (ep_dispatch: alltoall) implements dispatch/combine using torch.distributed.all_to_all_single with variable split sizes.

  • src/xorl/distributed/moe/alltoall.py
alltoall_pre_dispatch(hidden_states, routing_weights, selected_experts, num_experts, ep_group)
-> (permuted_tokens, cumsum, AllToAllDispatchContext)

Steps:

  1. Build expert maskone_hot(selected_experts, num_experts) gives a [num_experts, top_k, num_tokens] indicator tensor.

  2. preprocess() — computes the variable send/receive split sizes:

    • input_splits[r] = number of tokens this rank must send to EP rank r.
    • Gathers num_global_tokens_per_expert across all EP ranks via dist.all_gather_into_tensor so every rank knows exactly how many tokens each other rank will send to each expert.
    • output_splits[r] = number of tokens this rank will receive from EP rank r.
  3. token_pre_all2all() — executes the actual communication:

    • Permutes local tokens into grouped-by-destination order (permute()).
    • Calls all_to_all(ep_group, ...) using the computed split sizes.
    • Re-groups received tokens by local expert using sort_chunks_by_idxs() so each expert’s input is contiguous.

The returned AllToAllDispatchContext carries input_splits, output_splits, num_tokens_per_expert, the permutation mapping, original shape, routing weights, and selected experts — everything needed to reverse the operation.

After expert computation, alltoall_post_combine() reverses the dispatch:

  1. Re-groups expert outputs by source rank using sort_chunks_by_idxs() with the transposed token count matrix.
  2. Calls all_to_all() again (with input_splits and output_splits swapped) to return outputs to their originating ranks.
  3. unpermute() restores the original token order and accumulates weighted expert contributions into the final hidden state.

Each AllToAll transfers num_tokens * hidden_dim * sizeof(bfloat16) bytes per direction. For ep_size=8, seq_len=32768, hidden_dim=2048, top_k=8:

tokens = 32768 * 8 = 262144 (token-expert pairs before deduplication per destination)
bytes per direction ≈ 262144 * 2048 * 2 / 8 ≈ 134 MB per EP rank per direction

The preprocess() step requires an additional lightweight all_gather_into_tensor to exchange token count metadata (negligible relative to the main transfer).


When ep_dispatch: deepep is set, xorl uses DeepEP — a NVLink-optimized communication library that runs dispatch/combine on a dedicated communication stream, freeing compute SMs for overlapped expert computation.

  • src/xorl/distributed/moe/deepep.py

DeepEPBuffer wraps a deep_ep.Buffer object. The buffer is lazily initialized on the first dispatch call. Key initialization parameters:

  • buffer_size_gb (default 2.0 GB) — NVLink staging buffer. If DeepEP provides config hints for the given ep_size, the buffer is sized automatically via config.get_nvl_buffer_size_hint(hidden_bytes, num_ranks).
  • num_sms (default 20) — number of SMs allocated to the communication kernels via Buffer.set_num_sms(). Fewer SMs leave more for overlapped compute.
  • low_latency_mode — reserved for inference use cases.

A global singleton _default_buffer is shared across all MoE layers in a forward pass. Calling destroy_default_buffer() at the end of training frees the NVLink staging memory.

The forward dispatch is a single torch.autograd.Function boundary:

Forward:

  1. buffer.get_dispatch_layout(topk_idx, num_experts) — computes num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank without any communication.
  2. buffer.dispatch(async_finish=True, allocate_on_comm_stream=True) — sends tokens to expert-owning ranks on the NVLink comm stream. Returns a handle for the reverse combine and an EventOverlap to sync streams.
  3. event.current_stream_wait() — blocks the compute stream until dispatch is done.
  4. permute_for_experts() — argsort + index_select to arrange received tokens in expert-sorted order, slicing invalid entries (expert id == -1) to the end.

Backward:

  1. index_add_ — scatter gradient from expert order back to received order.
  2. buffer.combine(async_finish=True) — reverse-dispatch gradients to originating ranks.

Forward:

  1. Weighted scatter-add: output[perm_idx] += score * expert_output (unpermute).
  2. buffer.combine(async_finish=True, allocate_on_comm_stream=True) — sends accumulated outputs back to originating ranks.
  3. If async_combine=True, stores the pending EventOverlap in a global variable _pending_combine_event so the next layer’s computation can overlap with the transfer. The next call to token_pre_dispatch() automatically calls sync_pending_combine().

Backward:

  1. buffer.dispatch() — reverse combine (sends grads to expert-owning ranks).
  2. Reverse weighted scatter-add: grad_expert[i] = score[i] * grad[perm_idx[i]].

Setting deepep_async_combine: true enables the async combine path. The combine communication for layer L overlaps with the attention/dense-FFN computation of layer L+1. The pending event is synchronized at the start of the next layer’s dispatch.

Buffer config hints from DeepEP are available for the following EP sizes: 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 144, 160. Other sizes fall back to the fixed buffer_size_gb allocation.

Set the environment variable XORL_DEEPEP_PROFILE=1 to enable timing printouts for dispatch and combine operations on rank 0.


5. EP device mesh — 2D [EP, FSDP] for expert params

Section titled “5. EP device mesh — 2D [EP, FSDP] for expert params”

Expert parameters require a different device mesh from non-expert parameters because they are already sharded along the expert dimension (dim 0) and must not participate in the global FSDP all-gather along that dimension.

init_parallel_state() (src/xorl/distributed/parallel_state.py) builds a separate ep_fsdp_device_mesh whenever ep_size > 1:

ranks_per_stage = world_size // pp_size
ep_fsdp_size = ranks_per_stage // ep_size

Without PP, the mesh is 2-D with shape [ep_size, ep_fsdp_size]:

ep_fsdp_device_mesh = DeviceMesh(
device_type=device_type,
mesh=mesh, # [ep_size, ep_fsdp_size]
mesh_dim_names=("ep", "ep_fsdp"),
)

With PP, a 3-D mesh (_pp_ep, ep, ep_fsdp) is created so that each PP stage gets its own EP process groups. Slicing by ["ep"] automatically returns the correct per-stage submesh for the calling rank.

By default (ep_outside=False), the rank assignment within each [ep_size, ep_fsdp_size] block interleaves EP and FSDP ranks so that consecutive ranks belong to the same EP group. Setting ep_outside=True transposes the layout so consecutive ranks belong to the same FSDP group instead.

init_ep_mesh_matrix() implements both layouts:

# ep_outside=False (default): FSDP-major, EP-minor
mesh = arange(ep_size * ep_fsdp_size).view(ep_fsdp_size, ep_size).T
# ep_outside=True: EP-major, FSDP-minor
mesh = arange(ep_size * ep_fsdp_size).view(ep_size, ep_fsdp_size)

When applying fully_shard to the experts module, the mesh is restricted to the ep_fsdp sub-dimension:

ep_fsdp_mesh = parallel_state.ep_fsdp_device_mesh["ep_fsdp"]
expert_fsdp_kwargs["mesh"] = ep_fsdp_mesh

This means expert parameters are further sharded along hidden_dim (dim 1, not dim 0) across ep_fsdp_size ranks within the same EP group, while non-expert parameters use the full global FSDP mesh. When ep_fsdp_size == 1, there is no hidden-dim sharding of experts; FSDP still wraps the module but performs no all-gather.

Attention and dense-FFN parameters continue to use parallel_state.fsdp_mesh (which spans dp_shard * cp_size ranks), so their all-gather communicates globally across all data and context parallel ranks as usual.


set_gradient_divide_factor(ep_size) for expert params

Section titled “set_gradient_divide_factor(ep_size) for expert params”

During the backward pass, FSDP2 performs a reduce-scatter of gradients within the all-gather mesh. For non-expert parameters this mesh is the full FSDP mesh, so gradients are already averaged across fsdp_size ranks.

For expert parameters, the reduce-scatter operates only within the ep_fsdp sub-mesh (size ep_fsdp_size). This means gradients are summed across the ep_fsdp_size ranks that hold the same expert shard, but not averaged across the ep_size EP ranks that hold different experts.

To correct this, xorl calls:

experts_mod.set_gradient_divide_factor(parallel_state.ep_size)

This instructs FSDP2 to divide the reduce-scattered gradient by ep_size in addition to the normal averaging, so the effective gradient across all EP ranks is properly averaged as if the experts were not partitioned.

After calling set_gradient_divide_factor, the module is marked with _is_ep_fsdp = True. This flag prevents the global gradient divide-factor reset loop from overriding the expert-specific factor:

src/xorl/distributed/torch_parallelize.py
for module in model.modules():
if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False):
module.set_gradient_divide_factor(1.0) # disable auto-averaging for non-expert

Non-expert modules have their divide factor set to 1.0 because xorl manages gradient normalization manually (via gradient_accumulate_loss or explicit grad.mul_(1/gvt) in the PP path).

Expert FSDP shards along dim 1 (hidden dimension) rather than dim 0:

def _experts_shard_placement_fn(param):
return Shard(1)
expert_fsdp_kwargs["shard_placement_fn"] = _experts_shard_placement_fn

Dim 0 is already used by EP to shard the expert axis. FSDP2 must shard orthogonally along dim 1 to avoid double-sharding the same dimension.


Problem: non-deterministic routing breaks gradient checkpointing

Section titled “Problem: non-deterministic routing breaks gradient checkpointing”

Gradient checkpointing reruns the forward pass during backward to reconstruct activations. In MoE+EP training, hidden_states depend on flash attention which is non-deterministic (random tiling order). If the recomputed hidden_states differ slightly from the original forward’s values, the gate’s softmax output changes, top-k expert selection may change, and the AllToAll split sizes become inconsistent — causing a runtime error.

Solution: record expert indices; replay on recompute

Section titled “Solution: record expert indices; replay on recompute”

RoutingReplay (src/xorl/models/layers/moe/routing_replay.py) is a per-layer buffer that records expert routing decisions and replays them during checkpoint recomputation.

Each MoEBlock that has routing replay enabled (_routing_replay is set) operates in one of four stages controlled by a global _replay_stage flag:

StageBehavior
NoneNormal routing: gate → softmax → top-k → experts
"record"top-k runs under torch.no_grad(); result saved to CPU pinned memory
"replay_forward"Pops pre-recorded indices; recomputes weights via _regather_routing()
"replay_backward"Pops pre-recorded indices for checkpoint recompute

_regather_routing() always recomputes routing weights from a fresh softmax over the gate output and gathers with cached expert indices. This ensures the autograd graph structure is identical between the original forward and the recomputed forward (softmax → gather → normalize), which is required by non-reentrant checkpointing.

Routing replay is only activated when both conditions are true:

  • ep_size > 1 (EP is enabled)
  • moe_recomputed is True (MoE layers are inside gradient checkpointed regions)

RoutingReplay maintains two independent read pointers: forward_index and backward_index. In PP mode, each microbatch calls forward multiple times (for different stages), requiring separate advancement of the forward replay pointer and the backward (recompute) replay pointer.

R3: Routing Replay from Rollout (inference-to-training)

Section titled “R3: Routing Replay from Rollout (inference-to-training)”

When the xorl server receives training batches from an inference rollout, it can also supply pre-computed routing decisions (routed_experts) and optionally pre-computed routing weights (routed_expert_logits).

RoutingReplayHandler.fill_routing_replay() (src/xorl/server/runner/utils/routing_replay_handler.py) pre-populates all RoutingReplay instances from this data before calling set_r3_mode(True) and set_replay_stage("replay_forward"). MoE blocks then pop pre-populated routing instead of computing it from the gate.

When routed_expert_logits (softmax weights from inference) is also provided, the weights are replayed directly from the pre-populated buffer (cached_weights), bypassing the _regather_routing() call entirely. This matches the exact numerical values produced during inference for GRPO-style training.

When context parallelism is active, sequences are padded before being sharded across CP ranks. fill_routing_replay() must replicate this padding when slicing the routing tensor to match the CP rank’s subsequence boundary. The effective padding multiple is:

pad_to_multiple_of = math.lcm(128, cp_size)

The value 128 matches the hardware-level alignment used by SequentialPacker; cp_size ensures the padded length is divisible by the number of SP ranks.


The two strategies compose cleanly because they operate on different parameter subsets and different mesh dimensions.

parallelize_model_fsdp2() (src/xorl/distributed/torch_parallelize.py) applies EP then FSDP2 in strict order:

  1. EP sharding (parallel_plan.apply()) — slice expert tensors from [num_experts, H, I] to [num_local_experts, H, I] on each rank.
  2. Expert FSDP (fully_shard(experts_mod, mesh=ep_fsdp_mesh, shard_placement_fn=Shard(1))) — further shard local expert tensors along hidden dim across ep_fsdp_size ranks.
  3. Non-expert FSDP (fully_shard(layer_mod, mesh=fsdp_mesh)) — standard dim-0 sharding over the global FSDP mesh for attention and dense-FFN layers.
  4. Root FSDP (fully_shard(model, mesh=fsdp_mesh)) — covers embeddings and the LM head.

Because the experts module is wrapped by a separate fully_shard call inside each decoder layer, FSDP2’s automatic prefetching does not know the correct prefetch order. xorl configures manual prefetching:

# Forward: prefetch next block's attention first, then gate, then experts
current_block.set_modules_to_forward_prefetch(list(reversed(next_block._fsdp_modules)))
# Backward: prefetch previous block in reverse
current_block.set_modules_to_backward_prefetch(...)

Expert and non-expert parameters have different FSDP meshes and therefore different all-gather/reduce-scatter collectives. Some optimizers (e.g. Muon) maintain separate parameter groups for expert and non-expert weights, each with its own learning rate schedule. The optimizer state is sharded consistently with the FSDP2 shard placement, meaning optimizer states for expert parameters are only synchronized within ep_fsdp groups, not across EP ranks.


9. EP + context parallelism (CP) — folded axis

Section titled “9. EP + context parallelism (CP) — folded axis”

Context parallelism (ring attention or Ulysses) splits sequences across multiple GPUs. When EP and CP are combined on the same node, xorl can fold the CP axis into the FSDP shard dimension (cp_fsdp_mode), which reduces the number of distinct process groups needed.

In a typical EP+CP setup (e.g. ep_size=4, ringattn_size=4, dp_shard_size=1):

  • Each EP rank communicates expert tokens with its ep_size peers via AllToAll.
  • Ring attention communicates hidden states with its ringattn_size peers via P2P.
  • The two collectives are independent and can run on different GPUs without conflict.

The ep_fsdp_device_mesh is constructed to span only the ranks within a single PP stage, so EP groups never cross PP boundaries.

Expert FSDP sharding uses the ep_fsdp mesh dimension. The ep_fsdp_size is:

ep_fsdp_size = ranks_per_stage / ep_size

This value absorbs whatever CP+DP ranks are co-located with each EP rank. With ep_size=4, ringattn_size=4, dp_shard_size=1 on 16 GPUs per PP stage:

ep_fsdp_size = 16 / 4 = 4 (the 4 ring-attention peers within one EP group)

Expert FSDP all-gathers hidden-dim shards across the 4 ring-attention peers only. Ring attention communication and expert FSDP all-gather thus operate on the same set of physical ranks, maximizing NVLink locality.

See examples/local/dummy/configs/full/qwen3_30b_a3b_pp2_ep4_cp4_muon.yaml:

pipeline_parallel_size: 2
ringattn_parallel_size: 4
data_parallel_shard_size: 1
expert_parallel_size: 4

With 8 GPUs: world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8. EP=4 is folded onto each PP stage’s 4 ranks (ep_fsdp_size = 4 / 4 = 1).


An all-to-all across stages would deadlock during asynchronous pipeline execution because ranks in different stages are at different points in the forward/backward schedule. xorl enforces that EP groups are strictly within a single PP stage:

assert ranks_per_stage % ep_size == 0

The ep_fsdp_device_mesh is created with a leading _pp_ep dimension so PyTorch’s DeviceMesh can automatically select the correct per-stage submesh for each rank when sliced by ["ep"] or ["ep", "ep_fsdp"].

PP schedules (e.g. 1F1B) pipeline microbatches through stages. Each stage sees the same sequence of microbatch token tensors — EP dispatch is called independently within each stage on its local tokens. Routing decisions are consistent across stages because each stage runs its own gate and router on its local hidden states.

In the PP path, routing replay requires more care. The PP scheduler calls forward() multiple times before calling backward(). The global replay stage is set to "replay_backward" at the start, and _pp_forward temporarily switches to "record" during each forward call then restores "replay_backward". Checkpoint recomputation reads from the backward_index pointer, while repeated PP forwards use forward_index.


MoEBlock dispatches expert compute to one of four backends controlled by the moe_implementation argument:

ValueDescriptionBest for
"eager"Per-expert Python loop with index_add_Debugging, small models
"triton"Triton group GEMM kernelsGeneral training (default)
"native"torch._grouped_mmPyTorch-native, no Triton dependency
"quack"Quack kernelsHigh-throughput, works well with DeepEP

The "eager" mode does not use EP dispatch at all — it iterates over all local experts sequentially, which is only correct when ep_size == 1.

For the "triton", "native", and "quack" backends, MoEExperts.forward() receives the permuted_tokens tensor (already in expert-sorted order after dispatch) and the cumsum of tokens per expert. This interface is identical for both the AllToAll and DeepEP dispatch backends.

The moe_implementation field in the model config block controls the kernel selection. For Qwen3 MoE models, "triton" and "quack" are the recommended production choices; DeepEP dispatch pairs well with "quack".


ConstraintEnforcement
num_experts % ep_size == 0Asserted in ParallelPlan.apply() for each expert parameter
ranks_per_stage % ep_size == 0Asserted in init_parallel_state()
EP with init_device: cpu is not supportedChecked in TrainingArgs.__post_init__()
EP + eager moe_implementation is undefinedNo error raised, but produces wrong results with ep_size > 1; do not combine
DeepEP requires deep_ep package installedcheck_deepep_available() raises ImportError with install instructions
Routing replay requires both ep_size > 1 AND moe_recomputed=TrueChecked in trainer and direct train CLI
EP groups must not span PP stage boundariesAsserted via ranks_per_stage % ep_size == 0 and per-stage mesh slicing

examples/local/dummy/configs/full/qwen3_30b_a3b_ep8.yaml
model:
model_path: Qwen/Qwen3-30B-A3B
moe_implementation: triton
train:
data_parallel_mode: fsdp2
data_parallel_shard_size: 8
expert_parallel_size: 8
load_weights_mode: all_ranks

Result: 8 EP ranks, each holding 16 of 128 experts. No separate FSDP within EP groups (ep_fsdp_size = 8 / 8 = 1).

EP + DeepEP + Ulysses + Quack — Qwen3-30B-A3B on 8 GPUs

Section titled “EP + DeepEP + Ulysses + Quack — Qwen3-30B-A3B on 8 GPUs”
examples/local/dummy/configs/full/qwen3_30b_a3b_muon.yaml
model:
moe_implementation: quack
ep_dispatch: deepep
deepep_buffer_size_gb: 2.0
deepep_num_sms: 48
train:
ulysses_parallel_size: 8
expert_parallel_size: 8
data_parallel_shard_size: 1

Result: EP=8 and Ulysses-SP=8 on the same 8 GPUs. DeepEP handles dispatch with 48 SMs allocated to communication, leaving the remaining SMs for Quack expert kernels.

examples/local/dummy/configs/full/qwen3_30b_a3b_pp2_ep4_muon.yaml
model:
moe_implementation: triton
train:
pipeline_parallel_size: 2
data_parallel_shard_size: 4
expert_parallel_size: 4

world_size = PP × DP_shard = 2 × 4 = 8 per stage (16 GPUs total, 8 ranks per stage). EP=4 is folded onto each stage’s 8 ranks: ep_fsdp_size = 8 / 4 = 2. Each EP group’s experts are further FSDP-sharded across 2 ranks within the same stage.

EP + PP + ring-attn — Qwen3-30B-A3B on 8 GPUs

Section titled “EP + PP + ring-attn — Qwen3-30B-A3B on 8 GPUs”
examples/local/dummy/configs/full/qwen3_30b_a3b_pp2_ep4_cp4_muon.yaml
model:
moe_implementation: triton
train:
pipeline_parallel_size: 2
ringattn_parallel_size: 4
data_parallel_shard_size: 1
expert_parallel_size: 4

Result: world_size = PP × CP_ring × DP = 2 × 4 × 1 = 8 GPUs. Each PP stage has 4 ranks. EP=4 is folded onto those 4 ranks (ep_fsdp_size = 4 / 4 = 1). sample_packing_sequence_len: 96000 allows long sequences distributed over 4 ring-attn ranks.


ParameterTypeDefaultDescription
moe_implementationstr | NoneNoneMoE compute kernel: "eager", "triton", "native", "quack". None inherits from model class default.
ep_dispatchstr"alltoall"EP token dispatch strategy: "alltoall" or "deepep".
deepep_buffer_size_gbfloat2.0NVLink staging buffer size in GB for DeepEP.
deepep_num_smsint20SMs allocated to DeepEP communication kernels (must be even). Lower values leave more for expert compute.
deepep_async_combineboolFalseOverlap DeepEP combine with next layer’s compute.
ParameterTypeDefaultDescription
expert_parallel_sizeint1Number of EP ranks. Must divide num_experts and ranks_per_pp_stage.
ep_outsideboolFalseWhen True, consecutive ranks are in the same EP-FSDP group rather than the same EP group.
data_parallel_shard_sizeint-1FSDP shard size. Together with expert_parallel_size, determines ep_fsdp_size = ranks_per_stage / ep_size.
pipeline_parallel_sizeint1PP degree. EP groups are confined within each PP stage.
ringattn_parallel_sizeint1Ring attention CP size. Can be folded into EP-FSDP axis via ep_fsdp_size.
load_weights_modestr"broadcast""all_ranks" lets every rank read its local expert shard directly without a broadcast. Preferred for EP.
VariableDefaultDescription
XORL_DEEPEP_PROFILE0Set to 1 to print dispatch/combine timing on rank 0.
SymbolLocationDescription
ParallelState.ep_meshparallel_state.pySub-mesh along the EP dimension (for EP collectives).
ParallelState.ep_fsdp_meshparallel_state.py2-D sub-mesh [ep, ep_fsdp] for expert module FSDP.
ParallelState.ep_groupparallel_state.pyProcess group for AllToAll dispatch.
ParallelState.ep_rankparallel_state.pyThis rank’s position within the EP group (0-based).
experts_mod._is_ep_fsdptorch_parallelize.pyMarks expert FSDP units so the global divide-factor reset skips them.
experts_mod.set_gradient_divide_factor(ep_size)torch_parallelize.pyAverages expert gradients across EP ranks during backward.
RoutingReplaymoe/routing_replay.pyPer-layer buffer for deterministic routing replay under gradient checkpointing.
AllToAllDispatchContextmoe/alltoall.pyState carrier between alltoall_pre_dispatch() and alltoall_post_combine().
DispatchContextmoe/deepep.pyState carrier between token_pre_dispatch() and tokens_post_combine() for DeepEP.
DeepEPBuffermoe/deepep.pyWrapper around deep_ep.Buffer; manages lifecycle and SM count.
FileDescription
src/xorl/distributed/parallel_state.pyEP mesh construction (ep_mesh, ep_fsdp_mesh, ep_group)
src/xorl/models/layers/moe/MoEBlock, TopKRouter, expert weight tensors, AllToAll and DeepEP dispatch