Skip to content

Tensor Parallelism

Tensor parallelism (TP) shards individual weight matrices across multiple GPUs so that each GPU holds a vertical or horizontal slice of every linear layer. Unlike FSDP2, which shards parameters at rest and gathers them during forward, TP keeps weights sharded during the entire forward and backward pass. Every GPU in a TP group processes the same token sequence and contributes a partial result; a collective reduces partials to produce the full output.

xorl implements TP using PyTorch’s torch.distributed.tensor.parallel API (parallelize_module, ColwiseParallel, RowwiseParallel) on a 1-D DeviceMesh slice named "tp".

Input x[B,S,H] repl.ColwiseParallelW split on out_features (dim 0)Rank 0W[0:H/2, :]→ y[0:H/2]Rank 1W[H/2:H, :]→ y[H/2:H]output sharded on out-dimno all-reduce neededshardedRowwiseParallelW split on in_features (dim 1)Rank 0W[:, 0:H/2]→ partial yRank 1W[:, H/2:H]→ partial yall-reduce partials→ replicated full outputOutput y[B,S,H] repl.2 all-reduce collectives per transformer layer (after o_proj and down_proj) Tensor Parallelism: Weight ShardingInput[B,S,H] repl.Rank 0W[:,0:H/4]Rank 1W[:,H/4:H/2]Rank 2W[:,H/2:3H/4]Rank 3W[:,3H/4:H]all-gatheroutput [B,S,H]All TP ranks receive identical input; each computes over its weight column shard.

TP maps each linear layer to one of two sharding styles.

The weight matrix W of shape [out_features, in_features] is split along out_features (dimension 0). Each TP rank holds W[rank * out/tp : (rank+1) * out/tp, :]. The input activations are replicated (broadcast) to all TP ranks; each rank computes a partial output y_partial = x @ W_local^T. The outputs are naturally sharded across the output dimension, so no collective is needed at the end of this layer if the next layer expects sharded input.

Layers using ColwiseParallel in Qwen3:

  • embed_tokens (embedding variant — shards vocabulary dimension)
  • self_attn.q_proj, self_attn.k_proj, self_attn.v_proj
  • mlp.gate_proj, mlp.up_proj
  • lm_head

The weight matrix is split along in_features (dimension 1). Each TP rank holds W[:, rank * in/tp : (rank+1) * in/tp]. The input activations are expected to be sharded (which is exactly what ColwiseParallel produces), so no gather is needed on the way in. Each rank computes a partial output and an all-reduce aggregates partials across the TP group into the full output tensor, which is then replicated across TP ranks.

Layers using RowwiseParallel in Qwen3:

  • self_attn.o_proj
  • mlp.down_proj

Data flow through a transformer layer with TP=4

Section titled “Data flow through a transformer layer with TP=4”
Input x [B, S, H] (replicated across 4 TP ranks)
|
v
q_proj (colwise) → q_local [B, S, H/4]
k_proj (colwise) → k_local [B, S, H_kv/4] (per rank)
v_proj (colwise) → v_local [B, S, H_kv/4]
|
v
Multi-head attention (each rank handles H/4 heads, no collective)
|
v
o_proj (rowwise) → partial_y → all-reduce → y [B, S, H] (replicated)
|
v
gate_proj (colwise) → g_local [B, S, I/4]
up_proj (colwise) → u_local [B, S, I/4]
|
v
SiLU + multiply (local)
|
v
down_proj (rowwise) → partial_y → all-reduce → y [B, S, H] (replicated)

Two all-reduce collectives occur per transformer layer: one after o_proj and one after down_proj.


2. Device Mesh — How the TP Mesh Is Constructed

Section titled “2. Device Mesh — How the TP Mesh Is Constructed”

The global device mesh is built in src/xorl/distributed/parallel_state.py by init_parallel_state. All parallelism dimensions are composed into a single N-dimensional DeviceMesh:

# Dimension order (only included if size > 1 or name == "dp_shard"):
# [pp, dp_replicate, dp_shard, ringattn, ulysses, tp]
device_mesh = init_device_mesh(
device_type=device_type,
mesh_shape=tuple(mesh_shape),
mesh_dim_names=tuple(mesh_dim_names),
)

The "tp" dimension is always the innermost dimension of the mesh. On a typical 8-GPU single node with TP=4 and DP=2:

Ranks: [0, 1, 2, 3, 4, 5, 6, 7]
Mesh shape: [dp_shard=2, tp=4]
dp_shard dim → rows: [0,1,2,3] and [4,5,6,7]
tp dim → cols: 0&4, 1&5, 2&6, 3&7
TP group for ranks 0–3: {0, 1, 2, 3}
TP group for ranks 4–7: {4, 5, 6, 7}
FSDP group (dp_shard): {0, 4}, {1, 5}, {2, 6}, {3, 7}

The TP sub-mesh is accessed as parallel_state.tp_mesh (device_mesh["tp"]) and the TP process group as parallel_state.tp_group (device_mesh.get_group("tp")).

The FSDP mesh (parallel_state.fsdp_mesh) is always over the DP dimensions, excluding the TP dimension. This means TP ranks are in the same FSDP shard group — FSDP shards the TP-local weight slices, not the full weights. See Section 8 for implications.


3. QKV Unfusing — Why merge_qkv Must Be False

Section titled “3. QKV Unfusing — Why merge_qkv Must Be False”

By default, xorl models use a fused qkv_proj linear layer of shape [q_dim + 2 * kv_dim, hidden_size] for efficiency. TP cannot shard this fused layer correctly because Q, K, and V have different output dimensions in GQA (grouped-query attention) models — the number of heads differs between Q and KV, so a uniform column split would produce misaligned head boundaries.

Setting merge_qkv: false in the model config triggers unfuse_for_tp() before parallelize_module is called:

src/xorl/distributed/torch_parallelize.py
if parallel_state.tp_enabled:
if hasattr(model, "unfuse_for_tp"):
model.unfuse_for_tp() # splits qkv_proj → q_proj, k_proj, v_proj
# splits gate_up_proj → gate_proj, up_proj
tp_plan = _build_tp_plan(model)
model = parallelize_module(model, device_mesh=parallel_state.tp_mesh,
parallelize_plan=tp_plan)

At the attention layer level (src/xorl/models/layers/attention/multi_head_attention.py):

def unfuse_for_tp(self):
device = self.qkv_proj.weight.device
dtype = self.qkv_proj.weight.dtype
self.q_proj = nn.Linear(hidden_size, q_dim, bias=..., device=device, dtype=dtype)
self.k_proj = nn.Linear(hidden_size, kv_dim, bias=..., device=device, dtype=dtype)
self.v_proj = nn.Linear(hidden_size, kv_dim, bias=..., device=device, dtype=dtype)
del self.qkv_proj

At the MLP layer level (src/xorl/models/transformers/qwen3/modeling_qwen3.py):

def unfuse_for_tp(self):
self.gate_proj = nn.Linear(hidden_size, intermediate_size, ...)
self.up_proj = nn.Linear(hidden_size, intermediate_size, ...)
del self.gate_up_proj

What happens if merge_qkv: true with TP enabled:

The _build_tp_plan function reads config.base_model_tp_plan, which after unfuse_for_tp is overridden to the unfused plan (q_proj, k_proj, v_proj). If unfusing is skipped, the TP plan still references the unfused projection names (q_proj, k_proj, v_proj), but those modules do not exist on the model — parallelize_module will silently skip them (or raise a KeyError depending on PyTorch version). The model will run without TP applied to the attention projections, producing incorrect training results with no visible error.

Always set merge_qkv: false when tensor_parallel_size > 1.


4. Column-wise vs Row-wise Parallelism — Complete Layer Assignment

Section titled “4. Column-wise vs Row-wise Parallelism — Complete Layer Assignment”

The TP plan for Qwen3 dense models (src/xorl/models/transformers/qwen3/parallelize.py):

TP_PLAN = {
"embed_tokens": "embedding", # RowwiseParallel(input=Replicate, output=Replicate)
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
MODEL_TP_PLAN = {
"lm_head": "colwise", # dense: ColwiseParallel (sharded output)
}

For Qwen3 MoE models (src/xorl/models/transformers/qwen3_moe/parallelize.py), the attention and dense MLP layers use the same plan. Expert weights inside MoEBlock are not TP-sharded — they use Expert Parallelism (EP) instead. The lm_head in MoE uses "colwise_rep" (ColwiseParallel with output_layouts=Replicate()) to produce a replicated output suitable for softmax over the full vocabulary.

The embedding table [vocab_size, hidden_size] is sharded on the vocabulary dimension (dim 0), implemented as a RowwiseParallel with input_layouts=Replicate(), output_layouts=Replicate(). Each TP rank owns vocab_size / tp_size token embeddings. During a lookup, each rank returns its partial embedding (zero for tokens outside its vocab slice) and an all-reduce sums partials to produce the full embedding vector.

LayerStyleWeight split dimensionInputOutput
embed_tokensembeddingvocab (dim 0)replicated token idsreplicated hidden states
q_proj, k_proj, v_projcolwiseout_features (dim 0)replicatedsharded (heads split)
o_projrowwisein_features (dim 1)shardedreplicated (all-reduce)
gate_proj, up_projcolwiseout_features (dim 0)replicatedsharded (intermediate split)
down_projrowwisein_features (dim 1)shardedreplicated (all-reduce)
lm_headcolwiseout_features (vocab, dim 0)replicatedsharded or replicated

5. Sequence Dimension and Context Parallelism

Section titled “5. Sequence Dimension and Context Parallelism”

In the current xorl TP implementation, the sequence dimension is not sharded across TP ranks. All TP ranks process the full sequence length S. This means:

  • The activation tensor [B, S, H] is replicated at the start of each layer.
  • The all-reduce at the end of o_proj and down_proj reduces across TP ranks.
  • Memory for activations scales with S, not S / tp_size.

This is different from “Megatron-style context parallelism” (also called SP), where the non-attention layers use a reduce-scatter / all-gather pattern to keep activations sharded along the sequence dimension between the colwise and rowwise layers. xorl does not currently implement Megatron-style SP layered on top of TP. If you need to reduce activation memory along the sequence dimension, use Ring Attention (ringattn_parallel_size) or Ulysses (ulysses_parallel_size) instead — these are orthogonal to TP.

TP can be freely combined with Ring Attention and Ulysses because the TP mesh dimension is independent of the CP mesh dimensions in the global device mesh.


For each transformer layer with TP enabled, the following collectives occur:

Point in the layerCollectiveGroupCost
After o_projAll-reduceTP group2 * B * S * H * (tp-1) / tp bytes
After down_projAll-reduceTP group2 * B * S * H * (tp-1) / tp bytes

The two all-reduces are the dominant TP communication cost. For a 32-layer model with TP=4 and H=4096, B=1, S=8192, this is approximately 32 × 2 × 2 × 8192 × 4096 × (3/4) × 2 bytes ≈ 12 GB of data moved per forward pass through the TP interconnect.

During the embedding lookup (embed_tokens), a single all-reduce sums partial embeddings across TP ranks. This is a [B, S, H] tensor.

During the backward pass, PyTorch’s DTensor mechanism handles gradients automatically. For ColwiseParallel layers, the gradient with respect to the input is a reduce-scatter (each rank holds a shard of the input gradient). For RowwiseParallel layers, the gradient with respect to the input is an all-gather before being passed to the previous layer. The net effect mirrors the forward collectives.

All TP collectives are synchronous blocking all-reduces or all-gathers. There is no overlap of TP communication with computation in xorl (unlike some Megatron implementations that pipeline the all-reduce with the next layer). This means TP latency adds directly to the critical path.

When TP and FSDP2 are both enabled, weight loading must happen in a specific order: TP is applied first (weights become DTensors with TP placements), then FSDP2 wraps the TP-sharded parameters. See Section 8 for details.


num_attention_heads % tensor_parallel_size == 0
num_key_value_heads % tensor_parallel_size == 0

Both constraints must hold. For GQA models like Qwen3 with num_key_value_heads < num_attention_heads, the KV head count is the binding constraint.

Qwen3-8B (num_attention_heads=32, num_key_value_heads=8):

  • TP=2: 32 % 2 == 0, 8 % 2 == 0. Valid.
  • TP=4: 32 % 4 == 0, 8 % 4 == 0. Valid.
  • TP=8: 32 % 8 == 0, 8 % 8 == 0. Valid.

Qwen3-30B-A3B (MoE, num_attention_heads=16, num_key_value_heads=8):

  • TP=2: 16 % 2 == 0, 8 % 2 == 0. Valid.
  • TP=4: 16 % 4 == 0, 8 % 4 == 0. Valid.
  • TP=8: 16 % 8 == 0, 8 % 8 == 0. Valid.

hidden_size % tensor_parallel_size == 0 must hold. All Qwen3 models have hidden_size divisible by 8, so this is never a binding constraint for practical TP sizes.

intermediate_size % tensor_parallel_size == 0 must hold for MLP sharding.

world_size = PP × DP_shard × DP_replicate × TP × CP_ring × CP_ulysses

This is validated in ParallelState.__post_init__ and also in TrainArguments.__post_init__:

non_dp_size = ulysses_parallel_size * tensor_parallel_size
* ringattn_parallel_size * pipeline_parallel_size
self.data_parallel_size = world_size // non_dp_size

TP + LoRA is not currently supported. build_parallelize_model explicitly raises NotImplementedError if any LoraLinear modules are found when TP is enabled:

if parallel_state.tp_enabled:
if any(isinstance(m, LoraLinear) for m in model.modules()):
raise NotImplementedError(
"Tensor parallelism + LoRA is not currently supported."
)

Use FSDP2 alone for LoRA fine-tuning.

TP requires PyTorch >= 2.4. The ColwiseParallel, RowwiseParallel, and parallelize_module imports are guarded by is_torch_version_greater_than("2.4").


TP and FSDP2 compose so that TP shards within a node and FSDP2 shards the TP-local slices across the DP group. The key property is that the "tp" dimension is innermost in the global device mesh, meaning consecutive ranks on the same node form a TP group, while the FSDP2 mesh spans ranks that have the same TP rank across nodes.

parallelize_module must be called before fully_shard. After parallelize_module, each parameter is a DTensor with a Shard(0) or Shard(1) placement on the TP mesh, representing the local TP slice. When fully_shard then wraps the module, it shards this already-reduced tensor further across the FSDP group.

When init_device: meta is used, weights must be loaded after TP is applied but before FSDP wrapping. This is because after parallelize_module, parameters are meta DTensors with TP placements. FSDP’s lazy init cannot handle meta DTensors correctly (the logical DTensor shape does not match the local shard shape). xorl handles this automatically:

src/xorl/distributed/torch_parallelize.py
if parallel_state.tp_enabled:
model = parallelize_module(model, ...) # apply TP first
if kwargs.get("init_device") == "meta" and weights_path is not None:
rank0_load_and_broadcast_weights(...) # load weights into TP-sharded params
kwargs["skip_weight_loading"] = True # skip again in FSDP path
# then FSDP2 wraps the already-materialized TP DTensors
model = parallelize_model_fsdp2(model, ...)

For a 32-GPU job with TP=4, DP_shard=8:

  • Each GPU has a TP-local slice of every parameter (1/4 of output features for colwise layers).
  • FSDP2 further shards this TP-local slice across 8 GPUs in the FSDP group.
  • GPU effective parameter storage = full_param_size / (TP × FSDP_shard_size) = full_param_size / 32.

The FSDP2 all-gather during the forward pass gathers the TP-local slice (not the full parameter). The TP communication then distributes the work across TP ranks.

FSDP2’s reduce-scatter during the backward pass averages TP-local gradients across the FSDP group. The TP all-reduce (inside PyTorch’s DTensor autograd) runs separately during the backward pass over the TP group. The two collectives operate on different process groups and do not interfere.

The same MixedPrecisionPolicy(param_dtype=bfloat16, reduce_dtype=float32) applies to TP-sharded parameters. TP all-reduces happen in the dtype of the activations (bfloat16 during forward, float32 during gradient reduce via FSDP2’s reduce_dtype).


All examples assume 8 GPUs on a single node.

model:
model_path: Qwen/Qwen3-8B
merge_qkv: false # required for TP
train:
data_parallel_mode: fsdp2
tensor_parallel_size: 4
data_parallel_shard_size: 2 # 4 TP × 2 FSDP = 8 GPUs
data_parallel_replicate_size: 1
ulysses_parallel_size: 1
ringattn_parallel_size: 1
pipeline_parallel_size: 1
expert_parallel_size: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_compile: true # recommended with TP for fused kernels
init_device: meta
load_weights_mode: broadcast

GPU layout (TP groups across columns, FSDP shard groups across rows):

tp=0 tp=1 tp=2 tp=3
fsdp=0: [ 0, 1, 2, 3 ]
fsdp=1: [ 4, 5, 6, 7 ]
model:
model_path: Qwen/Qwen3-8B
merge_qkv: false
train:
data_parallel_mode: fsdp2
tensor_parallel_size: 2
data_parallel_shard_size: 4 # 2 TP × 4 FSDP = 8 GPUs
data_parallel_replicate_size: 1
enable_compile: true
init_device: meta
model:
model_path: Qwen/Qwen3-32B
merge_qkv: false
train:
data_parallel_mode: fsdp2
tensor_parallel_size: 4
data_parallel_shard_size: 4 # 4 TP × 4 FSDP = 16 GPUs
data_parallel_replicate_size: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
init_device: meta
load_weights_mode: broadcast

TP=4 + compile (from qwen3_8b_tp4_compile.yaml)

Section titled “TP=4 + compile (from qwen3_8b_tp4_compile.yaml)”

The example config at examples/local/dummy/configs/full/qwen3_8b_tp4_compile.yaml combines TP=4 with torch.compile:

model:
model_path: Qwen/Qwen3-8B
attn_implementation: flash_attention_3
train:
data_parallel_mode: fsdp2
tensor_parallel_size: 4
data_parallel_shard_size: 2
data_parallel_replicate_size: 1
ulysses_parallel_size: 1
ringattn_parallel_size: 1
expert_parallel_size: 1
enable_mixed_precision: true
enable_gradient_checkpointing: true
enable_compile: true
init_device: meta
load_weights_mode: broadcast

torch.compile is applied to each decoder block before FSDP wrapping (build_parallelize_model handles the ordering). Compiled kernels can fuse the TP linear operations with activation functions for better throughput.

model:
model_path: Qwen/Qwen3-8B
merge_qkv: false
train:
data_parallel_mode: fsdp2
tensor_parallel_size: 2
pipeline_parallel_size: 2
pipeline_parallel_schedule: 1F1B
data_parallel_shard_size: 2 # 2 TP × 2 PP × 2 FSDP = 8 GPUs
data_parallel_replicate_size: 1
gradient_accumulation_steps: 2 # >= pipeline_parallel_size
init_device: meta

When PP is combined with TP, build_parallelize_model applies TP per model-part (per PP stage) before wrapping each part with FSDP2.


Use FSDP2 alone when:

  • The model’s full parameter count fits on the available GPUs after FSDP sharding (param_bytes / dp_shard_size fits in GPU memory).
  • You want maximum throughput with minimum communication overhead.
  • You are doing LoRA or QLoRA fine-tuning (TP is incompatible with LoRA).

FSDP2 has lower overhead than TP: it only communicates during the all-gather before a layer and reduce-scatter after, and these can be overlapped with computation via prefetching.

Use TP when:

  • FSDP2 memory reduction is insufficient — you cannot fit even a single FSDP-sharded parameter slice in GPU memory.
  • You are training very large dense models (30B+) and want to reduce per-GPU activation memory by distributing matrix multiplications.
  • You are on a node with fast NVLink interconnect (TP all-reduces are latency-sensitive; they are not suitable for slow cross-node Ethernet).
  • You want to reduce the memory footprint of activations during the compute phase (since each GPU only computes a partial result).

TP is best used within a single node (NVLink bandwidth). Cross-node TP is generally not recommended due to PCIe/Infiniband bandwidth limitations relative to the volume of data in the all-reduce.

Use PP when:

  • The model is too large to fit even with FSDP2 + TP on a single node.
  • You are distributing across nodes and want to minimize cross-node communication (PP only sends activations between adjacent stages, which is a small [B, S, H] tensor rather than full parameter shards).
  • You have a large number of gradient accumulation steps (PP requires gradient_accumulation_steps >= pipeline_parallel_size; more steps means better pipeline utilization).
ConfigurationPer-GPU memoryCommunication volumeFlexibility
FSDP2 only, shard=8params / 8All-gather + reduce-scatter over 8 GPUsLoRA, QLoRA supported
TP=4, FSDP2 shard=2params / 8TP all-reduce (×2/layer) + FSDP over 2 GPUsLoRA not supported
TP=8, FSDP2 shard=1params / 8TP all-reduce (×2/layer, over 8 GPUs)Higher TP comm. cost

For the same memory budget, FSDP2 with larger shard groups is generally more efficient than TP with smaller shard groups because FSDP communication can be overlapped with computation (prefetch), while TP all-reduces block the forward pass. Use TP when you need to reduce per-layer activation memory or when the model is larger than can be handled by FSDP2 sharding alone.

ParallelismMemory savingsActivation costCommunication patternComposability
FSDP2Params + grads + optimizer statesFull activations per GPUAll-gather (fwd), reduce-scatter (bwd) — overlappableComposes with all
TPPartial activations per GPU (matrix partial results)Reduced per-layer computeAll-reduce per layer (blocking)No LoRA; within-node recommended
PPActivations only on active stageLow (only active stage)Point-to-point between stagesRequires grad_accum ≥ pp_size

Training arguments (src/xorl/arguments.py)

Section titled “Training arguments (src/xorl/arguments.py)”
ParameterTypeDefaultDescription
tensor_parallel_sizeint1Number of TP ranks. Shards model weights across this many GPUs within a node. Must divide num_attention_heads, num_key_value_heads, hidden_size, and intermediate_size.
data_parallel_shard_sizeintderivedFSDP shard group size. tensor_parallel_size × data_parallel_shard_size × data_parallel_replicate_size × pipeline_parallel_size × ringattn_parallel_size × ulysses_parallel_size == world_size.
pipeline_parallel_sizeint1PP stage count. Compose with TP by assigning tensor_parallel_size × pipeline_parallel_size GPUs to model parallelism.
data_parallel_modestr"fsdp2"Must be "fsdp2" when using TP with meta-device initialization.
init_devicestr"meta"Use "meta" with TP to avoid materializing full weights before TP sharding.
load_weights_modestr"broadcast"With TP, rank 0 loads from disk and broadcasts to TP peers. Use "all_ranks" if filesystem supports parallel reads.
enable_compileboolfalsetorch.compile per decoder block. Recommended with TP for fused kernels; must be applied before FSDP wrapping (handled automatically).

Model arguments (src/xorl/arguments.py, ModelArguments dataclass)

Section titled “Model arguments (src/xorl/arguments.py, ModelArguments dataclass)”
ParameterTypeDefaultDescription
merge_qkvbooltrueMust be set to false when TP is enabled. Controls whether q_proj, k_proj, v_proj are fused into a single qkv_proj. TP requires unfused projections. Also controls whether gate_proj and up_proj are fused as gate_up_proj.

ParallelState properties (src/xorl/distributed/parallel_state.py)

Section titled “ParallelState properties (src/xorl/distributed/parallel_state.py)”
PropertyReturnsDescription
tp_enabledboolTrue if tp_size > 1
tp_sizeintNumber of TP ranks
tp_rankintThis rank’s index within its TP group
tp_meshDeviceMesh1-D sub-mesh over the "tp" dimension
tp_groupProcessGroupTP process group for explicit collectives
fsdp_meshDeviceMeshFSDP sub-mesh (excludes TP dimension)

The TP plan is a flat dict mapping fully-qualified module names (with wildcard * for layer indices) to ColwiseParallel or RowwiseParallel instances. It is assembled in _build_tp_plan from two sources:

  1. model.config.base_model_tp_plan — the base model’s plan (e.g., Qwen3Model). After unfuse_for_tp, this is overridden to point to the unfused layer names. The prefix for the base model attribute (model., transformer., etc.) is prepended automatically.
  2. model._tp_plan — the top-level causal LM wrapper plan (e.g., lm_head).
def _resolve_tp_style(style_str):
if style_str == "colwise_rep": return ColwiseParallel(output_layouts=Replicate())
elif style_str == "embedding": return RowwiseParallel(input_layouts=Replicate(),
output_layouts=Replicate())
elif style_str == "colwise": return ColwiseParallel()
elif style_str == "rowwise": return RowwiseParallel()
1. Build model on meta device (no weights allocated)
2. Call unfuse_for_tp() — splits fused projections (still meta)
3. Call parallelize_module() — annotates params as DTensors with TP placement (still meta)
4. Load weights (rank0_load_and_broadcast_weights) — materializes meta DTensors into TP-sharded real tensors
5. Call parallelize_model_fsdp2() — wraps TP DTensors with FSDP2

Steps 4 and 5 happen inside build_parallelize_model with skip_weight_loading=True passed to the FSDP path so that weights are not loaded a second time.

For Qwen3-MoE models, the TP plan covers attention and dense MLP layers only. Expert weights (model.layers.*.mlp.experts.*) are excluded from the TP plan and instead use Expert Parallelism (EP) with FSDP2. This avoids TP all-reduces for expert computation, which would defeat the purpose of EP’s local expert dispatch.

FileDescription
src/xorl/distributed/torch_parallelize.pybuild_parallelize_model, TP + FSDP2 ordering, _build_tp_plan
src/xorl/distributed/parallel_state.pytp_mesh, tp_group, device mesh construction
src/xorl/models/transformers/qwen3/parallelize.pyQwen3 TP plan (ColwiseParallel, RowwiseParallel assignments)