Skip to content

Expert Kernels

Select the expert compute backend via moe_implementation. This controls which kernel handles the batch grouped GEMM over all active experts.


BackendKernelEP compatibletorch.compileLoRANotes
tritonTriton group_gemm_same_nkYesPartial†YesRecommended for Hopper. Autotuned; works with moe_act
quackQuack custom CUDA extensionYesNoLimitedPairs well with DeepEP; custom CUDA kernels not compilable
nativetorch._grouped_mmYesYes (full)YesNo custom deps; slightly slower than triton on Hopper
eagerPython loop + index_add_No (ep_size=1 only)YesYesDebugging only; slowest but simplest
model:
moe_implementation: triton # recommended for H100

†Triton: The group_gemm_same_nk kernel is compatible with torch.compile. However, EP dispatch/combine (AllToAll, DeepEP) is excluded from compile graphs because it involves dynamic shapes (variable token counts per rank per step).


Quack uses CuTe DSL-based GEMM kernels written by Tri Dao (the author of Flash Attention). The kernels are built on PyTorch’s CuTe DSL and compiled for SM90+ (H100) and SM100+ (B200). They are not available on older GPUs.

Unlike the Triton backend which compiles kernels via Python-defined Triton JIT, Quack kernels are written directly in CuTe DSL — the same low-level CUDA tensor abstraction used for Flash Attention. This gives:

  • Persistent kernels with pingpong double-buffering (H100) — the kernel occupies SMs continuously across multiple expert GEMMs, overlapping memory loads of the next tile with computation of the current tile
  • cu_seqlens format — uses Flash Attention’s cumulative sequence length format to group tokens by expert, avoiding unnecessary tensor copies
  • SwiGLU fusiongate_proj and up_proj GEMMs plus the SiLU activation can be fused into a single kernel dispatch via gemm_gated / gemm_act
  • SM100 (B200) support — different tile config: tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False

Default tile config per architecture:

Architecturetile_mtile_ncluster_mcluster_npingpong
SM90 (H100)12819221True
SM100 (B200)25625621False

By default, autotuning is disabled (XORL_QUACK_TUNED=0). MoE routing produces variable token counts per expert per step — each unique total_M would trigger a full ~60-config benchmark, making training orders of magnitude slower. The fixed default config is chosen to be performant across typical token budgets.

Enable autotuning only when shapes are stable (e.g. evaluation with fixed batch size) or when the tuning cache is already warm:

Terminal window
XORL_QUACK_TUNED=1 torchrun ... -m xorl.cli.train config.yaml

Quack has dedicated kernel variants for every training mode:

VariantClass / FunctionUse case
Standardquack_expert_forwardDefault forward, no EP
EPQuackEPGroupGemm.applyEP dispatch + compute
LoRAquack_moe_lora_forwardLoRA adapters, no EP
EP + LoRAQuackEPGroupGemmWithLoRA.applyEP + LoRA
moe_actquack_expert_forward_moe_actGradient checkpointing, no EP
EP + moe_actQuackEPGroupGemmMoeAct.applyEP + gradient checkpointing

Quack and DeepEP are both optimized for H100 NVLink clusters and complement each other well: DeepEP handles the token dispatch/combine with dedicated SM kernels, while Quack handles the expert FFN compute with persistent CuTe kernels. Combining them maximizes utilization of both the NVLink interconnect and the SM compute fabric.

model:
moe_implementation: quack
ep_dispatch: deepep
deepep_num_sms: 48 # more SMs for DeepEP; fewer left for Quack — tune based on profiling
  • Requires SM90+ (H100) or SM100+ (B200) — not available on A100 or older
  • Not compatible with torch.compile — custom CUDA/CuTe extensions cannot be traced
  • LoRA support is limitedmoe_hybrid_shared_lora is not supported; use triton or native for hybrid shared LoRA
  • Autotuning disabled by default — variable token counts make per-step autotuning prohibitively slow

The core of the triton backend is src/xorl/ops/group_gemm/kernel/group_gemm.py. The primary kernel group_gemm_same_nk_kernel handles the case where all experts share the same K (input) and N (output) dimensions.

Inputs and outputs:

  • A: [total_tokens, K] — all tokens routed to this rank, concatenated
  • B: [G, K, N] — expert weight tensor in GKN layout
  • cumsum_M: [G] — cumulative expert token counts (prefix sum)
  • C: [total_tokens, N] — output

Tile strategy: Each thread block processes a BLOCK_M × BLOCK_N output tile for one expert. An inner loop steps through BLOCK_K slices of A and B, accumulating the partial sum in registers.

Autotuned configs: BLOCK_M ∈ {64, 128}, BLOCK_N ∈ {128, 256}, BLOCK_K ∈ {32, 64} — autotuned per (N, K) shape at first use and cached.

Triton group_gemm_same_nk: Tiling StrategyExpert Group GA slice: [tokens_g, K]BLOCK_M× BLOCK_N tilenext tilenext rowthread block per tile per expertinner K loopB[G, K, N] — GKN weightsexpert g slice: [K, N]BLOCK_K slicesteps through KOutput C[total_tokens, N]accumulated inregisters, writtenonce per tile

Optional activation fusion: When moe_act gradient checkpointing is active, the kernel stores the post-gate×up intermediate activations to a separate buffer during the forward pass (controlled by the STORE_ACTIVATIONS flag). The custom backward function recomputes only the down projection from these stored intermediates.


torch.compile is applied per decoder block when enable_compile: true.

What is compiled:

  • Attention layers (Q/K/V projections, flash attention wrapper, output projection)
  • Router gate linear and softmax
  • Expert FFN compute (Triton group_gemm_same_nk is torch.compile-friendly)
  • Dense shared-expert FFN layers (if present)

What is excluded from compile graphs:

  • EP dispatch/combine collectives (alltoall_pre_dispatch, alltoall_post_combine, DeepEP dispatch/combine) — dynamic shapes from variable per-rank token counts break static graph tracing
  • Routing replay record/replay logic — mutable external state, decorated with @torch.compiler.disable

Recommended config:

model:
moe_implementation: triton # or native; quack is incompatible with compile
train:
enable_compile: true

The native backend (torch._grouped_mm) compiles cleanly end-to-end. The quack backend uses custom CUDA extensions and must not be used with enable_compile: true.


MoE Gradient Checkpointing (moe_checkpoint_method)

Section titled “MoE Gradient Checkpointing (moe_checkpoint_method)”

xorl provides two MoE-aware checkpoint methods:

MethodWhat is recomputedEP comm recomputedMemory savings
moe_act (default)Expert MLP activations only (gate×up product inside each expert)No — dispatch/combine skipped via R3Moderate — recovers large expert FFN tensors
full_recomputeAll activations: attention, router, expert FFNYes — full AllToAll re-executedMaximum — stores nothing
None / not setNothingN/ANone — all activations kept
train:
enable_gradient_checkpointing: true
moe_checkpoint_method: moe_act # strongly recommended for EP models

Why moe_act is the right default: Expert FFN activations (gate×up intermediate tensor of shape [tokens, intermediate_size]) are the largest memory consumer in a MoE forward pass. moe_act drops these and recomputes them from stored inputs during backward, without re-running the expensive AllToAll dispatch. This gives most of the memory savings of full recompute at a fraction of the communication cost.

How it works in the Triton kernel: The group_gemm_same_nk_kernel has an optional STORE_ACTIVATIONS mode. When moe_act is active, the kernel stores the post-gate×up intermediate activations to a separate buffer. The custom backward function recomputes only the down projection from these stored intermediates.


FileDescription
src/xorl/models/layers/moe/experts.pyMoEExperts — GKN weight tensors, backend dispatch, EP forward, moe_act
src/xorl/ops/group_gemm/kernel/group_gemm.pygroup_gemm_same_nk_kernel — Triton autotuned grouped GEMM with optional STORE_ACTIVATIONS
src/xorl/ops/__init__.pytriton_moe_forward, triton_moe_lora_forward, backend registry
src/xorl/ops/group_gemm/kernel/quack.pyquack_group_gemm_same_nk — CuTe GEMM wrapper; XORL_QUACK_TUNED env var
src/xorl/ops/quack/gemm_interface.pyQuack GEMM interface — SM90/SM100 configs, autotuner, gated activation kernels
src/xorl/models/layers/moe/backend/quack.pyquack_expert_forward — non-EP quack MoE forward
src/xorl/ops/moe/quack.pyQuackEPGroupGemm, QuackEPGroupGemmMoeAct — EP + moe_act variants