Skip to content

[rl] Use torchtitan config system for inference and simple GRPO#2191

Open
wwwjn wants to merge 17 commits intogh/wwwjn/2/basefrom
gh/wwwjn/2/head
Open

[rl] Use torchtitan config system for inference and simple GRPO#2191
wwwjn wants to merge 17 commits intogh/wwwjn/2/basefrom
gh/wwwjn/2/head

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jan 2, 2026

Stack from ghstack (oldest at bottom):

  1. Add job_config.py to extend current JobConfig. Now an issue is trainer's config and generator's config are not symmetric, eg Parallelism and Generation.parallelism
  2. Use job config system as the centralized / source-of-truth config, loading config from run_configs/qwen3_0.6b.toml file.
  3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM()
  4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive
  5. Clean up unused code branch

Test: (trainer ddp = 2, n_generator =1)
Screenshot 2025-12-30 at 7 34 00 PM

Following-up refactors:

  • Refactor2: vllm model register - using setup.py and plugin instead of import
  • Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator
  • Refactor4: Use torchtitan Trainer, modularize each component

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 2, 2026
@wwwjn wwwjn closed this Jan 2, 2026
[ghstack-poisoned]
@wwwjn wwwjn reopened this Jan 2, 2026
@wwwjn wwwjn changed the title config sys v1 [rl] Using JobConfig as the centralized config system for inference and simple GRPO Jan 2, 2026
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this direction, thanks! Mostly nits here

wwwjn added 2 commits January 12, 2026 16:13
…inference and simple GRPO"



1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism`
2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file.
3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM()
4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive
5. Clean up unused code branch

Test: (trainer ddp = 2, n_generator =1)
<img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" />

Following-up refactors:
-  Refactor2: vllm model register - using setup.py and plugin instead of import 
- Refactor3: Weight updater,  by directly passing state_dict (DTensor) between trainer and generator
- Refactor4: Use torchtitan Trainer, modularize each component


[ghstack-poisoned]
…inference and simple GRPO"



1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism`
2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file.
3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM()
4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive
5. Clean up unused code branch

Test: (trainer ddp = 2, n_generator =1)
<img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" />

Following-up refactors:
-  Refactor2: vllm model register - using setup.py and plugin instead of import 
- Refactor3: Weight updater,  by directly passing state_dict (DTensor) between trainer and generator
- Refactor4: Use torchtitan Trainer, modularize each component


[ghstack-poisoned]
…inference and simple GRPO"



1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism`
2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file.
3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM()
4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive
5. Clean up unused code branch

Test: (trainer ddp = 2, n_generator =1)
<img width="755" height="294" alt="Screenshot 2025-12-30 at 7 34 00 PM" src="https://github.com/user-attachments/assets/94a3038f-6e5c-4749-9f7b-c575c63be2a1" />

Following-up refactors:
-  Refactor2: vllm model register - using setup.py and plugin instead of import 
- Refactor3: Weight updater,  by directly passing state_dict (DTensor) between trainer and generator
- Refactor4: Use torchtitan Trainer, modularize each component


[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
wwwjn added a commit that referenced this pull request Feb 22, 2026
…ight tying (#2410)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.13.0)
(oldest at bottom):
* #2395
* #2244
* #2221
* #2194
* #2191
* __->__ #2410


This is a alternative fix to
#2402 (comment).

Weight updating between trainer and generator is totally broken because:
It's caused by we called "reload_weights" when updating the weights. The
reload_weights has following steps:

- initialize_layerwise_reload(model): Saves the current real GPU tensors
as info.kernel_tensors, and replace all parameters with meta tensor.
- Call model.load_weights(weights_iter): This function is written by us
and calls set_model_state_dict, Internally, set_model_state_dict tries
to do param.data.copy_(loaded_weight) for each parameter. When
parameters are meta tensor, it will do "no-op". So the weights never get
updated


In this PR:

- Totally bypass reload_weights, and don't load from a file when we
update the weights
- Gets the model via
self.engine.model_executor.driver_worker.get_model()
- Iterates over model.named_parameters() to find the matching parameter
by name
- Does param.data.copy_(new_tensor) directly
[ghstack-poisoned]
[ghstack-poisoned]
@wwwjn wwwjn changed the title [rl] Using JobConfig as the centralized config system for inference and simple GRPO [rl] Use torchtitan config system for inference and simple GRPO Feb 23, 2026
[ghstack-poisoned]
@wwwjn wwwjn mentioned this pull request Feb 23, 2026
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
token_log_probs: List of per-token log prob lists for each completion
prompt_token_ids: List of prompt token ID lists for each completion
"""
logger.info(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe keep this at debug level. It seems like @daniellepintz pointed out these can flood the terminal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bunch of these throughout generation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, cleaned up

ddp_size: int = 1,
tp_size: int = 1,
config: Config,
policy_optimization: PolicyOptimizationConfig,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs only after config position

selective_ac_option="op",
),
),
batch_invariant_mode=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only applied in generator therefore it should be part of the Generator config, not the top-level config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is if we want to achieve true "on-policy" RL, both trainer / generator side would need to be replaced with batch-invariant kernels. (So that even batch size on trainer / generator are different, they will produce exactly the same logprob for each sample). So I made this a top-level config, this is a "mode" we want to switch on/off globally

self,
config: Config,
*,
model_spec: ModelSpec,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, curious what is the Criteria for keep fields here vs in Config? Should we align with how the rest of torchtitan behave?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I passed it here because it was set by upper-level class' config (here the upper-level class is RLTrainer)

@acisseJZhong
Copy link
Contributor

thanks for making the changes and rebasing to main, overall it looks good to me! just have a few questions around naming and config structure.

# NOTE: We need to apply parallelize within model.__init__ because vllm
# doesn't separate model creation and parallelism application and instead
# requires parallelization to be done inside model constructor.
cfg = self.parallel_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the trainer config, not the parallel_config, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, fixed!


# Ensure weights stay in bfloat16
vllm_state = {
k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In VLLMEngine.Config you have a dtype field. Is it not for this cast to use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dtype conversion is from original code, and VLLMEngine.Config.dtype field is used to set EngineArgs(dtype=...), which is also bfloat16. This latter fields controls the weight dtype of vllm model.

My understanding is that, vLLM model's weights are usually set to bfloat16, because: 1) FA kernels only takes bfloat16 inputs, and 2) the kv_cache needs to be in bfloat16/float16 (kv_cache does not support float32 because it takes too much memory).

Once we have control of the forward path, we can control weights / activation dtype in a finer-granularity way, but we will need to figure-out where to change it to bfloat16 so it's compatible with vllm's kv cache mechanism. Now I just simply set generator model's weights to be bfloat16 (and trainer's weights to be bfloat16 when checking bit-wise identity) to bypass this issue, which need more careful investigation of inner vllm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see why they wouldn't agree.

When engine doesn't exist, you use EngineArgs to create one, with VLLMEngine.Config.dtype.
Here the state dict is from PolicyTrainer to update the existing engine, so has to align with what's set above.

from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan

titan_state = vllm_compat_to_torchtitan(vllm_compat_state)
titan_state = vllm_to_torchtitan(vllm_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe leave a TODO on renaming vllm_to_torchtitan to hint that it's for state dict conversion

But actually, I wonder why we don't use our state_dict adapter / checkpointer utils for this conversion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder why we don't use our state_dict adapter / checkpointer utils for this conversion.

This is achieved by the 3rd PR in this stack (which is focusing on removing fqn conversion & saving to file during weight transfer). In this PR I tried to focus on Config system changes and make each PR easy to read


self.llm = None
self.tp_size = tp_size
self.engine = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why VLLMEngine would own another engine?

model = model_args.build()
# Set global default dtype to bfloat16. This is needed because vLLM's Attention
# layer uses torch.get_default_dtype() and it doesn't support float32
torch.set_default_dtype(torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, trainer also can only use torch.bfloat16 master weight?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more for bit-wise identity check (as vllm is using bloat16 weights because of the reason here , I remove it when bit-wise identity mode is not enabled

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants