Orchestrate runtime and memory profile measurement in simple_rl and simple_rl_multiprocess#2398
Orchestrate runtime and memory profile measurement in simple_rl and simple_rl_multiprocess#2398Lucaskabela wants to merge 2 commits intopytorch:mainfrom
Conversation
d3ef4d9 to
8985a6a
Compare
| lambda: self.state == GeneratorState.READY_TO_GENERATE | ||
| ) | ||
|
|
||
| # Generate samples using vLLM |
There was a problem hiding this comment.
Purely indentation change here (put inside the gpu_timer
8985a6a to
07acc32
Compare
torchtitan/components/metrics.py
Outdated
|
|
||
|
|
||
| @contextmanager | ||
| def gpu_timer(sync: bool = True, enabled: bool = True): |
There was a problem hiding this comment.
It is likely that current torchtitan logging / profiling tools cannot cover new use cases, but I would hope we have a more systematic approach on it.
cc @allenwang28 @felipemello1
There was a problem hiding this comment.
Yeah I looked at the torchtitan logger and it doesn't seem to offer fine-grained enough control over timing individual stages of the trainer
There was a problem hiding this comment.
hey @Lucaskabela thanks for looking into this!
@felipemello1 built similar tooling here: https://github.com/meta-pytorch/torchforge/tree/main/src/forge/observability
We plan to propose more logging capabilities soon-ish into Titan, but in the meantime I feel like a starting point we can edit later is:
- Prefer context manager approaches for the regions you want measured
- Restrict these changes to the simple_rl folder where we plan to iterate quickly
There was a problem hiding this comment.
Sounds good - I think this PR has been updated to align with this goal :) let me know any specific code we feel needs changing
There was a problem hiding this comment.
If it's only used in RL folder, no need to change core.
There was a problem hiding this comment.
Thats fair - will move this to the rl specific metric file
2c36e4f to
04428a5
Compare
There was a problem hiding this comment.
Thanks for the change, Observability is very important area. I want to postpone this PR for a while because of following reason:
- I have a stack of PR which optimize / change the API signature a lot, I prefer reconsider the metrics/observability once we are a relative stable rl loop.
- What infra metrics we really need, is these timer really enough to analysis RL runs?
- Can we do the metrics logging more clean (today in main torchtitan, it's "almost" just one config), not changing a lot of APIs and everywhere in generator and trainer
If you'll need to get these data for performance comparison before and after some changes, you can put up a script to do so for now
| gen_time_s=gen_time_s, | ||
| gen_peak_active_gib=gen_peak_active_gib, | ||
| gen_peak_active_pct=gen_peak_active_pct, | ||
| gen_peak_reserved_gib=gen_peak_reserved_gib, |
There was a problem hiding this comment.
This metrics should not be put into Trajectory data
allenwang28
left a comment
There was a problem hiding this comment.
Thanks for making these changes! As I look through them more though, I want to be upfront - @felipemello1 has been thinking about this since the Forge days, and we've been discussing how we want to iterate and check these changes into Titan.
Rather than going back and forth in review, I think it'd be more efficient to propose a logging RFC that addresses these concerns holistically. I'd hold off on further changes here until then. Let me know if you'd like to be part of the discussions or if this is blocking you now!
|
Sure :) In the meanwhile I will leave this PR as a reference for some subsequent PRs I am putting up for us to have timing data |
04428a5 to
8caac90
Compare
Summary
This PR instruments the RL code with timing and memory logging in order to evaluate subsequent changes (compile enablement) in an apples to apples manner
Test
vllm_compat - Simple RL
Execute
CUDA_VISIBLE_DEVICES=7 with-proxy VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/vllm_compat/simple_rl.pyunified - Simple multiprocess RL
Execute
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py