[RL] Prepare vllm definition for support_torch_compile compatibility#2393
[RL] Prepare vllm definition for support_torch_compile compatibility#2393Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Conversation
- Handle DTensor-to-local conversion in VLLMAttention with narrow() workaround for prim_to_local symbolic shape corruption under compile - Remove PrepareModuleInputOutput from TP plan; VLLMAttention now handles DTensor conversion internally, RowwiseParallel uses from_local(Shard(-1)) - Add weak_ref_tensor DTensor compatibility patches (FakeTensor kernel, sharding strategy, Python-level guard) for piecewise CUDA-graph capture - Pre-extend RoPE cache in __init__ instead of dynamic extension (avoids graph breaks from .item() and DTensor checks during forward) - Simplify forward() to be compile-friendly: remove conditionals, use _tp_enabled bool flag, reshape with -1
88de8a8 to
6612d3e
Compare
| # shape before calling to_local(), then uses narrow() to fix the | ||
| # corrupted symbolic dimension that prim_to_local produces under | ||
| # torch.compile. Its output is a plain tensor; RowwiseParallel on | ||
| # attention.wo wraps it back into a DTensor via from_local(Shard(-1)). |
There was a problem hiding this comment.
So the only change to make qwen3 model compile-able is change the attention part (let vllm.Attention to handle TP instead of apply TP ourselves)? What's the difference between using PrepareModuleInputOutput which inserts hooks behind the scene, and convert to DTensor manually?
There was a problem hiding this comment.
For compilation enablement, we also need to modify some of the forward code in vllm_wrapper to avoid data dependence where possible (this is mostly achieved through the unsqueeze changes)
In regards to the TP change, we change dtensor conversion to be handled internally to vllm attention to fix an incorrect shape propogation causing runtime crashes (potentially a bug in dtensor + compile; I am working on getting a minimal repro)
For the other changes, these are needed for cudagraph enablement, and these are:
- Monkeypatching
weak_ref_tensorsto extract the memory from the DTensor in vllm_wrapper (needed for cudagraphs) - Statically allocating the rope_cache (since cudagraph requires memory be statically recorded)
There was a problem hiding this comment.
Re the difference: See 175690
This is the repro, which I will have a potential PR fix for shortly. In the event we don't get that landed though, the edits in this PR will also fix the symbolic shape failure
There was a problem hiding this comment.
pytorch/pytorch#175692 for potential fix (need feedback on this ofc)
| head_dim=model_args.head_dim, | ||
| layer_name=layer_name, | ||
| scale=model_args.head_dim**-0.5, | ||
| tp_enabled=(tp_degree > 1), |
There was a problem hiding this comment.
This patch is only applied in Generator model. Do you only apply compile in generator, not trainer model now?
There was a problem hiding this comment.
And if we only apply compile in generator side, not trainer side, would this make the log probability distribution in both side further different, once we further apply inductor?
There was a problem hiding this comment.
Yes to clarify - this PR is strictly aimed at compiling the generator/rollout/vllm definition. The trainer is orthogonal since it won't use the same vllm compilation mechanism.
Inductor will impact numerics regardless of if we apply on one or both models (due to capturing different graphs), so we don't turn it on for now (that is a later question around where/how to apply inductor - a prerequisite is graph capturability which this PR stack will handle)
From summary, see Lucaskabela#4 for the next step turning on vLLM compilation - this does not affect logits since we only use eager compile and cudagraphs
There was a problem hiding this comment.
The trainer is orthogonal since it won't use the same vllm compilation mechanism.
But since we use the same model definition, enabling compile should be the same on both side? Can you explain more what's the differences on how to enabling compile in vllm vs. non-vllm?
There was a problem hiding this comment.
They are not the same definition since trainer also needs to have graph capturable backwards functions - these are the orthogonal components in reference. When we enable compile for the trainer model most changes will be related to that.
Outside of the fact the graphs we are capturing are different, vLLM compile integration is quite different: the biggest difference is that it drops guards for inference speed at the cost of safety (causing issues such as the DTensor one we observe - if the guards were kept, when the symbolic dimension changes, it would trigger a recompilation)
See https://docs.vllm.ai/en/latest/design/torch_compile/ for more info on this.
| # --------------------------------------------------------------------------- | ||
| # vLLM weak_ref_tensor + DTensor compatibility patches | ||
| # | ||
| # Piecewise CUDA-graph capture calls weak_ref_tensor() on every subgraph |
There was a problem hiding this comment.
What is weak_ref_tensor()? Can you point me to some docs?
There was a problem hiding this comment.
Sure - following the code pointer in this comment we lead to:
This weak_ref_tensor is a utility used to record the tensor memory which is required to be kept alive for cudagraphs.
https://github.com/vllm-project/vllm/blob/a0c70816956298f7dd1d0cf47cfa1a169a413692/vllm/utils/torch_utils.py#L649 is the definition
There was a problem hiding this comment.
Thanks, seems this change is needed because the model is DTensor-ified, and the weak_ref_tensor is assuming plain tensor
Also curious, without weak_ref_tensor, does CUDAGraph in vllm still work? Or it's totally broken?
There was a problem hiding this comment.
We were seeing crashes, but when I patched the fix for DTensor I no longer need this code so can clean it up probably
wwwjn
left a comment
There was a problem hiding this comment.
Also from the PR description , seems enabling CUDAGraph + Compile front-end (Dynamo) didn't bring too much runtime / memory gain or burden. This might because of the benchmarking is not representative, eg the prompt is too short. Can you try how much speed up we can get from the a larger scale benchmark (eg, https://fburl.com/gdoc/mr91ct3v cudagraph path here)
I was thinking would user only compile the generator model, or the user would strictly prefer compile both trainer and generator model together. Theoritically this PR also helps enable trainer side compile as well, right?
| head_dim=model_args.head_dim, | ||
| layer_name=layer_name, | ||
| scale=model_args.head_dim**-0.5, | ||
| tp_enabled=(tp_degree > 1), |
There was a problem hiding this comment.
The trainer is orthogonal since it won't use the same vllm compilation mechanism.
But since we use the same model definition, enabling compile should be the same on both side? Can you explain more what's the differences on how to enabling compile in vllm vs. non-vllm?
| # --------------------------------------------------------------------------- | ||
| # vLLM weak_ref_tensor + DTensor compatibility patches | ||
| # | ||
| # Piecewise CUDA-graph capture calls weak_ref_tensor() on every subgraph |
There was a problem hiding this comment.
Thanks, seems this change is needed because the model is DTensor-ified, and the weak_ref_tensor is assuming plain tensor
Also curious, without weak_ref_tensor, does CUDAGraph in vllm still work? Or it's totally broken?
| "_C::weak_ref_tensor", | ||
| lambda tensor: torch.empty_like(tensor), | ||
| ) | ||
| register_op_strategy(torch.ops._C.weak_ref_tensor.default)(pointwise_strategy) |
There was a problem hiding this comment.
Registering this is telling DTensor sharding propgation system that "weak_ref_tensor" is pointwise ops?
There was a problem hiding this comment.
Yes that is correct - mostly for fake tensor prop which is needed for tracing
| k = k.reshape(batch_size * seq_len, num_kv_heads, head_dim) | ||
| v = v.reshape(batch_size * seq_len, num_kv_heads, head_dim) | ||
| # Flatten batch and seq_len: (batch * seq_len, num_heads, head_dim) | ||
| # Use -1 to avoid relying on symbolic shape from prim_to_local |
There was a problem hiding this comment.
Please educate me what is prim_to_local
There was a problem hiding this comment.
prim_to_local is the function that backs DTensor.to_local() by extracting a local tensor shard from a distributed tensor (DTensor)
I tried to make this clear from the summary, but I can reiterate here; there is significant speedup to the rollout engine (vLLM) going from 10s for the 10 steps to 2s. It is listed in the followup PR which actually enables compile - this PR is the prep work to get us ready to turn on compile Still, I will test the larger benchmark too
Yes, but trainer has more work to be done (capture backward); this will also impact the cudagraph strategy but that is outside the scope of this PR; it is best to keep this PR singular in focus to the vLLM rollout engine |
|
Running vLLM TorchTitan without cudagraphs vs with cudagraphs using the benchmark script from above:: vLLM TorchTitan
--------------------------------------------------------------------------------
Throughput (tokens/sec):
Mean: 276.64 ± 5.97
Median: 276.40
Range: 268.52 - 283.93
Latency (ms/token):
Mean: 3.62 ± 0.08
Median: 3.62
Range: 3.52 - 3.72
First Token Latency (ms):
Mean: 0.00 ± 0.00
Peak Memory (GB):
Mean: 0.00 ± 0.00
Total tokens: 10,240
Runs: 5vs vLLM TorchTitan
--------------------------------------------------------------------------------
Throughput (tokens/sec):
Mean: 955.85 ± 1.08
Median: 955.49
Range: 954.88 - 957.94
Latency (ms/token):
Mean: 1.05 ± 0.00
Median: 1.05
Range: 1.04 - 1.05
First Token Latency (ms):
Mean: 0.00 ± 0.00
Peak Memory (GB):
Mean: 0.00 ± 0.00
Total tokens: 10,240
Runs: 5So nearly 4-5x speedup in the generator (consistent with previous findings). Of course this won't be enabled directly until the next PR |
Summary
This PR makes the necessary model changes to support
support_torch_compilefor the vLLM model definition. These are:For Graph Capture
to avoid datadendent errors
For CudaGraph Capture
See Lucaskabela#4 for next step enablement (will resubmit against main after this Pr lands)
Test Plan
Baseline (losses/rewards)
This PR (losses/rewards)
These match exactly showing we preserve training stability
Metric comparisson
On top of #2398 we get the following metrics:
of note - there is no significant change in overall runtime, nor memory usage
Authored with claude