Skip to content

Conversation

@jreiml
Copy link
Contributor

@jreiml jreiml commented Jan 10, 2026

What does this PR do?

Adds automatic tensor parallel (TP) resharding support when training and inference use different TP configurations. This fixes weight loading failures in fully async workflows where Megatron exports full HuggingFace weights but vLLM expects TP-sharded weights.

Related to #400, #1063, #4497

Problem

When using fully async training with:

  • Megatron training: tensor_model_parallel_size=8, use_mbridge=True
  • vLLM inference: tensor_model_parallel_size=32

Weight synchronization fails because:

  1. mbridge exports full (unsharded) HuggingFace-format weights
  2. vLLM parameters expect weights pre-sharded for TP=32
  3. Parameters without a weight_loader attribute cause assertion errors:
    AssertionError: param.shape [hidden_size/32, ...] != loaded_weight.shape [hidden_size, ...]
    

This affects large MoE models (DeepSeek-V3/R1) where training requires pipeline parallelism but inference benefits from higher TP across nodes.

Solution

  1. Add VERL_ENABLE_TP_RESHARD environment variable (opt-in to avoid breaking existing setups)
  2. Patch parameters without weight_loader to use a TP-aware loader that:
    • Detects which dimension differs by a factor of tp_size
    • Automatically shards along that dimension based on tp_rank
  3. Refactor patch_vllm_moe_model_weight_loader to handle both MoE expert patching and general TP resharding

Test

Validated on DeepSeek-V3-Base with:

  • 16 nodes Megatron training (TP=8, PP=16, EP=8)
  • 16 nodes vLLM inference (TP=32, EP=32)
  • Fully async DAPO training with use_mbridge=True

Before fix: Weight loading fails with shape mismatch assertions
After fix: Weights load correctly, training proceeds normally

Config used

parallelism:
  tp: 8
  pp: 16
  ep: 8
  etp: 1
  infer_tp: 32
  infer_ep: 32
  infer_dp: 1

Launch command

VERL_ENABLE_TP_RESHARD=1 \
VLLM_USE_DEEP_GEMM=1 \
VLLM_ALL2ALL_BACKEND=deepep_high_throughput \
python3 -m ta_verl.recipe.fully_async_dapo.fully_async_main \
  --config-name="deepseek_v3_20251219.yaml" \
  trainer.nnodes="16" \
  rollout.nnodes="16"

API and Usage Example

# Enable TP resharding for async rollout
VERL_ENABLE_TP_RESHARD=1 python -m verl.trainer.main ...

No config changes required. The fix automatically detects and handles TP mismatches.

Design & Code Changes

File: verl/utils/vllm/patch.py

  1. _get_tp_rank_and_size(): Helper to get vLLM's tensor parallel configuration
  2. _create_tp_aware_weight_loader(): Creates weight loaders that:
    • Pass through if shapes match
    • Auto-shard if loaded weight is tp_size times larger in one dimension
  3. patch_vllm_moe_model_weight_loader(): Extended to patch all parameters without weight_loader when VERL_ENABLE_TP_RESHARD=1

Checklist Before Submitting

  • Read the Contribute Guide
  • Apply pre-commit checks
  • Add/Update documentation
  • Add unit tests (manual validation on DeepSeek-V3 scale)

Related Issues

Issue Link Relevance
#4497 #4497 MoE weight format mismatch (same root cause)
#708 #708 DeepSeek R1 infrastructure project
#400 #400 TP rollout + FSDP / TP actor feature request
#1063 #1063 RFC auto resharding design

@jreiml jreiml marked this pull request as draft January 10, 2026 18:20
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature for automatic tensor-parallel resharding of weights, which is crucial for flexible deployment scenarios where training and inference parallelism configurations differ. The implementation is clean, using an environment variable for opt-in, and refactors the existing patching logic for better clarity. The core logic for detecting the sharding dimension and applying it seems correct. I have one suggestion to improve the robustness of the error handling within the new TP-aware weight loader.

Comment on lines +216 to +224
if shard_dim is None:
# Can't determine sharding, fall back to assertion (will fail with clear error)
assert param.shape == loaded_weight.shape, (
f"Cannot determine sharding strategy for {param_name}. "
f"Loaded weight shape {loaded_weight.shape} does not match "
f"parameter shape {param.shape} and is not a simple TP multiple."
)
param.data.copy_(loaded_weight)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for error handling in this case is not fully robust, as assertions can be disabled with the -O Python flag. If disabled, the param.data.copy_(loaded_weight) line would execute and likely raise a less informative RuntimeError due to a shape mismatch, as this code path is only taken when shapes are already known to be different.

To ensure this critical error is always caught and clearly reported, it's better to explicitly raise a ValueError. This makes the code more robust and its intent clearer. The lines following the assert are also unreachable if assertions are enabled, so replacing the block simplifies the code.

Suggested change
if shard_dim is None:
# Can't determine sharding, fall back to assertion (will fail with clear error)
assert param.shape == loaded_weight.shape, (
f"Cannot determine sharding strategy for {param_name}. "
f"Loaded weight shape {loaded_weight.shape} does not match "
f"parameter shape {param.shape} and is not a simple TP multiple."
)
param.data.copy_(loaded_weight)
return
if shard_dim is None:
# Can't determine sharding, so raise an error with a clear message.
raise ValueError(
f"Cannot determine sharding strategy for {param_name}. "
f"Loaded weight shape {loaded_weight.shape} does not match "
f"parameter shape {param.shape} and is not a simple TP multiple."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant