diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py index 2af3bcea68..3bdae48e4f 100644 --- a/torchtitan/experiments/rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -7,9 +7,9 @@ """ Unified approach for running TorchTitan models with vLLM inference. -To manually load the plugin: - from torchtitan.experiments.rl.unified import plugin - plugin.register() +To register TorchTitan models with vLLM: + from torchtitan.experiments.rl.unified.plugin import register + register(model_spec) """ from torchtitan.experiments.rl.unified.models.vllm_wrapper import ( diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 68c9459a73..9c4b3b267c 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -4,17 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import glob import logging import os -import shutil from dataclasses import dataclass, field from typing import List import torch from monarch.actor import Actor, endpoint -from safetensors.torch import save_file from torchtitan.config import CommConfig, Configurable, ParallelismConfig from torchtitan.distributed import utils as dist_utils @@ -75,7 +72,6 @@ class Generator(Actor, Configurable): Args: config: Generator-specific configuration. model_path: Path to the HF model checkpoint. - dump_folder: Root output folder for RL artifacts. batch_invariant_mode: Enable batch-invariant mode for deterministic ops. policy_optimization: GRPO hyperparameters. prompt_texts: List of prompt strings. @@ -95,8 +91,8 @@ class Config(Configurable.Config): enforce_eager: bool = True """Disable CUDA graphs in vLLM (use eager execution).""" - seed: int | None = None - """Random seed for reproducible generation. None means no fixed seed.""" + seed: int = 42 + """Random seed for reproducible generation.""" parallelism: ParallelismConfig = field(default_factory=ParallelismConfig) """Parallelism configuration for the vLLM engine.""" @@ -113,7 +109,6 @@ def __init__( *, model_spec: ModelSpec, model_path: str, - dump_folder: str, batch_invariant_mode: bool, policy_optimization: PolicyOptimizationConfig, prompt_texts: list[str], @@ -129,7 +124,6 @@ def __init__( ) register_model_to_vllm_model_registry(model_spec) - self._vllm_model_name = VLLM_MODEL_NAME # Set vLLM environment variables from config before any vLLM initialization if batch_invariant_mode: @@ -148,135 +142,39 @@ def __init__( self.group_size = policy_optimization.group_size self.grpo_beta = policy_optimization.beta self.use_stable_grpo = policy_optimization.use_stable_grpo - - # Initialize distributed environment for SPMD generator - world_size = dist_utils.init_distributed(CommConfig()) - - # Set up temp model directory for vLLM weight loading - self._base_model_path = model_path - self._temp_model_dir = os.path.abspath( - os.path.join(dump_folder, "vllm_temp_model") + + # Build vLLM engine + engine_args = EngineArgs( + model=model_path, + trust_remote_code=True, + dtype=config.dtype, + tensor_parallel_size=config.parallelism.tensor_parallel_degree, + distributed_executor_backend="external_launcher", + gpu_memory_utilization=config.gpu_memory_limit, + enforce_eager=config.enforce_eager, + seed=config.seed, + hf_overrides={"architectures": [VLLM_MODEL_NAME]}, + attention_config=AttentionConfig( + backend=AttentionBackendEnum.FLASH_ATTN, + ), ) - os.makedirs(self._temp_model_dir, exist_ok=True) - - # Copy config/tokenizer files from base model to temp dir - for file in [ - "config.json", - "tokenizer.json", - "tokenizer_config.json", - "special_tokens_map.json", - "merges.txt", - "vocab.json", - ]: - src = os.path.join(model_path, file) - if os.path.exists(src): - shutil.copy2(src, self._temp_model_dir) - - # Copy the original model shard files if they exist - # We'll overwrite these with our single model.safetensors later - for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): - dst = os.path.join(self._temp_model_dir, os.path.basename(shard_file)) - shutil.copy2(shard_file, dst) - - # Copy index file if it exists - index_file = os.path.join(model_path, "model.safetensors.index.json") - if os.path.exists(index_file): - shutil.copy2(index_file, self._temp_model_dir) - - self._engine: LLMEngine | None = None + + logger.info("Initializing LLMEngine from EngineArgs...") + self._engine = LLMEngine.from_engine_args(engine_args) + logger.info("vLLM rollout engine initialized") self.policy_version = 0 # Reward function. TODO: Move reward calculation out of generator self.reward_fn = trivial_reward_function - logger.debug("Generator initialized (vLLM engine will load on first use)") + logger.info("Generator initialized with vLLM engine") - def _update_vllm_model_weights(self, vllm_state: dict) -> None: + def _get_model(self): + """Access the model from the vLLM engine. + Returns a TorchTitanVLLMModelWrapper instance. """ - Update vLLM model weights from vLLM model state dict. This function is used - when updating vLLM model's weights from trainer's updated weights. - - Args: - vllm_state: vLLM model state dict, a map from vLLM model's fqn names to weights - """ - # Save to temp model directory - checkpoint_path = os.path.join(self._temp_model_dir, "model.safetensors") - - # Update the shard files that vLLM will actually load - # We need to split our weights to match the original 2-shard structure - shard_files = sorted( - glob.glob(os.path.join(self._temp_model_dir, "model-*.safetensors")) - ) - index_file = os.path.join(self._temp_model_dir, "model.safetensors.index.json") - - # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 1 safetensor files which we only - # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore - # to achieve the weight communication. - if torch.distributed.get_rank() == 0: - logger.debug(f"Saving weights to {checkpoint_path}") - - # TODO: Check the detail of vLLM's dtype conversion journey - # Currently converting float32 to bfloat16 to match vLLM's attention and kv cache dtype - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - save_file(vllm_state, checkpoint_path) - - # Synchronize all ranks before reloading to ensure rank 0 finished writing - torch.distributed.barrier() - logger.debug( - f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" - ) - - # First time: create the engine using LLMEngine and EngineArgs - if self._engine is None: - cfg = self.config - - engine_args = EngineArgs( - # Model configuration - model=self._temp_model_dir, - trust_remote_code=True, - dtype=cfg.dtype, - # Parallelism configuration - tensor_parallel_size=cfg.parallelism.tensor_parallel_degree, - # Use external_launcher because Monarch already spawns the worker processes - distributed_executor_backend="external_launcher", - # Memory and performance - gpu_memory_utilization=cfg.gpu_memory_limit, - enforce_eager=cfg.enforce_eager, - # Seed - seed=cfg.seed, - # HuggingFace overrides to use registered TorchTitan model - hf_overrides={"architectures": [self._vllm_model_name]}, - attention_config=AttentionConfig( - backend=AttentionBackendEnum.FLASH_ATTN, - ), - ) - - logger.debug("Initializing LLMEngine from EngineArgs...") - self._engine = LLMEngine.from_engine_args(engine_args) - logger.debug("Created new vLLM LLMEngine") - else: - # Direct parameter copy into model tensors. - # This bypasses vLLM's reload_weights() which uses a layerwise - # reload mechanism that moves params to meta device - from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan - - titan_state = vllm_to_torchtitan(vllm_state) - model = self._engine.model_executor.driver_worker.get_model() - params = dict(model.named_parameters()) - - for name, new_weight in titan_state.items(): - # TorchTitanVLLMModelWrapper stores the model as self.model, - # so parameters have a "model." prefix - param_name = f"model.{name}" - if param_name in params: - param = params[param_name] - new_w = new_weight.to(device=param.device, dtype=param.dtype) - param.data.copy_(new_w) + return self._engine.model_executor.driver_worker.get_model() def _compute_rewards_and_advantages( self, completions: list[str] @@ -359,16 +257,12 @@ async def generate(self) -> TrajectoryData: for sample in output.outputs: completions.append(sample.text) - prompt_token_ids_list.append( - prompt_token_ids - ) # Store prompt tokens for this sample - token_ids_list.append( - sample.token_ids - ) # Extract token IDs (generated tokens only) + prompt_token_ids_list.append(prompt_token_ids) + token_ids_list.append(sample.token_ids) per_token_log_probs = [ list(logprob_dict.values())[0].logprob for logprob_dict in sample.logprobs - ] # Extract per-token log probs + ] token_log_probs_list.append(per_token_log_probs) # Compute rewards and advantages @@ -391,18 +285,19 @@ async def generate(self) -> TrajectoryData: return trajectory @endpoint - async def update(self, version: int, vllm_compat_state: dict) -> None: - """Update generate weights. + async def update(self, version: int, state_dict: dict) -> None: + """Update generator weights. Called by the orchestrator (simple_grpo.py). Args: version: New policy version number - vllm_compat_state: vLLM-compatible state dict + state_dict: Model state dict """ - # TODO: remove the helper function (_update_vllm_model_weights) once we clean up the weight updates - self._update_vllm_model_weights(vllm_compat_state) + load_weights = self._get_model().load_weights_from_state_dict(state_dict) self.policy_version = version logger.debug( + f"Updated weights into vLLM engine model. " + f"Number of parameters: {len(load_weights)}" f"{os.getpid()=} Generator updating weights to policy v{version}..." ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 47120fe708..e7f2705406 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -152,14 +152,13 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM weights for generator. + """Get model weights for generator. Returns: - vLLM state dict + model state dict """ titan_state = self.model.state_dict() - vllm_state = torchtitan_to_vllm(titan_state) - return vllm_state + return titan_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index 952383194a..f57cabd4fc 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist +import torch.distributed.checkpoint as dcp import torch.nn as nn from torch.distributed._tensor import DTensor, Replicate from torch.distributed.checkpoint.state_dict import ( @@ -159,6 +160,9 @@ def __init__( parallelism=parallelism, ) + # Initial load model weights from HuggingFace checkpoint path + self._initial_load_weights(checkpoint_path=vllm_config.model_config.model) + def _extend_rope_cache_if_needed( self, rope_cache: torch.Tensor, max_position: int ) -> torch.Tensor: @@ -209,11 +213,13 @@ def _extend_rope_cache_if_needed( return rope_cache def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings.""" + """vLLM required API. + Convert input token IDs to embeddings.""" return self.model.tok_embeddings(input_ids) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - """Convert input token IDs to embeddings (deprecated vLLM interface).""" + """vLLM required API. + Convert input token IDs to embeddings (deprecated vLLM interface).""" return self.embed_input_ids(input_ids) def forward( @@ -224,6 +230,7 @@ def forward( **kwargs, ) -> torch.Tensor: """ + vLLM required API. Forward pass with vLLM interface. Args: @@ -287,7 +294,9 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata=None, ) -> torch.Tensor | None: - """Compute logits from hidden states.""" + """vLLM required API. + Compute logits from hidden states.""" + # When TP is applied, we return the full tensor (plain tensor) to vLLM engine # at the end of TorchTitanVLLMModelWrapper.forward(). # We need to wrap the input from vLLM engine back to DTensor with Replicate() placement. @@ -304,37 +313,19 @@ def compute_logits( return logits - def load_weights(self, weights_iter): + def load_weights_from_state_dict(self, trainer_state_dict): """ - Load weights from HF checkpoint using the provided state dict adapter. - vLLM engine would call this function to load model weights. - - Args: - weights_iter: Iterator of (name, tensor) pairs from HF checkpoint - - Returns: - Set of loaded parameter names + Load model weights directly from a state dict containing DTensor. """ - # Collect weights from iterator - hf_state_dict = {} - for name, tensor in weights_iter: - hf_state_dict[name] = tensor - - # Use adapter to convert HF → TorchTitan format - adapter = self.state_dict_adapter( - model_config=self.config, - hf_assets_path=None, - ) - torchtitan_state_dict = adapter.from_hf(hf_state_dict) model_state_dict = {k: v for k, v in self.model.state_dict().items()} - # Convert to DTensor if target is DTensor - for name, tensor in torchtitan_state_dict.items(): + # Convert to DTensor if target is DTensor (when the target model is sharded) + for name, tensor in trainer_state_dict.items(): if name in model_state_dict and isinstance(model_state_dict[name], DTensor): target_dtensor = model_state_dict[name] device_mesh = target_dtensor.device_mesh - torchtitan_state_dict[name] = DTensor.from_local( + trainer_state_dict[name] = DTensor.from_local( tensor.to(device_mesh.device_type), device_mesh=device_mesh, placements=[Replicate()], @@ -343,10 +334,63 @@ def load_weights(self, weights_iter): # Load state dict set_model_state_dict( model=self.model, - model_state_dict=torchtitan_state_dict, + model_state_dict=trainer_state_dict, options=StateDictOptions(strict=False), ) - loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + loaded_params = trainer_state_dict.keys() return loaded_params + + def _initial_load_weights(self, checkpoint_path): + """ + Helper function to load torchtitan model weights from HF checkpoint when initialize this model. + + Args: + checkpoint_path: Path to the HuggingFace checkpoint directory + """ + # Create adapter instance + adapter = self.state_dict_adapter( + model_config=self.config, + hf_assets_path=None, + ) + + # Get HF storage reader from adapter + storage_reader = adapter.get_hf_storage_reader(checkpoint_path) + + # Load HF state dict using DCP + hf_state_dict = adapter.to_hf(self.model.state_dict()) + dcp.load(hf_state_dict, storage_reader=storage_reader) + + # Convert HF state dict to TorchTitan format + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + + return self.load_weights_from_state_dict(torchtitan_state_dict) + + def load_weights(self, weights_iter): + """ + vLLM required API. + Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + + # Since our model weights are already loaded during initialization, + # we need to return the names of all parameters that have been loaded + # so vLLM's safety check passes. + loaded_param_names = set() + for name, _ in self.model.named_parameters(): + loaded_param_names.add("model." + name) + + logger.info( + f"Weights already loaded during model initialization. \ + Returning {len(loaded_param_names)} loaded parameter names to satisfy vLLM safety check." + ) + + # Return the names of all loaded parameters so vLLM knows they were handled + return loaded_param_names diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_grpo.py index 0c0326446d..77d43a8be7 100644 --- a/torchtitan/experiments/rl/unified/simple_grpo.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -159,7 +159,6 @@ async def main(): config.generator, model_spec=config.model_spec, model_path=config.trainer.hf_assets_path, - dump_folder=config.dump_folder, batch_invariant_mode=config.batch_invariant_mode, policy_optimization=config.policy_optimization, prompt_texts=prompt_texts, @@ -168,7 +167,7 @@ async def main(): # Initialize generator with trainer weights initial_weights = trainer.get_weights.call().get().item(gpus=0) - await generator.update.call(0, initial_weights) + generator.update.call(0, initial_weights).get() # Training loop logger.info("\n" + "=" * 80) @@ -178,11 +177,11 @@ async def main(): for step in range(num_steps): # Fully sync RL loop # NOTE: This is only getting Trajectory generated from trainer 0, and trainer 1's data is ignored. - # .get() is a blocking method which makes the loop fully sync + # .get() is a monarch synchronize API which makes the loop fully sync batch = generator.generate.call().get().item(gpus=0) metrics = trainer.step.call(batch).get().item(gpus=0) weights = trainer.get_weights.call().get().item(gpus=0) - await generator.update.call(metrics["policy_version"], weights) + generator.update.call(metrics["policy_version"], weights).get() logger.info( f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | "