Skip to content

Enable torch.compile and CUDA graphs for vLLM inference#4

Draft
Lucaskabela wants to merge 1 commit intolucaskabela/compile_logic_changesfrom
lucaskabela/compile_infra
Draft

Enable torch.compile and CUDA graphs for vLLM inference#4
Lucaskabela wants to merge 1 commit intolucaskabela/compile_logic_changesfrom
lucaskabela/compile_infra

Conversation

@Lucaskabela
Copy link
Owner

@Lucaskabela Lucaskabela commented Feb 18, 2026

Summary

We now enable usingsupport_torch_compile in the vllm wrapper definition in order to improve our end to end training runtime.

The particular changes we do in this PR are:

  • Add build_compilation_config() to parallelism_utils with TP-aware cudagraph_mode selection (piecewise for TP>1, full_and_piecewise otherwise) NOTE: we can't use this form of TP with full cudagraphs
  • Add @support_torch_compile decorator to TorchTitanVLLMModelWrapper
  • Add vllm_compile_and_cudagraph flag to Generator and VLLMRolloutEngine in both unified and vllm_compat paths
  • Add --disable-vllm-compile-and-cudagraph CLI arg to infer.py
  • Wire compilation_config and enforce_eager through all LLM instantiation sites

Test Plan

Execute

VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py

Baseline (losses/rewards)

Step Loss
0 -0.0196
1 7.4252
2 3.4577
3 1.2154
4 0.4386
5 0.2471
6 0.2507
7 0.2722
8 0.2140
9 0.1330

This PR (losses/rewards)

Step Loss
0 -0.0196
1 7.4252
2 3.4577
3 1.2154
4 0.4386
5 0.2471
6 0.2507
7 0.2722
8 0.2140
9 0.1330

These match exactly showing we preserve training stability

Metric comparisson

On top of pytorch#2398 we get the following metrics:

Metric Baseline % After Changes %
Total wall-clock 135.15s 125.34s
Cumul. rollout 9.45s 7.0% 2.60s 2.1%
Cumul. train 56.23s 41.6% 46.56s 37.1%
Cumul. optimizer 0.31s 0.2% 0.30s 0.2%
Cumul. weight_sync 69.47s 51.4% 76.18s 60.8%
Peak mem (rollout) 4.09 GiB 4.12 GiB
Peak mem (train) 9.58 GiB 9.58 GiB
Peak mem (optimizer) 8.44 GiB 8.44 GiB

of note - runtime improves significantly, cutting rollout time by almost 5x. There is no significant memory usage

Authored with claude

- Add build_compilation_config() and get_cudagraph_mode() to
  parallelism_utils; callers derive cudagraph_mode from tp_size via
  the helper and pass the string to build_compilation_config
- Add @support_torch_compile decorator to TorchTitanVLLMModelWrapper
- Add vllm_compile_and_cudagraph flag to Generator and VLLMRolloutEngine
  in both unified and vllm_compat paths
- Add --disable-vllm-compile-and-cudagraph CLI arg to infer.py
- Wire compilation_config and enforce_eager through all LLM instantiation sites
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant