Skip to content

[RL] Prepare vllm definition for support_torch_compile compatibility#2393

Open
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/compile_logic_changes
Open

[RL] Prepare vllm definition for support_torch_compile compatibility#2393
Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Lucaskabela:lucaskabela/compile_logic_changes

Conversation

@Lucaskabela
Copy link
Contributor

@Lucaskabela Lucaskabela commented Feb 18, 2026

Summary

This PR makes the necessary model changes to support support_torch_compile for the vLLM model definition. These are:

For Graph Capture

  • Handle DTensor-to-local conversion in VLLMAttention with narrow() - this is a workaround for prim_to_local symbolic shape propogation causing issues under compile
  • Remove PrepareModuleInputOutput from TP plan as VLLMAttention now handles DTensor conversion internally ( RowwiseParallel uses from_local(Shard(-1))) (related to the DTensor change above)
  • Simplify forward() to be compile-friendly: remove conditionals, use _tp_enabled bool flag, and reshape with -1
    to avoid datadendent errors

For CudaGraph Capture

  • Add weak_ref_tensor DTensor patches (FakeTensor kernel, sharding strategy, Python-level guard) for piecewise CUDA-graph capture in vLLM
  • Pre-extend RoPE cache in init instead of dynamic extension (to keep memory static for cudagraph)

See Lucaskabela#4 for next step enablement (will resubmit against main after this Pr lands)

Test Plan

VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py

Baseline (losses/rewards)

Step Loss
0 -0.0196
1 7.4252
2 3.4577
3 1.2154
4 0.4386
5 0.2471
6 0.2507
7 0.2722
8 0.2140
9 0.1330

This PR (losses/rewards)

Step Loss
0 -0.0196
1 7.4252
2 3.4577
3 1.2154
4 0.4386
5 0.2471
6 0.2507
7 0.2722
8 0.2140
9 0.1330

These match exactly showing we preserve training stability

Metric comparisson

On top of #2398 we get the following metrics:

Metric Baseline % After Changes %
Total wall-clock 135.15s 133.62s
Cumul. rollout 9.45s 7.0% 8.27s 6.2%
Cumul. train 56.23s 41.6% 49.94s 37.4%
Cumul. optimizer 0.31s 0.2% 0.30s 0.2%
Cumul. weight_sync 69.47s 51.4% 75.40s 56.4%
Peak mem (rollout) 4.09 GiB 4.09 GiB
Peak mem (train) 9.58 GiB 9.58 GiB
Peak mem (optimizer) 8.44 GiB 8.44 GiB

of note - there is no significant change in overall runtime, nor memory usage

Authored with claude

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 18, 2026
- Handle DTensor-to-local conversion in VLLMAttention with narrow()
  workaround for prim_to_local symbolic shape corruption under compile
- Remove PrepareModuleInputOutput from TP plan; VLLMAttention now handles
  DTensor conversion internally, RowwiseParallel uses from_local(Shard(-1))
- Add weak_ref_tensor DTensor compatibility patches (FakeTensor kernel,
  sharding strategy, Python-level guard) for piecewise CUDA-graph capture
- Pre-extend RoPE cache in __init__ instead of dynamic extension (avoids
  graph breaks from .item() and DTensor checks during forward)
- Simplify forward() to be compile-friendly: remove conditionals, use
  _tp_enabled bool flag, reshape with -1
@Lucaskabela Lucaskabela force-pushed the lucaskabela/compile_logic_changes branch from 88de8a8 to 6612d3e Compare February 19, 2026 00:00
@Lucaskabela Lucaskabela marked this pull request as ready for review February 20, 2026 20:59
# shape before calling to_local(), then uses narrow() to fix the
# corrupted symbolic dimension that prim_to_local produces under
# torch.compile. Its output is a plain tensor; RowwiseParallel on
# attention.wo wraps it back into a DTensor via from_local(Shard(-1)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the only change to make qwen3 model compile-able is change the attention part (let vllm.Attention to handle TP instead of apply TP ourselves)? What's the difference between using PrepareModuleInputOutput which inserts hooks behind the scene, and convert to DTensor manually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For compilation enablement, we also need to modify some of the forward code in vllm_wrapper to avoid data dependence where possible (this is mostly achieved through the unsqueeze changes)

In regards to the TP change, we change dtensor conversion to be handled internally to vllm attention to fix an incorrect shape propogation causing runtime crashes (potentially a bug in dtensor + compile; I am working on getting a minimal repro)

For the other changes, these are needed for cudagraph enablement, and these are:

  1. Monkeypatching weak_ref_tensors to extract the memory from the DTensor in vllm_wrapper (needed for cudagraphs)
  2. Statically allocating the rope_cache (since cudagraph requires memory be statically recorded)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the difference: See 175690

This is the repro, which I will have a potential PR fix for shortly. In the event we don't get that landed though, the edits in this PR will also fix the symbolic shape failure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch/pytorch#175692 for potential fix (need feedback on this ofc)

head_dim=model_args.head_dim,
layer_name=layer_name,
scale=model_args.head_dim**-0.5,
tp_enabled=(tp_degree > 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch is only applied in Generator model. Do you only apply compile in generator, not trainer model now?

Copy link
Contributor

@wwwjn wwwjn Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if we only apply compile in generator side, not trainer side, would this make the log probability distribution in both side further different, once we further apply inductor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to clarify - this PR is strictly aimed at compiling the generator/rollout/vllm definition. The trainer is orthogonal since it won't use the same vllm compilation mechanism.

Inductor will impact numerics regardless of if we apply on one or both models (due to capturing different graphs), so we don't turn it on for now (that is a later question around where/how to apply inductor - a prerequisite is graph capturability which this PR stack will handle)

From summary, see Lucaskabela#4 for the next step turning on vLLM compilation - this does not affect logits since we only use eager compile and cudagraphs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer is orthogonal since it won't use the same vllm compilation mechanism.

But since we use the same model definition, enabling compile should be the same on both side? Can you explain more what's the differences on how to enabling compile in vllm vs. non-vllm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not the same definition since trainer also needs to have graph capturable backwards functions - these are the orthogonal components in reference. When we enable compile for the trainer model most changes will be related to that.

Outside of the fact the graphs we are capturing are different, vLLM compile integration is quite different: the biggest difference is that it drops guards for inference speed at the cost of safety (causing issues such as the DTensor one we observe - if the guards were kept, when the symbolic dimension changes, it would trigger a recompilation)

See https://docs.vllm.ai/en/latest/design/torch_compile/ for more info on this.

# ---------------------------------------------------------------------------
# vLLM weak_ref_tensor + DTensor compatibility patches
#
# Piecewise CUDA-graph capture calls weak_ref_tensor() on every subgraph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is weak_ref_tensor()? Can you point me to some docs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - following the code pointer in this comment we lead to:

https://github.com/vllm-project/vllm/blob/a0c70816956298f7dd1d0cf47cfa1a169a413692/vllm/compilation/cuda_graph.py#L21

This weak_ref_tensor is a utility used to record the tensor memory which is required to be kept alive for cudagraphs.

https://github.com/vllm-project/vllm/blob/a0c70816956298f7dd1d0cf47cfa1a169a413692/vllm/utils/torch_utils.py#L649 is the definition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, seems this change is needed because the model is DTensor-ified, and the weak_ref_tensor is assuming plain tensor

Also curious, without weak_ref_tensor, does CUDAGraph in vllm still work? Or it's totally broken?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were seeing crashes, but when I patched the fix for DTensor I no longer need this code so can clean it up probably

@Lucaskabela Lucaskabela changed the title Prepare model logic for torch.compile compatibility [RL] Prepare vllm definition for support_torch_compile compatibility Feb 24, 2026
@Lucaskabela Lucaskabela requested a review from wwwjn February 24, 2026 16:08
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also from the PR description , seems enabling CUDAGraph + Compile front-end (Dynamo) didn't bring too much runtime / memory gain or burden. This might because of the benchmarking is not representative, eg the prompt is too short. Can you try how much speed up we can get from the a larger scale benchmark (eg, https://fburl.com/gdoc/mr91ct3v cudagraph path here)

I was thinking would user only compile the generator model, or the user would strictly prefer compile both trainer and generator model together. Theoritically this PR also helps enable trainer side compile as well, right?

head_dim=model_args.head_dim,
layer_name=layer_name,
scale=model_args.head_dim**-0.5,
tp_enabled=(tp_degree > 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer is orthogonal since it won't use the same vllm compilation mechanism.

But since we use the same model definition, enabling compile should be the same on both side? Can you explain more what's the differences on how to enabling compile in vllm vs. non-vllm?

# ---------------------------------------------------------------------------
# vLLM weak_ref_tensor + DTensor compatibility patches
#
# Piecewise CUDA-graph capture calls weak_ref_tensor() on every subgraph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, seems this change is needed because the model is DTensor-ified, and the weak_ref_tensor is assuming plain tensor

Also curious, without weak_ref_tensor, does CUDAGraph in vllm still work? Or it's totally broken?

"_C::weak_ref_tensor",
lambda tensor: torch.empty_like(tensor),
)
register_op_strategy(torch.ops._C.weak_ref_tensor.default)(pointwise_strategy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering this is telling DTensor sharding propgation system that "weak_ref_tensor" is pointwise ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that is correct - mostly for fake tensor prop which is needed for tracing

k = k.reshape(batch_size * seq_len, num_kv_heads, head_dim)
v = v.reshape(batch_size * seq_len, num_kv_heads, head_dim)
# Flatten batch and seq_len: (batch * seq_len, num_heads, head_dim)
# Use -1 to avoid relying on symbolic shape from prim_to_local
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please educate me what is prim_to_local

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prim_to_local is the function that backs DTensor.to_local() by extracting a local tensor shard from a distributed tensor (DTensor)

@Lucaskabela
Copy link
Contributor Author

Also from the PR description , seems enabling CUDAGraph + Compile front-end (Dynamo) didn't bring too much runtime / memory gain or burden. This might because of the benchmarking is not representative, eg the prompt is too short. Can you try how much speed up we can get from the a larger scale benchmark (eg, https://fburl.com/gdoc/mr91ct3v cudagraph path here)

I tried to make this clear from the summary, but I can reiterate here; there is significant speedup to the rollout engine (vLLM) going from 10s for the 10 steps to 2s. It is listed in the followup PR which actually enables compile - this PR is the prep work to get us ready to turn on compile

Still, I will test the larger benchmark too

I was thinking would user only compile the generator model, or the user would strictly prefer compile both trainer and generator model together. Theoritically this PR also helps enable trainer side compile as well, right?

Yes, but trainer has more work to be done (capture backward); this will also impact the cudagraph strategy but that is outside the scope of this PR; it is best to keep this PR singular in focus to the vLLM rollout engine

@Lucaskabela
Copy link
Contributor Author

Running vLLM TorchTitan without cudagraphs vs with cudagraphs using the benchmark script from above::

vLLM TorchTitan
--------------------------------------------------------------------------------
  Throughput (tokens/sec):
    Mean:       276.64 ± 5.97
    Median:     276.40
    Range:      268.52 - 283.93

  Latency (ms/token):
    Mean:         3.62 ± 0.08
    Median:       3.62
    Range:        3.52 - 3.72

  First Token Latency (ms):
    Mean:         0.00 ± 0.00

  Peak Memory (GB):
    Mean:         0.00 ± 0.00

  Total tokens: 10,240
  Runs: 5

vs

vLLM TorchTitan
--------------------------------------------------------------------------------
  Throughput (tokens/sec):
    Mean:       955.85 ± 1.08
    Median:     955.49
    Range:      954.88 - 957.94

  Latency (ms/token):
    Mean:         1.05 ± 0.00
    Median:       1.05
    Range:        1.04 - 1.05

  First Token Latency (ms):
    Mean:         0.00 ± 0.00

  Peak Memory (GB):
    Mean:         0.00 ± 0.00

  Total tokens: 10,240
  Runs: 5

So nearly 4-5x speedup in the generator (consistent with previous findings). Of course this won't be enabled directly until the next PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants