Expert Kernels
Select the expert compute backend via moe_implementation. This controls which kernel handles the batch grouped GEMM over all active experts.
Backend Comparison
Section titled “Backend Comparison”| Backend | Kernel | EP compatible | torch.compile | LoRA | Notes |
|---|---|---|---|---|---|
triton | Triton group_gemm_same_nk | Yes | Partial† | Yes | Recommended for Hopper. Autotuned; works with moe_act |
quack | Quack custom CUDA extension | Yes | No | Limited | Pairs well with DeepEP; custom CUDA kernels not compilable |
native | torch._grouped_mm | Yes | Yes (full) | Yes | No custom deps; slightly slower than triton on Hopper |
eager | Python loop + index_add_ | No (ep_size=1 only) | Yes | Yes | Debugging only; slowest but simplest |
model: moe_implementation: triton # recommended for H100†Triton: The group_gemm_same_nk kernel is compatible with torch.compile. However, EP dispatch/combine (AllToAll, DeepEP) is excluded from compile graphs because it involves dynamic shapes (variable token counts per rank per step).
Quack Backend
Section titled “Quack Backend”Quack uses CuTe DSL-based GEMM kernels written by Tri Dao (the author of Flash Attention). The kernels are built on PyTorch’s CuTe DSL and compiled for SM90+ (H100) and SM100+ (B200). They are not available on older GPUs.
What makes Quack different
Section titled “What makes Quack different”Unlike the Triton backend which compiles kernels via Python-defined Triton JIT, Quack kernels are written directly in CuTe DSL — the same low-level CUDA tensor abstraction used for Flash Attention. This gives:
- Persistent kernels with pingpong double-buffering (H100) — the kernel occupies SMs continuously across multiple expert GEMMs, overlapping memory loads of the next tile with computation of the current tile
cu_seqlensformat — uses Flash Attention’s cumulative sequence length format to group tokens by expert, avoiding unnecessary tensor copies- SwiGLU fusion —
gate_projandup_projGEMMs plus the SiLU activation can be fused into a single kernel dispatch viagemm_gated/gemm_act - SM100 (B200) support — different tile config:
tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False
Default tile config per architecture:
| Architecture | tile_m | tile_n | cluster_m | cluster_n | pingpong |
|---|---|---|---|---|---|
| SM90 (H100) | 128 | 192 | 2 | 1 | True |
| SM100 (B200) | 256 | 256 | 2 | 1 | False |
Autotuning
Section titled “Autotuning”By default, autotuning is disabled (XORL_QUACK_TUNED=0). MoE routing produces variable token counts per expert per step — each unique total_M would trigger a full ~60-config benchmark, making training orders of magnitude slower. The fixed default config is chosen to be performant across typical token budgets.
Enable autotuning only when shapes are stable (e.g. evaluation with fixed batch size) or when the tuning cache is already warm:
XORL_QUACK_TUNED=1 torchrun ... -m xorl.cli.train config.yamlFull feature matrix
Section titled “Full feature matrix”Quack has dedicated kernel variants for every training mode:
| Variant | Class / Function | Use case |
|---|---|---|
| Standard | quack_expert_forward | Default forward, no EP |
| EP | QuackEPGroupGemm.apply | EP dispatch + compute |
| LoRA | quack_moe_lora_forward | LoRA adapters, no EP |
| EP + LoRA | QuackEPGroupGemmWithLoRA.apply | EP + LoRA |
| moe_act | quack_expert_forward_moe_act | Gradient checkpointing, no EP |
| EP + moe_act | QuackEPGroupGemmMoeAct.apply | EP + gradient checkpointing |
Pairing with DeepEP
Section titled “Pairing with DeepEP”Quack and DeepEP are both optimized for H100 NVLink clusters and complement each other well: DeepEP handles the token dispatch/combine with dedicated SM kernels, while Quack handles the expert FFN compute with persistent CuTe kernels. Combining them maximizes utilization of both the NVLink interconnect and the SM compute fabric.
model: moe_implementation: quack ep_dispatch: deepep deepep_num_sms: 48 # more SMs for DeepEP; fewer left for Quack — tune based on profilingLimitations
Section titled “Limitations”- Requires SM90+ (H100) or SM100+ (B200) — not available on A100 or older
- Not compatible with
torch.compile— custom CUDA/CuTe extensions cannot be traced - LoRA support is limited —
moe_hybrid_shared_lorais not supported; usetritonornativefor hybrid shared LoRA - Autotuning disabled by default — variable token counts make per-step autotuning prohibitively slow
Triton Grouped GEMM Details
Section titled “Triton Grouped GEMM Details”The core of the triton backend is src/xorl/ops/group_gemm/kernel/group_gemm.py. The primary kernel group_gemm_same_nk_kernel handles the case where all experts share the same K (input) and N (output) dimensions.
Inputs and outputs:
A: [total_tokens, K]— all tokens routed to this rank, concatenatedB: [G, K, N]— expert weight tensor in GKN layoutcumsum_M: [G]— cumulative expert token counts (prefix sum)C: [total_tokens, N]— output
Tile strategy: Each thread block processes a BLOCK_M × BLOCK_N output tile for one expert. An inner loop steps through BLOCK_K slices of A and B, accumulating the partial sum in registers.
Autotuned configs: BLOCK_M ∈ {64, 128}, BLOCK_N ∈ {128, 256}, BLOCK_K ∈ {32, 64} — autotuned per (N, K) shape at first use and cached.
Optional activation fusion: When moe_act gradient checkpointing is active, the kernel stores the post-gate×up intermediate activations to a separate buffer during the forward pass (controlled by the STORE_ACTIVATIONS flag). The custom backward function recomputes only the down projection from these stored intermediates.
torch.compile with MoE
Section titled “torch.compile with MoE”torch.compile is applied per decoder block when enable_compile: true.
What is compiled:
- Attention layers (Q/K/V projections, flash attention wrapper, output projection)
- Router gate linear and softmax
- Expert FFN compute (Triton
group_gemm_same_nkis torch.compile-friendly) - Dense shared-expert FFN layers (if present)
What is excluded from compile graphs:
- EP dispatch/combine collectives (
alltoall_pre_dispatch,alltoall_post_combine, DeepEPdispatch/combine) — dynamic shapes from variable per-rank token counts break static graph tracing - Routing replay record/replay logic — mutable external state, decorated with
@torch.compiler.disable
Recommended config:
model: moe_implementation: triton # or native; quack is incompatible with compiletrain: enable_compile: trueThe native backend (torch._grouped_mm) compiles cleanly end-to-end. The quack backend uses custom CUDA extensions and must not be used with enable_compile: true.
MoE Gradient Checkpointing (moe_checkpoint_method)
Section titled “MoE Gradient Checkpointing (moe_checkpoint_method)”xorl provides two MoE-aware checkpoint methods:
| Method | What is recomputed | EP comm recomputed | Memory savings |
|---|---|---|---|
moe_act (default) | Expert MLP activations only (gate×up product inside each expert) | No — dispatch/combine skipped via R3 | Moderate — recovers large expert FFN tensors |
full_recompute | All activations: attention, router, expert FFN | Yes — full AllToAll re-executed | Maximum — stores nothing |
None / not set | Nothing | N/A | None — all activations kept |
train: enable_gradient_checkpointing: true moe_checkpoint_method: moe_act # strongly recommended for EP modelsWhy moe_act is the right default: Expert FFN activations (gate×up intermediate tensor of shape [tokens, intermediate_size]) are the largest memory consumer in a MoE forward pass. moe_act drops these and recomputes them from stored inputs during backward, without re-running the expensive AllToAll dispatch. This gives most of the memory savings of full recompute at a fraction of the communication cost.
How it works in the Triton kernel: The group_gemm_same_nk_kernel has an optional STORE_ACTIVATIONS mode. When moe_act is active, the kernel stores the post-gate×up intermediate activations to a separate buffer. The custom backward function recomputes only the down projection from these stored intermediates.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/models/layers/moe/experts.py | MoEExperts — GKN weight tensors, backend dispatch, EP forward, moe_act |
src/xorl/ops/group_gemm/kernel/group_gemm.py | group_gemm_same_nk_kernel — Triton autotuned grouped GEMM with optional STORE_ACTIVATIONS |
src/xorl/ops/__init__.py | triton_moe_forward, triton_moe_lora_forward, backend registry |
src/xorl/ops/group_gemm/kernel/quack.py | quack_group_gemm_same_nk — CuTe GEMM wrapper; XORL_QUACK_TUNED env var |
src/xorl/ops/quack/gemm_interface.py | Quack GEMM interface — SM90/SM100 configs, autotuner, gated activation kernels |
src/xorl/models/layers/moe/backend/quack.py | quack_expert_forward — non-EP quack MoE forward |
src/xorl/ops/moe/quack.py | QuackEPGroupGemm, QuackEPGroupGemmMoeAct — EP + moe_act variants |