Tensor Parallelism
Tensor parallelism (TP) shards individual weight matrices across multiple GPUs so that each GPU holds a vertical or horizontal slice of every linear layer. Unlike FSDP2, which shards parameters at rest and gathers them during forward, TP keeps weights sharded during the entire forward and backward pass. Every GPU in a TP group processes the same token sequence and contributes a partial result; a collective reduces partials to produce the full output.
xorl implements TP using PyTorch’s torch.distributed.tensor.parallel API (parallelize_module, ColwiseParallel, RowwiseParallel) on a 1-D DeviceMesh slice named "tp".
1. What TP Does — Column/Row Sharding
Section titled “1. What TP Does — Column/Row Sharding”TP maps each linear layer to one of two sharding styles.
Column-wise parallel (ColwiseParallel)
Section titled “Column-wise parallel (ColwiseParallel)”The weight matrix W of shape [out_features, in_features] is split along out_features (dimension 0). Each TP rank holds W[rank * out/tp : (rank+1) * out/tp, :]. The input activations are replicated (broadcast) to all TP ranks; each rank computes a partial output y_partial = x @ W_local^T. The outputs are naturally sharded across the output dimension, so no collective is needed at the end of this layer if the next layer expects sharded input.
Layers using ColwiseParallel in Qwen3:
embed_tokens(embedding variant — shards vocabulary dimension)self_attn.q_proj,self_attn.k_proj,self_attn.v_projmlp.gate_proj,mlp.up_projlm_head
Row-wise parallel (RowwiseParallel)
Section titled “Row-wise parallel (RowwiseParallel)”The weight matrix is split along in_features (dimension 1). Each TP rank holds W[:, rank * in/tp : (rank+1) * in/tp]. The input activations are expected to be sharded (which is exactly what ColwiseParallel produces), so no gather is needed on the way in. Each rank computes a partial output and an all-reduce aggregates partials across the TP group into the full output tensor, which is then replicated across TP ranks.
Layers using RowwiseParallel in Qwen3:
self_attn.o_projmlp.down_proj
Data flow through a transformer layer with TP=4
Section titled “Data flow through a transformer layer with TP=4”Input x [B, S, H] (replicated across 4 TP ranks) | vq_proj (colwise) → q_local [B, S, H/4]k_proj (colwise) → k_local [B, S, H_kv/4] (per rank)v_proj (colwise) → v_local [B, S, H_kv/4] | vMulti-head attention (each rank handles H/4 heads, no collective) | vo_proj (rowwise) → partial_y → all-reduce → y [B, S, H] (replicated) | vgate_proj (colwise) → g_local [B, S, I/4]up_proj (colwise) → u_local [B, S, I/4] | vSiLU + multiply (local) | vdown_proj (rowwise) → partial_y → all-reduce → y [B, S, H] (replicated)Two all-reduce collectives occur per transformer layer: one after o_proj and one after down_proj.
2. Device Mesh — How the TP Mesh Is Constructed
Section titled “2. Device Mesh — How the TP Mesh Is Constructed”The global device mesh is built in src/xorl/distributed/parallel_state.py by init_parallel_state. All parallelism dimensions are composed into a single N-dimensional DeviceMesh:
# Dimension order (only included if size > 1 or name == "dp_shard"):# [pp, dp_replicate, dp_shard, ringattn, ulysses, tp]device_mesh = init_device_mesh( device_type=device_type, mesh_shape=tuple(mesh_shape), mesh_dim_names=tuple(mesh_dim_names),)The "tp" dimension is always the innermost dimension of the mesh. On a typical 8-GPU single node with TP=4 and DP=2:
Ranks: [0, 1, 2, 3, 4, 5, 6, 7]Mesh shape: [dp_shard=2, tp=4]
dp_shard dim → rows: [0,1,2,3] and [4,5,6,7]tp dim → cols: 0&4, 1&5, 2&6, 3&7
TP group for ranks 0–3: {0, 1, 2, 3}TP group for ranks 4–7: {4, 5, 6, 7}FSDP group (dp_shard): {0, 4}, {1, 5}, {2, 6}, {3, 7}The TP sub-mesh is accessed as parallel_state.tp_mesh (device_mesh["tp"]) and the TP process group as parallel_state.tp_group (device_mesh.get_group("tp")).
The FSDP mesh (parallel_state.fsdp_mesh) is always over the DP dimensions, excluding the TP dimension. This means TP ranks are in the same FSDP shard group — FSDP shards the TP-local weight slices, not the full weights. See Section 8 for implications.
3. QKV Unfusing — Why merge_qkv Must Be False
Section titled “3. QKV Unfusing — Why merge_qkv Must Be False”By default, xorl models use a fused qkv_proj linear layer of shape [q_dim + 2 * kv_dim, hidden_size] for efficiency. TP cannot shard this fused layer correctly because Q, K, and V have different output dimensions in GQA (grouped-query attention) models — the number of heads differs between Q and KV, so a uniform column split would produce misaligned head boundaries.
Setting merge_qkv: false in the model config triggers unfuse_for_tp() before parallelize_module is called:
if parallel_state.tp_enabled: if hasattr(model, "unfuse_for_tp"): model.unfuse_for_tp() # splits qkv_proj → q_proj, k_proj, v_proj # splits gate_up_proj → gate_proj, up_proj tp_plan = _build_tp_plan(model) model = parallelize_module(model, device_mesh=parallel_state.tp_mesh, parallelize_plan=tp_plan)At the attention layer level (src/xorl/models/layers/attention/multi_head_attention.py):
def unfuse_for_tp(self): device = self.qkv_proj.weight.device dtype = self.qkv_proj.weight.dtype self.q_proj = nn.Linear(hidden_size, q_dim, bias=..., device=device, dtype=dtype) self.k_proj = nn.Linear(hidden_size, kv_dim, bias=..., device=device, dtype=dtype) self.v_proj = nn.Linear(hidden_size, kv_dim, bias=..., device=device, dtype=dtype) del self.qkv_projAt the MLP layer level (src/xorl/models/transformers/qwen3/modeling_qwen3.py):
def unfuse_for_tp(self): self.gate_proj = nn.Linear(hidden_size, intermediate_size, ...) self.up_proj = nn.Linear(hidden_size, intermediate_size, ...) del self.gate_up_projWhat happens if merge_qkv: true with TP enabled:
The _build_tp_plan function reads config.base_model_tp_plan, which after unfuse_for_tp is overridden to the unfused plan (q_proj, k_proj, v_proj). If unfusing is skipped, the TP plan still references the unfused projection names (q_proj, k_proj, v_proj), but those modules do not exist on the model — parallelize_module will silently skip them (or raise a KeyError depending on PyTorch version). The model will run without TP applied to the attention projections, producing incorrect training results with no visible error.
Always set merge_qkv: false when tensor_parallel_size > 1.
4. Column-wise vs Row-wise Parallelism — Complete Layer Assignment
Section titled “4. Column-wise vs Row-wise Parallelism — Complete Layer Assignment”The TP plan for Qwen3 dense models (src/xorl/models/transformers/qwen3/parallelize.py):
TP_PLAN = { "embed_tokens": "embedding", # RowwiseParallel(input=Replicate, output=Replicate) "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise",}
MODEL_TP_PLAN = { "lm_head": "colwise", # dense: ColwiseParallel (sharded output)}For Qwen3 MoE models (src/xorl/models/transformers/qwen3_moe/parallelize.py), the attention and dense MLP layers use the same plan. Expert weights inside MoEBlock are not TP-sharded — they use Expert Parallelism (EP) instead. The lm_head in MoE uses "colwise_rep" (ColwiseParallel with output_layouts=Replicate()) to produce a replicated output suitable for softmax over the full vocabulary.
The "embedding" special style
Section titled “The "embedding" special style”The embedding table [vocab_size, hidden_size] is sharded on the vocabulary dimension (dim 0), implemented as a RowwiseParallel with input_layouts=Replicate(), output_layouts=Replicate(). Each TP rank owns vocab_size / tp_size token embeddings. During a lookup, each rank returns its partial embedding (zero for tokens outside its vocab slice) and an all-reduce sums partials to produce the full embedding vector.
| Layer | Style | Weight split dimension | Input | Output |
|---|---|---|---|---|
embed_tokens | embedding | vocab (dim 0) | replicated token ids | replicated hidden states |
q_proj, k_proj, v_proj | colwise | out_features (dim 0) | replicated | sharded (heads split) |
o_proj | rowwise | in_features (dim 1) | sharded | replicated (all-reduce) |
gate_proj, up_proj | colwise | out_features (dim 0) | replicated | sharded (intermediate split) |
down_proj | rowwise | in_features (dim 1) | sharded | replicated (all-reduce) |
lm_head | colwise | out_features (vocab, dim 0) | replicated | sharded or replicated |
5. Sequence Dimension and Context Parallelism
Section titled “5. Sequence Dimension and Context Parallelism”In the current xorl TP implementation, the sequence dimension is not sharded across TP ranks. All TP ranks process the full sequence length S. This means:
- The activation tensor
[B, S, H]is replicated at the start of each layer. - The all-reduce at the end of
o_projanddown_projreduces across TP ranks. - Memory for activations scales with
S, notS / tp_size.
This is different from “Megatron-style context parallelism” (also called SP), where the non-attention layers use a reduce-scatter / all-gather pattern to keep activations sharded along the sequence dimension between the colwise and rowwise layers. xorl does not currently implement Megatron-style SP layered on top of TP. If you need to reduce activation memory along the sequence dimension, use Ring Attention (ringattn_parallel_size) or Ulysses (ulysses_parallel_size) instead — these are orthogonal to TP.
TP can be freely combined with Ring Attention and Ulysses because the TP mesh dimension is independent of the CP mesh dimensions in the global device mesh.
6. Communication Patterns
Section titled “6. Communication Patterns”Per-layer collectives
Section titled “Per-layer collectives”For each transformer layer with TP enabled, the following collectives occur:
| Point in the layer | Collective | Group | Cost |
|---|---|---|---|
After o_proj | All-reduce | TP group | 2 * B * S * H * (tp-1) / tp bytes |
After down_proj | All-reduce | TP group | 2 * B * S * H * (tp-1) / tp bytes |
The two all-reduces are the dominant TP communication cost. For a 32-layer model with TP=4 and H=4096, B=1, S=8192, this is approximately 32 × 2 × 2 × 8192 × 4096 × (3/4) × 2 bytes ≈ 12 GB of data moved per forward pass through the TP interconnect.
Embedding all-reduce
Section titled “Embedding all-reduce”During the embedding lookup (embed_tokens), a single all-reduce sums partial embeddings across TP ranks. This is a [B, S, H] tensor.
Gradient all-reduce (backward)
Section titled “Gradient all-reduce (backward)”During the backward pass, PyTorch’s DTensor mechanism handles gradients automatically. For ColwiseParallel layers, the gradient with respect to the input is a reduce-scatter (each rank holds a shard of the input gradient). For RowwiseParallel layers, the gradient with respect to the input is an all-gather before being passed to the previous layer. The net effect mirrors the forward collectives.
TP is always synchronous
Section titled “TP is always synchronous”All TP collectives are synchronous blocking all-reduces or all-gathers. There is no overlap of TP communication with computation in xorl (unlike some Megatron implementations that pipeline the all-reduce with the next layer). This means TP latency adds directly to the critical path.
TP + FSDP2 interaction
Section titled “TP + FSDP2 interaction”When TP and FSDP2 are both enabled, weight loading must happen in a specific order: TP is applied first (weights become DTensors with TP placements), then FSDP2 wraps the TP-sharded parameters. See Section 8 for details.
7. Constraints
Section titled “7. Constraints”Attention head divisibility
Section titled “Attention head divisibility”num_attention_heads % tensor_parallel_size == 0num_key_value_heads % tensor_parallel_size == 0Both constraints must hold. For GQA models like Qwen3 with num_key_value_heads < num_attention_heads, the KV head count is the binding constraint.
Qwen3-8B (num_attention_heads=32, num_key_value_heads=8):
- TP=2: 32 % 2 == 0, 8 % 2 == 0. Valid.
- TP=4: 32 % 4 == 0, 8 % 4 == 0. Valid.
- TP=8: 32 % 8 == 0, 8 % 8 == 0. Valid.
Qwen3-30B-A3B (MoE, num_attention_heads=16, num_key_value_heads=8):
- TP=2: 16 % 2 == 0, 8 % 2 == 0. Valid.
- TP=4: 16 % 4 == 0, 8 % 4 == 0. Valid.
- TP=8: 16 % 8 == 0, 8 % 8 == 0. Valid.
Hidden size divisibility
Section titled “Hidden size divisibility”hidden_size % tensor_parallel_size == 0 must hold. All Qwen3 models have hidden_size divisible by 8, so this is never a binding constraint for practical TP sizes.
Intermediate size divisibility
Section titled “Intermediate size divisibility”intermediate_size % tensor_parallel_size == 0 must hold for MLP sharding.
World size constraint
Section titled “World size constraint”world_size = PP × DP_shard × DP_replicate × TP × CP_ring × CP_ulyssesThis is validated in ParallelState.__post_init__ and also in TrainArguments.__post_init__:
non_dp_size = ulysses_parallel_size * tensor_parallel_size * ringattn_parallel_size * pipeline_parallel_sizeself.data_parallel_size = world_size // non_dp_sizeLoRA incompatibility
Section titled “LoRA incompatibility”TP + LoRA is not currently supported. build_parallelize_model explicitly raises NotImplementedError if any LoraLinear modules are found when TP is enabled:
if parallel_state.tp_enabled: if any(isinstance(m, LoraLinear) for m in model.modules()): raise NotImplementedError( "Tensor parallelism + LoRA is not currently supported." )Use FSDP2 alone for LoRA fine-tuning.
PyTorch version
Section titled “PyTorch version”TP requires PyTorch >= 2.4. The ColwiseParallel, RowwiseParallel, and parallelize_module imports are guarded by is_torch_version_greater_than("2.4").
8. Interaction with FSDP2
Section titled “8. Interaction with FSDP2”TP and FSDP2 compose so that TP shards within a node and FSDP2 shards the TP-local slices across the DP group. The key property is that the "tp" dimension is innermost in the global device mesh, meaning consecutive ranks on the same node form a TP group, while the FSDP2 mesh spans ranks that have the same TP rank across nodes.
Ordering requirement: TP before FSDP
Section titled “Ordering requirement: TP before FSDP”parallelize_module must be called before fully_shard. After parallelize_module, each parameter is a DTensor with a Shard(0) or Shard(1) placement on the TP mesh, representing the local TP slice. When fully_shard then wraps the module, it shards this already-reduced tensor further across the FSDP group.
Weight loading order
Section titled “Weight loading order”When init_device: meta is used, weights must be loaded after TP is applied but before FSDP wrapping. This is because after parallelize_module, parameters are meta DTensors with TP placements. FSDP’s lazy init cannot handle meta DTensors correctly (the logical DTensor shape does not match the local shard shape). xorl handles this automatically:
if parallel_state.tp_enabled: model = parallelize_module(model, ...) # apply TP first if kwargs.get("init_device") == "meta" and weights_path is not None: rank0_load_and_broadcast_weights(...) # load weights into TP-sharded params kwargs["skip_weight_loading"] = True # skip again in FSDP path
# then FSDP2 wraps the already-materialized TP DTensorsmodel = parallelize_model_fsdp2(model, ...)What each GPU actually stores
Section titled “What each GPU actually stores”For a 32-GPU job with TP=4, DP_shard=8:
- Each GPU has a TP-local slice of every parameter (1/4 of output features for colwise layers).
- FSDP2 further shards this TP-local slice across 8 GPUs in the FSDP group.
- GPU effective parameter storage =
full_param_size / (TP × FSDP_shard_size)=full_param_size / 32.
The FSDP2 all-gather during the forward pass gathers the TP-local slice (not the full parameter). The TP communication then distributes the work across TP ranks.
Gradient reduce-scatter
Section titled “Gradient reduce-scatter”FSDP2’s reduce-scatter during the backward pass averages TP-local gradients across the FSDP group. The TP all-reduce (inside PyTorch’s DTensor autograd) runs separately during the backward pass over the TP group. The two collectives operate on different process groups and do not interfere.
Mixed precision
Section titled “Mixed precision”The same MixedPrecisionPolicy(param_dtype=bfloat16, reduce_dtype=float32) applies to TP-sharded parameters. TP all-reduces happen in the dtype of the activations (bfloat16 during forward, float32 during gradient reduce via FSDP2’s reduce_dtype).
9. Configuration Examples
Section titled “9. Configuration Examples”All examples assume 8 GPUs on a single node.
TP=4, FSDP2 shard=2 (Qwen3-8B, 8 GPUs)
Section titled “TP=4, FSDP2 shard=2 (Qwen3-8B, 8 GPUs)”model: model_path: Qwen/Qwen3-8B merge_qkv: false # required for TP
train: data_parallel_mode: fsdp2 tensor_parallel_size: 4 data_parallel_shard_size: 2 # 4 TP × 2 FSDP = 8 GPUs data_parallel_replicate_size: 1 ulysses_parallel_size: 1 ringattn_parallel_size: 1 pipeline_parallel_size: 1 expert_parallel_size: 1
enable_mixed_precision: true enable_gradient_checkpointing: true enable_compile: true # recommended with TP for fused kernels init_device: meta load_weights_mode: broadcastGPU layout (TP groups across columns, FSDP shard groups across rows):
tp=0 tp=1 tp=2 tp=3fsdp=0: [ 0, 1, 2, 3 ]fsdp=1: [ 4, 5, 6, 7 ]TP=2, FSDP2 shard=4 (Qwen3-8B, 8 GPUs)
Section titled “TP=2, FSDP2 shard=4 (Qwen3-8B, 8 GPUs)”model: model_path: Qwen/Qwen3-8B merge_qkv: false
train: data_parallel_mode: fsdp2 tensor_parallel_size: 2 data_parallel_shard_size: 4 # 2 TP × 4 FSDP = 8 GPUs data_parallel_replicate_size: 1 enable_compile: true init_device: metaTP=4, FSDP2 shard=4 (Qwen3-32B, 16 GPUs)
Section titled “TP=4, FSDP2 shard=4 (Qwen3-32B, 16 GPUs)”model: model_path: Qwen/Qwen3-32B merge_qkv: false
train: data_parallel_mode: fsdp2 tensor_parallel_size: 4 data_parallel_shard_size: 4 # 4 TP × 4 FSDP = 16 GPUs data_parallel_replicate_size: 1 enable_mixed_precision: true enable_gradient_checkpointing: true init_device: meta load_weights_mode: broadcastTP=4 + compile (from qwen3_8b_tp4_compile.yaml)
Section titled “TP=4 + compile (from qwen3_8b_tp4_compile.yaml)”The example config at examples/local/dummy/configs/full/qwen3_8b_tp4_compile.yaml combines TP=4 with torch.compile:
model: model_path: Qwen/Qwen3-8B attn_implementation: flash_attention_3
train: data_parallel_mode: fsdp2 tensor_parallel_size: 4 data_parallel_shard_size: 2 data_parallel_replicate_size: 1 ulysses_parallel_size: 1 ringattn_parallel_size: 1 expert_parallel_size: 1
enable_mixed_precision: true enable_gradient_checkpointing: true enable_compile: true init_device: meta load_weights_mode: broadcasttorch.compile is applied to each decoder block before FSDP wrapping (build_parallelize_model handles the ordering). Compiled kernels can fuse the TP linear operations with activation functions for better throughput.
TP + PP (Qwen3-8B, PP=2, TP=2, 8 GPUs)
Section titled “TP + PP (Qwen3-8B, PP=2, TP=2, 8 GPUs)”model: model_path: Qwen/Qwen3-8B merge_qkv: false
train: data_parallel_mode: fsdp2 tensor_parallel_size: 2 pipeline_parallel_size: 2 pipeline_parallel_schedule: 1F1B data_parallel_shard_size: 2 # 2 TP × 2 PP × 2 FSDP = 8 GPUs data_parallel_replicate_size: 1 gradient_accumulation_steps: 2 # >= pipeline_parallel_size init_device: metaWhen PP is combined with TP, build_parallelize_model applies TP per model-part (per PP stage) before wrapping each part with FSDP2.
10. When to Use TP vs FSDP2 vs PP
Section titled “10. When to Use TP vs FSDP2 vs PP”FSDP2 (default recommendation)
Section titled “FSDP2 (default recommendation)”Use FSDP2 alone when:
- The model’s full parameter count fits on the available GPUs after FSDP sharding (
param_bytes / dp_shard_sizefits in GPU memory). - You want maximum throughput with minimum communication overhead.
- You are doing LoRA or QLoRA fine-tuning (TP is incompatible with LoRA).
FSDP2 has lower overhead than TP: it only communicates during the all-gather before a layer and reduce-scatter after, and these can be overlapped with computation via prefetching.
Use TP when:
- FSDP2 memory reduction is insufficient — you cannot fit even a single FSDP-sharded parameter slice in GPU memory.
- You are training very large dense models (30B+) and want to reduce per-GPU activation memory by distributing matrix multiplications.
- You are on a node with fast NVLink interconnect (TP all-reduces are latency-sensitive; they are not suitable for slow cross-node Ethernet).
- You want to reduce the memory footprint of activations during the compute phase (since each GPU only computes a partial result).
TP is best used within a single node (NVLink bandwidth). Cross-node TP is generally not recommended due to PCIe/Infiniband bandwidth limitations relative to the volume of data in the all-reduce.
Use PP when:
- The model is too large to fit even with FSDP2 + TP on a single node.
- You are distributing across nodes and want to minimize cross-node communication (PP only sends activations between adjacent stages, which is a small
[B, S, H]tensor rather than full parameter shards). - You have a large number of gradient accumulation steps (PP requires
gradient_accumulation_steps >= pipeline_parallel_size; more steps means better pipeline utilization).
Combining TP + FSDP2 vs. pure FSDP2
Section titled “Combining TP + FSDP2 vs. pure FSDP2”| Configuration | Per-GPU memory | Communication volume | Flexibility |
|---|---|---|---|
| FSDP2 only, shard=8 | params / 8 | All-gather + reduce-scatter over 8 GPUs | LoRA, QLoRA supported |
| TP=4, FSDP2 shard=2 | params / 8 | TP all-reduce (×2/layer) + FSDP over 2 GPUs | LoRA not supported |
| TP=8, FSDP2 shard=1 | params / 8 | TP all-reduce (×2/layer, over 8 GPUs) | Higher TP comm. cost |
For the same memory budget, FSDP2 with larger shard groups is generally more efficient than TP with smaller shard groups because FSDP communication can be overlapped with computation (prefetch), while TP all-reduces block the forward pass. Use TP when you need to reduce per-layer activation memory or when the model is larger than can be handled by FSDP2 sharding alone.
Summary table
Section titled “Summary table”| Parallelism | Memory savings | Activation cost | Communication pattern | Composability |
|---|---|---|---|---|
| FSDP2 | Params + grads + optimizer states | Full activations per GPU | All-gather (fwd), reduce-scatter (bwd) — overlappable | Composes with all |
| TP | Partial activations per GPU (matrix partial results) | Reduced per-layer compute | All-reduce per layer (blocking) | No LoRA; within-node recommended |
| PP | Activations only on active stage | Low (only active stage) | Point-to-point between stages | Requires grad_accum ≥ pp_size |
11. Parameter Reference
Section titled “11. Parameter Reference”Training arguments (src/xorl/arguments.py)
Section titled “Training arguments (src/xorl/arguments.py)”| Parameter | Type | Default | Description |
|---|---|---|---|
tensor_parallel_size | int | 1 | Number of TP ranks. Shards model weights across this many GPUs within a node. Must divide num_attention_heads, num_key_value_heads, hidden_size, and intermediate_size. |
data_parallel_shard_size | int | derived | FSDP shard group size. tensor_parallel_size × data_parallel_shard_size × data_parallel_replicate_size × pipeline_parallel_size × ringattn_parallel_size × ulysses_parallel_size == world_size. |
pipeline_parallel_size | int | 1 | PP stage count. Compose with TP by assigning tensor_parallel_size × pipeline_parallel_size GPUs to model parallelism. |
data_parallel_mode | str | "fsdp2" | Must be "fsdp2" when using TP with meta-device initialization. |
init_device | str | "meta" | Use "meta" with TP to avoid materializing full weights before TP sharding. |
load_weights_mode | str | "broadcast" | With TP, rank 0 loads from disk and broadcasts to TP peers. Use "all_ranks" if filesystem supports parallel reads. |
enable_compile | bool | false | torch.compile per decoder block. Recommended with TP for fused kernels; must be applied before FSDP wrapping (handled automatically). |
Model arguments (src/xorl/arguments.py, ModelArguments dataclass)
Section titled “Model arguments (src/xorl/arguments.py, ModelArguments dataclass)”| Parameter | Type | Default | Description |
|---|---|---|---|
merge_qkv | bool | true | Must be set to false when TP is enabled. Controls whether q_proj, k_proj, v_proj are fused into a single qkv_proj. TP requires unfused projections. Also controls whether gate_proj and up_proj are fused as gate_up_proj. |
ParallelState properties (src/xorl/distributed/parallel_state.py)
Section titled “ParallelState properties (src/xorl/distributed/parallel_state.py)”| Property | Returns | Description |
|---|---|---|
tp_enabled | bool | True if tp_size > 1 |
tp_size | int | Number of TP ranks |
tp_rank | int | This rank’s index within its TP group |
tp_mesh | DeviceMesh | 1-D sub-mesh over the "tp" dimension |
tp_group | ProcessGroup | TP process group for explicit collectives |
fsdp_mesh | DeviceMesh | FSDP sub-mesh (excludes TP dimension) |
12. Implementation Notes
Section titled “12. Implementation Notes”TP plan construction
Section titled “TP plan construction”The TP plan is a flat dict mapping fully-qualified module names (with wildcard * for layer indices) to ColwiseParallel or RowwiseParallel instances. It is assembled in _build_tp_plan from two sources:
model.config.base_model_tp_plan— the base model’s plan (e.g.,Qwen3Model). Afterunfuse_for_tp, this is overridden to point to the unfused layer names. The prefix for the base model attribute (model.,transformer., etc.) is prepended automatically.model._tp_plan— the top-level causal LM wrapper plan (e.g.,lm_head).
String-to-style resolution
Section titled “String-to-style resolution”def _resolve_tp_style(style_str): if style_str == "colwise_rep": return ColwiseParallel(output_layouts=Replicate()) elif style_str == "embedding": return RowwiseParallel(input_layouts=Replicate(), output_layouts=Replicate()) elif style_str == "colwise": return ColwiseParallel() elif style_str == "rowwise": return RowwiseParallel()Weight initialization order
Section titled “Weight initialization order”1. Build model on meta device (no weights allocated)2. Call unfuse_for_tp() — splits fused projections (still meta)3. Call parallelize_module() — annotates params as DTensors with TP placement (still meta)4. Load weights (rank0_load_and_broadcast_weights) — materializes meta DTensors into TP-sharded real tensors5. Call parallelize_model_fsdp2() — wraps TP DTensors with FSDP2Steps 4 and 5 happen inside build_parallelize_model with skip_weight_loading=True passed to the FSDP path so that weights are not loaded a second time.
MoE and TP
Section titled “MoE and TP”For Qwen3-MoE models, the TP plan covers attention and dense MLP layers only. Expert weights (model.layers.*.mlp.experts.*) are excluded from the TP plan and instead use Expert Parallelism (EP) with FSDP2. This avoids TP all-reduces for expert computation, which would defeat the purpose of EP’s local expert dispatch.
Source
Section titled “Source”| File | Description |
|---|---|
src/xorl/distributed/torch_parallelize.py | build_parallelize_model, TP + FSDP2 ordering, _build_tp_plan |
src/xorl/distributed/parallel_state.py | tp_mesh, tp_group, device mesh construction |
src/xorl/models/transformers/qwen3/parallelize.py | Qwen3 TP plan (ColwiseParallel, RowwiseParallel assignments) |