diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index eff23e0edb..64a0f35ba1 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -36,23 +36,18 @@ uv pip install torch vllm xformers --pre \ python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... ``` -4. Run inference: +4. Run inference with unified model definition: ```bash -python torchtitan/experiments/rl/unified/infer.py -``` - -Run with TP: -```bash -python torchtitan/experiments/rl/unified/infer.py --tensor-parallel-size 2 - +torchrun --nproc_per_node= \ + torchtitan/experiments/rl/unified/infer.py ``` 5. Run simple rl loop -```bash -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` -Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, -which uses a unified model definition for trainer and generator. +python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --trainer.checkpoint.initial_load_path= +``` +We use a unified model definition for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs. ## TODO Work on batch invariance: diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py index 8a4efac66f..59fbb97610 100644 --- a/torchtitan/experiments/rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -58,7 +58,7 @@ def __init__(self, *, vllm_config, prefix=""): # Register with vLLM ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) - logger.info( + logger.debug( f"Successfully registered {model_name} with vLLM using ModelSpec " f"(flavor={model_spec.flavor})" ) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 90790e4c27..937f722ea9 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -4,31 +4,40 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import asyncio +import glob import logging import os +import shutil -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List import torch from monarch.actor import Actor, endpoint from safetensors.torch import save_file -from torchtitan.config import CommConfig +from torchtitan.config import CommConfig, Configurable, ParallelismConfig from torchtitan.distributed import utils as dist_utils # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 +from torchtitan.experiments.rl.unified.configs import ( + PolicyOptimizationConfig, + VLLMSamplingConfig, +) + from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_grpo_advantages, compute_grpo_advantages_stable, - math_reward_function, trivial_reward_function, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm -from vllm import LLM, SamplingParams +from torchtitan.protocols.model_spec import ModelSpec +from vllm import EngineArgs, LLMEngine, SamplingParams + from vllm.config import AttentionConfig +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from vllm.sampling_params import RequestOutputKind + from vllm.v1.attention.backends.registry import AttentionBackendEnum logger = logging.getLogger(__name__) @@ -58,36 +67,93 @@ class TrajectoryData: advantages: torch.Tensor -class VLLMRolloutEngine: +class Generator(Actor, Configurable): """ - vLLM engine for fast rollouts with weight updates. + Generates rollouts using vLLM engine. - Note: vLLM loads from model_config.model path, so we create a temporary - directory with updated weights and restart the engine. This is faster than - recreating temp dirs repeatedly and handles config/tokenizer files properly. + Maintains a vLLM engine that is synchronized with the Trainer + via weight sync. Generates completions for given prompts and + computes rewards/advantages. Args: - model_path: Path to HuggingFace model (for config/tokenizer) - temp_checkpoint_dir: Directory to save temporary weight checkpoints + config: Generator-specific configuration. + model_path: Path to the HF model checkpoint. + dump_folder: Root output folder for RL artifacts. + batch_invariant_mode: Enable batch-invariant mode for deterministic ops. + policy_optimization: GRPO hyperparameters. + prompt_texts: List of prompt strings. + expected_answers: List of expected answer strings. """ + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """Generator actor configuration.""" + + dtype: str = "bfloat16" + """Data type for model weights, passed directly to vLLM (auto, float16, bfloat16, float32).""" + + gpu_memory_limit: float = 0.5 + """Fraction of GPU memory to use for the vLLM engine (0.0 to 1.0).""" + + enforce_eager: bool = True + """Disable CUDA graphs in vLLM (use eager execution).""" + + seed: int | None = None + """Random seed for reproducible generation. None means no fixed seed.""" + + parallelism: ParallelismConfig = field(default_factory=ParallelismConfig) + """Parallelism configuration for the vLLM engine.""" + + sampling: VLLMSamplingConfig = field(default_factory=VLLMSamplingConfig) + """Default sampling parameters for generation.""" + + vllm_attention_backend: str = "FLASH_ATTN" + """vLLM attention backend to use (e.g., FLASH_ATTN, XFORMERS).""" + def __init__( self, + config: Config, + *, + model_spec: ModelSpec, model_path: str, - temp_checkpoint_dir: str = "./converted", - tp_size: int = 1, + dump_folder: str, + batch_invariant_mode: bool, + policy_optimization: PolicyOptimizationConfig, + prompt_texts: list[str], + expected_answers: list[str], ): - self.base_model_path = model_path - self.temp_model_dir = os.path.abspath( - os.path.join(temp_checkpoint_dir, "vllm_temp_model") - ) - os.makedirs(self.temp_model_dir, exist_ok=True) + self.config = config + self.model_spec = model_spec - import glob + # Set vLLM environment variables from config before any vLLM initialization + if batch_invariant_mode: + os.environ["VLLM_BATCH_INVARIANT"] = "1" + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) - # Copy config/tokenizer files from base model to temp dir - import shutil + os.environ["VLLM_ATTENTION_BACKEND"] = config.vllm_attention_backend + + self.prompt_texts = prompt_texts + self.expected_answers = expected_answers + + # Extract needed fields from configs + self.model_path = model_path + self.max_new_tokens = config.sampling.max_tokens + self.temperature = config.sampling.temperature + self.group_size = policy_optimization.group_size + self.grpo_beta = policy_optimization.beta + self.use_stable_grpo = policy_optimization.use_stable_grpo + + # Initialize distributed environment for SPMD generator + world_size = dist_utils.init_distributed(CommConfig()) + + # Set up temp model directory for vLLM weight loading + self._base_model_path = model_path + self._temp_model_dir = os.path.abspath( + os.path.join(dump_folder, "vllm_temp_model") + ) + os.makedirs(self._temp_model_dir, exist_ok=True) + # Copy config/tokenizer files from base model to temp dir for file in [ "config.json", "tokenizer.json", @@ -98,396 +164,245 @@ def __init__( ]: src = os.path.join(model_path, file) if os.path.exists(src): - shutil.copy2(src, self.temp_model_dir) + shutil.copy2(src, self._temp_model_dir) # Copy the original model shard files if they exist # We'll overwrite these with our single model.safetensors later for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): - dst = os.path.join(self.temp_model_dir, os.path.basename(shard_file)) + dst = os.path.join(self._temp_model_dir, os.path.basename(shard_file)) shutil.copy2(shard_file, dst) # Copy index file if it exists index_file = os.path.join(model_path, "model.safetensors.index.json") if os.path.exists(index_file): - shutil.copy2(index_file, self.temp_model_dir) + shutil.copy2(index_file, self._temp_model_dir) - self.llm = None - self.tp_size = tp_size - logger.info("vLLM rollout engine initialized (will load on first use)") + self._engine: LLMEngine | None = None - def update_weights(self, vllm_compat_state: dict) -> None: - """ - Update vLLM model weights from vLLM-compat state dict. + self.policy_version = 0 - This converts weights to vLLM format, saves them, and reloads using - vLLM's reload_weights() API after updating the model path config. + # Reward function. TODO: Move reward calculation out of generator + self.reward_fn = trivial_reward_function - Args: - vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + logger.debug("Generator initialized (vLLM engine will load on first use)") + + def _update_vllm_model_weights(self, vllm_state: dict) -> None: """ - # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) - vllm_state = torchtitan_to_vllm(vllm_compat_state) + Update vLLM model weights from vLLM model state dict. This function is used + when updating vLLM model's weights from trainer's updated weights. + Args: + vllm_state: vLLM model state dict, a map from vLLM model's fqn names to weights + """ # Save to temp model directory - import os - - checkpoint_path = os.path.join(self.temp_model_dir, "model.safetensors") + checkpoint_path = os.path.join(self._temp_model_dir, "model.safetensors") # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure - import glob - import json - shard_files = sorted( - glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) + glob.glob(os.path.join(self._temp_model_dir, "model-*.safetensors")) ) - index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") + index_file = os.path.join(self._temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safe tensor files which we only + # right now we hardcoded to work with 1 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. - # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: - logger.info(f"Saving weights to {checkpoint_path}") - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + logger.debug(f"Saving weights to {checkpoint_path}") + + # TODO: Check the detail of vLLM's dtype conversion journey + # Currently converting float32 to bfloat16 to match vLLM's attention and kv cache dtype + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() - logger.info( + logger.debug( f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine - if self.llm is None: - self.llm = LLM( - model=self.temp_model_dir, - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, + # First time: create the engine using LLMEngine and EngineArgs + if self._engine is None: + cfg = self.config + + engine_args = EngineArgs( + # Model configuration + model=self._temp_model_dir, trust_remote_code=True, - max_model_len=2048, - dtype="bfloat16", - gpu_memory_utilization=0.1, # Reduced from 0.5 - distributed_executor_backend="external_launcher", # vllm do not spawn processes - seed=42, # Fixed seed for determinism - enforce_eager=True, - tensor_parallel_size=self.tp_size, + dtype=cfg.dtype, + # Parallelism configuration + tensor_parallel_size=cfg.parallelism.tensor_parallel_degree, + # Use external_launcher because Monarch already spawns the worker processes + distributed_executor_backend="external_launcher", + # Memory and performance + gpu_memory_utilization=cfg.gpu_memory_limit, + enforce_eager=cfg.enforce_eager, + # Seed + seed=cfg.seed, + # HuggingFace overrides to use TorchTitan model. + # TODO: make this field configurable and align with model registration + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, attention_config=AttentionConfig( backend=AttentionBackendEnum.FLASH_ATTN, ), ) - logger.info("Created new vLLM engine") + + logger.debug("Initializing LLMEngine from EngineArgs...") + self._engine = LLMEngine.from_engine_args(engine_args) + logger.debug("Created new vLLM LLMEngine") else: # Direct parameter copy into model tensors. # This bypasses vLLM's reload_weights() which uses a layerwise # reload mechanism that moves params to meta device - from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - vllm_compat_to_torchtitan, - ) - - titan_state = vllm_compat_to_torchtitan(vllm_compat_state) - self._direct_weight_update(titan_state) - - def _direct_weight_update(self, titan_state: dict) -> None: - """Update model weights by copying directly into GPU parameters. - - Args: - titan_state: TorchTitan format state dict (w1/w2/w3, wq/wk/wv/wo, etc.) - """ - - # Access model from vLLM engine - model = self.llm.llm_engine.model_executor.driver_worker.get_model() - params = dict(model.named_parameters()) - - for name, new_weight in titan_state.items(): - # TorchTitanVLLMModelWrapper stores the model as self.model, - # so parameters have a "model." prefix - param_name = f"model.{name}" - if param_name in params: - param = params[param_name] - new_w = new_weight.to(device=param.device, dtype=param.dtype) - param.data.copy_(new_w) - - @torch.no_grad() - def generate( - self, - prompt_texts: list[str], - max_new_tokens: int = 20, - temperature: float = 1.0, - n_samples_per_prompt: int = 4, - ) -> tuple[ - list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] - ]: - """ - Generate samples using vLLM. + from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan + + titan_state = vllm_to_torchtitan(vllm_state) + model = self._engine.model_executor.driver_worker.get_model() + params = dict(model.named_parameters()) + + for name, new_weight in titan_state.items(): + # TorchTitanVLLMModelWrapper stores the model as self.model, + # so parameters have a "model." prefix + param_name = f"model.{name}" + if param_name in params: + param = params[param_name] + new_w = new_weight.to(device=param.device, dtype=param.dtype) + param.data.copy_(new_w) + + def _compute_rewards_and_advantages( + self, completions: list[str] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute rewards and GRPO advantages for generated completions. + TODO: Move this function out of generator for encapsulation. Args: - prompt_texts: List of prompt strings - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - n_samples_per_prompt: Number of samples per prompt + completions: List of completion strings. Returns: - completions: List of completion strings - log_probs: [batch] - Sum of log probs for each completion - token_ids: List of token ID lists for each completion (generated tokens only) - token_log_probs: List of per-token log prob lists for each completion - prompt_token_ids: List of prompt token ID lists for each completion + rewards: Raw rewards tensor. + advantages: GRPO advantage tensor. """ - sampling_params = SamplingParams( - temperature=temperature, - max_tokens=max_new_tokens, - n=n_samples_per_prompt, - seed=42, - logprobs=1, - prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + logger.debug( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" ) + rewards = self.reward_fn(completions, self.expected_answers, self.group_size) - outputs = self.llm.generate(prompt_texts, sampling_params) - - # Extract completions and log probs - completions = [] - log_probs_list = [] - token_ids_list = [] - token_log_probs_list = [] - prompt_token_ids_list = [] - - for output in outputs: - # Extract prompt token IDs from the output - prompt_token_ids = output.prompt_token_ids - - for sample in output.outputs: - completions.append(sample.text) - - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) - - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) - - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) - - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) - - log_probs = torch.tensor(log_probs_list, dtype=torch.float32) - - return ( - completions, - log_probs, - token_ids_list, - token_log_probs_list, - prompt_token_ids_list, - ) - - def __del__(self): - """Cleanup vLLM engine.""" - if hasattr(self, "llm"): - del self.llm - torch.cuda.empty_cache() - - -class GeneratorState: - """States for the Generator's state machine.""" - - READY_TO_GENERATE = "READY_TO_GENERATE" - READY_TO_UPDATE = "READY_TO_UPDATE" - - -class Generator(Actor): - """ - Generates rollouts using vLLM engine. - - Maintains a vLLM engine that is synchronized with the Trainer - via weight sync. Generates completions for given prompts and - computes rewards/advantages. - - Args: - model_path: Path to HuggingFace model - prompt_texts: List of prompt strings - expected_answers: List of expected answers - group_size: Number of samples per prompt - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - use_real_dataset: Whether using real dataset (GSM8K) - grpo_beta: Beta for GRPO advantages - use_stable_grpo: Whether to use stable GRPO - tp_size: Tensor Parallel size - """ - - def __init__( - self, - model_path: str, - prompt_texts: List[str], - expected_answers: List[str], - group_size: int = 8, - max_new_tokens: int = 20, - temperature: float = 1.0, - use_real_dataset: bool = False, - grpo_beta: float = 0.1, - use_stable_grpo: bool = False, - tp_size: int = 1, - ): - self.model_path = model_path - self.prompt_texts = prompt_texts - self.expected_answers = expected_answers - self.group_size = group_size - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.use_real_dataset = use_real_dataset - self.grpo_beta = grpo_beta - self.use_stable_grpo = use_stable_grpo - self.tp_size = tp_size - - # Initialize distributed environment for SPMD generator - world_size = dist_utils.init_distributed( - CommConfig(), - ) - # Initialize vLLM engine - self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) - - # State machine - self.state = GeneratorState.READY_TO_UPDATE - self.cond = asyncio.Condition() - self.policy_version = 0 + # Normalize rewards + reward_mean = rewards.mean() + reward_std = rewards.std() + if reward_std > 1e-8: + rewards_normalized = (rewards - reward_mean) / reward_std + else: + rewards_normalized = rewards - reward_mean - # Reward function - self.reward_fn = ( - math_reward_function if use_real_dataset else trivial_reward_function - ) + # Compute advantages using GRPO + if self.use_stable_grpo: + advantages = compute_grpo_advantages_stable( + rewards_normalized, self.group_size + ) + else: + advantages = compute_grpo_advantages( + rewards_normalized, self.group_size, beta=self.grpo_beta + ) - logger.info("Generator initialized with vLLM engine") + return rewards, advantages @endpoint - async def generate(self) -> None: - """Generate trajectories and compute rewards/advantages.""" - logger.info( + async def generate(self) -> TrajectoryData: + """Generate trajectories and compute rewards/advantages. + Called by the orchestrator (simple_grpo.py). + """ + logger.debug( f"{os.getpid()=} Generating start generate (policy v{self.policy_version})..." ) - async with self.cond: - # Wait until ready to generate (weights have been updated) - await self.cond.wait_for( - lambda: self.state == GeneratorState.READY_TO_GENERATE - ) + with torch.no_grad(): # Generate samples using vLLM - ( - completions, - vllm_log_probs, - vllm_token_ids, - vllm_token_log_probs, - prompt_token_ids, - ) = self.vllm_engine.generate( - self.prompt_texts, - self.max_new_tokens, - self.temperature, - n_samples_per_prompt=self.group_size, + sampling_params = SamplingParams( + temperature=self.temperature, + max_tokens=self.max_new_tokens, + n=self.group_size, + seed=self.config.seed, + logprobs=1, + prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - # Compute rewards - rewards = self.reward_fn( - completions, self.expected_answers, self.group_size - ) - - # Normalize rewards - reward_mean = rewards.mean() - reward_std = rewards.std() - if reward_std > 1e-8: - rewards_normalized = (rewards - reward_mean) / reward_std - else: - rewards_normalized = rewards - reward_mean - - # Compute advantages using GRPO - if self.use_stable_grpo: - advantages = compute_grpo_advantages_stable( - rewards_normalized, self.group_size - ) - else: - advantages = compute_grpo_advantages( - rewards_normalized, self.group_size, beta=self.grpo_beta - ) - - # Create trajectory data - trajectory = TrajectoryData( - policy_version=self.policy_version, - completions=completions, - vllm_token_ids=vllm_token_ids, - vllm_token_log_probs=vllm_token_log_probs, - prompt_token_ids=prompt_token_ids, - rewards=rewards, - advantages=advantages, - ) - - # Signal ready for update - self.state = GeneratorState.READY_TO_UPDATE - self.cond.notify_all() + for request_id, prompt in enumerate(self.prompt_texts): + self._engine.add_request(str(request_id), prompt, sampling_params) + + # Step through engine until all requests are finished + all_outputs = [] + while self._engine.has_unfinished_requests(): + request_outputs = self._engine.step() + all_outputs.extend(request_outputs) + + # Extract completions and log probs + completions = [] + token_ids_list = [] + token_log_probs_list = [] + prompt_token_ids_list = [] + + for output in all_outputs: + prompt_token_ids = output.prompt_token_ids + + for sample in output.outputs: + completions.append(sample.text) + prompt_token_ids_list.append( + prompt_token_ids + ) # Store prompt tokens for this sample + token_ids_list.append( + sample.token_ids + ) # Extract token IDs (generated tokens only) + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] # Extract per-token log probs + token_log_probs_list.append(per_token_log_probs) + + # Compute rewards and advantages + rewards, advantages = self._compute_rewards_and_advantages(completions) + + # Create trajectory data + trajectory = TrajectoryData( + policy_version=self.policy_version, + completions=completions, + vllm_token_ids=token_ids_list, + vllm_token_log_probs=token_log_probs_list, + prompt_token_ids=prompt_token_ids_list, + rewards=rewards, + advantages=advantages, + ) - logger.info( - f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." - ) - return trajectory + logger.debug( + f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." + ) + return trajectory @endpoint async def update(self, version: int, vllm_compat_state: dict) -> None: """Update generate weights. + Called by the orchestrator (simple_grpo.py). Args: version: New policy version number vllm_compat_state: vLLM-compatible state dict """ - async with self.cond: - self.vllm_engine.update_weights(vllm_compat_state) - # Update version and state - self.policy_version = version - self.state = GeneratorState.READY_TO_GENERATE - self.cond.notify_all() - logger.info( - f"{os.getpid()=} Generator updating weights to policy v{version}..." - ) + # TODO: remove the helper function (_update_vllm_model_weights) once we clean up the weight updates + self._update_vllm_model_weights(vllm_compat_state) + self.policy_version = version + logger.debug( + f"{os.getpid()=} Generator updating weights to policy v{version}..." + ) + + def __del__(self): + """Cleanup vLLM engine.""" + if hasattr(self, "_engine"): + del self._engine + torch.cuda.empty_cache() diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index b7118ce189..c859faeead 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -6,69 +6,134 @@ import logging import os +from dataclasses import dataclass, field from typing import Any, Optional import torch from monarch.actor import Actor, endpoint +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.config import Configurable +from torchtitan.config.configs import ( + ActivationCheckpointConfig, + ParallelismConfig, + TrainingConfig, +) from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData +from torchtitan.experiments.rl.unified.configs import PolicyOptimizationConfig from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.unified.models.utils import ( + replace_with_vllm_compatible_flash_attention, +) from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( + torchtitan_to_vllm, + vllm_to_torchtitan, ) +from torchtitan.protocols.model_spec import ModelSpec logger = logging.getLogger(__name__) -class Trainer(Actor): +class PolicyTrainer(Actor, Configurable): """ - Updates policy based on collected trajectories. + Updates policy based on collected Episodes. + + Run model forward on Episodes, computes loss, and run backward. + Receives the top-level ``RLTrainer.Config`` and reads policy trainer + settings (batch_invariant_mode, grpo) directly from it, plus model / + optimizer / parallelism settings from the nested ``config.trainer``. - Run model forward on trajectories, computes loss, and run backward. + TODO: Use torchtitan PolicyTrainer for model init and parallelism. Args: - titan_checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model - learning_rate: Learning rate for optimizer - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model + config: PolicyTrainer.Config for model/optimizer/parallelism settings. + policy_optimization: GRPO hyperparameters. """ + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """PolicyTrainer configuration for optimizer, training, and parallelism.""" + + optimizer: OptimizersContainer.Config = field( + default_factory=OptimizersContainer.Config + ) + lr_scheduler: LRSchedulersContainer.Config = field( + default_factory=LRSchedulersContainer.Config + ) + training: TrainingConfig = field(default_factory=TrainingConfig) + parallelism: ParallelismConfig = field(default_factory=ParallelismConfig) + checkpoint: CheckpointManager.Config = field( + default_factory=CheckpointManager.Config + ) + activation_checkpoint: ActivationCheckpointConfig = field( + default_factory=ActivationCheckpointConfig + ) + def __init__( self, - titan_checkpoint_path: str, - model_path: str, - learning_rate: float = 1e-5, - model_mode: str = ModelMode.VLLM_COMPAT, - ddp_size: int = 1, - tp_size: int = 1, + config: Config, + *, + model_spec: ModelSpec, + policy_optimization: PolicyOptimizationConfig, + batch_invariant_mode: bool, ): + self.config = config + self.model_spec = model_spec + + # Extract needed fields from config + model_path = config.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = config.optimizer.lr + self.ddp_size = config.parallelism.data_parallel_replicate_degree + self.tp_size = config.parallelism.tensor_parallel_degree + + # GRPO settings + self.group_size = policy_optimization.group_size + self.grpo_beta = policy_optimization.beta + self.use_stable_grpo = policy_optimization.use_stable_grpo + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - self.model = load_model( - titan_checkpoint_path, model_path, model_mode=model_mode - ) - self.ddp_size = ddp_size - self.tp_size = tp_size + # Step1: Load trainer model from HF/vLLM checkpoint. TODO: Use torchtitan components + model_config = model_spec.model + titan_state_dict = vllm_to_torchtitan(model_path) + + # If weight tying is enabled but output.weight is missing from the checkpoint + if model_config.enable_weight_tying and "output.weight" not in titan_state_dict: + titan_state_dict["output.weight"] = titan_state_dict[ + "tok_embeddings.weight" + ] + + self.model = model_config.build() + self.model.load_state_dict(titan_state_dict, strict=True) + + # Step2: Replace attention kernel be to vLLM's attention. + if batch_invariant_mode: + replace_with_vllm_compatible_flash_attention(self.model) + # vLLM's Attention requires bfloat16 inputs. + # TODO: Refine the dtype journey in trainer / generator + self.model.to(torch.bfloat16) + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism - # TODO: right now it only works for qwen3 model, need to formalize this to use parallelize_fn from ModelSpec - from torchtitan.models.llama3.parallelize import apply_ddp + # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from model_spec + if self.ddp_size > 1: + from torchtitan.models.llama3.parallelize import apply_ddp - apply_ddp( - self.model, - self.parallel_dims.get_mesh("dp_replicate"), - enable_compile=False, - ) + apply_ddp( + self.model, + self.parallel_dims.get_mesh("dp_replicate"), + enable_compile=False, + ) self.model = self.model.to(device) self.model.train() @@ -78,18 +143,22 @@ def __init__( self.policy_version = 0 self.generator: Optional[Any] = None - logger.info("Trainer initialized with TorchTitan model") + logger.debug( + f"PolicyTrainer initialized: " + f"group_size={self.group_size}, grpo_beta={self.grpo_beta}, " + f"use_stable_grpo={self.use_stable_grpo}" + ) @endpoint async def get_weights(self) -> dict: - """Get vLLM-compatible weights for generator. + """Get vLLM weights for generator. Returns: - vLLM-compatible state dict + vLLM state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm_compat(titan_state) - return vllm_compat_state + vllm_state = torchtitan_to_vllm(titan_state) + return vllm_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: @@ -98,8 +167,8 @@ async def step(self, trajectory: TrajectoryData) -> dict: Returns: Training metrics """ - logger.info( - f"{os.getpid()=} Trainer starts to train {self.policy_version} on traj:" + logger.debug( + f"{os.getpid()=} PolicyTrainer starts to train {self.policy_version} on traj:" ) # Compute loss loss, loss_metrics = compute_policy_gradient_loss_vllm( @@ -119,8 +188,6 @@ async def step(self, trajectory: TrajectoryData) -> dict: self.policy_version += 1 - # TODO: save dcp checkpoint to file here instead of sending weight dicts - # Return metrics metrics = { "loss": loss.item(), @@ -132,5 +199,5 @@ async def step(self, trajectory: TrajectoryData) -> dict: "policy_version": self.policy_version, **loss_metrics, } - logger.info(f"{os.getpid()=} Trainer finish step {self.policy_version}") + logger.debug(f"{os.getpid()=} PolicyTrainer finish step {self.policy_version}") return metrics diff --git a/torchtitan/experiments/rl/unified/config_registry.py b/torchtitan/experiments/rl/unified/config_registry.py new file mode 100644 index 0000000000..12ef11ce64 --- /dev/null +++ b/torchtitan/experiments/rl/unified/config_registry.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Config entry points for the RL/unified experiment. + +Each function returns a complete ``RLTrainer.Config`` and is discoverable by +``ConfigManager`` via ``--module rl.unified --config ``. +""" + +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.components.optimizer import OptimizersContainer +from torchtitan.config.configs import ( + ActivationCheckpointConfig, + ParallelismConfig, + TrainingConfig, +) +from torchtitan.experiments.rl.unified.actors.generator import Generator +from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer +from torchtitan.experiments.rl.unified.configs import ( + PolicyOptimizationConfig, + VLLMSamplingConfig, +) +from torchtitan.experiments.rl.unified.simple_grpo import RLTrainer +from torchtitan.models.qwen3 import model_registry + + +def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: + """GRPO training config for Qwen3-0.6B.""" + return RLTrainer.Config( + model_spec=model_registry("0.6B"), + num_steps=10, + batch_invariant_mode=True, + trainer=PolicyTrainer.Config( + optimizer=OptimizersContainer.Config(lr=1e-6), + lr_scheduler=LRSchedulersContainer.Config( + warmup_steps=2, + decay_type="linear", + ), + training=TrainingConfig( + local_batch_size=4, + seq_len=4096, + ), + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + data_parallel_replicate_degree=2, + ), + checkpoint=CheckpointManager.Config( + initial_load_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B", + initial_load_model_only=True, + initial_load_in_hf=True, + ), + activation_checkpoint=ActivationCheckpointConfig( + mode="selective", + selective_ac_option="op", + ), + ), + policy_optimization=PolicyOptimizationConfig( + beta=0.1, + group_size=8, + use_stable_grpo=False, + ), + generator=Generator.Config( + dtype="bfloat16", + gpu_memory_limit=0.5, + enforce_eager=True, + seed=42, + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=0.8, + top_p=0.95, + max_tokens=100, + ), + vllm_attention_backend="FLASH_ATTN", + ), + ) + + +def rl_grpo_qwen3_debug() -> RLTrainer.Config: + """Debug config for quick iteration — small model, few steps.""" + return RLTrainer.Config( + model_spec=model_registry("debugmodel"), + num_steps=5, + batch_invariant_mode=False, + trainer=PolicyTrainer.Config( + optimizer=OptimizersContainer.Config(lr=8e-4), + lr_scheduler=LRSchedulersContainer.Config( + warmup_steps=2, + decay_type="linear", + ), + training=TrainingConfig( + local_batch_size=2, + seq_len=2048, + ), + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + data_parallel_replicate_degree=1, + ), + checkpoint=CheckpointManager.Config( + interval=5, + ), + ), + policy_optimization=PolicyOptimizationConfig( + beta=0.1, + group_size=4, + use_stable_grpo=False, + ), + generator=Generator.Config( + gpu_memory_limit=0.3, + enforce_eager=True, + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=1.0, + max_tokens=50, + ), + ), + ) diff --git a/torchtitan/experiments/rl/unified/configs.py b/torchtitan/experiments/rl/unified/configs.py new file mode 100644 index 0000000000..6f92391ca2 --- /dev/null +++ b/torchtitan/experiments/rl/unified/configs.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Leaf data configs for RL training. + +These are plain dataclasses (not ``Configurable``) since they are just data, +not buildable components. +""" + +from dataclasses import dataclass + + +@dataclass(kw_only=True, slots=True) +class VLLMSamplingConfig: + """Sampling parameters passed to vLLM's SamplingParams.""" + + temperature: float = 0.8 + """Sampling temperature. 0.0 = greedy, higher = more random.""" + + top_p: float = 0.95 + """Nucleus sampling threshold.""" + + max_tokens: int = 100 + """Maximum number of tokens to generate per completion.""" + + +@dataclass(kw_only=True, slots=True) +class PolicyOptimizationConfig: + """Hyperparameters for Group Relative Policy Optimization.""" + + beta: float = 0.1 + """Temperature for GRPO exponential advantage weighting.""" + + group_size: int = 8 + """Number of completions per prompt for group-relative ranking.""" + + use_stable_grpo: bool = False + """Use stable mean-centering GRPO instead of exponential weighting.""" diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py deleted file mode 100755 index 32927f880c..0000000000 --- a/torchtitan/experiments/rl/unified/infer.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os - -# Must set spawn method before any CUDA operations or vLLM imports -# CUDA cannot be re-initialized in forked subprocesses -# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - -import argparse - -# Import unified module - this automatically registers TorchTitan models with vLLM -from torchtitan.experiments.rl import unified # noqa: F401 -from vllm import LLM, SamplingParams -from vllm.logger import init_logger - - -logger = init_logger(__name__) - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Run TorchTitan model inference with vLLM Engine", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--model-ckpt-path", - type=str, - default="torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B/", - help="Path to TorchTitan checkpoint directory", - ) - parser.add_argument( - "--prompt", - type=str, - default="Hello, my name is", - help="Prompt text for generation", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Sampling temperature", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=1, - help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", - ) - return parser.parse_args() - - -def infer(): - args = parse_args() - - logger.info("Initializing vLLM with TorchTitan model") - logger.info(f"Model: {args.model_ckpt_path}") - logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") - - # Initialize vLLM with custom TorchTitan model - # The LLM initialization will internally: - # 1. Load ModelSpec for Qwen3 (from models/__init__.py register()) - # 2. Create TorchTitanVLLMModel instance - # 3. Create JobConfig and ParallelDims from vLLM config - # 4. Apply parallelization using parallelize_qwen3 - # 5. Load model weights and prepare for inference - # The tensor_parallel_size will be used by vLLM to configure parallelization - # and will be available in vllm_config in worker processes - logger.info("Creating vLLM LLM engine...") - - llm = LLM( - model=args.model_ckpt_path, # Model checkpoint path - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, - dtype="bfloat16", - trust_remote_code=True, - enforce_eager=True, # Use eager mode - tensor_parallel_size=args.tensor_parallel_size, - gpu_memory_utilization=0.5, - ) - - logger.info("vLLM engine initialized successfully") - logger.info(f"Prompt: {args.prompt}") - - # Prepare prompt and sampling parameters - prompts = [args.prompt] - sampling_params = SamplingParams( - temperature=args.temperature, - top_p=0.95, - max_tokens=args.max_tokens, - ) - - # Generate text - logger.info("Generating text...") - outputs = llm.generate( - prompts=prompts, - sampling_params=sampling_params, - ) - - # Print results - logger.info("Generation complete") - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - - print(f"\nPrompt: {prompt}") - print(f"Generated text: {generated_text!r}\n") - - -if __name__ == "__main__": - infer() diff --git a/torchtitan/experiments/rl/unified/inference_example.py b/torchtitan/experiments/rl/unified/inference_example.py new file mode 100755 index 0000000000..e9963bfe24 --- /dev/null +++ b/torchtitan/experiments/rl/unified/inference_example.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Example inference script using TorchTitan models with vLLM LLMEngine. + +This script uses the RL unified config_registry to configure both +the vLLM engine and sampling parameters. + +Run: torchrun --nproc_per_node= \ + torchtitan/experiments/rl/unified/infer.py +""" +import os + +# Must set spawn method before any CUDA operations or vLLM imports +# CUDA cannot be re-initialized in forked subprocesses +# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.rl import unified # noqa: F401 + +from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b + +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def generate(): + + config = rl_grpo_qwen3_0_6b() + gen_config = config.generator + model_path = config.trainer.checkpoint.initial_load_path + + logger.debug("Initializing vLLM LLMEngine with TorchTitan model") + logger.debug(f"Model: {model_path}") + logger.debug( + f"Tensor Parallel Size: {gen_config.parallelism.tensor_parallel_degree}" + ) + + # Create EngineArgs from config + engine_args = EngineArgs( + # Model configuration + model=model_path, + trust_remote_code=True, + dtype=gen_config.dtype, + # Parallelism configuration + tensor_parallel_size=gen_config.parallelism.tensor_parallel_degree, + # Use external_launcher only when launched via torchrun (multi-GPU); + # for single-GPU, let vLLM pick the default executor. + distributed_executor_backend=("external_launcher"), + # Memory and performance + gpu_memory_utilization=gen_config.gpu_memory_limit, + enforce_eager=gen_config.enforce_eager, + # Seed + seed=gen_config.seed, + # HuggingFace overrides + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, + ) + + logger.debug("Initializing LLMEngine from EngineArgs...") + engine = LLMEngine.from_engine_args(engine_args) + + logger.debug("vLLM LLMEngine initialized successfully") + + # Create sampling parameters from config + sampling = gen_config.sampling + sampling_params = SamplingParams( + temperature=sampling.temperature, + top_p=sampling.top_p, + max_tokens=sampling.max_tokens, + ) + + logger.debug( + f"Sampling params: temperature={sampling.temperature}, " + f"top_p={sampling.top_p}, max_tokens={sampling.max_tokens}" + ) + + # Example prompt + prompt = "Hello, my name is" + logger.debug(f"Prompt: {prompt}") + + # Add request to engine + logger.debug("Adding request to engine...") + request_id = "0" + engine.add_request(request_id, prompt, sampling_params) + + # Generate text by stepping through engine + logger.debug("Generating text...") + while engine.has_unfinished_requests(): + request_outputs = engine.step() + + # Process finished requests + for request_output in request_outputs: + if request_output.finished: + prompt = request_output.prompt + generated_text = request_output.outputs[0].text + + # Print results + logger.debug("Generation complete") + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") + + +if __name__ == "__main__": + generate() diff --git a/torchtitan/experiments/rl/unified/infra/parallelism_utils.py b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py index a59a82f646..73b4c06ea0 100644 --- a/torchtitan/experiments/rl/unified/infra/parallelism_utils.py +++ b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py @@ -14,12 +14,10 @@ import torch.distributed as dist from torchtitan.config import CommConfig, ParallelismConfig, TrainingConfig -from torchtitan.trainer import Trainer - -JobConfig = Trainer.Config from torchtitan.distributed import utils as dist_utils from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.trainer import Trainer from vllm.config import VllmConfig from vllm.logger import init_logger @@ -60,7 +58,7 @@ def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDi world_size=world_size, ) - logger.info( + logger.debug( f"Created ParallelDims from vLLM config: " f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" @@ -98,29 +96,31 @@ def create_trainer_parallel_dims(ddp_size, tp_size) -> ParallelDims: ) -def create_job_config_from_vllm_config( +def create_trainer_config_from_vllm_config( vllm_config: VllmConfig, - model_name: str = "qwen3", hf_assets_path: str = "/path/to/hf/assets", -) -> JobConfig: +) -> Trainer.Config: """ - Create TorchTitan JobConfig from vLLM configuration. + Create a Trainer.Config from vLLM configuration. + + Maps vLLM parallelism and training settings to a Trainer.Config so that + TorchTitan's parallelize functions can be called with the correct kwargs. Args: vllm_config: vLLM configuration object containing model, parallel, and cache configs - model_name: Model name to use (default: "qwen3") - hf_assets_path: Path to HuggingFace assets directory (default: "/path/to/hf/assets") + hf_assets_path: Path to HuggingFace assets directory Returns: - JobConfig object with settings mapped from vLLM config + Trainer.Config with settings mapped from vLLM config + + TODO: Remove this function once explicitly register vllm model instead of import """ - # Create JobConfig with defaults - job_config = JobConfig() + config = Trainer.Config() - job_config.hf_assets_path = hf_assets_path + config.hf_assets_path = hf_assets_path parallel_config = vllm_config.parallel_config - job_config.parallelism = ParallelismConfig( + config.parallelism = ParallelismConfig( data_parallel_replicate_degree=parallel_config.data_parallel_size, data_parallel_shard_degree=1, # vLLM doesn't use FSDP sharding in inference context_parallel_degree=parallel_config.decode_context_parallel_size, @@ -130,9 +130,9 @@ def create_job_config_from_vllm_config( expert_tensor_parallel_degree=1, # Not used in vLLM inference yet ) - job_config.training = TrainingConfig( + config.training = TrainingConfig( local_batch_size=1, # Inference typically processes one batch at a time steps=1, # Single step for inference ) - return job_config + return config diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index d6c2622eff..133e33988c 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,40 +6,15 @@ import logging -from enum import Enum -import torch -from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) -from torchtitan.models.qwen3.model import Qwen3Model -from transformers import AutoConfig logger = logging.getLogger(__name__) -class ModelMode(str, Enum): - """ - Enum defining which TorchTitan model to use. - - Attributes: - UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified - training and inference. - VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, - ensuring bitwise determinism between training and inference. - STANDARD: Plain TorchTitan model without any modifications. - """ - - UNIFIED = "unified" - VLLM_COMPAT = "vllm_compat" - STANDARD = "standard" - - def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -70,18 +45,19 @@ def replace_with_vllm_attention(model, tp_degree=1): raise ValueError(f"Layer {layer_name} must have .attention attribute") # GQA + head_dim = model_args.layer.attention.head_dim vllm_attn = VLLMAttention( hidden_size=model_args.dim, num_heads=model_args.layer.attention.n_heads // tp_degree, num_kv_heads=num_kv_heads, - head_dim=model_args.head_dim, + head_dim=head_dim, layer_name=layer_name, - scale=model_args.head_dim**-0.5, + scale=head_dim**-0.5, ) layer.attention.inner_attention = vllm_attn - logger.info( + logger.debug( f"Successfully replaced TorchTitan attention with VLLMAttention " f"({len(model.layers)} layers)" ) @@ -106,93 +82,7 @@ def replace_with_vllm_compatible_flash_attention(model): layer.attention.inner_attention = vllm_attn - logger.info( + logger.debug( f"Successfully replaced TorchTitan attention with VLLMCompatibleFlashAttention " f"({len(model.layers)} layers)" ) - - -def load_model( - checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT -): - """ - Load TorchTitan model from checkpoint for trainer. - - Args: - checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model (for config) - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model - - Returns: - model: Loaded TorchTitan model for trainer. - """ - # Load HuggingFace config - # TODO: do not depend on transformers.AutoConfig, use qwen_args directly - hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - - # Create model args - model_args = Qwen3Model.Config( - dim=hf_config.hidden_size, - n_layers=hf_config.num_hidden_layers, - n_heads=hf_config.num_attention_heads, - n_kv_heads=hf_config.num_key_value_heads, - vocab_size=hf_config.vocab_size, - head_dim=getattr( - hf_config, - "head_dim", - hf_config.hidden_size // hf_config.num_attention_heads, - ), - hidden_dim=hf_config.intermediate_size, - norm_eps=hf_config.rms_norm_eps, - rope_theta=hf_config.rope_theta, - max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), - qk_norm=True, - depth_init=True, - eos_id=getattr(hf_config, "eos_token_id", 151645), - enable_weight_tying=getattr(hf_config, "tie_word_embeddings", False), - ) - - # state_dict is in standard TorchTitan format (w1, w2, w3) - state_dict = load_file(checkpoint_path) - - # If weight tying is enabled but output.weight is missing from the checkpoint - # (HF models with tie_word_embeddings=True may not store lm_head.weight), - # synthesize it from tok_embeddings.weight so load_state_dict(strict=True) works. - if model_args.enable_weight_tying and "output.weight" not in state_dict: - state_dict["output.weight"] = state_dict["tok_embeddings.weight"] - - if model_mode == ModelMode.UNIFIED: - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention for training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=True) - elif model_mode == ModelMode.VLLM_COMPAT: - # Create and load model that has bitwise determinism between training and inference - from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( - Qwen3VLLMCompatModel, - ) - - model = Qwen3VLLMCompatModel(model_args) - # Convert to vLLM-compat format (merged gate_up_proj, down_proj) - vllm_compat_state = torchtitan_to_vllm_compat(state_dict) - model.load_state_dict(vllm_compat_state, strict=False) - else: - # Use standard TorchTitan model - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=False) - - model.to(torch.bfloat16) - - return model diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index a335581ac6..c9505a6b55 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -11,7 +11,7 @@ TorchTitan models for vLLM. """ -from functools import partial +import dataclasses import torch import torch.nn as nn @@ -22,12 +22,11 @@ ) from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( - create_job_config_from_vllm_config, create_parallel_dims_from_vllm_config, + create_trainer_config_from_vllm_config, ) from torchtitan.experiments.rl.unified.models.utils import replace_with_vllm_attention -from torchtitan.models.qwen3.model import precompute_rope_cache from torchtitan.protocols.model import BaseModel from torchtitan.protocols.model_spec import ParallelizeFunction from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter @@ -77,23 +76,20 @@ def __init__( # Use TorchTitan model config directly (no HF config mapping) self.config = model_config - logger.info(f"Creating model with config: {model_config}") + logger.debug(f"Creating model with config: {model_config}") self.model = model_config.build() - # Setup RoPE cache extension function if provided - self.rope_cache_extension_fn = partial( - precompute_rope_cache, - dim=self.config.head_dim, - base=self.config.rope_theta, - ) + # RoPE config from model for cache extension + self.rope_config = self.config.rope - # Create ParallelDims and JobConfig from vLLM config at runtime + # Create ParallelDims and Trainer.Config from vLLM config at runtime # vLLM config contains the tensor_parallel_size from command-line args # and this will be consistent across all worker processes self.parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) - self.parallel_config = create_job_config_from_vllm_config( + self.trainer_config = create_trainer_config_from_vllm_config( vllm_config=vllm_config, ) + # Replace attention with vLLM paged attention tp_size = self.parallel_dims.tp if tp_size > 1: @@ -106,10 +102,16 @@ def __init__( # NOTE: We need to apply parallelize within model.__init__ because vllm # doesn't separate model creation and parallelism application and instead # requires parallelization to be done inside model constructor. + cfg = self.trainer_config self.model = parallelize_fn( model=self.model, parallel_dims=self.parallel_dims, - job_config=self.parallel_config, + training=cfg.training, + model_converters=cfg.model_converters, + parallelism=cfg.parallelism, + compile_config=cfg.compile, + ac_config=cfg.activation_checkpoint, + dump_folder=cfg.dump_folder, ) def _extend_rope_cache_if_needed( @@ -131,15 +133,6 @@ def _extend_rope_cache_if_needed( if required_len <= rope_cache.shape[0]: return rope_cache - # If no extension function provided, return original cache - if self.rope_cache_extension_fn is None: - logger.warning( - f"RoPE cache extension needed (required_len={required_len}, " - f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " - "Returning original cache." - ) - return rope_cache - # Handle DTensor case is_dtensor = isinstance(rope_cache, DTensor) if is_dtensor: @@ -151,16 +144,12 @@ def _extend_rope_cache_if_needed( device = rope_cache.device dtype = rope_cache.dtype - # Use provided extension function - try: - extended_cache = self.rope_cache_extension_fn(self.config, required_len) - extended_cache = extended_cache.to(device=device, dtype=dtype) - except Exception as e: - logger.warning( - f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " - "Returning original cache." - ) - return rope_cache + # Build a new RoPE module with extended max_seq_len + extended_rope_config = dataclasses.replace( + self.rope_config, max_seq_len=required_len + ) + extended_rope = extended_rope_config.build() + extended_cache = extended_rope.cache.to(device=device, dtype=dtype) # Convert back to DTensor if needed if is_dtensor: @@ -216,9 +205,7 @@ def forward( # Get RoPE cache (handle model-specific attribute names) # Use hasattr to avoid ambiguous boolean value error with tensors - if hasattr(self.model, "rope_cache"): - rope_attr = self.model.rope_cache - elif hasattr(self.model, "freqs_cis"): + if hasattr(self.model, "freqs_cis"): rope_attr = self.model.freqs_cis else: rope_attr = None diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_grpo.py new file mode 100644 index 0000000000..8c9ab2f1f4 --- /dev/null +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multiprocess RL training loop using Monarch Actors. + +This demonstrates: +1. Distributed actor architecture with Generator (vLLM) and PolicyTrainer (TorchTitan) components +2. File based weight synchronization between trainer and generator + +The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. + +Command to run: +python3 torchtitan/experiments/rl/unified/simple_grpo.py +""" + +import asyncio +import logging +from dataclasses import dataclass, field + +import torch + +import torchtitan.experiments.rl.unified # noqa: F401 — registers models with vLLM +from monarch.actor import this_host +from monarch.utils import setup_env_for_distributed +from torchtitan.config import Configurable +from torchtitan.experiments.rl.unified.actors.generator import Generator +from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer +from torchtitan.experiments.rl.unified.configs import PolicyOptimizationConfig +from torchtitan.protocols.model_spec import ModelSpec + +logger = logging.getLogger(__name__) + + +class RLTrainer(Configurable): + """Top-level RL training orchestrator. + + Composes a ``PolicyTrainer`` (for model construction, parallelism, + optimizer, checkpoint, etc.) with RL-specific settings and a + Generator Monarch actor. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """Master config for RL training. + + ``trainer`` holds the PolicyTrainer.Config that handles model + construction, parallelism, optimizer, LR scheduler, checkpoint, etc. + The remaining fields are RL-specific. + """ + + model_spec: ModelSpec | None = None + """Model specification shared by trainer and generator. + Set programmatically via config_registry (not from CLI).""" + + trainer: PolicyTrainer.Config = field(default_factory=PolicyTrainer.Config) + """PolicyTrainer config. Controls optimizer, training, parallelism, + lr_scheduler, checkpoint, activation_checkpoint.""" + + num_steps: int = 10 + """Number of RL training steps.""" + + dump_folder: str = "outputs/rl" + """Root output folder for RL artifacts (temp weights, logs, etc.).""" + + batch_invariant_mode: bool = True + """Enable batch-invariant mode for deterministic NCCL collective + operations and bitwise-reproducible forward/backward passes.""" + + policy_optimization: PolicyOptimizationConfig = field( + default_factory=PolicyOptimizationConfig + ) + """Policy optimization hyperparameters.""" + + generator: Generator.Config = field(default_factory=Generator.Config) + """Generator actor configuration (vLLM engine, sampling).""" + + +async def main(): + """Run the distributed RL training loop using Monarch.""" + + # Load config from config_registry + from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b + + # TODO: Make simple_grpo.py take --config as input + config = rl_grpo_qwen3_0_6b() + trainer_cfg = config.trainer + + # Compute world size for trainer and generator + # TODO: refine the world size computation and check + trainer_ddp_size = trainer_cfg.parallelism.data_parallel_replicate_degree + trainer_tp_size = trainer_cfg.parallelism.tensor_parallel_degree + + # RL Training config + num_steps = config.num_steps + + # Use fake dataset for test. TODO: Implement real RL dataloader. + logger.debug("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] + + logger.debug(f"Loaded {len(prompt_texts)} prompts") + + # Create process meshes + trainer_mesh = this_host().spawn_procs( + per_host={"gpus": trainer_ddp_size * trainer_tp_size} + ) + gen_tp_size = config.generator.parallelism.tensor_parallel_degree + gen_mesh = this_host().spawn_procs(per_host={"gpus": gen_tp_size}) + + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + trainer_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29500, # TODO: figure out what to set + ) + + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + gen_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29501, # TODO: figure out what to set + ) + + # Spawn actors on trainer and generator mesh + trainer = trainer_mesh.spawn( + "trainer", + PolicyTrainer, + config.trainer, + model_spec=config.model_spec, + policy_optimization=config.policy_optimization, + batch_invariant_mode=config.batch_invariant_mode, + ) + + generator = gen_mesh.spawn( + "generator", + Generator, + config.generator, + model_spec=config.model_spec, + model_path=trainer_cfg.checkpoint.initial_load_path, + dump_folder=config.dump_folder, + batch_invariant_mode=config.batch_invariant_mode, + policy_optimization=config.policy_optimization, + prompt_texts=prompt_texts, + expected_answers=expected_answers, + ) + + # Initialize generator with trainer weights + initial_weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call(0, initial_weights) + + # Training loop + logger.info("\n" + "=" * 80) + logger.info(f"Starting RL training for {num_steps} steps") + logger.info("=" * 80) + + for step in range(num_steps): + # Fully sync RL loop + # NOTE: This is only getting Trajectory generated from trainer 0, and trainer 1's data is ignored. + # .get() is a blocking method which makes the loop fully sync + batch = generator.generate.call().get().item(gpus=0) + metrics = trainer.step.call(batch).get().item(gpus=0) + weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call(metrics["policy_version"], weights) + + logger.info( + f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " + f"Reward: {metrics['reward_mean']:+.3f}" + ) + logger.debug(f" Sample: {metrics['sample_completion']}...") + + # Check for divergence + if not torch.isfinite(torch.tensor(metrics["loss"])): + logger.debug("\n" + "!" * 80) + logger.debug("ERROR: Loss is NaN/Inf! Training diverged.") + logger.debug("!" * 80) + break + + logger.info("\n" + "=" * 80) + logger.info("RL Training complete") + logger.info("=" * 80) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py deleted file mode 100644 index b221dad8d7..0000000000 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Multiprocess RL training loop using Monarch Actors. - -This demonstrates: -1. Distributed actor architecture with Generator (vLLM) and Trainer (TorchTitan) components -2. File based weight synchronization between trainer and generator - -The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. - -Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py -""" -import asyncio -import logging - -import torch -from monarch.actor import this_host -from monarch.utils import setup_env_for_distributed -from torchtitan.experiments.rl.unified.actors.generator import Generator -from torchtitan.experiments.rl.unified.actors.trainer import Trainer -from torchtitan.experiments.rl.unified.models.utils import ModelMode -from torchtitan.experiments.rl.vllm_compat.simple_rl import ( - download_and_convert_model, - load_gsm8k_dataset, -) -from vllm.model_executor.layers.batch_invariant import ( - init_batch_invariance, - vllm_is_batch_invariant, -) -from vllm.v1.attention.backends.registry import AttentionBackendEnum - -logger = logging.getLogger(__name__) - - -async def main(): - """Run the distributed RL training loop using Monarch.""" - # Model Config - model_name = "Qwen/Qwen3-0.6B" - cache_dir = "./models" - output_dir = "./converted" - - # Training config - group_size = 8 - num_steps = 10 - learning_rate = 1e-5 - max_new_tokens = 20 - - # GRPO config - use_stable_grpo = False - grpo_beta = 0.1 - - # Dataset config - use_real_dataset = False - num_dataset_samples = 5 - - # Parallelism sizes - trainer_ddp_size = 2 - trainer_tp_size = 1 - generator_tp_size = 1 - - init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) - batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.UNIFIED - - # Set up batch invariant - if batch_invariant: - logger.info("Batch invariance detected - using vLLM-compatible model") - from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( - enable_batch_invariant_backward_mode, - ) - - enable_batch_invariant_backward_mode() - else: - raise RuntimeError("Batch invariance NOT detected - using standard model") - - # Download and convert model - titan_checkpoint_path, model_path = download_and_convert_model( - model_name, cache_dir, output_dir - ) - - # Load dataset - if use_real_dataset: - logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") - # TODO: Refactor into loading torchtitan dataset - prompt_texts, expected_answers = load_gsm8k_dataset( - split="train", num_samples=num_dataset_samples - ) - if prompt_texts is None or len(prompt_texts) == 0: - use_real_dataset = False - - if not use_real_dataset: - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] - - logger.info(f"Loaded {len(prompt_texts)} prompts") - - # Create process meshes - trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) - gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) - - # Set up distributed env vars so that actors are connected via c10d - await setup_env_for_distributed( - trainer_mesh, - master_addr="localhost", # TODO: figure out what to set - master_port=29500, # TODO: figure out what to set - ) - - # Set up distributed env vars so that actors are connected via c10d - await setup_env_for_distributed( - gen_mesh, - master_addr="localhost", # TODO: figure out what to set - master_port=29501, # TODO: figure out what to set - ) - - # Spawn actors on trainer and generator mesh - trainer = trainer_mesh.spawn( - "trainer", - Trainer, - titan_checkpoint_path, - model_path, - learning_rate, - mode, - trainer_ddp_size, - trainer_tp_size, - ) - - generator = gen_mesh.spawn( - "generator", - Generator, - model_path, - prompt_texts, - expected_answers, - group_size, - max_new_tokens, - 1.0, # temperature - use_real_dataset, - grpo_beta, - use_stable_grpo, - generator_tp_size, - ) - - # Initialize generator with trainer weights - initial_weights = trainer.get_weights.call().get().item(gpus=0) - await generator.update.call(0, initial_weights) - - # Training loop - logger.info("\n" + "=" * 80) - logger.info(f"Starting RL training for {num_steps} steps") - logger.info("=" * 80) - - for step in range(num_steps): - # Fully sync RL loop - batch = generator.generate.call().get().item(gpus=0) - metrics = trainer.step.call(batch).get().item(gpus=0) - weights = trainer.get_weights.call().get().item(gpus=0) - await generator.update.call(metrics["policy_version"], weights) - - logger.info( - f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " - f"Reward: {metrics['reward_mean']:+.3f}" - ) - logger.info(f" Sample: {metrics['sample_completion']}...") - - # Check for divergence - if not torch.isfinite(torch.tensor(metrics["loss"])): - logger.info("\n" + "!" * 80) - logger.info("ERROR: Loss is NaN/Inf! Training diverged.") - logger.info("!" * 80) - break - - logger.info("\n" + "=" * 80) - logger.info("RL Training complete") - logger.info("=" * 80) - - -if __name__ == "__main__": - asyncio.run(main())