Gradient Accumulation & Token Normalization
Correctly accumulating gradients across micro-batches, data-parallel ranks, and context-parallel ranks requires token-normalized loss scaling. This page explains the full pipeline — from counting valid tokens to syncing SP gradients — and why each step is necessary.
The Problem with Naive Averaging
Section titled “The Problem with Naive Averaging”A standard gradient accumulation loop averages the loss over micro-batches. FSDP normally averages gradients over its shard group. Both of these are wrong when sequences are packed: different micro-batches and different DP ranks process variable numbers of valid (non-padding) tokens.
Step 1: Count Global Valid Tokens (before the loop)
Section titled “Step 1: Count Global Valid Tokens (before the loop)”Before any micro-batch forward pass, xorl counts the total non-padding tokens across all micro-batches and all DP ranks with a single all-reduce:
# src/xorl/trainers/training_utils.py — count_valid_tokens()global_valid_tokens = torch.tensor(0, device=device)for mb in micro_batches: labels = mb.get("labels", mb.get("target_tokens")) if labels is not None: global_valid_tokens += (labels != IGNORE_INDEX).sum() # local sum
dist.all_reduce(global_valid_tokens, op=dist.ReduceOp.SUM) # sum across DPIGNORE_INDEX = -100 marks padding positions. The all-reduce uses the FSDP process group, summing across all DP ranks (and SP ranks within the same shard group if CP is folded into FSDP).
This value is computed once before the gradient accumulation loop and reused for every micro-batch backward pass.
Step 2: Token-Normalized Loss per Micro-Batch
Section titled “Step 2: Token-Normalized Loss per Micro-Batch”For each micro-batch, xorl calls GradientAccumulateLoss (src/xorl/distributed/gradient_accumulate_loss.py):
local_valid_tokens = (labels != IGNORE_INDEX).sum() # tokens in this micro-batch on this rankga_loss, _ = gradient_accumulate_loss(loss, local_valid_tokens, global_valid_tokens)ga_loss.backward()GradientAccumulateLoss is a custom torch.autograd.Function that implements:
Forward:
loss_sum = loss × local_valid_tokens # scale by local token countall_reduce(loss_sum, SUM) # sum across all DP ranksreturn loss_sum / global_valid_tokens # normalize to per-token lossBackward:
grad = grad_output × (local_valid_tokens / global_valid_tokens)The gradient flowing back through ga_loss.backward() is thus scaled by the fraction of valid tokens on this rank and micro-batch. Summing these gradients across micro-batches and DP ranks gives the correct token-averaged gradient — identical to computing the loss over the entire global batch at once.
Step 3: FSDP Gradient Averaging Disabled
Section titled “Step 3: FSDP Gradient Averaging Disabled”FSDP2 normally divides gradients by fsdp_size during its reduce-scatter. xorl disables this by calling set_gradient_divide_factor(1.0) on every FSDP module:
for module in model.modules(): if isinstance(module, FSDPModule) and not getattr(module, "_is_ep_fsdp", False): module.set_gradient_divide_factor(1.0) # raw SUM, not averageWith divide_factor=1.0, the reduce-scatter sums gradients across ranks rather than averaging them. This is correct because GradientAccumulateLoss already scales each rank’s backward contribution by local_valid_tokens / global_valid_tokens — so summing the contributions across ranks gives the correct token-average.
The only exception: EP expert modules retain set_gradient_divide_factor(ep_size) and are marked _is_ep_fsdp = True so the loop above skips them. See Expert Parallelism for details.
Step 4: SP Gradient Sync
Section titled “Step 4: SP Gradient Sync”After the accumulation loop completes, context-parallel (CP) dimensions that are not folded into the FSDP mesh require a separate all-reduce to synchronize gradients:
# src/xorl/trainers/training_utils.py — sync_sp_gradients()if sp_grad_sync_group is not None: for p in model.parameters(): if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.SUM, group=sp_grad_sync_group)Why this is needed: CP ranks process complementary (non-overlapping) token chunks of the same sequence. Each rank computes gradients only for the tokens it saw. To get the full gradient for a shared parameter (e.g., attention weights), the contributions from all CP ranks must be summed.
When it’s a no-op: If cp_fsdp_mode: all (the default), the CP ranks are folded into the FSDP shard group, so FSDP’s reduce-scatter already handles the CP gradient sync. The sp_grad_sync_group is None and this step is skipped.
cp_fsdp_mode | sp_grad_sync_group | Behavior |
|---|---|---|
all (default) | None | CP folded into FSDP — no separate sync needed |
ulysses_only | ring attention group | Only ring attention ranks need an extra all-reduce |
ring_only | Ulysses group | Only Ulysses ranks need an extra all-reduce |
none | unified SP group | Both ring and Ulysses need an extra all-reduce |
Full Pipeline Summary
Section titled “Full Pipeline Summary”Server Training: Deferred Normalization
Section titled “Server Training: Deferred Normalization”The server training path (ModelRunner) uses a different strategy: raw gradients are accumulated across multiple forward_backward API calls, and normalization is deferred to optim_step.
This is necessary because the client controls the gradient accumulation loop externally — the server doesn’t know in advance how many forward_backward calls will happen before optim_step.
Per forward_backward call
Section titled “Per forward_backward call”Each API call does:
# 1. Count gvt for this call's micro-batches (same as local training)global_valid_tokens = count_valid_tokens(micro_batches)
# 2. For each micro-batch: raw (unnormalized) backwardraw_loss = loss * local_valid_tokens # scale by local token count onlyraw_loss.backward() # accumulates raw gradients
# 3. SP gradient sync (same as local training)sync_sp_gradients(model, sp_grad_sync_group)
# 4. Accumulate gvt for deferred normalization_accumulated_valid_tokens[model_id] += global_valid_tokensThe key difference from local training: raw_loss = loss × local_valid_tokens — not divided by global_valid_tokens. This means gradients accumulate as a raw token-weighted sum rather than a per-token normalized value.
At optim_step
Section titled “At optim_step”When the client calls optim_step, the server normalizes all accumulated raw gradients by the total token count across every forward_backward call since the last optim_step:
accumulated = _accumulated_valid_tokens.pop(model_id) # total tokens across all calls
# Normalize all gradients in-place: grad ÷ accumulated_valid_tokensscale = 1.0 / accumulatedfor p in model.parameters(): if p.grad is not None: p.grad.mul_(scale) # in-place to preserve FSDP2 DTensor metadata
clip_gradients(model, max_grad_norm)optimizer.step()p.grad.mul_(scale) is used instead of p.grad = p.grad * scale to preserve the DTensor metadata that FSDP2 stores on gradient tensors.
Comparison with local training
Section titled “Comparison with local training”Both approaches produce mathematically identical gradients: token-averaged over all tokens seen since the last optimizer step. The server defers the division to give the client flexibility to call forward_backward any number of times before stepping.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/gradient_accumulate_loss.py | GradientAccumulateLoss — custom autograd for local training token-normalized loss |
src/xorl/trainers/training_utils.py | count_valid_tokens() — all-reduce valid token count; sync_sp_gradients() — CP gradient sync |
src/xorl/distributed/torch_parallelize.py | set_gradient_divide_factor(1.0) — disables FSDP automatic averaging |
src/xorl/trainers/trainer.py | train_step(), _forward_backward() — local training pipeline |
src/xorl/server/runner/model_runner.py | _forward_backward_core() — raw gradient accumulation; optim_step() — deferred normalization via grad.mul_(1/accumulated) |