Expert Kernels
Select the expert compute backend via moe_implementation. This controls which kernel handles the batch grouped GEMM over all active experts.
Backend Comparison
Section titled “Backend Comparison”| Backend | Kernel | EP compatible | torch.compile | LoRA | Notes |
|---|---|---|---|---|---|
triton | Triton group_gemm_same_nk | Yes | Partial† | Yes | Recommended for Hopper. Autotuned; works with moe_act |
quack | Quack custom CUDA extension | Yes | No | Limited | Pairs well with DeepEP; custom CUDA kernels not compilable |
native | torch._grouped_mm | Yes | Yes (full) | Yes | No custom deps; slightly slower than triton on Hopper |
eager | Python loop + index_add_ | No (ep_size=1 only) | Yes | Yes | Debugging only; slowest but simplest |
model: moe_implementation: triton # recommended for H100†Triton: The group_gemm_same_nk kernel is compatible with torch.compile. However, EP dispatch/combine (AllToAll, DeepEP) is excluded from compile graphs because it involves dynamic shapes (variable token counts per rank per step).
Quack Backend
Section titled “Quack Backend”Quack uses CuTe DSL-based GEMM kernels written by Tri Dao (the author of Flash Attention). The kernels are built on PyTorch’s CuTe DSL and compiled for SM90+ (H100) and SM100+ (B200). They are not available on older GPUs.
What makes Quack different
Section titled “What makes Quack different”Unlike the Triton backend which compiles kernels via Python-defined Triton JIT, Quack kernels are written directly in CuTe DSL — the same low-level CUDA tensor abstraction used for Flash Attention. This gives:
- Persistent kernels with pingpong double-buffering (H100) — the kernel occupies SMs continuously across multiple expert GEMMs, overlapping memory loads of the next tile with computation of the current tile
cu_seqlensformat — uses Flash Attention’s cumulative sequence length format to group tokens by expert, avoiding unnecessary tensor copies- SwiGLU fusion —
gate_projandup_projGEMMs plus the SiLU activation can be fused into a single kernel dispatch viagemm_gated/gemm_act - SM100 (B200) support — different tile config:
tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False
Default tile config per architecture:
| Architecture | tile_m | tile_n | cluster_m | cluster_n | pingpong |
|---|---|---|---|---|---|
| SM90 (H100) | 128 | 192 | 2 | 1 | True |
| SM100 (B200) | 256 | 256 | 2 | 1 | False |
Autotuning
Section titled “Autotuning”By default, autotuning is disabled (XORL_QUACK_TUNED=0). MoE routing produces variable token counts per expert per step — each unique total_M would trigger a full ~60-config benchmark, making training orders of magnitude slower. The fixed default config is chosen to be performant across typical token budgets.
Enable autotuning only when shapes are stable (e.g. evaluation with fixed batch size) or when the tuning cache is already warm:
XORL_QUACK_TUNED=1 torchrun ... -m xorl.cli.train config.yamlFull feature matrix
Section titled “Full feature matrix”Quack has dedicated kernel variants for every training mode:
| Variant | Class / Function | Use case |
|---|---|---|
| Standard | quack_expert_forward | Default forward, no EP |
| EP | QuackEPGroupGemm.apply | EP dispatch + compute |
| LoRA | quack_moe_lora_forward | LoRA adapters, no EP |
| EP + LoRA | QuackEPGroupGemmWithLoRA.apply | EP + LoRA |
| moe_act | quack_expert_forward_moe_act | Gradient checkpointing, no EP |
| EP + moe_act | QuackEPGroupGemmMoeAct.apply | EP + gradient checkpointing |
Pairing with DeepEP
Section titled “Pairing with DeepEP”Quack and DeepEP are both optimized for H100 NVLink clusters and complement each other well: DeepEP handles the token dispatch/combine with dedicated SM kernels, while Quack handles the expert FFN compute with persistent CuTe kernels. Combining them maximizes utilization of both the NVLink interconnect and the SM compute fabric.
model: moe_implementation: quack ep_dispatch: deepep deepep_num_sms: 48 # more SMs for DeepEP; fewer left for Quack — tune based on profilingLimitations
Section titled “Limitations”- Requires SM90+ (H100) or SM100+ (B200) — not available on A100 or older
- Not compatible with
torch.compile— custom CUDA/CuTe extensions cannot be traced - LoRA support is limited —
moe_hybrid_shared_lorais not supported; usetritonornativefor hybrid shared LoRA - Autotuning disabled by default — variable token counts make per-step autotuning prohibitively slow
Triton Grouped GEMM Details
Section titled “Triton Grouped GEMM Details”The core of the triton backend is src/xorl/ops/group_gemm/kernel/group_gemm.py. The primary kernel group_gemm_same_nk_kernel handles the case where all experts share the same K (input) and N (output) dimensions.
Inputs and outputs:
A: [total_tokens, K]— all tokens routed to this rank, concatenatedB: [G, K, N]— expert weight tensor in GKN layoutcumsum_M: [G]— cumulative expert token counts (prefix sum)C: [total_tokens, N]— output
Tile strategy: Each thread block processes a BLOCK_M × BLOCK_N output tile for one expert. An inner loop steps through BLOCK_K slices of A and B, accumulating the partial sum in registers.
Autotuned configs: BLOCK_M ∈ {64, 128}, BLOCK_N ∈ {128, 256}, BLOCK_K ∈ {32, 64} — autotuned per (N, K) shape at first use and cached.
Optional activation fusion: When moe_act gradient checkpointing is active, the kernel stores the post-gate×up intermediate activations to a separate buffer during the forward pass (controlled by the STORE_ACTIVATIONS flag). The custom backward function recomputes only the down projection from these stored intermediates.
torch.compile with MoE
Section titled “torch.compile with MoE”torch.compile is applied per decoder block when enable_compile: true.
What is compiled:
- Attention layers (Q/K/V projections, flash attention wrapper, output projection)
- Router gate linear and softmax
- Expert FFN compute (Triton
group_gemm_same_nkis torch.compile-friendly) - Dense shared-expert FFN layers (if present)
What is excluded from compile graphs:
- EP dispatch/combine collectives (
alltoall_pre_dispatch,alltoall_post_combine, DeepEPdispatch/combine) — dynamic shapes from variable per-rank token counts break static graph tracing - Routing replay record/replay logic — mutable external state, decorated with
@torch.compiler.disable
Recommended config:
model: moe_implementation: triton # or native; quack is incompatible with compiletrain: enable_compile: trueThe native backend (torch._grouped_mm) compiles cleanly end-to-end. The quack backend uses custom CUDA extensions and must not be used with enable_compile: true.
MoE Gradient Checkpointing (gradient_checkpointing_method)
Section titled “MoE Gradient Checkpointing (gradient_checkpointing_method)”gradient_checkpointing_method controls what is recomputed in the backward pass. For MoE models with EP, this has a large impact on throughput because recompute_full_layer recomputes AllToAll dispatch + combine on every backward layer.
Methods
Section titled “Methods”| Method | Backward recomputes | EP comm recomputed | Peak mem (32k) | Throughput |
|---|---|---|---|---|
recompute_full_layer (default) | Entire decoder layer | Yes — full AllToAll re-executed | 37.5 GB (1N) / 21.2 GB (2N) | baseline |
recompute_before_dispatch | Attn + router only; keeps dispatch + expert + combine | No | 54.8 GB (1N) / 46.6 GB (2N) | +24.5% (1N) / +33.4% (2N) |
no_recompute | Nothing | No | highest | max throughput |
Benchmarks: Qwen3-Coder-30B-A3B, quack + alltoall, 32k seq, H100 80GB. 1N = 1-node EP8, 2N = 2-node EP16.
# Maximum throughput (short seq, memory permits):train: enable_gradient_checkpointing: true gradient_checkpointing_method: no_recompute
# Balanced memory/speed — recompute attn+router, keep dispatch+expert+combine:train: enable_gradient_checkpointing: true gradient_checkpointing_method: recompute_before_dispatch
# Maximum memory savings (long seq, required for 128k):train: enable_gradient_checkpointing: true gradient_checkpointing_method: recompute_full_layerHow recompute_before_dispatch works (Megatron-style)
Section titled “How recompute_before_dispatch works (Megatron-style)”In the expert MLP forward: gate_output = x @ gate_proj, up_output = x @ up_proj, postact = silu(gate_output) * up_output, output = postact @ down_proj.
With recompute_before_dispatch, the attention and router outputs are recomputed in backward, but MoE dispatch and combine are not recomputed. This avoids re-executing AllToAll communication, which is the dominant cost at multi-node.
The elementwise recompute (postact = silu(gate_output) * up_output) has ~0% overhead — confirmed by benchmarks showing recompute_before_dispatch within 0.2% of no-checkpointing speed.
Why the speedup grows with node count
Section titled “Why the speedup grows with node count”| Scale | recompute_full_layer | recompute_before_dispatch | Speedup |
|---|---|---|---|
| 1-node EP8 (NVLink) | 17,542 tok/s | 21,851 tok/s | +24.5% |
| 2-node EP16 (InfiniBand) | 20,586 tok/s | 27,470 tok/s | +33.4% |
At 2-node, dispatch/combine crosses InfiniBand (~100 Gbps) instead of NVLink (~900 Gbps). recompute_full_layer recomputes both dispatch and combine in backward — doubling the IB communication. recompute_before_dispatch avoids this entirely.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/models/layers/moe/experts.py | MoEExperts — GKN weight tensors, backend dispatch, EP forward, moe_act |
src/xorl/ops/group_gemm/kernel/group_gemm.py | group_gemm_same_nk_kernel — Triton autotuned grouped GEMM with optional STORE_ACTIVATIONS |
src/xorl/ops/__init__.py | triton_moe_forward, triton_moe_lora_forward, backend registry |
src/xorl/ops/group_gemm/kernel/quack.py | quack_group_gemm_same_nk — CuTe GEMM wrapper; XORL_QUACK_TUNED env var |
src/xorl/ops/quack/gemm_interface.py | Quack GEMM interface — SM90/SM100 configs, autotuner, gated activation kernels |
src/xorl/models/layers/moe/backend/quack.py | quack_expert_forward — non-EP quack MoE forward |
src/xorl/ops/moe/quack.py | QuackEPGroupGemm, QuackEPGroupGemmMoeAct — EP + moe_act variants |