From c810835e13bc498a44338c70951fef019314ea14 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 12 Jan 2026 16:20:08 -0800 Subject: [PATCH] refactor save and load model weights using DCP [ghstack-poisoned] --- .../rl/unified/actors/generator.py | 4 +- .../experiments/rl/unified/actors/trainer.py | 11 +--- .../rl/unified/models/vllm_wrapper.py | 57 +++++++++++-------- 3 files changed, 38 insertions(+), 34 deletions(-) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index efba2ac2c7..219c3bccf9 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -57,7 +57,7 @@ class TrajectoryData: advantages: torch.Tensor -class VLLMRolloutEngine: +class VLLMGenerator: """ vLLM engine for fast rollouts with weight updates. @@ -355,7 +355,7 @@ def __init__( Comm(), ) # Initialize vLLM engine with job_config - self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) + self.vllm_engine = VLLMGenerator(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index ccfa10a708..fc15e5d426 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -19,7 +19,6 @@ from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm logger = logging.getLogger(__name__) @@ -52,7 +51,6 @@ def __init__( # load trainer model and patch to vllm.Attention() self.model = load_trainer_model(model_path) - self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -77,14 +75,13 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM weights for generator. + """Get model weights for generator. Returns: - vLLM state dict + model state dict """ titan_state = self.model.state_dict() - vllm_state = torchtitan_to_vllm(titan_state) - return vllm_state + return titan_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: @@ -114,8 +111,6 @@ async def step(self, trajectory: TrajectoryData) -> dict: self.policy_version += 1 - # TODO: save dcp checkpoint to file here instead of sending weight dicts - # Return metrics metrics = { "loss": loss.item(), diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index f3ae7f348a..699593b9b7 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -275,37 +275,19 @@ def compute_logits( return logits - def load_weights(self, weights_iter): + def load_weights_from_state_dict(self, titan_state_dict): """ - Load weights from HF checkpoint using the provided state dict adapter. - vLLM engine would call this function to load model weights. - - Args: - weights_iter: Iterator of (name, tensor) pairs from HF checkpoint - - Returns: - Set of loaded parameter names + Load model weights directly from """ - # Collect weights from iterator - hf_state_dict = {} - for name, tensor in weights_iter: - hf_state_dict[name] = tensor - - # Use adapter to convert HF → TorchTitan format - adapter = self.state_dict_adapter( - model_args=self.config, - hf_assets_path=None, - ) - torchtitan_state_dict = adapter.from_hf(hf_state_dict) model_state_dict = {k: v for k, v in self.model.state_dict().items()} # Convert to DTensor if target is DTensor - for name, tensor in torchtitan_state_dict.items(): + for name, tensor in titan_state_dict.items(): if name in model_state_dict and isinstance(model_state_dict[name], DTensor): target_dtensor = model_state_dict[name] device_mesh = target_dtensor.device_mesh - torchtitan_state_dict[name] = DTensor.from_local( + titan_state_dict[name] = DTensor.from_local( tensor.to(device_mesh.device_type), device_mesh=device_mesh, placements=[Replicate()], @@ -314,10 +296,37 @@ def load_weights(self, weights_iter): # Load state dict set_model_state_dict( model=self.model, - model_state_dict=torchtitan_state_dict, + model_state_dict=titan_state_dict, options=StateDictOptions(strict=False), ) - loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + loaded_params = titan_state_dict.keys() return loaded_params + + def load_weights(self, weights_iter): + """ + Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + + # Since our model weights are already loaded during initialization, + # we need to return the names of all parameters that have been loaded + # so vLLM's safety check passes. + loaded_param_names = set() + for name, _ in self.model.named_parameters(): + loaded_param_names.add(name) + + logger.info( + f"Weights already loaded during model initialization. \ + Returning {len(loaded_param_names)} loaded parameter names to satisfy vLLM safety check." + ) + + # Return the names of all loaded parameters so vLLM knows they were handled + return loaded_param_names