Skip to content

Expert Kernels

Select the expert compute backend via moe_implementation. This controls which kernel handles the batch grouped GEMM over all active experts.


BackendKernelEP compatibletorch.compileLoRANotes
tritonTriton group_gemm_same_nkYesPartial†YesRecommended for Hopper. Autotuned; works with moe_act
quackQuack custom CUDA extensionYesNoLimitedPairs well with DeepEP; custom CUDA kernels not compilable
nativetorch._grouped_mmYesYes (full)YesNo custom deps; slightly slower than triton on Hopper
eagerPython loop + index_add_No (ep_size=1 only)YesYesDebugging only; slowest but simplest
model:
moe_implementation: triton # recommended for H100

†Triton: The group_gemm_same_nk kernel is compatible with torch.compile. However, EP dispatch/combine (AllToAll, DeepEP) is excluded from compile graphs because it involves dynamic shapes (variable token counts per rank per step).


Quack uses CuTe DSL-based GEMM kernels written by Tri Dao (the author of Flash Attention). The kernels are built on PyTorch’s CuTe DSL and compiled for SM90+ (H100) and SM100+ (B200). They are not available on older GPUs.

Unlike the Triton backend which compiles kernels via Python-defined Triton JIT, Quack kernels are written directly in CuTe DSL — the same low-level CUDA tensor abstraction used for Flash Attention. This gives:

  • Persistent kernels with pingpong double-buffering (H100) — the kernel occupies SMs continuously across multiple expert GEMMs, overlapping memory loads of the next tile with computation of the current tile
  • cu_seqlens format — uses Flash Attention’s cumulative sequence length format to group tokens by expert, avoiding unnecessary tensor copies
  • SwiGLU fusiongate_proj and up_proj GEMMs plus the SiLU activation can be fused into a single kernel dispatch via gemm_gated / gemm_act
  • SM100 (B200) support — different tile config: tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False

Default tile config per architecture:

Architecturetile_mtile_ncluster_mcluster_npingpong
SM90 (H100)12819221True
SM100 (B200)25625621False

By default, autotuning is disabled (XORL_QUACK_TUNED=0). MoE routing produces variable token counts per expert per step — each unique total_M would trigger a full ~60-config benchmark, making training orders of magnitude slower. The fixed default config is chosen to be performant across typical token budgets.

Enable autotuning only when shapes are stable (e.g. evaluation with fixed batch size) or when the tuning cache is already warm:

Terminal window
XORL_QUACK_TUNED=1 torchrun ... -m xorl.cli.train config.yaml

Quack has dedicated kernel variants for every training mode:

VariantClass / FunctionUse case
Standardquack_expert_forwardDefault forward, no EP
EPQuackEPGroupGemm.applyEP dispatch + compute
LoRAquack_moe_lora_forwardLoRA adapters, no EP
EP + LoRAQuackEPGroupGemmWithLoRA.applyEP + LoRA
moe_actquack_expert_forward_moe_actGradient checkpointing, no EP
EP + moe_actQuackEPGroupGemmMoeAct.applyEP + gradient checkpointing

Quack and DeepEP are both optimized for H100 NVLink clusters and complement each other well: DeepEP handles the token dispatch/combine with dedicated SM kernels, while Quack handles the expert FFN compute with persistent CuTe kernels. Combining them maximizes utilization of both the NVLink interconnect and the SM compute fabric.

model:
moe_implementation: quack
ep_dispatch: deepep
deepep_num_sms: 48 # more SMs for DeepEP; fewer left for Quack — tune based on profiling
  • Requires SM90+ (H100) or SM100+ (B200) — not available on A100 or older
  • Not compatible with torch.compile — custom CUDA/CuTe extensions cannot be traced
  • LoRA support is limitedmoe_hybrid_shared_lora is not supported; use triton or native for hybrid shared LoRA
  • Autotuning disabled by default — variable token counts make per-step autotuning prohibitively slow

The core of the triton backend is src/xorl/ops/group_gemm/kernel/group_gemm.py. The primary kernel group_gemm_same_nk_kernel handles the case where all experts share the same K (input) and N (output) dimensions.

Inputs and outputs:

  • A: [total_tokens, K] — all tokens routed to this rank, concatenated
  • B: [G, K, N] — expert weight tensor in GKN layout
  • cumsum_M: [G] — cumulative expert token counts (prefix sum)
  • C: [total_tokens, N] — output

Tile strategy: Each thread block processes a BLOCK_M × BLOCK_N output tile for one expert. An inner loop steps through BLOCK_K slices of A and B, accumulating the partial sum in registers.

Autotuned configs: BLOCK_M ∈ {64, 128}, BLOCK_N ∈ {128, 256}, BLOCK_K ∈ {32, 64} — autotuned per (N, K) shape at first use and cached.

Triton group_gemm_same_nk: Tiling StrategyExpert Group GA slice: [tokens_g, K]BLOCK_M× BLOCK_N tilenext tilenext rowthread block per tile per expertinner K loopB[G, K, N] — GKN weightsexpert g slice: [K, N]BLOCK_K slicesteps through KOutput C[total_tokens, N]accumulated inregisters, writtenonce per tile

Optional activation fusion: When moe_act gradient checkpointing is active, the kernel stores the post-gate×up intermediate activations to a separate buffer during the forward pass (controlled by the STORE_ACTIVATIONS flag). The custom backward function recomputes only the down projection from these stored intermediates.


torch.compile is applied per decoder block when enable_compile: true.

What is compiled:

  • Attention layers (Q/K/V projections, flash attention wrapper, output projection)
  • Router gate linear and softmax
  • Expert FFN compute (Triton group_gemm_same_nk is torch.compile-friendly)
  • Dense shared-expert FFN layers (if present)

What is excluded from compile graphs:

  • EP dispatch/combine collectives (alltoall_pre_dispatch, alltoall_post_combine, DeepEP dispatch/combine) — dynamic shapes from variable per-rank token counts break static graph tracing
  • Routing replay record/replay logic — mutable external state, decorated with @torch.compiler.disable

Recommended config:

model:
moe_implementation: triton # or native; quack is incompatible with compile
train:
enable_compile: true

The native backend (torch._grouped_mm) compiles cleanly end-to-end. The quack backend uses custom CUDA extensions and must not be used with enable_compile: true.


MoE Gradient Checkpointing (gradient_checkpointing_method)

Section titled “MoE Gradient Checkpointing (gradient_checkpointing_method)”

gradient_checkpointing_method controls what is recomputed in the backward pass. For MoE models with EP, this has a large impact on throughput because recompute_full_layer recomputes AllToAll dispatch + combine on every backward layer.

MethodBackward recomputesEP comm recomputedPeak mem (32k)Throughput
recompute_full_layer (default)Entire decoder layerYes — full AllToAll re-executed37.5 GB (1N) / 21.2 GB (2N)baseline
recompute_before_dispatchAttn + router only; keeps dispatch + expert + combineNo54.8 GB (1N) / 46.6 GB (2N)+24.5% (1N) / +33.4% (2N)
no_recomputeNothingNohighestmax throughput

Benchmarks: Qwen3-Coder-30B-A3B, quack + alltoall, 32k seq, H100 80GB. 1N = 1-node EP8, 2N = 2-node EP16.

# Maximum throughput (short seq, memory permits):
train:
enable_gradient_checkpointing: true
gradient_checkpointing_method: no_recompute
# Balanced memory/speed — recompute attn+router, keep dispatch+expert+combine:
train:
enable_gradient_checkpointing: true
gradient_checkpointing_method: recompute_before_dispatch
# Maximum memory savings (long seq, required for 128k):
train:
enable_gradient_checkpointing: true
gradient_checkpointing_method: recompute_full_layer

How recompute_before_dispatch works (Megatron-style)

Section titled “How recompute_before_dispatch works (Megatron-style)”

In the expert MLP forward: gate_output = x @ gate_proj, up_output = x @ up_proj, postact = silu(gate_output) * up_output, output = postact @ down_proj.

With recompute_before_dispatch, the attention and router outputs are recomputed in backward, but MoE dispatch and combine are not recomputed. This avoids re-executing AllToAll communication, which is the dominant cost at multi-node.

The elementwise recompute (postact = silu(gate_output) * up_output) has ~0% overhead — confirmed by benchmarks showing recompute_before_dispatch within 0.2% of no-checkpointing speed.

Scalerecompute_full_layerrecompute_before_dispatchSpeedup
1-node EP8 (NVLink)17,542 tok/s21,851 tok/s+24.5%
2-node EP16 (InfiniBand)20,586 tok/s27,470 tok/s+33.4%

At 2-node, dispatch/combine crosses InfiniBand (~100 Gbps) instead of NVLink (~900 Gbps). recompute_full_layer recomputes both dispatch and combine in backward — doubling the IB communication. recompute_before_dispatch avoids this entirely.


FileDescription
src/xorl/models/layers/moe/experts.pyMoEExperts — GKN weight tensors, backend dispatch, EP forward, moe_act
src/xorl/ops/group_gemm/kernel/group_gemm.pygroup_gemm_same_nk_kernel — Triton autotuned grouped GEMM with optional STORE_ACTIVATIONS
src/xorl/ops/__init__.pytriton_moe_forward, triton_moe_lora_forward, backend registry
src/xorl/ops/group_gemm/kernel/quack.pyquack_group_gemm_same_nk — CuTe GEMM wrapper; XORL_QUACK_TUNED env var
src/xorl/ops/quack/gemm_interface.pyQuack GEMM interface — SM90/SM100 configs, autotuner, gated activation kernels
src/xorl/models/layers/moe/backend/quack.pyquack_expert_forward — non-EP quack MoE forward
src/xorl/ops/moe/quack.pyQuackEPGroupGemm, QuackEPGroupGemmMoeAct — EP + moe_act variants