Skip to content

Gradient Accumulation & Token Normalization

Correctly accumulating gradients across micro-batches, data-parallel ranks, and context-parallel ranks requires token-normalized loss scaling. This page explains the full pipeline — from counting valid tokens to syncing SP gradients — and why each step is necessary.


A standard gradient accumulation loop averages the loss over micro-batches. FSDP normally averages gradients over its shard group. Both of these are wrong when sequences are packed: different micro-batches and different DP ranks process variable numbers of valid (non-padding) tokens.

Why simple averaging is wrong with packed sequences❌ Naive: loss.mean() per rankRank 0900 valid tokensRank 1100 valid tokensloss₀ = CE_sum / 900grad₀ = loss₀.bwd()loss₁ = CE_sum / 100grad₁ = loss₁.bwd()FSDP avg: (grad₀ + grad₁) / 2Rank 1’s 100 tokens weighted equally to Rank 0’s 900✓ Token-normalizedRank 0900 valid tokensRank 1100 valid tokensgvt = 900 + 100 = 1000loss₀ = CE_sum × (900/1000)loss₁ = CE_sum × (100/1000)FSDP sum: grad₀ + grad₁(divide_factor=1.0, no FSDP averaging)Each token contributes equally to the gradientEquivalent to computing CE over all 1000 tokens on one GPUgvt = global valid tokens · all-reduced across all DP ranks before the accumulation loop begins

Step 1: Count Global Valid Tokens (before the loop)

Section titled “Step 1: Count Global Valid Tokens (before the loop)”

Before any micro-batch forward pass, xorl counts the total non-padding tokens across all micro-batches and all DP ranks with a single all-reduce:

# src/xorl/trainers/training_utils.py — count_valid_tokens()
global_valid_tokens = torch.tensor(0, device=device)
for mb in micro_batches:
labels = mb.get("labels", mb.get("target_tokens"))
if labels is not None:
global_valid_tokens += (labels != IGNORE_INDEX).sum() # local sum
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.SUM) # sum across DP

IGNORE_INDEX = -100 marks padding positions. The all-reduce uses the FSDP process group, summing across all DP ranks (and SP ranks within the same shard group if CP is folded into FSDP).

This value is computed once before the gradient accumulation loop and reused for every micro-batch backward pass.


Step 2: Token-Normalized Loss per Micro-Batch

Section titled “Step 2: Token-Normalized Loss per Micro-Batch”

For each micro-batch, xorl calls GradientAccumulateLoss (src/xorl/distributed/gradient_accumulate_loss.py):

local_valid_tokens = (labels != IGNORE_INDEX).sum() # tokens in this micro-batch on this rank
ga_loss, _ = gradient_accumulate_loss(loss, local_valid_tokens, global_valid_tokens)
ga_loss.backward()

GradientAccumulateLoss is a custom torch.autograd.Function that implements:

Forward:

loss_sum = loss × local_valid_tokens # scale by local token count
all_reduce(loss_sum, SUM) # sum across all DP ranks
return loss_sum / global_valid_tokens # normalize to per-token loss

Backward:

grad = grad_output × (local_valid_tokens / global_valid_tokens)

The gradient flowing back through ga_loss.backward() is thus scaled by the fraction of valid tokens on this rank and micro-batch. Summing these gradients across micro-batches and DP ranks gives the correct token-averaged gradient — identical to computing the loss over the entire global batch at once.

Gradient Accumulation Loop (non-PP path)gvt = all_reduce( Σ valid_tokens across ALL micro-batches × ALL DP ranks ) — computed once before loopMicro-batch 1lvt₁ = valid tokens in mb1 on this rankloss₁ = CE_sum(mb1) · (lvt₁ / gvt)loss₁.backward()grad += grad₁ × (lvt₁/gvt)Micro-batch 2lvt₂ = valid tokens in mb2 on this rankloss₂ = CE_sum(mb2) · (lvt₂ / gvt)loss₂.backward()grad += grad₂ × (lvt₂/gvt)… × gradient_accumulation_stepsFSDP reduce-scatter (divide_factor = 1.0 → SUM, not average) → gradient shard on each rankSP gradient all-reduce (if ring/Ulysses not folded into FSDP) → sync across CP groupFinal gradient = Σ (grad_i × lvt_i / gvt) across all micro-batches and DP ranks = token-averaged gradient

FSDP2 normally divides gradients by fsdp_size during its reduce-scatter. xorl disables this by calling set_gradient_divide_factor(1.0) on every FSDP module:

src/xorl/distributed/torch_parallelize.py
for module in model.modules():
if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False):
module.set_gradient_divide_factor(1.0) # raw SUM, not average

With divide_factor=1.0, the reduce-scatter sums gradients across ranks rather than averaging them. This is correct because GradientAccumulateLoss already scales each rank’s backward contribution by local_valid_tokens / global_valid_tokens — so summing the contributions across ranks gives the correct token-average.

The only exception: EP expert modules retain set_gradient_divide_factor(ep_size) and are marked _is_ep_fsdp = True so the loop above skips them. See Expert Parallelism for details.


After the accumulation loop completes, context-parallel (CP) dimensions that are not folded into the FSDP mesh require a separate all-reduce to synchronize gradients:

# src/xorl/trainers/training_utils.py — sync_sp_gradients()
if sp_grad_sync_group is not None:
for p in model.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=sp_grad_sync_group)

Why this is needed: CP ranks process complementary (non-overlapping) token chunks of the same sequence. Each rank computes gradients only for the tokens it saw. To get the full gradient for a shared parameter (e.g., attention weights), the contributions from all CP ranks must be summed.

When it’s a no-op: If cp_fsdp_mode: all (the default), the CP ranks are folded into the FSDP shard group, so FSDP’s reduce-scatter already handles the CP gradient sync. The sp_grad_sync_group is None and this step is skipped.

cp_fsdp_modesp_grad_sync_groupBehavior
all (default)NoneCP folded into FSDP — no separate sync needed
ulysses_onlyring attention groupOnly ring attention ranks need an extra all-reduce
ring_onlyUlysses groupOnly Ulysses ranks need an extra all-reduce
noneunified SP groupBoth ring and Ulysses need an extra all-reduce

Complete Gradient Accumulation PipelineCount global valid tokensgvt = all_reduce( Σ (labels != -100).sum() over all micro-batches × DP ranks )For each micro-batch: forward + token-normalized backwardga_loss = GradientAccumulateLoss(CE_loss, lvt, gvt) → .backward() accumulates grad × (lvt/gvt)FSDP reduce-scatter with divide_factor=1.0Sums raw gradients across DP ranks (no averaging) → each rank holds its gradient shardSP gradient all-reduce (conditional)all_reduce(grad, SUM, group=sp_grad_sync_group) — only if CP dims not folded into FSDPGradient clip + optimizer stepclip_grad_norm_() → optimizer.step() → scheduler.step()Result: gradient equivalent to computing CE over all tokens globally, on a single device

The server training path (ModelRunner) uses a different strategy: raw gradients are accumulated across multiple forward_backward API calls, and normalization is deferred to optim_step.

This is necessary because the client controls the gradient accumulation loop externally — the server doesn’t know in advance how many forward_backward calls will happen before optim_step.

Each API call does:

# 1. Count gvt for this call's micro-batches (same as local training)
global_valid_tokens = count_valid_tokens(micro_batches)
# 2. For each micro-batch: raw (unnormalized) backward
raw_loss = loss * local_valid_tokens # scale by local token count only
raw_loss.backward() # accumulates raw gradients
# 3. SP gradient sync (same as local training)
sync_sp_gradients(model, sp_grad_sync_group)
# 4. Accumulate gvt for deferred normalization
_accumulated_valid_tokens[model_id] += global_valid_tokens

The key difference from local training: raw_loss = loss × local_valid_tokensnot divided by global_valid_tokens. This means gradients accumulate as a raw token-weighted sum rather than a per-token normalized value.

When the client calls optim_step, the server normalizes all accumulated raw gradients by the total token count across every forward_backward call since the last optim_step:

accumulated = _accumulated_valid_tokens.pop(model_id) # total tokens across all calls
# Normalize all gradients in-place: grad ÷ accumulated_valid_tokens
scale = 1.0 / accumulated
for p in model.parameters():
if p.grad is not None:
p.grad.mul_(scale) # in-place to preserve FSDP2 DTensor metadata
clip_gradients(model, max_grad_norm)
optimizer.step()

p.grad.mul_(scale) is used instead of p.grad = p.grad * scale to preserve the DTensor metadata that FSDP2 stores on gradient tensors.

Local vs Server: when normalization happensLocal Trainingmicro-batch backwardgrad += CE_loss × (lvt / gvt)micro-batch backwardgrad += CE_loss × (lvt / gvt)optimizer.step()already normalized — no extra scalingnormalize per micro-batch during backwardServer Trainingforward_backward call 1: grad += CE_loss × lvt₁forward_backward call 2: grad += CE_loss × lvt₂forward_backward call N: grad += CE_loss × lvtₙoptim_step(): grad.mul_(1 / Σ gvtᵢ)then clip → optimizer.step()normalize once at optim_step across all calls

Both approaches produce mathematically identical gradients: token-averaged over all tokens seen since the last optimizer step. The server defers the division to give the client flexibility to call forward_backward any number of times before stepping.


FileDescription
src/xorl/distributed/gradient_accumulate_loss.pyGradientAccumulateLoss — custom autograd for local training token-normalized loss
src/xorl/trainers/training_utils.pycount_valid_tokens() — all-reduce valid token count; sync_sp_gradients() — CP gradient sync
src/xorl/distributed/torch_parallelize.pyset_gradient_divide_factor(1.0) — disables FSDP automatic averaging
src/xorl/trainers/trainer.pytrain_step(), _forward_backward() — local training pipeline
src/xorl/server/runner/model_runner.py_forward_backward_core() — raw gradient accumulation; optim_step() — deferred normalization via grad.mul_(1/accumulated)