[rl] Use torchtitan config system for inference and simple GRPO#2191
[rl] Use torchtitan config system for inference and simple GRPO#2191wwwjn wants to merge 17 commits intogh/wwwjn/2/basefrom
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
allenwang28
left a comment
There was a problem hiding this comment.
I like this direction, thanks! Mostly nits here
…inference and simple GRPO" 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) <img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" /> Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned]
…inference and simple GRPO" 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) <img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" /> Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned]
…inference and simple GRPO" 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) <img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" /> Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned]
…ight tying (#2410) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0) (oldest at bottom): * #2395 * #2244 * #2221 * #2194 * #2191 * __->__ #2410 This is a alternative fix to #2402 (comment). Weight updating between trainer and generator is totally broken because: It's caused by we called "reload_weights" when updating the weights. The reload_weights has following steps: - initialize_layerwise_reload(model): Saves the current real GPU tensors as info.kernel_tensors, and replace all parameters with meta tensor. - Call model.load_weights(weights_iter): This function is written by us and calls set_model_state_dict, Internally, set_model_state_dict tries to do param.data.copy_(loaded_weight) for each parameter. When parameters are meta tensor, it will do "no-op". So the weights never get updated In this PR: - Totally bypass reload_weights, and don't load from a file when we update the weights - Gets the model via self.engine.model_executor.driver_worker.get_model() - Iterates over model.named_parameters() to find the matching parameter by name - Does param.data.copy_(new_tensor) directly
| token_log_probs: List of per-token log prob lists for each completion | ||
| prompt_token_ids: List of prompt token ID lists for each completion | ||
| """ | ||
| logger.info( |
There was a problem hiding this comment.
nit: maybe keep this at debug level. It seems like @daniellepintz pointed out these can flood the terminal
There was a problem hiding this comment.
Bunch of these throughout generation
There was a problem hiding this comment.
Good point, cleaned up
| ddp_size: int = 1, | ||
| tp_size: int = 1, | ||
| config: Config, | ||
| policy_optimization: PolicyOptimizationConfig, |
There was a problem hiding this comment.
kwargs only after config position
| selective_ac_option="op", | ||
| ), | ||
| ), | ||
| batch_invariant_mode=True, |
There was a problem hiding this comment.
This is only applied in generator therefore it should be part of the Generator config, not the top-level config
There was a problem hiding this comment.
My understanding is if we want to achieve true "on-policy" RL, both trainer / generator side would need to be replaced with batch-invariant kernels. (So that even batch size on trainer / generator are different, they will produce exactly the same logprob for each sample). So I made this a top-level config, this is a "mode" we want to switch on/off globally
| self, | ||
| config: Config, | ||
| *, | ||
| model_spec: ModelSpec, |
There was a problem hiding this comment.
same, curious what is the Criteria for keep fields here vs in Config? Should we align with how the rest of torchtitan behave?
There was a problem hiding this comment.
I passed it here because it was set by upper-level class' config (here the upper-level class is RLTrainer)
|
thanks for making the changes and rebasing to main, overall it looks good to me! just have a few questions around naming and config structure. |
| # NOTE: We need to apply parallelize within model.__init__ because vllm | ||
| # doesn't separate model creation and parallelism application and instead | ||
| # requires parallelization to be done inside model constructor. | ||
| cfg = self.parallel_config |
There was a problem hiding this comment.
This is the trainer config, not the parallel_config, no?
|
|
||
| # Ensure weights stay in bfloat16 | ||
| vllm_state = { | ||
| k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v |
There was a problem hiding this comment.
In VLLMEngine.Config you have a dtype field. Is it not for this cast to use?
There was a problem hiding this comment.
This dtype conversion is from original code, and VLLMEngine.Config.dtype field is used to set EngineArgs(dtype=...), which is also bfloat16. This latter fields controls the weight dtype of vllm model.
My understanding is that, vLLM model's weights are usually set to bfloat16, because: 1) FA kernels only takes bfloat16 inputs, and 2) the kv_cache needs to be in bfloat16/float16 (kv_cache does not support float32 because it takes too much memory).
Once we have control of the forward path, we can control weights / activation dtype in a finer-granularity way, but we will need to figure-out where to change it to bfloat16 so it's compatible with vllm's kv cache mechanism. Now I just simply set generator model's weights to be bfloat16 (and trainer's weights to be bfloat16 when checking bit-wise identity) to bypass this issue, which need more careful investigation of inner vllm
There was a problem hiding this comment.
I still don't see why they wouldn't agree.
When engine doesn't exist, you use EngineArgs to create one, with VLLMEngine.Config.dtype.
Here the state dict is from PolicyTrainer to update the existing engine, so has to align with what's set above.
| from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan | ||
|
|
||
| titan_state = vllm_compat_to_torchtitan(vllm_compat_state) | ||
| titan_state = vllm_to_torchtitan(vllm_state) |
There was a problem hiding this comment.
maybe leave a TODO on renaming vllm_to_torchtitan to hint that it's for state dict conversion
But actually, I wonder why we don't use our state_dict adapter / checkpointer utils for this conversion.
There was a problem hiding this comment.
wonder why we don't use our state_dict adapter / checkpointer utils for this conversion.
This is achieved by the 3rd PR in this stack (which is focusing on removing fqn conversion & saving to file during weight transfer). In this PR I tried to focus on Config system changes and make each PR easy to read
|
|
||
| self.llm = None | ||
| self.tp_size = tp_size | ||
| self.engine = None |
There was a problem hiding this comment.
why VLLMEngine would own another engine?
| model = model_args.build() | ||
| # Set global default dtype to bfloat16. This is needed because vLLM's Attention | ||
| # layer uses torch.get_default_dtype() and it doesn't support float32 | ||
| torch.set_default_dtype(torch.bfloat16) |
There was a problem hiding this comment.
hmm, trainer also can only use torch.bfloat16 master weight?
There was a problem hiding this comment.
This is more for bit-wise identity check (as vllm is using bloat16 weights because of the reason here , I remove it when bit-wise identity mode is not enabled
Stack from ghstack (oldest at bottom):
trainer's config andgenerator's config are not symmetric, egParallelismandGeneration.parallelismrun_configs/qwen3_0.6b.tomlfile.Test: (trainer ddp = 2, n_generator =1)

Following-up refactors: