[rl][combo] Refactor simple RL loop with torchtitan components#2443
[rl][combo] Refactor simple RL loop with torchtitan components#2443
Conversation
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class PolicyOptimizationConfig: |
There was a problem hiding this comment.
If it's GRPO specific, the name should reflect that. Also maybe better place to put this is simple_grpo.py
|
|
||
|
|
||
| @dataclass(kw_only=True, slots=True) | ||
| class VLLMSamplingConfig: |
There was a problem hiding this comment.
This could be put into generator.py if that's the only place using it.
|
|
||
| # Model-agnostic name used for vLLM model registration. | ||
| # Must match the hf_overrides["architectures"] value passed to EngineArgs. | ||
| VLLM_MODEL_NAME = "TorchTitanForCausalLM" |
There was a problem hiding this comment.
what's the "For" mean?
| VLLM_MODEL_NAME = "TorchTitanForCausalLM" | |
| VLLM_MODEL_NAME = "TorchTitanCausalLM" |
| Usage: | ||
| from torchtitan.experiments.rl.unified.plugin import register | ||
| register(model_spec) |
|
|
||
| config = ConfigManager().parse_args() | ||
|
|
||
| # Patch model_spec to use the RL-specific parallelize function. |
There was a problem hiding this comment.
most things below should be put into RLTrainer, after config.build() gives you an instance of RLTrainer.
| from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( | ||
| vllm_compat_to_torchtitan, | ||
| ) | ||
| vllm_attention_backend: str = "FLASH_ATTN" |
There was a problem hiding this comment.
If you call it VLLMGenerator.Config, you no longer need to have "vllm_" prefix any more.
| ) | ||
| self.config = model_spec.model | ||
| logger.debug(f"Creating model with config: {self.config}") | ||
| self.model = self.config.build() |
There was a problem hiding this comment.
you may need to do meta init when model is large, OK to put a TODO
| max_position = 0 | ||
|
|
||
| rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) | ||
| rope_cache = self._extend_rope_cache_if_needed(self.model.freqs_cis, max_position) |
| def load_weights(self, weights_iter): | ||
| """ | ||
| vLLM required API. | ||
| Load weights from HF checkpoint using the provided state dict adapter. |
There was a problem hiding this comment.
is this comment meaningful?
| self.config.layer.attention.n_heads % tp_size == 0 | ||
| ), "Only support when n_heads can be divided by tp_size" | ||
|
|
||
| replace_with_vllm_attention(self.model, tp_degree=tp_size) |
There was a problem hiding this comment.
Ideally we should do this on the config, rather than on the model. The vllm attention module needs to support build from config.
|
|
||
|
|
||
| class VLLMRolloutEngine: | ||
| class Generator(Actor, Configurable): |
|
|
||
| titan_state = vllm_compat_to_torchtitan(vllm_compat_state) | ||
| self._direct_weight_update(titan_state) | ||
| vllm_gpu_memory_limit: float = 0.5 |
There was a problem hiding this comment.
I think in another comment, you said the trainer and generator were not co-located, in which case, why are we limiting this to 0.5? The default is at least 0.9 IIRC.
There was a problem hiding this comment.
In current PR, the trainer and generator are collocated
| token_log_probs_list, | ||
| prompt_token_ids_list, | ||
| # Register TorchTitan model with vLLM before any engine creation | ||
| from torchtitan.experiments.rl.unified.plugin import ( |
There was a problem hiding this comment.
why is this a nested import?
| async def generate(self) -> None: | ||
| """Generate trajectories and compute rewards/advantages.""" | ||
| logger.info( | ||
| async def generate(self) -> Episode: |
There was a problem hiding this comment.
Just a note that this will change in #2423 to allow for a proper passing of prompt to the generate function so no one calls this out
| lambda: self.state == GeneratorState.READY_TO_GENERATE | ||
| ) | ||
|
|
||
| with torch.no_grad(): |
There was a problem hiding this comment.
Is it actually necessary to do this under torch.no_grad? I don't see that pattern in other vLLM use cases
| return episode | ||
|
|
||
| @endpoint | ||
| async def set_reward_fn(self, reward_fn: Callable) -> None: |
There was a problem hiding this comment.
Apologies for the repeat comment, but when might we want to use this function?
| model_spec: ModelSpec, | ||
| policy_optimization: PolicyOptimizationConfig, | ||
| batch_invariant_mode: bool, | ||
| hf_assets_path: str = "./tests/assets/tokenizer", |
There was a problem hiding this comment.
Is there a way to default to None or to a common location rather than a test asset?
| dcp.load(hf_state_dict, storage_reader=storage_reader) | ||
| torchtitan_state_dict = self.sd_adapter.from_hf(hf_state_dict) | ||
|
|
||
| from torch.distributed.checkpoint.state_dict import ( |
There was a problem hiding this comment.
Why is this a nested import?
|
|
||
| # Compute loss | ||
| loss, loss_metrics = compute_policy_gradient_loss_vllm( | ||
| loss, loss_metrics, batch_token_log_probs = compute_policy_gradient_loss( |
There was a problem hiding this comment.
Can we compute the forward_backward first and then pass this to the loss to keep the loss function only computing the loss itself? This will match the eventual pattern we want with Trainer.
There was a problem hiding this comment.
We can even put it directly in trainer.py for now until we have a more generic loss abstraction
This PR is a combination of all the 4 different PRs:
#2244
#2221
#2194
#2191
Current status:
Next Step: