Skip to content

[rl][combo] Refactor simple RL loop with torchtitan components#2443

Open
wwwjn wants to merge 4 commits intomainfrom
refactor-rl
Open

[rl][combo] Refactor simple RL loop with torchtitan components#2443
wwwjn wants to merge 4 commits intomainfrom
refactor-rl

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Feb 26, 2026

This PR is a combination of all the 4 different PRs:
#2244
#2221
#2194
#2191

Current status:

  • Applied the same parallelism on trainer / generator, eg, Trainer (TP=2) and Generator (TP=2)
    • NOTE: now we have a strong assumption that the trainer and generator must have the same parallelism, and are collocated. This is because we unwrap the DTensor and assume we can always wrap it back with the same device mesh and placement.
  • We should be able to remove this constrain once we have more powerful weight sync. cc @daniellepintz
  • Weight transfer: Simply unwrap DTensor in trainer before sending
  • Collocated trainer and generator

Next Step:

  • Patch batch-invariant kernel, test if the trainer / generator is batch-invariant
  • Supports different parallelism on trainer / generator side
  • CI as a guard

ghstack-source-id: d49aef3
Pull Request resolved: #2191

config sys for simple_grpo v1

ghstack-source-id: d49aef3
Pull Request resolved: #2192

config sys for simple_grpo v2

ghstack-source-id: d49aef3
Pull Request resolved: #2193
ghstack-source-id: 19d7ad4
Pull Request resolved: #2194
ghstack-source-id: f826258
Pull Request resolved: #2221
ghstack-source-id: 9f5bb36
Pull Request resolved: #2244
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 26, 2026
@wwwjn wwwjn changed the title [rl [rl] Refactor simple RL loop with torchtitan components Feb 26, 2026
@wwwjn wwwjn changed the title [rl] Refactor simple RL loop with torchtitan components [rl][combo] Refactor simple RL loop with torchtitan components Feb 26, 2026
@wwwjn wwwjn requested review from Lucaskabela, acisseJZhong and tianyu-l and removed request for tianyu-l February 26, 2026 00:09


@dataclass(kw_only=True, slots=True)
class PolicyOptimizationConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's GRPO specific, the name should reflect that. Also maybe better place to put this is simple_grpo.py



@dataclass(kw_only=True, slots=True)
class VLLMSamplingConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be put into generator.py if that's the only place using it.


# Model-agnostic name used for vLLM model registration.
# Must match the hf_overrides["architectures"] value passed to EngineArgs.
VLLM_MODEL_NAME = "TorchTitanForCausalLM"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the "For" mean?

Suggested change
VLLM_MODEL_NAME = "TorchTitanForCausalLM"
VLLM_MODEL_NAME = "TorchTitanCausalLM"

Usage:
from torchtitan.experiments.rl.unified.plugin import register
register(model_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to update


config = ConfigManager().parse_args()

# Patch model_spec to use the RL-specific parallelize function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most things below should be put into RLTrainer, after config.build() gives you an instance of RLTrainer.

from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import (
vllm_compat_to_torchtitan,
)
vllm_attention_backend: str = "FLASH_ATTN"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you call it VLLMGenerator.Config, you no longer need to have "vllm_" prefix any more.

)
self.config = model_spec.model
logger.debug(f"Creating model with config: {self.config}")
self.model = self.config.build()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may need to do meta init when model is large, OK to put a TODO

max_position = 0

rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position)
rope_cache = self._extend_rope_cache_if_needed(self.model.freqs_cis, max_position)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should build this capability into torchtitan models. cc @fegin @shuhuayu

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn maybe make an issue

def load_weights(self, weights_iter):
"""
vLLM required API.
Load weights from HF checkpoint using the provided state dict adapter.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment meaningful?

self.config.layer.attention.n_heads % tp_size == 0
), "Only support when n_heads can be divided by tp_size"

replace_with_vllm_attention(self.model, tp_degree=tp_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should do this on the config, rather than on the model. The vllm attention module needs to support build from config.



class VLLMRolloutEngine:
class Generator(Actor, Configurable):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit -> VllmGenerator


titan_state = vllm_compat_to_torchtitan(vllm_compat_state)
self._direct_weight_update(titan_state)
vllm_gpu_memory_limit: float = 0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in another comment, you said the trainer and generator were not co-located, in which case, why are we limiting this to 0.5? The default is at least 0.9 IIRC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current PR, the trainer and generator are collocated

token_log_probs_list,
prompt_token_ids_list,
# Register TorchTitan model with vLLM before any engine creation
from torchtitan.experiments.rl.unified.plugin import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a nested import?

async def generate(self) -> None:
"""Generate trajectories and compute rewards/advantages."""
logger.info(
async def generate(self) -> Episode:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note that this will change in #2423 to allow for a proper passing of prompt to the generate function so no one calls this out

lambda: self.state == GeneratorState.READY_TO_GENERATE
)

with torch.no_grad():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it actually necessary to do this under torch.no_grad? I don't see that pattern in other vLLM use cases

return episode

@endpoint
async def set_reward_fn(self, reward_fn: Callable) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the repeat comment, but when might we want to use this function?

model_spec: ModelSpec,
policy_optimization: PolicyOptimizationConfig,
batch_invariant_mode: bool,
hf_assets_path: str = "./tests/assets/tokenizer",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to default to None or to a common location rather than a test asset?

dcp.load(hf_state_dict, storage_reader=storage_reader)
torchtitan_state_dict = self.sd_adapter.from_hf(hf_state_dict)

from torch.distributed.checkpoint.state_dict import (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a nested import?


# Compute loss
loss, loss_metrics = compute_policy_gradient_loss_vllm(
loss, loss_metrics, batch_token_log_probs = compute_policy_gradient_loss(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we compute the forward_backward first and then pass this to the loss to keep the loss function only computing the loss itself? This will match the eventual pattern we want with Trainer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can even put it directly in trainer.py for now until we have a more generic loss abstraction

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants