From 5bc3c983e036662c9b83f76dcda4f644c92bca39 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 2 Jan 2026 08:59:28 -0800 Subject: [PATCH 01/10] config sys v1 [ghstack-poisoned] --- torchtitan/experiments/rl/unified/infer.py | 159 +++++++++--------- .../experiments/rl/unified/job_config.py | 88 ++++++++++ .../rl/unified/run_configs/qwen3_0.6b.toml | 62 +++++++ .../rl/unified/simple_rl_multiprocess.py | 2 + 4 files changed, 227 insertions(+), 84 deletions(-) create mode 100644 torchtitan/experiments/rl/unified/job_config.py create mode 100644 torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 43153fb70e..136b264a72 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -5,113 +5,104 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse +""" +Example inference script using TorchTitan models with vLLM LLMEngine. + +This script uses JobConfig loaded from a TOML file to configure both +the vLLM engine and sampling parameters. + +Run: torchrun --nproc_per_node=2 \ + torchtitan/experiments/rl/unified/infer.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +""" + + +from torchtitan.config import ConfigManager # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 -from vllm import LLM, SamplingParams +from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.logger import init_logger logger = init_logger(__name__) -def parse_args(): - parser = argparse.ArgumentParser( - description="Run TorchTitan model inference with vLLM Engine", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--model-ckpt-path", - type=str, - default="torchtitan/experiments/rl/example_checkpoint", - help="Path to TorchTitan checkpoint directory", - ) - parser.add_argument( - "--prompt", - type=str, - default="Hello, my name is", - help="Prompt text for generation", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Sampling temperature", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=1, - help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", +def infer(): + + config_manager = ConfigManager() + job_config = config_manager.parse_args() + + logger.info("Initializing vLLM LLMEngine with TorchTitan model") + logger.info(f"Model: {job_config.checkpoint.initial_load_path}") + logger.info( + f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" ) - return parser.parse_args() + # Create EngineArgs from JobConfig + # Map TorchTitan parallelism to vLLM parallelism + inference = job_config.inference -def infer(): - args = parse_args() - - logger.info("Initializing vLLM with TorchTitan model") - logger.info(f"Model: {args.model_ckpt_path}") - logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") - - # Initialize vLLM with custom TorchTitan model - # The LLM initialization will internally: - # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) - # 2. Create TorchTitanVLLMModel instance - # 3. Create JobConfig and ParallelDims from vLLM config - # 4. Apply parallelization using parallelize_qwen3 - # 5. Load model weights and prepare for inference - # The tensor_parallel_size will be used by vLLM to configure parallelization - # and will be available in vllm_config in worker processes - logger.info("Creating vLLM LLM engine...") - - llm = LLM( - model=args.model_ckpt_path, # Model checkpoint path - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, - dtype="bfloat16", + engine_args = EngineArgs( + # Model configuration + model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - enforce_eager=True, # Use eager mode - tensor_parallel_size=args.tensor_parallel_size, - gpu_memory_utilization=0.5, + dtype=inference.dtype, + # Parallelism configuration + tensor_parallel_size=inference.parallelism.tensor_parallel_degree, + distributed_executor_backend=inference.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=inference.gpu_memory_utilization, + enforce_eager=inference.enforce_eager, + # Seed + seed=inference.seed, + # HuggingFace overrides + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("vLLM engine initialized successfully") - logger.info(f"Prompt: {args.prompt}") + logger.info("Initializing LLMEngine from EngineArgs...") + engine = LLMEngine.from_engine_args(engine_args) - # Prepare prompt and sampling parameters - prompts = [args.prompt] + logger.info("vLLM LLMEngine initialized successfully") + + # Create sampling parameters from JobConfig + sampling = job_config.inference.sampling sampling_params = SamplingParams( - temperature=args.temperature, - top_p=0.95, - max_tokens=args.max_tokens, + temperature=sampling.temperature, + top_p=sampling.top_p, + max_tokens=sampling.max_tokens, ) - # Generate text - logger.info("Generating text...") - outputs = llm.generate( - prompts=prompts, - sampling_params=sampling_params, + logger.info( + f"Sampling params: temperature={sampling.temperature}, " + f"top_p={sampling.top_p}, max_tokens={sampling.max_tokens}" ) - # Print results - logger.info("Generation complete") - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text + # Example prompt + prompt = "Hello, my name is" + logger.info(f"Prompt: {prompt}") - print(f"\nPrompt: {prompt}") - print(f"Generated text: {generated_text!r}\n") + # Add request to engine + logger.info("Adding request to engine...") + request_id = "0" + engine.add_request(request_id, prompt, sampling_params) + + # Generate text by stepping through engine + logger.info("Generating text...") + while engine.has_unfinished_requests(): + request_outputs = engine.step() + + # Process finished requests + for request_output in request_outputs: + if request_output.finished: + prompt = request_output.prompt + generated_text = request_output.outputs[0].text + + # Print results + logger.info("Generation complete") + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") if __name__ == "__main__": diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py new file mode 100644 index 0000000000..47e41bc831 --- /dev/null +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Extended JobConfig for RL/Inference workloads. + +This module extends TorchTitan's base JobConfig with inference-specific +configurations needed for vLLM integration. +""" + +from dataclasses import dataclass, field + +from torchtitan.config.job_config import JobConfig as BaseJobConfig, Parallelism + + +@dataclass +class Sampling: + """ + Sampling configuration for vLLM generation. + + This dataclass contains sampling parameters used during text generation. + These map directly to vLLM's SamplingParams. + """ + + temperature: float = 0.8 + """ + Temperature for sampling. Controls randomness in generation. + - 0.0: Deterministic (greedy decoding) + - Higher values: More random outputs + """ + + top_p: float = 0.95 + """ + Top-p (nucleus) sampling threshold. + Only tokens with cumulative probability up to top_p are considered. + """ + + max_tokens: int = 100 + """Maximum number of tokens to generate""" + + +@dataclass +class Inference: + """ + Inference configuration for vLLM engine. + + This dataclass contains essential vLLM-specific settings for inference. + """ + + dtype: str = "bfloat16" + """Data type for model weights (auto, float16, bfloat16, float32)""" + + gpu_memory_utilization: float = 0.5 + """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" + + distributed_executor_backend: str = "external_launcher" + """ + Backend for distributed execution. + 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) + """ + + seed: int = 42 + """Random seed for sampling""" + + enforce_eager: bool = True + """Whether to enforce eager execution (disable CUDA graphs)""" + + parallelism: Parallelism = field(default_factory=Parallelism) + """Parallelism configuration for inference""" + + sampling: Sampling = field(default_factory=Sampling) + """Sampling configuration for inference""" + + +@dataclass +class JobConfig(BaseJobConfig): + """ + Extended JobConfig with inference support. + + This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields + for vLLM-specific inference and generation configurations. + """ + + inference: Inference = field(default_factory=Inference) + """Inference configuration for vLLM engine""" diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml new file mode 100644 index 0000000000..6166c50a91 --- /dev/null +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -0,0 +1,62 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3 0.6B training" +custom_config_module = "torchtitan.experiments.rl.unified.job_config" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3" +flavor = "0.6B" +hf_assets_path = "./assets/hf/Qwen3-0.6B" + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, 20% total steps + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" + +[parallelism] # training parallelism plan +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 + +[checkpoint] +initial_load_path = "/data/users/jianiw/model/qwen3-0.6b" + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[inference] +dtype = "bfloat16" +gpu_memory_utilization = 0.5 +distributed_executor_backend = "external_launcher" +seed = 42 +enforce_eager = true + +[inference.parallelism] +data_parallel_replicate_degree = 1 +tensor_parallel_degree = 2 + +[inference.sampling] +temperature = 0.8 +top_p = 0.95 +max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 3e914f3778..28581e2747 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -163,6 +163,8 @@ async def main(): for step in range(num_steps): # Fully sync RL loop + # NOTE: This is only getting Trajectory generated from trainer 0, and trainer 1's data is ignored. + # .get() is a blocking method which makes the loop fully sync batch = generator.generate.call().get().item(gpus=0) metrics = trainer.step.call(batch).get().item(gpus=0) weights = trainer.get_weights.call().get().item(gpus=0) From 9727ee8b73e3e5165d35e43f0c863b3662c5e07c Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Fri, 2 Jan 2026 09:21:16 -0800 Subject: [PATCH 02/10] Update on "config sys v1" [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 6 +- .../rl/unified/actors/generator.py | 240 +++++++++--------- .../experiments/rl/unified/actors/trainer.py | 45 ++-- torchtitan/experiments/rl/unified/infer.py | 18 +- .../experiments/rl/unified/job_config.py | 42 ++- .../experiments/rl/unified/models/utils.py | 81 ++---- .../rl/unified/run_configs/qwen3_0.6b.toml | 12 +- ...mple_rl_multiprocess.py => simple_grpo.py} | 94 ++----- 8 files changed, 227 insertions(+), 311 deletions(-) rename torchtitan/experiments/rl/unified/{simple_rl_multiprocess.py => simple_grpo.py} (64%) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..ae604b19aa 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,10 +66,10 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path None: + def update_weights(self, vllm_state: dict) -> None: """ Update vLLM model weights from vLLM-compat state dict. @@ -121,11 +123,8 @@ def update_weights(self, vllm_compat_state: dict) -> None: vLLM's reload_weights() API after updating the model path config. Args: - vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + vllm_state: vLLM model state dict """ - # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) - vllm_state = torchtitan_to_vllm(vllm_compat_state) - # Save to temp model directory import os @@ -134,7 +133,6 @@ def update_weights(self, vllm_compat_state: dict) -> None: # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure import glob - import json shard_files = sorted( glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) @@ -142,51 +140,20 @@ def update_weights(self, vllm_compat_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safe tensor files which we only + # right now we hardcoded to work with 2 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: logger.info(f"Saving weights to {checkpoint_path}") - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() @@ -194,28 +161,35 @@ def update_weights(self, vllm_compat_state: dict) -> None: f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine - if self.llm is None: - self.llm = LLM( + # First time: create the engine using LLMEngine and EngineArgs + if self.engine is None: + generation = self.job_config.generation + + engine_args = EngineArgs( + # Model configuration model=self.temp_model_dir, - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, trust_remote_code=True, - max_model_len=2048, - dtype="bfloat16", - gpu_memory_utilization=0.1, # Reduced from 0.5 - distributed_executor_backend="external_launcher", # vllm do not spawn processes - seed=42, # Fixed seed for determinism - enforce_eager=True, - tensor_parallel_size=self.tp_size, # Explicitly single GPU + dtype=generation.dtype, + # Parallelism configuration + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, + # Seed + seed=generation.seed, + # HuggingFace overrides to use TorchTitan model. + # TODO: make this field configurable and align with model registration + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("Created new vLLM engine") + + logger.info("Initializing LLMEngine from EngineArgs...") + self.engine = LLMEngine.from_engine_args(engine_args) + logger.info("Created new vLLM LLMEngine") else: # Use collective_rpc to call reload_weights on all workers # This reloads weights from temp_model_dir without recreating the engine - self.llm.collective_rpc("reload_weights") + self.engine.collective_rpc("reload_weights") @torch.no_grad() def generate( @@ -228,7 +202,7 @@ def generate( list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] ]: """ - Generate samples using vLLM. + Generate samples using vLLM LLMEngine. Args: prompt_texts: List of prompt strings @@ -243,16 +217,36 @@ def generate( token_log_probs: List of per-token log prob lists for each completion prompt_token_ids: List of prompt token ID lists for each completion """ + logger.info( + f"Starting generation: {len(prompt_texts)} prompts, " + f"n_samples_per_prompt={n_samples_per_prompt}, " + f"max_tokens={max_new_tokens}, temp={temperature}" + ) + sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, - n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - outputs = self.llm.generate(prompt_texts, sampling_params) + # Add requests to the engine + # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id + request_id = 0 + prompt_indices = [] # Track which prompt each request corresponds to + for prompt_idx, prompt in enumerate(prompt_texts): + for sample_idx in range(n_samples_per_prompt): + self.engine.add_request(str(request_id), prompt, sampling_params) + prompt_indices.append(prompt_idx) + request_id += 1 + + # Step through engine until all requests are finished + all_outputs = [] + while self.engine.has_unfinished_requests(): + request_outputs = self.engine.step() + all_outputs.extend(request_outputs) # Extract completions and log probs completions = [] @@ -261,30 +255,35 @@ def generate( token_log_probs_list = [] prompt_token_ids_list = [] - for output in outputs: + for output in all_outputs: # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids - for sample in output.outputs: - completions.append(sample.text) + # Each output now has exactly 1 sample (we submitted multiple requests) + assert ( + len(output.outputs) == 1 + ), f"Expected 1 output, got {len(output.outputs)}" + sample = output.outputs[0] + + completions.append(sample.text) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) @@ -298,8 +297,8 @@ def generate( def __del__(self): """Cleanup vLLM engine.""" - if hasattr(self, "llm"): - del self.llm + if hasattr(self, "engine"): + del self.engine torch.cuda.empty_cache() @@ -319,58 +318,46 @@ class Generator(Actor): computes rewards/advantages. Args: - model_path: Path to HuggingFace model + job_config: JobConfig dataclass containing all configuration prompt_texts: List of prompt strings expected_answers: List of expected answers - group_size: Number of samples per prompt - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - use_real_dataset: Whether using real dataset (GSM8K) - grpo_beta: Beta for GRPO advantages - use_stable_grpo: Whether to use stable GRPO - tp_size: Tensor Parallel size """ def __init__( self, - model_path: str, - prompt_texts: List[str], + job_config: JobConfig, + prompt_texts: List[ + str + ], # TODO: This field need to be removed once dataloader is implemented expected_answers: List[str], - group_size: int = 8, - max_new_tokens: int = 20, - temperature: float = 1.0, - use_real_dataset: bool = False, - grpo_beta: float = 0.1, - use_stable_grpo: bool = False, - tp_size: int = 1, ): - self.model_path = model_path + # Store job_config for accessing configuration + self.job_config = job_config self.prompt_texts = prompt_texts self.expected_answers = expected_answers - self.group_size = group_size - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.use_real_dataset = use_real_dataset - self.grpo_beta = grpo_beta - self.use_stable_grpo = use_stable_grpo - self.tp_size = tp_size + + # Extract needed fields from job_config + self.model_path = job_config.checkpoint.initial_load_path + self.max_new_tokens = job_config.generation.sampling.max_tokens + self.temperature = job_config.generation.sampling.temperature + self.group_size = job_config.rl.grpo_group_size + self.grpo_beta = job_config.rl.grpo_beta + self.use_stable_grpo = job_config.rl.use_stable_grpo # Initialize distributed environment for SPMD generator world_size = dist_utils.init_distributed( Comm(), ) - # Initialize vLLM engine - self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + # Initialize vLLM engine with job_config + self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE self.cond = asyncio.Condition() self.policy_version = 0 - # Reward function - self.reward_fn = ( - math_reward_function if use_real_dataset else trivial_reward_function - ) + # Reward function. TODO: Use a real reward function + self.reward_fn = trivial_reward_function logger.info("Generator initialized with vLLM engine") @@ -401,6 +388,11 @@ async def generate(self) -> None: ) # Compute rewards + logger.info( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" + ) rewards = self.reward_fn( completions, self.expected_answers, self.group_size ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 9ffb9f0f0a..ccfa10a708 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,16 +11,15 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData -from torchtitan.experiments.rl.unified.models.parallelism_utils import ( +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.unified.job_config import JobConfig +from torchtitan.experiments.rl.unified.models.utils import load_trainer_model from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm logger = logging.getLogger(__name__) @@ -30,34 +29,30 @@ class Trainer(Actor): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. + TODO: Use torchtitan Trainer Args: - titan_checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model - learning_rate: Learning rate for optimizer - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model + job_config: JobConfig dataclass containing all configuration """ def __init__( self, - titan_checkpoint_path: str, - model_path: str, - learning_rate: float = 1e-5, - model_mode: str = ModelMode.VLLM_COMPAT, - ddp_size: int = 1, - tp_size: int = 1, + job_config: JobConfig, ): + # Extract needed fields from job_config + model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = job_config.optimizer.lr + self.ddp_size = job_config.parallelism.data_parallel_replicate_degree + self.tp_size = job_config.parallelism.tensor_parallel_degree + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - self.model = load_model( - titan_checkpoint_path, model_path, model_mode=model_mode - ) - self.ddp_size = ddp_size - self.tp_size = tp_size + # load trainer model and patch to vllm.Attention() + self.model = load_trainer_model(model_path) + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -82,14 +77,14 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM-compatible weights for generator. + """Get vLLM weights for generator. Returns: - vLLM-compatible state dict + vLLM state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm_compat(titan_state) - return vllm_compat_state + vllm_state = torchtitan_to_vllm(titan_state) + return vllm_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 136b264a72..14e2542d16 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -37,26 +37,26 @@ def infer(): logger.info("Initializing vLLM LLMEngine with TorchTitan model") logger.info(f"Model: {job_config.checkpoint.initial_load_path}") logger.info( - f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" + f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" ) # Create EngineArgs from JobConfig # Map TorchTitan parallelism to vLLM parallelism - inference = job_config.inference + generation = job_config.generation engine_args = EngineArgs( # Model configuration model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - dtype=inference.dtype, + dtype=generation.dtype, # Parallelism configuration - tensor_parallel_size=inference.parallelism.tensor_parallel_degree, - distributed_executor_backend=inference.distributed_executor_backend, + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, # Memory and performance - gpu_memory_utilization=inference.gpu_memory_utilization, - enforce_eager=inference.enforce_eager, + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, # Seed - seed=inference.seed, + seed=generation.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) @@ -67,7 +67,7 @@ def infer(): logger.info("vLLM LLMEngine initialized successfully") # Create sampling parameters from JobConfig - sampling = job_config.inference.sampling + sampling = job_config.generation.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py index 47e41bc831..fdbe086352 100644 --- a/torchtitan/experiments/rl/unified/job_config.py +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. """ -Extended JobConfig for RL/Inference workloads. +Extended JobConfig for RL/Generation workloads. -This module extends TorchTitan's base JobConfig with inference-specific +This module extends TorchTitan's base JobConfig with generation-specific configurations needed for vLLM integration. """ @@ -43,25 +43,24 @@ class Sampling: @dataclass -class Inference: +class Generate: """ - Inference configuration for vLLM engine. + Generation configuration for vLLM engine. - This dataclass contains essential vLLM-specific settings for inference. + This dataclass contains essential vLLM-specific settings for generation. """ dtype: str = "bfloat16" """Data type for model weights (auto, float16, bfloat16, float32)""" gpu_memory_utilization: float = 0.5 - """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" + """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" distributed_executor_backend: str = "external_launcher" """ Backend for distributed execution. 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) """ - seed: int = 42 """Random seed for sampling""" @@ -69,20 +68,35 @@ class Inference: """Whether to enforce eager execution (disable CUDA graphs)""" parallelism: Parallelism = field(default_factory=Parallelism) - """Parallelism configuration for inference""" + """Parallelism configuration for generation""" sampling: Sampling = field(default_factory=Sampling) - """Sampling configuration for inference""" + """Sampling configuration for generation""" + + +@dataclass +class RL: + """Reinforcement Learning configuration for GRPO training.""" + + grpo_beta: int = 0.1 + """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" + + use_stable_grpo: bool = False + """Whether to use stable version of GRPO algorithm""" + + grpo_group_size: int = 8 + """Number of samples in each GRPO group for policy optimization""" @dataclass class JobConfig(BaseJobConfig): """ - Extended JobConfig with inference support. + Extended JobConfig with generation support. - This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields - for vLLM-specific inference and generation configurations. + This extends TorchTitan's base JobConfig by adding `generation` field + for vLLM-specific generation configurations. """ - inference: Inference = field(default_factory=Inference) - """Inference configuration for vLLM engine""" + generation: Generate = field(default_factory=Generate) + """Generation configuration for vLLM engine""" + rl: RL = field(default_factory=RL) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 0e5d6cde52..2738a95728 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,42 +6,22 @@ import logging -from enum import Enum import torch -from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig logger = logging.getLogger(__name__) -class ModelMode(str, Enum): - """ - Enum defining which TorchTitan model to use. - - Attributes: - UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified - training and inference. - VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, - ensuring bitwise determinism between training and inference. - STANDARD: Plain TorchTitan model without any modifications. - """ - - UNIFIED = "unified" - VLLM_COMPAT = "vllm_compat" - STANDARD = "standard" - - def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -102,26 +82,20 @@ def replace_with_vllm_compatible_flash_attention(model): ) -def load_model( - checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT -): +def load_trainer_model(model_path: str): """ Load TorchTitan model from checkpoint for trainer. Args: - checkpoint_path: Path to TorchTitan checkpoint model_path: Path to HuggingFace model (for config) - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model Returns: model: Loaded TorchTitan model for trainer. """ - # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # Create model args + # Create model args. model_args = Qwen3ModelArgs( dim=hf_config.hidden_size, n_layers=hf_config.num_hidden_layers, @@ -135,46 +109,27 @@ def load_model( ), hidden_dim=hf_config.intermediate_size, norm_eps=hf_config.rms_norm_eps, - rope_theta=hf_config.rope_theta, + rope_theta=getattr(hf_config.rope_parameters, "rope_theta", 1000000), max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, eos_id=getattr(hf_config, "eos_token_id", 151645), ) + # convert to torchtitan state_dict. TODO: Use torchtitan components + titan_state_dict = vllm_to_torchtitan(model_path) - # state_dict is in standard TorchTitan format (w1, w2, w3) - state_dict = load_file(checkpoint_path) - - if model_mode == ModelMode.UNIFIED: - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention for training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=True) - elif model_mode == ModelMode.VLLM_COMPAT: - # Create and load model that has bitwise determinism between training and inference - from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( - Qwen3VLLMCompatModel, - ) + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention during training. + replace_with_vllm_compatible_flash_attention(model) - model = Qwen3VLLMCompatModel(model_args) - # Convert to vLLM-compat format (merged gate_up_proj, down_proj) - vllm_compat_state = torchtitan_to_vllm_compat(state_dict) - model.load_state_dict(vllm_compat_state, strict=False) - else: - # Use standard TorchTitan model - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=False) + # Load standard TorchTitan format directly + model.load_state_dict(titan_state_dict, strict=True) model.to(torch.bfloat16) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml index 6166c50a91..2080e7a39f 100644 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -29,8 +29,8 @@ steps = 10 dataset = "c4" [parallelism] # training parallelism plan -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_replicate_degree = 2 +data_parallel_shard_degree = 1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 @@ -45,18 +45,18 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba enable=false components = ["model", "loss"] -[inference] +[generation] dtype = "bfloat16" gpu_memory_utilization = 0.5 distributed_executor_backend = "external_launcher" seed = 42 enforce_eager = true -[inference.parallelism] +[generation.parallelism] data_parallel_replicate_degree = 1 -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 -[inference.sampling] +[generation.sampling] temperature = 0.8 top_p = 0.95 max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_grpo.py similarity index 64% rename from torchtitan/experiments/rl/unified/simple_rl_multiprocess.py rename to torchtitan/experiments/rl/unified/simple_grpo.py index 28581e2747..c874feadc4 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -14,7 +14,8 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml """ import asyncio import logging @@ -22,13 +23,9 @@ import torch from monarch.actor import this_host from monarch.utils import setup_env_for_distributed +from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import Trainer -from torchtitan.experiments.rl.unified.models.utils import ModelMode -from torchtitan.experiments.rl.vllm_compat.simple_rl import ( - download_and_convert_model, - load_gsm8k_dataset, -) from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, vllm_is_batch_invariant, @@ -39,33 +36,21 @@ async def main(): """Run the distributed RL training loop using Monarch.""" - # Model Config - model_name = "Qwen/Qwen3-0.6B" - cache_dir = "./models" - output_dir = "./converted" - - # Training config - group_size = 8 - num_steps = 10 - learning_rate = 1e-5 - max_new_tokens = 20 - # GRPO config - use_stable_grpo = False - grpo_beta = 0.1 + # Step 1: Load job config using config manager + config_manager = ConfigManager() + job_config = config_manager.parse_args() - # Dataset config - use_real_dataset = False - num_dataset_samples = 5 + # RL Training config + num_steps = job_config.training.steps # Parallelism sizes - trainer_ddp_size = 2 - trainer_tp_size = 1 - generator_tp_size = 1 + trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree + trainer_tp_size = job_config.parallelism.tensor_parallel_degree + # TODO: add a flag to enable/disable batch_invariant init_batch_invariance() batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -78,37 +63,24 @@ async def main(): else: raise RuntimeError("Batch invariance NOT detected - using standard model") - # Download and convert model - titan_checkpoint_path, model_path = download_and_convert_model( - model_name, cache_dir, output_dir - ) - - # Load dataset - if use_real_dataset: - logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") - # TODO: Refactor into loading torchtitan dataset - prompt_texts, expected_answers = load_gsm8k_dataset( - split="train", num_samples=num_dataset_samples - ) - if prompt_texts is None or len(prompt_texts) == 0: - use_real_dataset = False - - if not use_real_dataset: - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] + # Use fake dataset for test. TODO: Implement real RL dataloader. + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] logger.info(f"Loaded {len(prompt_texts)} prompts") # Create process meshes - trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + trainer_mesh = this_host().spawn_procs( + per_host={"gpus": trainer_ddp_size * trainer_tp_size} + ) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) # Set up distributed env vars so that actors are connected via c10d @@ -129,27 +101,15 @@ async def main(): trainer = trainer_mesh.spawn( "trainer", Trainer, - titan_checkpoint_path, - model_path, - learning_rate, - mode, - trainer_ddp_size, - trainer_tp_size, + job_config, # Pass full job_config ) generator = gen_mesh.spawn( "generator", Generator, - model_path, + job_config, # Pass full job_config prompt_texts, expected_answers, - group_size, - max_new_tokens, - 1.0, # temperature - use_real_dataset, - grpo_beta, - use_stable_grpo, - generator_tp_size, ) # Initialize generator with trainer weights From f54db0c99f2747092a65e19e63c56d54e0a3bc3d Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 12 Jan 2026 16:13:47 -0800 Subject: [PATCH 03/10] Update on "[rl] Using JobConfig as the centralized config system for inference and simple GRPO" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) Screenshot 2025-12-30 at 7 34 00 PM Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 6 +- .../rl/unified/actors/generator.py | 240 +++++++++--------- .../experiments/rl/unified/actors/trainer.py | 45 ++-- torchtitan/experiments/rl/unified/infer.py | 18 +- .../experiments/rl/unified/job_config.py | 42 +-- .../experiments/rl/unified/models/utils.py | 81 ++++-- .../rl/unified/run_configs/qwen3_0.6b.toml | 12 +- ...mple_grpo.py => simple_rl_multiprocess.py} | 94 +++++-- 8 files changed, 311 insertions(+), 227 deletions(-) rename torchtitan/experiments/rl/unified/{simple_grpo.py => simple_rl_multiprocess.py} (64%) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index ae604b19aa..27550e977c 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,10 +66,10 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path None: + def update_weights(self, vllm_compat_state: dict) -> None: """ Update vLLM model weights from vLLM-compat state dict. @@ -123,8 +121,11 @@ def update_weights(self, vllm_state: dict) -> None: vLLM's reload_weights() API after updating the model path config. Args: - vllm_state: vLLM model state dict + vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) """ + # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) + vllm_state = torchtitan_to_vllm(vllm_compat_state) + # Save to temp model directory import os @@ -133,6 +134,7 @@ def update_weights(self, vllm_state: dict) -> None: # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure import glob + import json shard_files = sorted( glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) @@ -140,20 +142,51 @@ def update_weights(self, vllm_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safetensor files which we only + # right now we hardcoded to work with 2 safe tensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: logger.info(f"Saving weights to {checkpoint_path}") - - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + if len(shard_files) == 2 and os.path.exists(index_file): + # Load the index to see which weights go in which shard + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + + # Split weights according to the index + shard1_weights = {} + shard2_weights = {} + + for key, value in vllm_state.items(): + shard_file = weight_map.get(key, shard_files[0]) + if "model-00001-of-00002" in shard_file: + shard1_weights[key] = value + else: + shard2_weights[key] = value + + # Ensure weights stay in bfloat16 + shard1_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard1_weights.items() + } + shard2_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard2_weights.items() + } + + # Save to the shard files + save_file(shard1_weights, shard_files[0]) + save_file(shard2_weights, shard_files[1]) + else: + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() @@ -161,35 +194,28 @@ def update_weights(self, vllm_state: dict) -> None: f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine using LLMEngine and EngineArgs - if self.engine is None: - generation = self.job_config.generation - - engine_args = EngineArgs( - # Model configuration + # First time: create the engine + if self.llm is None: + self.llm = LLM( model=self.temp_model_dir, + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, trust_remote_code=True, - dtype=generation.dtype, - # Parallelism configuration - tensor_parallel_size=generation.parallelism.tensor_parallel_degree, - distributed_executor_backend=generation.distributed_executor_backend, - # Memory and performance - gpu_memory_utilization=generation.gpu_memory_utilization, - enforce_eager=generation.enforce_eager, - # Seed - seed=generation.seed, - # HuggingFace overrides to use TorchTitan model. - # TODO: make this field configurable and align with model registration - hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, + max_model_len=2048, + dtype="bfloat16", + gpu_memory_utilization=0.1, # Reduced from 0.5 + distributed_executor_backend="external_launcher", # vllm do not spawn processes + seed=42, # Fixed seed for determinism + enforce_eager=True, + tensor_parallel_size=self.tp_size, # Explicitly single GPU ) - - logger.info("Initializing LLMEngine from EngineArgs...") - self.engine = LLMEngine.from_engine_args(engine_args) - logger.info("Created new vLLM LLMEngine") + logger.info("Created new vLLM engine") else: # Use collective_rpc to call reload_weights on all workers # This reloads weights from temp_model_dir without recreating the engine - self.engine.collective_rpc("reload_weights") + self.llm.collective_rpc("reload_weights") @torch.no_grad() def generate( @@ -202,7 +228,7 @@ def generate( list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] ]: """ - Generate samples using vLLM LLMEngine. + Generate samples using vLLM. Args: prompt_texts: List of prompt strings @@ -217,36 +243,16 @@ def generate( token_log_probs: List of per-token log prob lists for each completion prompt_token_ids: List of prompt token ID lists for each completion """ - logger.info( - f"Starting generation: {len(prompt_texts)} prompts, " - f"n_samples_per_prompt={n_samples_per_prompt}, " - f"max_tokens={max_new_tokens}, temp={temperature}" - ) - sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, + n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs - output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - # Add requests to the engine - # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id - request_id = 0 - prompt_indices = [] # Track which prompt each request corresponds to - for prompt_idx, prompt in enumerate(prompt_texts): - for sample_idx in range(n_samples_per_prompt): - self.engine.add_request(str(request_id), prompt, sampling_params) - prompt_indices.append(prompt_idx) - request_id += 1 - - # Step through engine until all requests are finished - all_outputs = [] - while self.engine.has_unfinished_requests(): - request_outputs = self.engine.step() - all_outputs.extend(request_outputs) + outputs = self.llm.generate(prompt_texts, sampling_params) # Extract completions and log probs completions = [] @@ -255,35 +261,30 @@ def generate( token_log_probs_list = [] prompt_token_ids_list = [] - for output in all_outputs: + for output in outputs: # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids - # Each output now has exactly 1 sample (we submitted multiple requests) - assert ( - len(output.outputs) == 1 - ), f"Expected 1 output, got {len(output.outputs)}" - sample = output.outputs[0] - - completions.append(sample.text) + for sample in output.outputs: + completions.append(sample.text) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) @@ -297,8 +298,8 @@ def generate( def __del__(self): """Cleanup vLLM engine.""" - if hasattr(self, "engine"): - del self.engine + if hasattr(self, "llm"): + del self.llm torch.cuda.empty_cache() @@ -318,46 +319,58 @@ class Generator(Actor): computes rewards/advantages. Args: - job_config: JobConfig dataclass containing all configuration + model_path: Path to HuggingFace model prompt_texts: List of prompt strings expected_answers: List of expected answers + group_size: Number of samples per prompt + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + use_real_dataset: Whether using real dataset (GSM8K) + grpo_beta: Beta for GRPO advantages + use_stable_grpo: Whether to use stable GRPO + tp_size: Tensor Parallel size """ def __init__( self, - job_config: JobConfig, - prompt_texts: List[ - str - ], # TODO: This field need to be removed once dataloader is implemented + model_path: str, + prompt_texts: List[str], expected_answers: List[str], + group_size: int = 8, + max_new_tokens: int = 20, + temperature: float = 1.0, + use_real_dataset: bool = False, + grpo_beta: float = 0.1, + use_stable_grpo: bool = False, + tp_size: int = 1, ): - # Store job_config for accessing configuration - self.job_config = job_config + self.model_path = model_path self.prompt_texts = prompt_texts self.expected_answers = expected_answers - - # Extract needed fields from job_config - self.model_path = job_config.checkpoint.initial_load_path - self.max_new_tokens = job_config.generation.sampling.max_tokens - self.temperature = job_config.generation.sampling.temperature - self.group_size = job_config.rl.grpo_group_size - self.grpo_beta = job_config.rl.grpo_beta - self.use_stable_grpo = job_config.rl.use_stable_grpo + self.group_size = group_size + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.use_real_dataset = use_real_dataset + self.grpo_beta = grpo_beta + self.use_stable_grpo = use_stable_grpo + self.tp_size = tp_size # Initialize distributed environment for SPMD generator world_size = dist_utils.init_distributed( Comm(), ) - # Initialize vLLM engine with job_config - self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) + # Initialize vLLM engine + self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) # State machine self.state = GeneratorState.READY_TO_UPDATE self.cond = asyncio.Condition() self.policy_version = 0 - # Reward function. TODO: Use a real reward function - self.reward_fn = trivial_reward_function + # Reward function + self.reward_fn = ( + math_reward_function if use_real_dataset else trivial_reward_function + ) logger.info("Generator initialized with vLLM engine") @@ -388,11 +401,6 @@ async def generate(self) -> None: ) # Compute rewards - logger.info( - f"Computing rewards: {len(completions)} completions, " - f"{len(self.expected_answers)} expected answers, " - f"group_size={self.group_size}" - ) rewards = self.reward_fn( completions, self.expected_answers, self.group_size ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index ccfa10a708..9ffb9f0f0a 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,15 +11,16 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData -from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( +from torchtitan.experiments.rl.unified.models.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.job_config import JobConfig -from torchtitan.experiments.rl.unified.models.utils import load_trainer_model +from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) logger = logging.getLogger(__name__) @@ -29,30 +30,34 @@ class Trainer(Actor): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. - TODO: Use torchtitan Trainer Args: - job_config: JobConfig dataclass containing all configuration + titan_checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model + learning_rate: Learning rate for optimizer + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model """ def __init__( self, - job_config: JobConfig, + titan_checkpoint_path: str, + model_path: str, + learning_rate: float = 1e-5, + model_mode: str = ModelMode.VLLM_COMPAT, + ddp_size: int = 1, + tp_size: int = 1, ): - # Extract needed fields from job_config - model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint - learning_rate = job_config.optimizer.lr - self.ddp_size = job_config.parallelism.data_parallel_replicate_degree - self.tp_size = job_config.parallelism.tensor_parallel_degree - # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - # load trainer model and patch to vllm.Attention() - self.model = load_trainer_model(model_path) - + self.model = load_model( + titan_checkpoint_path, model_path, model_mode=model_mode + ) + self.ddp_size = ddp_size + self.tp_size = tp_size self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -77,14 +82,14 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM weights for generator. + """Get vLLM-compatible weights for generator. Returns: - vLLM state dict + vLLM-compatible state dict """ titan_state = self.model.state_dict() - vllm_state = torchtitan_to_vllm(titan_state) - return vllm_state + vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + return vllm_compat_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 14e2542d16..136b264a72 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -37,26 +37,26 @@ def infer(): logger.info("Initializing vLLM LLMEngine with TorchTitan model") logger.info(f"Model: {job_config.checkpoint.initial_load_path}") logger.info( - f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" + f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" ) # Create EngineArgs from JobConfig # Map TorchTitan parallelism to vLLM parallelism - generation = job_config.generation + inference = job_config.inference engine_args = EngineArgs( # Model configuration model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - dtype=generation.dtype, + dtype=inference.dtype, # Parallelism configuration - tensor_parallel_size=generation.parallelism.tensor_parallel_degree, - distributed_executor_backend=generation.distributed_executor_backend, + tensor_parallel_size=inference.parallelism.tensor_parallel_degree, + distributed_executor_backend=inference.distributed_executor_backend, # Memory and performance - gpu_memory_utilization=generation.gpu_memory_utilization, - enforce_eager=generation.enforce_eager, + gpu_memory_utilization=inference.gpu_memory_utilization, + enforce_eager=inference.enforce_eager, # Seed - seed=generation.seed, + seed=inference.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) @@ -67,7 +67,7 @@ def infer(): logger.info("vLLM LLMEngine initialized successfully") # Create sampling parameters from JobConfig - sampling = job_config.generation.sampling + sampling = job_config.inference.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py index fdbe086352..47e41bc831 100644 --- a/torchtitan/experiments/rl/unified/job_config.py +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. """ -Extended JobConfig for RL/Generation workloads. +Extended JobConfig for RL/Inference workloads. -This module extends TorchTitan's base JobConfig with generation-specific +This module extends TorchTitan's base JobConfig with inference-specific configurations needed for vLLM integration. """ @@ -43,24 +43,25 @@ class Sampling: @dataclass -class Generate: +class Inference: """ - Generation configuration for vLLM engine. + Inference configuration for vLLM engine. - This dataclass contains essential vLLM-specific settings for generation. + This dataclass contains essential vLLM-specific settings for inference. """ dtype: str = "bfloat16" """Data type for model weights (auto, float16, bfloat16, float32)""" gpu_memory_utilization: float = 0.5 - """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" + """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" distributed_executor_backend: str = "external_launcher" """ Backend for distributed execution. 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) """ + seed: int = 42 """Random seed for sampling""" @@ -68,35 +69,20 @@ class Generate: """Whether to enforce eager execution (disable CUDA graphs)""" parallelism: Parallelism = field(default_factory=Parallelism) - """Parallelism configuration for generation""" + """Parallelism configuration for inference""" sampling: Sampling = field(default_factory=Sampling) - """Sampling configuration for generation""" - - -@dataclass -class RL: - """Reinforcement Learning configuration for GRPO training.""" - - grpo_beta: int = 0.1 - """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" - - use_stable_grpo: bool = False - """Whether to use stable version of GRPO algorithm""" - - grpo_group_size: int = 8 - """Number of samples in each GRPO group for policy optimization""" + """Sampling configuration for inference""" @dataclass class JobConfig(BaseJobConfig): """ - Extended JobConfig with generation support. + Extended JobConfig with inference support. - This extends TorchTitan's base JobConfig by adding `generation` field - for vLLM-specific generation configurations. + This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields + for vLLM-specific inference and generation configurations. """ - generation: Generate = field(default_factory=Generate) - """Generation configuration for vLLM engine""" - rl: RL = field(default_factory=RL) + inference: Inference = field(default_factory=Inference) + """Inference configuration for vLLM engine""" diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 2738a95728..0e5d6cde52 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,22 +6,42 @@ import logging +from enum import Enum import torch +from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan - +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig logger = logging.getLogger(__name__) +class ModelMode(str, Enum): + """ + Enum defining which TorchTitan model to use. + + Attributes: + UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified + training and inference. + VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, + ensuring bitwise determinism between training and inference. + STANDARD: Plain TorchTitan model without any modifications. + """ + + UNIFIED = "unified" + VLLM_COMPAT = "vllm_compat" + STANDARD = "standard" + + def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -82,20 +102,26 @@ def replace_with_vllm_compatible_flash_attention(model): ) -def load_trainer_model(model_path: str): +def load_model( + checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT +): """ Load TorchTitan model from checkpoint for trainer. Args: + checkpoint_path: Path to TorchTitan checkpoint model_path: Path to HuggingFace model (for config) + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model Returns: model: Loaded TorchTitan model for trainer. """ + # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # Create model args. + # Create model args model_args = Qwen3ModelArgs( dim=hf_config.hidden_size, n_layers=hf_config.num_hidden_layers, @@ -109,27 +135,46 @@ def load_trainer_model(model_path: str): ), hidden_dim=hf_config.intermediate_size, norm_eps=hf_config.rms_norm_eps, - rope_theta=getattr(hf_config.rope_parameters, "rope_theta", 1000000), + rope_theta=hf_config.rope_theta, max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, eos_id=getattr(hf_config, "eos_token_id", 151645), ) - # convert to torchtitan state_dict. TODO: Use torchtitan components - titan_state_dict = vllm_to_torchtitan(model_path) - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention during training. - replace_with_vllm_compatible_flash_attention(model) + # state_dict is in standard TorchTitan format (w1, w2, w3) + state_dict = load_file(checkpoint_path) + + if model_mode == ModelMode.UNIFIED: + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention for training. + replace_with_vllm_compatible_flash_attention(model) + + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=True) + elif model_mode == ModelMode.VLLM_COMPAT: + # Create and load model that has bitwise determinism between training and inference + from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( + Qwen3VLLMCompatModel, + ) - # Load standard TorchTitan format directly - model.load_state_dict(titan_state_dict, strict=True) + model = Qwen3VLLMCompatModel(model_args) + # Convert to vLLM-compat format (merged gate_up_proj, down_proj) + vllm_compat_state = torchtitan_to_vllm_compat(state_dict) + model.load_state_dict(vllm_compat_state, strict=False) + else: + # Use standard TorchTitan model + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=False) model.to(torch.bfloat16) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml index 2080e7a39f..6166c50a91 100644 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -29,8 +29,8 @@ steps = 10 dataset = "c4" [parallelism] # training parallelism plan -data_parallel_replicate_degree = 2 -data_parallel_shard_degree = 1 +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 @@ -45,18 +45,18 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba enable=false components = ["model", "loss"] -[generation] +[inference] dtype = "bfloat16" gpu_memory_utilization = 0.5 distributed_executor_backend = "external_launcher" seed = 42 enforce_eager = true -[generation.parallelism] +[inference.parallelism] data_parallel_replicate_degree = 1 -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 -[generation.sampling] +[inference.sampling] temperature = 0.8 top_p = 0.95 max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py similarity index 64% rename from torchtitan/experiments/rl/unified/simple_grpo.py rename to torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index c874feadc4..28581e2747 100644 --- a/torchtitan/experiments/rl/unified/simple_grpo.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -14,8 +14,7 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ - --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py """ import asyncio import logging @@ -23,9 +22,13 @@ import torch from monarch.actor import this_host from monarch.utils import setup_env_for_distributed -from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import Trainer +from torchtitan.experiments.rl.unified.models.utils import ModelMode +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + download_and_convert_model, + load_gsm8k_dataset, +) from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, vllm_is_batch_invariant, @@ -36,21 +39,33 @@ async def main(): """Run the distributed RL training loop using Monarch.""" + # Model Config + model_name = "Qwen/Qwen3-0.6B" + cache_dir = "./models" + output_dir = "./converted" + + # Training config + group_size = 8 + num_steps = 10 + learning_rate = 1e-5 + max_new_tokens = 20 - # Step 1: Load job config using config manager - config_manager = ConfigManager() - job_config = config_manager.parse_args() + # GRPO config + use_stable_grpo = False + grpo_beta = 0.1 - # RL Training config - num_steps = job_config.training.steps + # Dataset config + use_real_dataset = False + num_dataset_samples = 5 # Parallelism sizes - trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree - trainer_tp_size = job_config.parallelism.tensor_parallel_degree + trainer_ddp_size = 2 + trainer_tp_size = 1 + generator_tp_size = 1 - # TODO: add a flag to enable/disable batch_invariant init_batch_invariance() batch_invariant = vllm_is_batch_invariant() + mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -63,24 +78,37 @@ async def main(): else: raise RuntimeError("Batch invariance NOT detected - using standard model") - # Use fake dataset for test. TODO: Implement real RL dataloader. - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] + # Download and convert model + titan_checkpoint_path, model_path = download_and_convert_model( + model_name, cache_dir, output_dir + ) + + # Load dataset + if use_real_dataset: + logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") + # TODO: Refactor into loading torchtitan dataset + prompt_texts, expected_answers = load_gsm8k_dataset( + split="train", num_samples=num_dataset_samples + ) + if prompt_texts is None or len(prompt_texts) == 0: + use_real_dataset = False + + if not use_real_dataset: + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] logger.info(f"Loaded {len(prompt_texts)} prompts") # Create process meshes - trainer_mesh = this_host().spawn_procs( - per_host={"gpus": trainer_ddp_size * trainer_tp_size} - ) + trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) # Set up distributed env vars so that actors are connected via c10d @@ -101,15 +129,27 @@ async def main(): trainer = trainer_mesh.spawn( "trainer", Trainer, - job_config, # Pass full job_config + titan_checkpoint_path, + model_path, + learning_rate, + mode, + trainer_ddp_size, + trainer_tp_size, ) generator = gen_mesh.spawn( "generator", Generator, - job_config, # Pass full job_config + model_path, prompt_texts, expected_answers, + group_size, + max_new_tokens, + 1.0, # temperature + use_real_dataset, + grpo_beta, + use_stable_grpo, + generator_tp_size, ) # Initialize generator with trainer weights From 7daa5f73678d274f4c420f98a0f917c34190154e Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 12 Jan 2026 16:20:06 -0800 Subject: [PATCH 04/10] Update on "[rl] Using JobConfig as the centralized config system for inference and simple GRPO" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) Screenshot 2025-12-30 at 7 34 00 PM Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 6 +- .../rl/unified/actors/generator.py | 240 +++++++++--------- .../experiments/rl/unified/actors/trainer.py | 45 ++-- torchtitan/experiments/rl/unified/infer.py | 18 +- .../experiments/rl/unified/job_config.py | 42 ++- .../experiments/rl/unified/models/utils.py | 81 ++---- .../rl/unified/run_configs/qwen3_0.6b.toml | 12 +- ...mple_rl_multiprocess.py => simple_grpo.py} | 94 ++----- 8 files changed, 227 insertions(+), 311 deletions(-) rename torchtitan/experiments/rl/unified/{simple_rl_multiprocess.py => simple_grpo.py} (64%) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..ae604b19aa 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,10 +66,10 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path None: + def update_weights(self, vllm_state: dict) -> None: """ Update vLLM model weights from vLLM-compat state dict. @@ -121,11 +123,8 @@ def update_weights(self, vllm_compat_state: dict) -> None: vLLM's reload_weights() API after updating the model path config. Args: - vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + vllm_state: vLLM model state dict """ - # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) - vllm_state = torchtitan_to_vllm(vllm_compat_state) - # Save to temp model directory import os @@ -134,7 +133,6 @@ def update_weights(self, vllm_compat_state: dict) -> None: # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure import glob - import json shard_files = sorted( glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) @@ -142,51 +140,20 @@ def update_weights(self, vllm_compat_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safe tensor files which we only + # right now we hardcoded to work with 2 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: logger.info(f"Saving weights to {checkpoint_path}") - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() @@ -194,28 +161,35 @@ def update_weights(self, vllm_compat_state: dict) -> None: f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine - if self.llm is None: - self.llm = LLM( + # First time: create the engine using LLMEngine and EngineArgs + if self.engine is None: + generation = self.job_config.generation + + engine_args = EngineArgs( + # Model configuration model=self.temp_model_dir, - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, trust_remote_code=True, - max_model_len=2048, - dtype="bfloat16", - gpu_memory_utilization=0.1, # Reduced from 0.5 - distributed_executor_backend="external_launcher", # vllm do not spawn processes - seed=42, # Fixed seed for determinism - enforce_eager=True, - tensor_parallel_size=self.tp_size, # Explicitly single GPU + dtype=generation.dtype, + # Parallelism configuration + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, + # Seed + seed=generation.seed, + # HuggingFace overrides to use TorchTitan model. + # TODO: make this field configurable and align with model registration + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("Created new vLLM engine") + + logger.info("Initializing LLMEngine from EngineArgs...") + self.engine = LLMEngine.from_engine_args(engine_args) + logger.info("Created new vLLM LLMEngine") else: # Use collective_rpc to call reload_weights on all workers # This reloads weights from temp_model_dir without recreating the engine - self.llm.collective_rpc("reload_weights") + self.engine.collective_rpc("reload_weights") @torch.no_grad() def generate( @@ -228,7 +202,7 @@ def generate( list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] ]: """ - Generate samples using vLLM. + Generate samples using vLLM LLMEngine. Args: prompt_texts: List of prompt strings @@ -243,16 +217,36 @@ def generate( token_log_probs: List of per-token log prob lists for each completion prompt_token_ids: List of prompt token ID lists for each completion """ + logger.info( + f"Starting generation: {len(prompt_texts)} prompts, " + f"n_samples_per_prompt={n_samples_per_prompt}, " + f"max_tokens={max_new_tokens}, temp={temperature}" + ) + sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, - n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - outputs = self.llm.generate(prompt_texts, sampling_params) + # Add requests to the engine + # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id + request_id = 0 + prompt_indices = [] # Track which prompt each request corresponds to + for prompt_idx, prompt in enumerate(prompt_texts): + for sample_idx in range(n_samples_per_prompt): + self.engine.add_request(str(request_id), prompt, sampling_params) + prompt_indices.append(prompt_idx) + request_id += 1 + + # Step through engine until all requests are finished + all_outputs = [] + while self.engine.has_unfinished_requests(): + request_outputs = self.engine.step() + all_outputs.extend(request_outputs) # Extract completions and log probs completions = [] @@ -261,30 +255,35 @@ def generate( token_log_probs_list = [] prompt_token_ids_list = [] - for output in outputs: + for output in all_outputs: # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids - for sample in output.outputs: - completions.append(sample.text) + # Each output now has exactly 1 sample (we submitted multiple requests) + assert ( + len(output.outputs) == 1 + ), f"Expected 1 output, got {len(output.outputs)}" + sample = output.outputs[0] + + completions.append(sample.text) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) @@ -298,8 +297,8 @@ def generate( def __del__(self): """Cleanup vLLM engine.""" - if hasattr(self, "llm"): - del self.llm + if hasattr(self, "engine"): + del self.engine torch.cuda.empty_cache() @@ -319,58 +318,46 @@ class Generator(Actor): computes rewards/advantages. Args: - model_path: Path to HuggingFace model + job_config: JobConfig dataclass containing all configuration prompt_texts: List of prompt strings expected_answers: List of expected answers - group_size: Number of samples per prompt - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - use_real_dataset: Whether using real dataset (GSM8K) - grpo_beta: Beta for GRPO advantages - use_stable_grpo: Whether to use stable GRPO - tp_size: Tensor Parallel size """ def __init__( self, - model_path: str, - prompt_texts: List[str], + job_config: JobConfig, + prompt_texts: List[ + str + ], # TODO: This field need to be removed once dataloader is implemented expected_answers: List[str], - group_size: int = 8, - max_new_tokens: int = 20, - temperature: float = 1.0, - use_real_dataset: bool = False, - grpo_beta: float = 0.1, - use_stable_grpo: bool = False, - tp_size: int = 1, ): - self.model_path = model_path + # Store job_config for accessing configuration + self.job_config = job_config self.prompt_texts = prompt_texts self.expected_answers = expected_answers - self.group_size = group_size - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.use_real_dataset = use_real_dataset - self.grpo_beta = grpo_beta - self.use_stable_grpo = use_stable_grpo - self.tp_size = tp_size + + # Extract needed fields from job_config + self.model_path = job_config.checkpoint.initial_load_path + self.max_new_tokens = job_config.generation.sampling.max_tokens + self.temperature = job_config.generation.sampling.temperature + self.group_size = job_config.rl.grpo_group_size + self.grpo_beta = job_config.rl.grpo_beta + self.use_stable_grpo = job_config.rl.use_stable_grpo # Initialize distributed environment for SPMD generator world_size = dist_utils.init_distributed( Comm(), ) - # Initialize vLLM engine - self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + # Initialize vLLM engine with job_config + self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE self.cond = asyncio.Condition() self.policy_version = 0 - # Reward function - self.reward_fn = ( - math_reward_function if use_real_dataset else trivial_reward_function - ) + # Reward function. TODO: Use a real reward function + self.reward_fn = trivial_reward_function logger.info("Generator initialized with vLLM engine") @@ -401,6 +388,11 @@ async def generate(self) -> None: ) # Compute rewards + logger.info( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" + ) rewards = self.reward_fn( completions, self.expected_answers, self.group_size ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 9ffb9f0f0a..ccfa10a708 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,16 +11,15 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData -from torchtitan.experiments.rl.unified.models.parallelism_utils import ( +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.unified.job_config import JobConfig +from torchtitan.experiments.rl.unified.models.utils import load_trainer_model from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm logger = logging.getLogger(__name__) @@ -30,34 +29,30 @@ class Trainer(Actor): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. + TODO: Use torchtitan Trainer Args: - titan_checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model - learning_rate: Learning rate for optimizer - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model + job_config: JobConfig dataclass containing all configuration """ def __init__( self, - titan_checkpoint_path: str, - model_path: str, - learning_rate: float = 1e-5, - model_mode: str = ModelMode.VLLM_COMPAT, - ddp_size: int = 1, - tp_size: int = 1, + job_config: JobConfig, ): + # Extract needed fields from job_config + model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = job_config.optimizer.lr + self.ddp_size = job_config.parallelism.data_parallel_replicate_degree + self.tp_size = job_config.parallelism.tensor_parallel_degree + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - self.model = load_model( - titan_checkpoint_path, model_path, model_mode=model_mode - ) - self.ddp_size = ddp_size - self.tp_size = tp_size + # load trainer model and patch to vllm.Attention() + self.model = load_trainer_model(model_path) + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -82,14 +77,14 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM-compatible weights for generator. + """Get vLLM weights for generator. Returns: - vLLM-compatible state dict + vLLM state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm_compat(titan_state) - return vllm_compat_state + vllm_state = torchtitan_to_vllm(titan_state) + return vllm_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 136b264a72..14e2542d16 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -37,26 +37,26 @@ def infer(): logger.info("Initializing vLLM LLMEngine with TorchTitan model") logger.info(f"Model: {job_config.checkpoint.initial_load_path}") logger.info( - f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" + f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" ) # Create EngineArgs from JobConfig # Map TorchTitan parallelism to vLLM parallelism - inference = job_config.inference + generation = job_config.generation engine_args = EngineArgs( # Model configuration model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - dtype=inference.dtype, + dtype=generation.dtype, # Parallelism configuration - tensor_parallel_size=inference.parallelism.tensor_parallel_degree, - distributed_executor_backend=inference.distributed_executor_backend, + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, # Memory and performance - gpu_memory_utilization=inference.gpu_memory_utilization, - enforce_eager=inference.enforce_eager, + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, # Seed - seed=inference.seed, + seed=generation.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) @@ -67,7 +67,7 @@ def infer(): logger.info("vLLM LLMEngine initialized successfully") # Create sampling parameters from JobConfig - sampling = job_config.inference.sampling + sampling = job_config.generation.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py index 47e41bc831..fdbe086352 100644 --- a/torchtitan/experiments/rl/unified/job_config.py +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. """ -Extended JobConfig for RL/Inference workloads. +Extended JobConfig for RL/Generation workloads. -This module extends TorchTitan's base JobConfig with inference-specific +This module extends TorchTitan's base JobConfig with generation-specific configurations needed for vLLM integration. """ @@ -43,25 +43,24 @@ class Sampling: @dataclass -class Inference: +class Generate: """ - Inference configuration for vLLM engine. + Generation configuration for vLLM engine. - This dataclass contains essential vLLM-specific settings for inference. + This dataclass contains essential vLLM-specific settings for generation. """ dtype: str = "bfloat16" """Data type for model weights (auto, float16, bfloat16, float32)""" gpu_memory_utilization: float = 0.5 - """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" + """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" distributed_executor_backend: str = "external_launcher" """ Backend for distributed execution. 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) """ - seed: int = 42 """Random seed for sampling""" @@ -69,20 +68,35 @@ class Inference: """Whether to enforce eager execution (disable CUDA graphs)""" parallelism: Parallelism = field(default_factory=Parallelism) - """Parallelism configuration for inference""" + """Parallelism configuration for generation""" sampling: Sampling = field(default_factory=Sampling) - """Sampling configuration for inference""" + """Sampling configuration for generation""" + + +@dataclass +class RL: + """Reinforcement Learning configuration for GRPO training.""" + + grpo_beta: int = 0.1 + """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" + + use_stable_grpo: bool = False + """Whether to use stable version of GRPO algorithm""" + + grpo_group_size: int = 8 + """Number of samples in each GRPO group for policy optimization""" @dataclass class JobConfig(BaseJobConfig): """ - Extended JobConfig with inference support. + Extended JobConfig with generation support. - This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields - for vLLM-specific inference and generation configurations. + This extends TorchTitan's base JobConfig by adding `generation` field + for vLLM-specific generation configurations. """ - inference: Inference = field(default_factory=Inference) - """Inference configuration for vLLM engine""" + generation: Generate = field(default_factory=Generate) + """Generation configuration for vLLM engine""" + rl: RL = field(default_factory=RL) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 0e5d6cde52..2738a95728 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,42 +6,22 @@ import logging -from enum import Enum import torch -from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig logger = logging.getLogger(__name__) -class ModelMode(str, Enum): - """ - Enum defining which TorchTitan model to use. - - Attributes: - UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified - training and inference. - VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, - ensuring bitwise determinism between training and inference. - STANDARD: Plain TorchTitan model without any modifications. - """ - - UNIFIED = "unified" - VLLM_COMPAT = "vllm_compat" - STANDARD = "standard" - - def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -102,26 +82,20 @@ def replace_with_vllm_compatible_flash_attention(model): ) -def load_model( - checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT -): +def load_trainer_model(model_path: str): """ Load TorchTitan model from checkpoint for trainer. Args: - checkpoint_path: Path to TorchTitan checkpoint model_path: Path to HuggingFace model (for config) - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model Returns: model: Loaded TorchTitan model for trainer. """ - # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # Create model args + # Create model args. model_args = Qwen3ModelArgs( dim=hf_config.hidden_size, n_layers=hf_config.num_hidden_layers, @@ -135,46 +109,27 @@ def load_model( ), hidden_dim=hf_config.intermediate_size, norm_eps=hf_config.rms_norm_eps, - rope_theta=hf_config.rope_theta, + rope_theta=getattr(hf_config.rope_parameters, "rope_theta", 1000000), max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, eos_id=getattr(hf_config, "eos_token_id", 151645), ) + # convert to torchtitan state_dict. TODO: Use torchtitan components + titan_state_dict = vllm_to_torchtitan(model_path) - # state_dict is in standard TorchTitan format (w1, w2, w3) - state_dict = load_file(checkpoint_path) - - if model_mode == ModelMode.UNIFIED: - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention for training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=True) - elif model_mode == ModelMode.VLLM_COMPAT: - # Create and load model that has bitwise determinism between training and inference - from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( - Qwen3VLLMCompatModel, - ) + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention during training. + replace_with_vllm_compatible_flash_attention(model) - model = Qwen3VLLMCompatModel(model_args) - # Convert to vLLM-compat format (merged gate_up_proj, down_proj) - vllm_compat_state = torchtitan_to_vllm_compat(state_dict) - model.load_state_dict(vllm_compat_state, strict=False) - else: - # Use standard TorchTitan model - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=False) + # Load standard TorchTitan format directly + model.load_state_dict(titan_state_dict, strict=True) model.to(torch.bfloat16) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml index 6166c50a91..2080e7a39f 100644 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -29,8 +29,8 @@ steps = 10 dataset = "c4" [parallelism] # training parallelism plan -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_replicate_degree = 2 +data_parallel_shard_degree = 1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 @@ -45,18 +45,18 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba enable=false components = ["model", "loss"] -[inference] +[generation] dtype = "bfloat16" gpu_memory_utilization = 0.5 distributed_executor_backend = "external_launcher" seed = 42 enforce_eager = true -[inference.parallelism] +[generation.parallelism] data_parallel_replicate_degree = 1 -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 -[inference.sampling] +[generation.sampling] temperature = 0.8 top_p = 0.95 max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_grpo.py similarity index 64% rename from torchtitan/experiments/rl/unified/simple_rl_multiprocess.py rename to torchtitan/experiments/rl/unified/simple_grpo.py index 28581e2747..c874feadc4 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -14,7 +14,8 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml """ import asyncio import logging @@ -22,13 +23,9 @@ import torch from monarch.actor import this_host from monarch.utils import setup_env_for_distributed +from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import Trainer -from torchtitan.experiments.rl.unified.models.utils import ModelMode -from torchtitan.experiments.rl.vllm_compat.simple_rl import ( - download_and_convert_model, - load_gsm8k_dataset, -) from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, vllm_is_batch_invariant, @@ -39,33 +36,21 @@ async def main(): """Run the distributed RL training loop using Monarch.""" - # Model Config - model_name = "Qwen/Qwen3-0.6B" - cache_dir = "./models" - output_dir = "./converted" - - # Training config - group_size = 8 - num_steps = 10 - learning_rate = 1e-5 - max_new_tokens = 20 - # GRPO config - use_stable_grpo = False - grpo_beta = 0.1 + # Step 1: Load job config using config manager + config_manager = ConfigManager() + job_config = config_manager.parse_args() - # Dataset config - use_real_dataset = False - num_dataset_samples = 5 + # RL Training config + num_steps = job_config.training.steps # Parallelism sizes - trainer_ddp_size = 2 - trainer_tp_size = 1 - generator_tp_size = 1 + trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree + trainer_tp_size = job_config.parallelism.tensor_parallel_degree + # TODO: add a flag to enable/disable batch_invariant init_batch_invariance() batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -78,37 +63,24 @@ async def main(): else: raise RuntimeError("Batch invariance NOT detected - using standard model") - # Download and convert model - titan_checkpoint_path, model_path = download_and_convert_model( - model_name, cache_dir, output_dir - ) - - # Load dataset - if use_real_dataset: - logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") - # TODO: Refactor into loading torchtitan dataset - prompt_texts, expected_answers = load_gsm8k_dataset( - split="train", num_samples=num_dataset_samples - ) - if prompt_texts is None or len(prompt_texts) == 0: - use_real_dataset = False - - if not use_real_dataset: - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] + # Use fake dataset for test. TODO: Implement real RL dataloader. + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] logger.info(f"Loaded {len(prompt_texts)} prompts") # Create process meshes - trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + trainer_mesh = this_host().spawn_procs( + per_host={"gpus": trainer_ddp_size * trainer_tp_size} + ) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) # Set up distributed env vars so that actors are connected via c10d @@ -129,27 +101,15 @@ async def main(): trainer = trainer_mesh.spawn( "trainer", Trainer, - titan_checkpoint_path, - model_path, - learning_rate, - mode, - trainer_ddp_size, - trainer_tp_size, + job_config, # Pass full job_config ) generator = gen_mesh.spawn( "generator", Generator, - model_path, + job_config, # Pass full job_config prompt_texts, expected_answers, - group_size, - max_new_tokens, - 1.0, # temperature - use_real_dataset, - grpo_beta, - use_stable_grpo, - generator_tp_size, ) # Initialize generator with trainer weights From 0f77caf13ee21c902372d71f8411ed2daace1134 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 12 Jan 2026 16:34:38 -0800 Subject: [PATCH 05/10] Update on "[rl] Using JobConfig as the centralized config system for inference and simple GRPO" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Add job_config.py to extend current JobConfig. Now an issue is `trainer`'s config and `generator`'s config are not symmetric, eg `Parallelism` and `Generation.parallelism` 2. Use job config system as the centralized / source-of-truth config, loading config from `run_configs/qwen3_0.6b.toml` file. 3. Refactor the generator to use EngineArgs() and LLMEngine(), instead of LLM() 4. Rename simple_rl_multiprocess -> simple_grpo to be more descriptive 5. Clean up unused code branch Test: (trainer ddp = 2, n_generator =1) Screenshot 2025-12-30 at 7 34 00 PM Following-up refactors: - Refactor2: vllm model register - using setup.py and plugin instead of import - Refactor3: Weight updater, by directly passing state_dict (DTensor) between trainer and generator - Refactor4: Use torchtitan Trainer, modularize each component [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 2 +- torchtitan/experiments/rl/unified/infer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index ae604b19aa..699489c72e 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -69,7 +69,7 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path Date: Mon, 23 Feb 2026 01:59:29 -0800 Subject: [PATCH 06/10] Update [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 3 +- .../rl/unified/actors/generator.py | 154 ++++++++++------- .../experiments/rl/unified/actors/trainer.py | 51 ++++-- .../experiments/rl/unified/config_registry.py | 143 ++++++++++++++++ torchtitan/experiments/rl/unified/configs.py | 157 ++++++++++++++++++ .../experiments/rl/unified/configurable.py | 66 ++++++++ .../experiments/rl/unified/simple_grpo.py | 41 +++-- 7 files changed, 518 insertions(+), 97 deletions(-) create mode 100644 torchtitan/experiments/rl/unified/config_registry.py create mode 100644 torchtitan/experiments/rl/unified/configs.py create mode 100644 torchtitan/experiments/rl/unified/configurable.py diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 00963df5e1..224ac68db7 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -49,8 +49,7 @@ python torchtitan/experiments/rl/unified/infer.py --tensor-parallel-size 2 5. Run simple rl loop ``` -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ - --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +python3 torchtitan/experiments/rl/unified/simple_grpo.py --checkpoint.initial_load_path= ``` We use a unified model definition for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs. diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index cb5ea7fa15..4a32637a35 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -8,28 +8,34 @@ import logging import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List import torch from monarch.actor import Actor, endpoint from safetensors.torch import save_file -from torchtitan.config.job_config import Comm + +# TODO: Replace with ``from torchtitan.config.configs import ParallelismConfig`` +# once the config branch lands. +from torchtitan.config.job_config import Parallelism as ParallelismConfig from torchtitan.distributed import utils as dist_utils # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 -from torchtitan.experiments.rl.unified.job_config import JobConfig +from torchtitan.experiments.rl.unified.configs import RLTrainer, VLLMSamplingConfig +# TODO: Replace with ``from torchtitan.config import Configurable`` +# once the config branch lands. +from torchtitan.experiments.rl.unified.configurable import Configurable from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_grpo_advantages, compute_grpo_advantages_stable, trivial_reward_function, ) +from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.config import AttentionConfig -from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.model_executor.layers.batch_invariant import init_batch_invariance from vllm.sampling_params import RequestOutputKind @@ -62,7 +68,7 @@ class TrajectoryData: advantages: torch.Tensor -class VLLMRolloutEngine: +class VLLMEngine(Configurable): """ vLLM engine for fast rollouts with weight updates. @@ -70,22 +76,44 @@ class VLLMRolloutEngine: directory with updated weights and restart the engine. This is faster than recreating temp dirs repeatedly and handles config/tokenizer files properly. - Args: - job_config: JobConfig dataclass containing all configuration - model_path: Path to HuggingFace model (for config/tokenizer) + Constructed via ``config.build(model_path=..., dump_folder=...)`` + which calls ``VLLMEngine(config=..., model_path=..., dump_folder=...)``. """ + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """vLLM engine configuration for rollout generation.""" + + dtype: str = "bfloat16" + """Data type for model weights (auto, float16, bfloat16, float32).""" + + gpu_memory_limit: float = 0.5 + """Fraction of GPU memory to use for the vLLM engine (0.0 to 1.0).""" + + enforce_eager: bool = True + """Disable CUDA graphs in vLLM (use eager execution).""" + + seed: int = 42 + """Random seed for reproducible generation.""" + + parallelism: ParallelismConfig = field(default_factory=ParallelismConfig) + """Parallelism configuration for the vLLM engine.""" + + sampling: VLLMSamplingConfig = field(default_factory=VLLMSamplingConfig) + """Default sampling parameters for generation.""" + def __init__( self, - job_config: JobConfig, + config: Config, + *, model_path: str, - ): - # Store job_config for accessing configuration - self.job_config = job_config + dump_folder: str, + ) -> None: + self.config = config self.base_model_path = model_path self.temp_model_dir = os.path.abspath( - os.path.join(job_config.job.dump_folder, "vllm_temp_model") + os.path.join(dump_folder, "vllm_temp_model") ) os.makedirs(self.temp_model_dir, exist_ok=True) @@ -145,7 +173,7 @@ def update_weights(self, vllm_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safetensor files which we only + # right now we hardcoded to work with 1 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight @@ -157,7 +185,6 @@ def update_weights(self, vllm_state: dict) -> None: k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in vllm_state.items() } - # Fallback: save as single file save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing @@ -168,21 +195,21 @@ def update_weights(self, vllm_state: dict) -> None: # First time: create the engine using LLMEngine and EngineArgs if self.engine is None: - generation = self.job_config.generation + cfg = self.config engine_args = EngineArgs( # Model configuration model=self.temp_model_dir, trust_remote_code=True, - dtype=generation.dtype, + dtype=cfg.dtype, # Parallelism configuration - tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + tensor_parallel_size=cfg.parallelism.tensor_parallel_degree, distributed_executor_backend="external_launcher", # Memory and performance - gpu_memory_utilization=generation.gpu_memory_utilization, - enforce_eager=generation.enforce_eager, + gpu_memory_utilization=cfg.gpu_memory_limit, + enforce_eager=cfg.enforce_eager, # Seed - seed=self.job_config.debug.seed, + seed=cfg.seed, # HuggingFace overrides to use TorchTitan model. # TODO: make this field configurable and align with model registration hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, @@ -198,9 +225,7 @@ def update_weights(self, vllm_state: dict) -> None: # Direct parameter copy into model tensors. # This bypasses vLLM's reload_weights() which uses a layerwise # reload mechanism that moves params to meta device - from torchtitan.experiments.rl.vllm_compat.weights import ( - vllm_to_torchtitan, - ) + from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan titan_state = vllm_to_torchtitan(vllm_state) self._direct_weight_update(titan_state) @@ -213,7 +238,7 @@ def _direct_weight_update(self, titan_state: dict) -> None: """ # Access model from vLLM engine - model = self.llm.llm_engine.model_executor.driver_worker.get_model() + model = self.engine.model_executor.driver_worker.get_model() params = dict(model.named_parameters()) for name, new_weight in titan_state.items(): @@ -266,15 +291,9 @@ def generate( output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - # Add requests to the engine - # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id - request_id = 0 - prompt_indices = [] # Track which prompt each request corresponds to - for prompt_idx, prompt in enumerate(prompt_texts): - for sample_idx in range(n_samples_per_prompt): - self.engine.add_request(str(request_id), prompt, sampling_params) - prompt_indices.append(prompt_idx) - request_id += 1 + # Add one request per prompt; vLLM handles n_samples_per_prompt via n= + for request_id, prompt in enumerate(prompt_texts): + self.engine.add_request(str(request_id), prompt, sampling_params) # Step through engine until all requests are finished all_outputs = [] @@ -290,7 +309,6 @@ def generate( prompt_token_ids_list = [] for output in all_outputs: - # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids # Each output now has exactly 1 sample (we submitted multiple requests) @@ -343,7 +361,7 @@ class GeneratorState: READY_TO_UPDATE = "READY_TO_UPDATE" -class Generator(Actor): +class Generator(Actor, Configurable): """ Generates rollouts using vLLM engine. @@ -351,47 +369,61 @@ class Generator(Actor): via weight sync. Generates completions for given prompts and computes rewards/advantages. - Args: - job_config: JobConfig dataclass containing all configuration - prompt_texts: List of prompt strings - expected_answers: List of expected answers + Constructed via ``config.build(rl_config=..., prompt_texts=..., expected_answers=...)`` + which calls ``Generator(config=..., rl_config=..., prompt_texts=..., expected_answers=...)``. """ + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """Generator actor configuration.""" + + vllm_engine: VLLMEngine.Config = field(default_factory=VLLMEngine.Config) + """vLLM rollout engine configuration.""" + + vllm_attention_backend: str = "FLASH_ATTN" + """vLLM attention backend to use (e.g., FLASH_ATTN, XFORMERS).""" + def __init__( self, - job_config: JobConfig, - prompt_texts: List[ - str - ], # TODO: This field need to be removed once dataloader is implemented - expected_answers: List[str], + config: Config, + *, + rl_config: RLTrainer.Config, + prompt_texts: list[str], + expected_answers: list[str], ): + self.config = config + self.rl_config = rl_config + # Set vLLM environment variables from config before any vLLM initialization - policy_opt = job_config.policy_optimization - if policy_opt.vllm_batch_invariant: + if rl_config.batch_invariant_mode: os.environ["VLLM_BATCH_INVARIANT"] = "1" init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) - os.environ["VLLM_ATTENTION_BACKEND"] = policy_opt.vllm_attention_backend + os.environ["VLLM_ATTENTION_BACKEND"] = config.vllm_attention_backend - # Store job_config for accessing configuration - self.job_config = job_config self.prompt_texts = prompt_texts self.expected_answers = expected_answers - # Extract needed fields from job_config - self.model_path = job_config.checkpoint.initial_load_path - self.max_new_tokens = job_config.generation.sampling.max_tokens - self.temperature = job_config.generation.sampling.temperature - self.group_size = job_config.policy_optimization.grpo_group_size - self.grpo_beta = job_config.policy_optimization.grpo_beta - self.use_stable_grpo = job_config.policy_optimization.use_stable_grpo + # Extract needed fields from configs + self.model_path = rl_config.trainer.checkpoint.initial_load_path + self.max_new_tokens = config.vllm_engine.sampling.max_tokens + self.temperature = config.vllm_engine.sampling.temperature + self.group_size = rl_config.policy_optimization.group_size + self.grpo_beta = rl_config.policy_optimization.beta + self.use_stable_grpo = rl_config.policy_optimization.use_stable # Initialize distributed environment for SPMD generator - world_size = dist_utils.init_distributed( - Comm(), + # TODO: Replace rl_config.trainer.comm with rl_config.trainer.comm + # once the config branch lands (Comm -> CommConfig). + from torchtitan.config.job_config import Comm + + world_size = dist_utils.init_distributed(Comm()) + + # Build vLLM engine from its config + self.vllm_engine = config.vllm_engine.build( + model_path=self.model_path, + dump_folder=rl_config.dump_folder, ) - # Initialize vLLM engine with job_config - self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index ccfa10a708..6c3f4c4d16 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,10 +11,10 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData +from torchtitan.experiments.rl.unified.configs import RLTrainer from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.job_config import JobConfig from torchtitan.experiments.rl.unified.models.utils import load_trainer_model from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, @@ -24,26 +24,38 @@ logger = logging.getLogger(__name__) -class Trainer(Actor): +class RLPolicyTrainer(Actor): """ Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. - TODO: Use torchtitan Trainer + Receives the top-level ``RLTrainer.Config`` and reads policy trainer + settings (batch_invariant_mode, grpo) directly from it, plus model / + optimizer / parallelism settings from the nested ``config.trainer``. + + TODO: Use torchtitan Trainer for model init and parallelisation. Args: - job_config: JobConfig dataclass containing all configuration + config: Top-level RLTrainer.Config containing all configuration. """ def __init__( self, - job_config: JobConfig, + config: RLTrainer.Config, ): - # Extract needed fields from job_config - model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint - learning_rate = job_config.optimizer.lr - self.ddp_size = job_config.parallelism.data_parallel_replicate_degree - self.tp_size = job_config.parallelism.tensor_parallel_degree + self.config = config + trainer_cfg = config.trainer + + # Extract needed fields from config + model_path = trainer_cfg.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = trainer_cfg.optimizer.lr + self.ddp_size = trainer_cfg.parallelism.data_parallel_replicate_degree + self.tp_size = trainer_cfg.parallelism.tensor_parallel_degree + + # GRPO settings from top-level config + self.group_size = config.policy_optimization.group_size + self.grpo_beta = config.policy_optimization.beta + self.use_stable_grpo = config.policy_optimization.use_stable # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) @@ -57,13 +69,14 @@ def __init__( # apply PT-D Parallelism # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from train_spec - from torchtitan.models.llama3.infra.parallelize import apply_ddp + if self.ddp_size > 1: + from torchtitan.models.llama3.infra.parallelize import apply_ddp - apply_ddp( - self.model, - self.parallel_dims.get_mesh("dp_replicate"), - enable_compile=False, - ) + apply_ddp( + self.model, + self.parallel_dims.get_mesh("dp_replicate"), + enable_compile=False, + ) self.model = self.model.to(device) self.model.train() @@ -73,7 +86,11 @@ def __init__( self.policy_version = 0 self.generator: Optional[Any] = None - logger.info("Trainer initialized with TorchTitan model") + logger.info( + f"RLPolicyTrainer initialized: " + f"group_size={self.group_size}, grpo_beta={self.grpo_beta}, " + f"use_stable_grpo={self.use_stable_grpo}" + ) @endpoint async def get_weights(self) -> dict: diff --git a/torchtitan/experiments/rl/unified/config_registry.py b/torchtitan/experiments/rl/unified/config_registry.py new file mode 100644 index 0000000000..2f96b44578 --- /dev/null +++ b/torchtitan/experiments/rl/unified/config_registry.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Config entry points for the RL/unified experiment. + +Each function returns a complete ``RLTrainer.Config`` and is discoverable by +``ConfigManager`` via ``--module rl.unified --config ``. + +TODO: Once the config branch lands, replace the ``JobConfig`` sub-dataclass +imports (``Checkpoint``, ``Optimizer``, ``LRScheduler``, ``Training``, +``Parallelism``, ``ActivationCheckpoint``) with their config-branch +counterparts (``CheckpointManager.Config``, ``OptimizersContainer.Config``, +etc.) and replace ``from torchtitan.experiments.rl.unified.job_config import +JobConfig`` with ``from torchtitan.trainer import Trainer``. +""" + +from torchtitan.config.job_config import ( + ActivationCheckpoint, + Checkpoint, + LRScheduler, + Model, + Optimizer, + Parallelism, + Training, +) + +from torchtitan.experiments.rl.unified.actors.generator import Generator, VLLMEngine +from torchtitan.experiments.rl.unified.configs import ( + PolicyOptimizationConfig, + RLTrainer, + VLLMSamplingConfig, +) +from torchtitan.experiments.rl.unified.job_config import JobConfig + + +def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: + """GRPO training config for Qwen3-0.6B.""" + return RLTrainer.Config( + trainer=JobConfig( + model=Model( + name="qwen3", + flavor="0.6B", + ), + optimizer=Optimizer(lr=1e-6), + lr_scheduler=LRScheduler( + warmup_steps=2, + decay_type="linear", + ), + training=Training( + local_batch_size=4, + seq_len=4096, + steps=10, + ), + parallelism=Parallelism( + tensor_parallel_degree=1, + data_parallel_replicate_degree=2, + ), + checkpoint=Checkpoint( + initial_load_path="/data/users/jianiw/model/qwen3-0.6b", + initial_load_model_only=True, + initial_load_in_hf=True, + ), + activation_checkpoint=ActivationCheckpoint( + mode="selective", + selective_ac_option="op", + ), + ), + batch_invariant_mode=True, + policy_optimization=PolicyOptimizationConfig( + beta=0.1, + group_size=8, + use_stable=False, + ), + generator=Generator.Config( + vllm_engine=VLLMEngine.Config( + dtype="bfloat16", + gpu_memory_limit=0.5, + enforce_eager=True, + seed=42, + parallelism=Parallelism( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=0.8, + top_p=0.95, + max_tokens=100, + ), + ), + vllm_attention_backend="FLASH_ATTN", + ), + ) + + +def rl_grpo_qwen3_debug() -> RLTrainer.Config: + """Debug config for quick iteration — small model, few steps.""" + return RLTrainer.Config( + trainer=JobConfig( + model=Model( + name="qwen3", + flavor="debugmodel", + ), + optimizer=Optimizer(lr=8e-4), + lr_scheduler=LRScheduler( + warmup_steps=2, + decay_type="linear", + ), + training=Training( + local_batch_size=2, + seq_len=2048, + steps=5, + ), + parallelism=Parallelism( + tensor_parallel_degree=1, + data_parallel_replicate_degree=1, + ), + checkpoint=Checkpoint( + interval=5, + ), + ), + batch_invariant_mode=False, + policy_optimization=PolicyOptimizationConfig( + beta=0.1, + group_size=4, + use_stable=False, + ), + generator=Generator.Config( + vllm_engine=VLLMEngine.Config( + gpu_memory_limit=0.3, + enforce_eager=True, + parallelism=Parallelism( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=1.0, + max_tokens=50, + ), + ), + ), + ) diff --git a/torchtitan/experiments/rl/unified/configs.py b/torchtitan/experiments/rl/unified/configs.py new file mode 100644 index 0000000000..3859117b97 --- /dev/null +++ b/torchtitan/experiments/rl/unified/configs.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Top-level RL training configuration. + +This module defines: +- Leaf data configs: VLLMSamplingConfig, PolicyOptimizationConfig (plain dataclasses, no build()) +- RLTrainer: top-level Configurable that composes a canonical JobConfig + (as a sub-component for model/optimizer/parallelism) with RL-specific fields. + +Config hierarchy:: + + RLTrainer.Config + ├── trainer: JobConfig # canonical trainer sub-component + │ ├── model_spec, training, parallelism, optimizer, lr_scheduler, + │ │ checkpoint, compile, activation_checkpoint, comm, debug, ... + ├── dump_folder: str # root output folder + ├── batch_invariant_mode: bool # policy trainer setting + ├── policy_optimization: PolicyOptimizationConfig + │ ├── beta, group_size, use_stable + └── generator: Generator.Config # rollout generator actor + ├── vllm_engine: VLLMEngine.Config + │ ├── dtype, gpu_memory_utilization, enforce_eager, seed + │ ├── parallelism: ParallelismConfig + │ └── sampling: VLLMSamplingConfig + └── vllm_attention_backend: str +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +# TODO: Replace with ``from torchtitan.config import Configurable`` +# once the config branch lands. +from torchtitan.experiments.rl.unified.configurable import Configurable + +# TODO: Replace with ``from torchtitan.trainer import Trainer`` +# once the config branch lands. For now we use the existing JobConfig +# as the trainer config type. +from torchtitan.experiments.rl.unified.job_config import JobConfig + + +# --------------------------------------------------------------------------- +# Leaf data configs (plain dataclasses, not Configurable — no build() needed) +# --------------------------------------------------------------------------- + + +@dataclass(kw_only=True, slots=True) +class VLLMSamplingConfig: + """Sampling parameters passed to vLLM's SamplingParams.""" + + temperature: float = 0.8 + """Sampling temperature. 0.0 = greedy, higher = more random.""" + + top_p: float = 0.95 + """Nucleus sampling threshold.""" + + max_tokens: int = 100 + """Maximum number of tokens to generate per completion.""" + + +@dataclass(kw_only=True, slots=True) +class PolicyOptimizationConfig: + """Hyperparameters for Group Relative Policy Optimization.""" + + beta: float = 0.1 + """Temperature for GRPO exponential advantage weighting.""" + + group_size: int = 8 + """Number of completions per prompt for group-relative ranking.""" + + use_stable: bool = False + """Use stable mean-centering GRPO instead of exponential weighting.""" + + +# --------------------------------------------------------------------------- +# Top-level RL orchestrator +# --------------------------------------------------------------------------- + + +class RLTrainer(Configurable): + """Top-level RL training orchestrator. + + Initialises a canonical ``Trainer`` (for model construction, parallelism, + optimizer, checkpoint, etc.) as a sub-component, then creates an + Generator Monarch actor on a process mesh and runs the + generate → train loop. + + Policy trainer settings (GRPO hyperparameters, batch invariance) live + directly on this config because the RLTrainer *is* the policy trainer. + + Constructed via ``config.build()`` which calls ``RLTrainer(config=...)``. + """ + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + """Master config for RL training. + + ``trainer`` holds the canonical ``JobConfig`` that handles model + construction, parallelism, optimizer, LR scheduler, checkpoint, etc. + The remaining fields are RL-specific. + """ + + # -- Canonical trainer as a sub-component -- + + trainer: JobConfig = field(default_factory=JobConfig) + """Canonical TorchTitan trainer config. Controls model_spec, training, + parallelism, optimizer, lr_scheduler, checkpoint, compile, + activation_checkpoint, comm, debug, and all other standard fields.""" + + # -- Top-level RL settings -- + + dump_folder: str = "outputs/rl" + """Root output folder for RL artifacts (temp weights, logs, etc.).""" + + batch_invariant_mode: bool = True + """Enable batch-invariant mode for deterministic NCCL collective + operations and bitwise-reproducible forward/backward passes.""" + + policy_optimization: PolicyOptimizationConfig = field( + default_factory=PolicyOptimizationConfig + ) + """Policy optimization hyperparameters (GRPO).""" + + # -- Generator actor config -- + # Lazy default factory breaks the circular import: + # configs.py ↔ actors/generator.py + + generator: Generator.Config = field( # type: ignore[name-defined] # noqa: F821 + default_factory=lambda: _default_rl_generator_config(), + ) + """Generator actor configuration (vLLM engine, sampling).""" + + def __init__(self, config: Config): + self.config = config + + # TODO: Once the config branch lands and ``trainer`` becomes a + # ``Trainer.Config`` (a Configurable), replace the line below with: + # self.trainer = config.trainer.build() + # That will handle distributed init, model construction, + # parallelisation, optimizer/LR-scheduler creation, checkpoint + # loading, etc. + + +# --------------------------------------------------------------------------- +# Lazy default factory — resolved at runtime to break circular import +# --------------------------------------------------------------------------- + + +def _default_rl_generator_config(): + from torchtitan.experiments.rl.unified.actors.generator import Generator + + return Generator.Config() diff --git a/torchtitan/experiments/rl/unified/configurable.py b/torchtitan/experiments/rl/unified/configurable.py new file mode 100644 index 0000000000..69aa328e15 --- /dev/null +++ b/torchtitan/experiments/rl/unified/configurable.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Local copy of the Configurable base class. + +This mirrors ``torchtitan.config.Configurable`` from the ``config`` branch. +Once that branch lands, replace all imports of this module with:: + + from torchtitan.config import Configurable +""" + +from dataclasses import dataclass, fields +from typing import ClassVar + + +class Configurable: + """Base class for all configurable components. + + Every configurable class: + - Inherits from Configurable (or Module for nn.Module components) + - Defines a nested Config(Configurable.Config) with @dataclass(kw_only=True, slots=True) + - Gets build() auto-wired via __init_subclass__ (no manual override needed) + - Accepts __init__(self, config: Config, **runtime_kwargs) + + Enforcement: Configurable.__init_subclass__ checks that every Config uses + @dataclass(kw_only=True, slots=True). This check runs on the OUTER class + (not Config.__init_subclass__) because @dataclass(slots=True) replaces the + class, so Config.__init_subclass__ sees the pre-decorator version. + """ + + @dataclass(kw_only=True, slots=True) + class Config: + _owner: ClassVar[type | None] = None + + def build(self, **kwargs): + """Construct the owning class. Auto-wired by __init_subclass__.""" + if self._owner is None: + raise NotImplementedError( + f"{type(self).__name__} has no owner class. " + "Define Config inside a Configurable subclass." + ) + return self._owner(config=self, **kwargs) + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if "Config" in cls.__dict__: + config_cls = cls.__dict__["Config"] + if issubclass(config_cls, Configurable.Config): + # Enforce @dataclass(kw_only=True, slots=True) + if "__slots__" not in config_cls.__dict__: + raise TypeError( + f"{cls.__name__}.Config must use " + "@dataclass(kw_only=True, slots=True)" + ) + for f in fields(config_cls): + if not f.kw_only: + raise TypeError( + f"{cls.__name__}.Config field '{f.name}' " + "must be keyword-only" + ) + # Auto-wire build() to construct this class + config_cls._owner = cls diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_grpo.py index 3bd8cd77ff..e32b66d0e2 100644 --- a/torchtitan/experiments/rl/unified/simple_grpo.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -14,18 +14,20 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -python3 torchtitan/experiments/rl/unified/simple_grpo.py \ - --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +python3 torchtitan/experiments/rl/unified/simple_grpo.py """ + import asyncio import logging import torch + +import torchtitan.experiments.rl.unified # noqa: F401 — registers models with vLLM from monarch.actor import this_host from monarch.utils import setup_env_for_distributed -from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator -from torchtitan.experiments.rl.unified.actors.trainer import Trainer +from torchtitan.experiments.rl.unified.actors.trainer import RLPolicyTrainer +from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b logger = logging.getLogger(__name__) @@ -33,17 +35,20 @@ async def main(): """Run the distributed RL training loop using Monarch.""" - # Step 1: Load job config using config manager - config_manager = ConfigManager() - job_config = config_manager.parse_args() + # Step 1: Load config from config_registry + # TODO: Once the config branch lands, replace with: + # from torchtitan.config.manager import ConfigManager + # config = ConfigManager().parse_args() + config = rl_grpo_qwen3_0_6b() + trainer_cfg = config.trainer - # compute world size for trainer and generator + # Compute world size for trainer and generator # TODO: refine the world size computation and check - trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree - trainer_tp_size = job_config.parallelism.tensor_parallel_degree + trainer_ddp_size = trainer_cfg.parallelism.data_parallel_replicate_degree + trainer_tp_size = trainer_cfg.parallelism.tensor_parallel_degree # RL Training config - num_steps = job_config.training.steps + num_steps = trainer_cfg.training.steps # Use fake dataset for test. TODO: Implement real RL dataloader. logger.info("Using default prompts") @@ -63,7 +68,8 @@ async def main(): trainer_mesh = this_host().spawn_procs( per_host={"gpus": trainer_ddp_size * trainer_tp_size} ) - gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) + gen_tp_size = config.generator.vllm_engine.parallelism.tensor_parallel_degree + gen_mesh = this_host().spawn_procs(per_host={"gpus": gen_tp_size}) # Set up distributed env vars so that actors are connected via c10d await setup_env_for_distributed( @@ -82,16 +88,17 @@ async def main(): # Spawn actors on trainer and generator mesh trainer = trainer_mesh.spawn( "trainer", - Trainer, - job_config, # Pass full job_config + RLPolicyTrainer, + config, ) generator = gen_mesh.spawn( "generator", Generator, - job_config, # Pass full job_config - prompt_texts, - expected_answers, + config.generator, + rl_config=config, + prompt_texts=prompt_texts, + expected_answers=expected_answers, ) # Initialize generator with trainer weights From 7bf6b7471c2e9a9ad5c092bc93e0b205bcfb43c3 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 23 Feb 2026 11:18:22 -0800 Subject: [PATCH 07/10] Update [ghstack-poisoned] --- .../rl/unified/actors/generator.py | 38 +++---- .../experiments/rl/unified/actors/trainer.py | 4 +- .../experiments/rl/unified/config_registry.py | 2 +- torchtitan/experiments/rl/unified/configs.py | 23 +--- .../experiments/rl/unified/job_config.py | 100 ------------------ .../rl/unified/run_configs/qwen3_0.6b.toml | 67 ------------ .../experiments/rl/unified/simple_grpo.py | 4 +- 7 files changed, 26 insertions(+), 212 deletions(-) delete mode 100644 torchtitan/experiments/rl/unified/job_config.py delete mode 100644 torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 4a32637a35..8ef0a19f84 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -285,6 +285,7 @@ def generate( sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, + n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs @@ -311,31 +312,26 @@ def generate( for output in all_outputs: prompt_token_ids = output.prompt_token_ids - # Each output now has exactly 1 sample (we submitted multiple requests) - assert ( - len(output.outputs) == 1 - ), f"Expected 1 output, got {len(output.outputs)}" - sample = output.outputs[0] + for sample in output.outputs: + completions.append(sample.text) - completions.append(sample.text) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) - - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 6c3f4c4d16..aa9f2bd3da 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class RLPolicyTrainer(Actor): +class Trainer(Actor): """ Updates policy based on collected trajectories. @@ -87,7 +87,7 @@ def __init__( self.generator: Optional[Any] = None logger.info( - f"RLPolicyTrainer initialized: " + f"Trainer initialized: " f"group_size={self.group_size}, grpo_beta={self.grpo_beta}, " f"use_stable_grpo={self.use_stable_grpo}" ) diff --git a/torchtitan/experiments/rl/unified/config_registry.py b/torchtitan/experiments/rl/unified/config_registry.py index 2f96b44578..1a6515166c 100644 --- a/torchtitan/experiments/rl/unified/config_registry.py +++ b/torchtitan/experiments/rl/unified/config_registry.py @@ -21,6 +21,7 @@ from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, + JobConfig, LRScheduler, Model, Optimizer, @@ -34,7 +35,6 @@ RLTrainer, VLLMSamplingConfig, ) -from torchtitan.experiments.rl.unified.job_config import JobConfig def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: diff --git a/torchtitan/experiments/rl/unified/configs.py b/torchtitan/experiments/rl/unified/configs.py index 3859117b97..83df69a540 100644 --- a/torchtitan/experiments/rl/unified/configs.py +++ b/torchtitan/experiments/rl/unified/configs.py @@ -34,19 +34,14 @@ from dataclasses import dataclass, field -# TODO: Replace with ``from torchtitan.config import Configurable`` -# once the config branch lands. -from torchtitan.experiments.rl.unified.configurable import Configurable - # TODO: Replace with ``from torchtitan.trainer import Trainer`` # once the config branch lands. For now we use the existing JobConfig # as the trainer config type. -from torchtitan.experiments.rl.unified.job_config import JobConfig - +from torchtitan.config.job_config import JobConfig -# --------------------------------------------------------------------------- -# Leaf data configs (plain dataclasses, not Configurable — no build() needed) -# --------------------------------------------------------------------------- +# TODO: Replace with ``from torchtitan.config import Configurable`` +# once the config branch lands. +from torchtitan.experiments.rl.unified.configurable import Configurable @dataclass(kw_only=True, slots=True) @@ -77,11 +72,6 @@ class PolicyOptimizationConfig: """Use stable mean-centering GRPO instead of exponential weighting.""" -# --------------------------------------------------------------------------- -# Top-level RL orchestrator -# --------------------------------------------------------------------------- - - class RLTrainer(Configurable): """Top-level RL training orchestrator. @@ -146,11 +136,6 @@ def __init__(self, config: Config): # loading, etc. -# --------------------------------------------------------------------------- -# Lazy default factory — resolved at runtime to break circular import -# --------------------------------------------------------------------------- - - def _default_rl_generator_config(): from torchtitan.experiments.rl.unified.actors.generator import Generator diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py deleted file mode 100644 index 2cb6a7fec3..0000000000 --- a/torchtitan/experiments/rl/unified/job_config.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Extended JobConfig for RL/Generation workloads. - -This module extends TorchTitan's base JobConfig with generation-specific -configurations needed for vLLM integration. -""" - -from dataclasses import dataclass, field - -from torchtitan.config.job_config import JobConfig as BaseJobConfig, Parallelism - - -@dataclass -class VLLMSamplingParams: - """ - Sampling configuration for vLLM generation. - - This dataclass contains sampling parameters used during text generation. - These map directly to vLLM's SamplingParams. - """ - - temperature: float = 0.8 - """ - Temperature for sampling. Controls randomness in generation. - - 0.0: Deterministic (greedy decoding) - - Higher values: More random outputs - """ - - top_p: float = 0.95 - """ - Top-p (nucleus) sampling threshold. - Only tokens with cumulative probability up to top_p are considered. - """ - - max_tokens: int = 100 - """Maximum number of tokens to generate""" - - -@dataclass -class Generation: - """ - Generation configuration for vLLM engine. - - This dataclass contains essential vLLM-specific settings for generation. - """ - - dtype: str = "bfloat16" - """Data type for model weights (auto, float16, bfloat16, float32)""" - - gpu_memory_utilization: float = 0.5 - """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" - - enforce_eager: bool = True - """Whether to enforce eager execution (disable CUDA graphs)""" - - parallelism: Parallelism = field(default_factory=Parallelism) - """Parallelism configuration for generation""" - - sampling: VLLMSamplingParams = field(default_factory=VLLMSamplingParams) - """Sampling configuration for generation""" - - -@dataclass -class PolicyOptimization: - """Policy optimization configuration for GRPO training.""" - - grpo_beta: int = 0.1 - """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" - - use_stable_grpo: bool = False - """Whether to use stable version of GRPO algorithm""" - - grpo_group_size: int = 8 - """Number of samples in each GRPO group for policy optimization""" - - vllm_batch_invariant: bool = True - """Enable vLLM batch invariant mode for deterministic backward pass""" - - vllm_attention_backend: str = "FLASH_ATTN" - """vLLM attention backend to use (e.g., FLASH_ATTN, XFORMERS)""" - - -@dataclass -class JobConfig(BaseJobConfig): - """ - Extended JobConfig with generation support. - - This extends TorchTitan's base JobConfig by adding `generation` field - for vLLM-specific generation configurations. - """ - - generation: Generation = field(default_factory=Generation) - """Generation configuration for vLLM engine""" - policy_optimization: PolicyOptimization = field(default_factory=PolicyOptimization) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml deleted file mode 100644 index 31c27b6d49..0000000000 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ /dev/null @@ -1,67 +0,0 @@ -[job] -dump_folder = "./outputs" -description = "Qwen 3 0.6B training" -custom_config_module = "torchtitan.experiments.rl.unified.job_config" - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 100 - -[metrics] -log_freq = 1 -enable_tensorboard = false -save_tb_folder = "tb" - -[model] -name = "qwen3" -flavor = "0.6B" -hf_assets_path = "./assets/hf/Qwen3-0.6B" - -[lr_scheduler] -warmup_steps = 2 # lr scheduler warm up, 20% total steps - - -[policy_optimization] -grpo_beta = 0.1 -use_stable_grpo = false -grpo_group_size = 8 - - -[training] -local_batch_size = 4 -seq_len = 4096 -max_norm = 1.0 # grad norm clipping -steps = 10 -dataset = "c4" - -[parallelism] # training parallelism plan -data_parallel_replicate_degree = 2 -data_parallel_shard_degree = 1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 - -[checkpoint] -initial_load_path = "/data/users/jianiw/model/qwen3-0.6b" - -[activation_checkpoint] -mode = "none" # ["none", "selective", "full"] -selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy - -[compile] -enable=false -components = ["model", "loss"] - -[generation] -dtype = "bfloat16" -gpu_memory_utilization = 0.5 -enforce_eager = true - -[generation.parallelism] -data_parallel_replicate_degree = 1 -tensor_parallel_degree = 1 - -[generation.sampling] -temperature = 0.8 -top_p = 0.95 -max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_grpo.py index e32b66d0e2..c2f7df641e 100644 --- a/torchtitan/experiments/rl/unified/simple_grpo.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -26,7 +26,7 @@ from monarch.actor import this_host from monarch.utils import setup_env_for_distributed from torchtitan.experiments.rl.unified.actors.generator import Generator -from torchtitan.experiments.rl.unified.actors.trainer import RLPolicyTrainer +from torchtitan.experiments.rl.unified.actors.trainer import Trainer from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b logger = logging.getLogger(__name__) @@ -88,7 +88,7 @@ async def main(): # Spawn actors on trainer and generator mesh trainer = trainer_mesh.spawn( "trainer", - RLPolicyTrainer, + Trainer, config, ) From 9f579c1224f8cfe6b42ec3e56e1e6ee0ec95d166 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 23 Feb 2026 18:00:40 -0800 Subject: [PATCH 08/10] Update [ghstack-poisoned] --- torchtitan/experiments/rl/unified/actors/trainer.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 6828b65476..b5a978fa2a 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -40,6 +40,11 @@ class PolicyTrainer(Actor, Configurable): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. + Receives the top-level ``RLTrainer.Config`` and reads policy trainer + settings (batch_invariant_mode, grpo) directly from it, plus model / + optimizer / parallelism settings from the nested ``config.trainer``. + + TODO: Use torchtitan Trainer for model init and parallelisation. TODO: Use torchtitan PolicyTrainer for model init and parallelism. From c9471120c1ca66de9ed840547b94801b3f9f26e3 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 24 Feb 2026 00:55:13 -0800 Subject: [PATCH 09/10] Update [ghstack-poisoned] --- torchtitan/experiments/rl/unified/README.md | 14 +++--- .../experiments/rl/unified/config_registry.py | 2 +- torchtitan/experiments/rl/unified/infer.py | 45 +++++++++---------- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index fbc4d7ef70..64a0f35ba1 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -36,20 +36,16 @@ uv pip install torch vllm xformers --pre \ python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... ``` -4. Run inference: +4. Run inference with unified model definition: ```bash -python torchtitan/experiments/rl/unified/infer.py -``` - -Run with TP: -```bash -python torchtitan/experiments/rl/unified/infer.py --tensor-parallel-size 2 - +torchrun --nproc_per_node= \ + torchtitan/experiments/rl/unified/infer.py ``` 5. Run simple rl loop ``` -python3 torchtitan/experiments/rl/unified/simple_grpo.py --checkpoint.initial_load_path= +python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --trainer.checkpoint.initial_load_path= ``` We use a unified model definition for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs. diff --git a/torchtitan/experiments/rl/unified/config_registry.py b/torchtitan/experiments/rl/unified/config_registry.py index ac0e33590a..01bd1eb537 100644 --- a/torchtitan/experiments/rl/unified/config_registry.py +++ b/torchtitan/experiments/rl/unified/config_registry.py @@ -49,7 +49,7 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: data_parallel_replicate_degree=2, ), checkpoint=CheckpointManager.Config( - initial_load_path="/data/users/jianiw/model/qwen3-0.6b", + initial_load_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B", initial_load_model_only=True, initial_load_in_hf=True, ), diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index d4845813d9..0a54be356a 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -8,12 +8,11 @@ """ Example inference script using TorchTitan models with vLLM LLMEngine. -This script uses JobConfig loaded from a TOML file to configure both +This script uses the RL unified config_registry to configure both the vLLM engine and sampling parameters. -Run: torchrun --nproc_per_node=2 \ - torchtitan/experiments/rl/unified/infer.py \ - --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +Run: torchrun --nproc_per_node= \ + torchtitan/experiments/rl/unified/infer.py """ import os @@ -22,11 +21,11 @@ # See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" -from torchtitan.config import ConfigManager - # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 +from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b + from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.logger import init_logger @@ -36,32 +35,32 @@ def generate(): - config_manager = ConfigManager() - job_config = config_manager.parse_args() + config = rl_grpo_qwen3_0_6b() + vllm_config = config.generator.vllm_engine + model_path = config.trainer.checkpoint.initial_load_path logger.info("Initializing vLLM LLMEngine with TorchTitan model") - logger.info(f"Model: {job_config.checkpoint.initial_load_path}") + logger.info(f"Model: {model_path}") logger.info( - f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" + f"Tensor Parallel Size: {vllm_config.parallelism.tensor_parallel_degree}" ) - # Create EngineArgs from JobConfig - # Map TorchTitan parallelism to vLLM parallelism - generation = job_config.generation - + # Create EngineArgs from config engine_args = EngineArgs( # Model configuration - model=job_config.checkpoint.initial_load_path, + model=model_path, trust_remote_code=True, - dtype=generation.dtype, + dtype=vllm_config.dtype, # Parallelism configuration - tensor_parallel_size=generation.parallelism.tensor_parallel_degree, - distributed_executor_backend="external_launcher", + tensor_parallel_size=vllm_config.parallelism.tensor_parallel_degree, + # Use external_launcher only when launched via torchrun (multi-GPU); + # for single-GPU, let vLLM pick the default executor. + distributed_executor_backend=("external_launcher"), # Memory and performance - gpu_memory_utilization=generation.gpu_memory_utilization, - enforce_eager=generation.enforce_eager, + gpu_memory_utilization=vllm_config.gpu_memory_limit, + enforce_eager=vllm_config.enforce_eager, # Seed - seed=job_config.debug.seed, + seed=vllm_config.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) @@ -71,8 +70,8 @@ def generate(): logger.info("vLLM LLMEngine initialized successfully") - # Create sampling parameters from JobConfig - sampling = job_config.generation.sampling + # Create sampling parameters from config + sampling = vllm_config.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, From ec37dab616662054c107f79aa14621d553189729 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 25 Feb 2026 13:54:22 -0800 Subject: [PATCH 10/10] Update [ghstack-poisoned] --- torchtitan/experiments/rl/unified/__init__.py | 2 +- .../rl/unified/actors/generator.py | 523 +++++++----------- .../experiments/rl/unified/actors/trainer.py | 55 +- .../experiments/rl/unified/config_registry.py | 48 +- .../{infer.py => inference_example.py} | 36 +- .../rl/unified/infra/parallelism_utils.py | 2 +- .../experiments/rl/unified/models/utils.py | 45 +- .../rl/unified/models/vllm_wrapper.py | 10 +- .../experiments/rl/unified/simple_grpo.py | 19 +- 9 files changed, 295 insertions(+), 445 deletions(-) rename torchtitan/experiments/rl/unified/{infer.py => inference_example.py} (77%) diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py index 8a4efac66f..59fbb97610 100644 --- a/torchtitan/experiments/rl/unified/__init__.py +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -58,7 +58,7 @@ def __init__(self, *, vllm_config, prefix=""): # Register with vLLM ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) - logger.info( + logger.debug( f"Successfully registered {model_name} with vLLM using ModelSpec " f"(flavor={model_spec.flavor})" ) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 70bf41eddc..937f722ea9 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import asyncio +import glob import logging import os +import shutil from dataclasses import dataclass, field from typing import List @@ -14,8 +15,7 @@ import torch from monarch.actor import Actor, endpoint from safetensors.torch import save_file -from torchtitan.config import CommConfig, Configurable -from torchtitan.config.configs import ParallelismConfig +from torchtitan.config import CommConfig, Configurable, ParallelismConfig from torchtitan.distributed import utils as dist_utils # Import unified module - this automatically registers TorchTitan models with vLLM @@ -67,23 +67,30 @@ class TrajectoryData: advantages: torch.Tensor -class VLLMEngine(Configurable): +class Generator(Actor, Configurable): """ - vLLM engine for fast rollouts with weight updates. + Generates rollouts using vLLM engine. - Note: vLLM loads from model_config.model path, so we create a temporary - directory with updated weights and restart the engine. This is faster than - recreating temp dirs repeatedly and handles config/tokenizer files properly. + Maintains a vLLM engine that is synchronized with the Trainer + via weight sync. Generates completions for given prompts and + computes rewards/advantages. - Constructed via ``VLLMEngine(config, model_path=..., dump_folder=...)``. + Args: + config: Generator-specific configuration. + model_path: Path to the HF model checkpoint. + dump_folder: Root output folder for RL artifacts. + batch_invariant_mode: Enable batch-invariant mode for deterministic ops. + policy_optimization: GRPO hyperparameters. + prompt_texts: List of prompt strings. + expected_answers: List of expected answer strings. """ @dataclass(kw_only=True, slots=True) class Config(Configurable.Config): - """vLLM engine configuration for rollout generation.""" + """Generator actor configuration.""" dtype: str = "bfloat16" - """Data type for model weights (auto, float16, bfloat16, float32).""" + """Data type for model weights, passed directly to vLLM (auto, float16, bfloat16, float32).""" gpu_memory_limit: float = 0.5 """Fraction of GPU memory to use for the vLLM engine (0.0 to 1.0).""" @@ -91,8 +98,8 @@ class Config(Configurable.Config): enforce_eager: bool = True """Disable CUDA graphs in vLLM (use eager execution).""" - seed: int = 42 - """Random seed for reproducible generation.""" + seed: int | None = None + """Random seed for reproducible generation. None means no fixed seed.""" parallelism: ParallelismConfig = field(default_factory=ParallelismConfig) """Parallelism configuration for the vLLM engine.""" @@ -100,26 +107,53 @@ class Config(Configurable.Config): sampling: VLLMSamplingConfig = field(default_factory=VLLMSamplingConfig) """Default sampling parameters for generation.""" + vllm_attention_backend: str = "FLASH_ATTN" + """vLLM attention backend to use (e.g., FLASH_ATTN, XFORMERS).""" + def __init__( self, config: Config, *, + model_spec: ModelSpec, model_path: str, dump_folder: str, - ) -> None: + batch_invariant_mode: bool, + policy_optimization: PolicyOptimizationConfig, + prompt_texts: list[str], + expected_answers: list[str], + ): self.config = config - self.base_model_path = model_path + self.model_spec = model_spec + + # Set vLLM environment variables from config before any vLLM initialization + if batch_invariant_mode: + os.environ["VLLM_BATCH_INVARIANT"] = "1" + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) + + os.environ["VLLM_ATTENTION_BACKEND"] = config.vllm_attention_backend + + self.prompt_texts = prompt_texts + self.expected_answers = expected_answers + + # Extract needed fields from configs + self.model_path = model_path + self.max_new_tokens = config.sampling.max_tokens + self.temperature = config.sampling.temperature + self.group_size = policy_optimization.group_size + self.grpo_beta = policy_optimization.beta + self.use_stable_grpo = policy_optimization.use_stable_grpo - self.temp_model_dir = os.path.abspath( + # Initialize distributed environment for SPMD generator + world_size = dist_utils.init_distributed(CommConfig()) + + # Set up temp model directory for vLLM weight loading + self._base_model_path = model_path + self._temp_model_dir = os.path.abspath( os.path.join(dump_folder, "vllm_temp_model") ) - os.makedirs(self.temp_model_dir, exist_ok=True) - - import glob + os.makedirs(self._temp_model_dir, exist_ok=True) # Copy config/tokenizer files from base model to temp dir - import shutil - for file in [ "config.json", "tokenizer.json", @@ -130,55 +164,55 @@ def __init__( ]: src = os.path.join(model_path, file) if os.path.exists(src): - shutil.copy2(src, self.temp_model_dir) + shutil.copy2(src, self._temp_model_dir) # Copy the original model shard files if they exist # We'll overwrite these with our single model.safetensors later for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): - dst = os.path.join(self.temp_model_dir, os.path.basename(shard_file)) + dst = os.path.join(self._temp_model_dir, os.path.basename(shard_file)) shutil.copy2(shard_file, dst) # Copy index file if it exists index_file = os.path.join(model_path, "model.safetensors.index.json") if os.path.exists(index_file): - shutil.copy2(index_file, self.temp_model_dir) + shutil.copy2(index_file, self._temp_model_dir) - self.engine = None - logger.info("vLLM rollout engine initialized (will load on first use)") + self._engine: LLMEngine | None = None - def update_weights(self, vllm_state: dict) -> None: - """ - Update vLLM model weights from vLLM-compat state dict. + self.policy_version = 0 + + # Reward function. TODO: Move reward calculation out of generator + self.reward_fn = trivial_reward_function + + logger.debug("Generator initialized (vLLM engine will load on first use)") - This converts weights to vLLM format, saves them, and reloads using - vLLM's reload_weights() API after updating the model path config. + def _update_vllm_model_weights(self, vllm_state: dict) -> None: + """ + Update vLLM model weights from vLLM model state dict. This function is used + when updating vLLM model's weights from trainer's updated weights. Args: - vllm_state: vLLM model state dict + vllm_state: vLLM model state dict, a map from vLLM model's fqn names to weights """ # Save to temp model directory - import os - - checkpoint_path = os.path.join(self.temp_model_dir, "model.safetensors") + checkpoint_path = os.path.join(self._temp_model_dir, "model.safetensors") # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure - import glob - shard_files = sorted( - glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) + glob.glob(os.path.join(self._temp_model_dir, "model-*.safetensors")) ) - index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") + index_file = os.path.join(self._temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load # right now we hardcoded to work with 1 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. - # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: - logger.info(f"Saving weights to {checkpoint_path}") + logger.debug(f"Saving weights to {checkpoint_path}") - # Ensure weights stay in bfloat16 + # TODO: Check the detail of vLLM's dtype conversion journey + # Currently converting float32 to bfloat16 to match vLLM's attention and kv cache dtype vllm_state = { k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in vllm_state.items() @@ -187,21 +221,22 @@ def update_weights(self, vllm_state: dict) -> None: # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() - logger.info( + logger.debug( f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) # First time: create the engine using LLMEngine and EngineArgs - if self.engine is None: + if self._engine is None: cfg = self.config engine_args = EngineArgs( # Model configuration - model=self.temp_model_dir, + model=self._temp_model_dir, trust_remote_code=True, dtype=cfg.dtype, # Parallelism configuration tensor_parallel_size=cfg.parallelism.tensor_parallel_degree, + # Use external_launcher because Monarch already spawns the worker processes distributed_executor_backend="external_launcher", # Memory and performance gpu_memory_utilization=cfg.gpu_memory_limit, @@ -216,9 +251,9 @@ def update_weights(self, vllm_state: dict) -> None: ), ) - logger.info("Initializing LLMEngine from EngineArgs...") - self.engine = LLMEngine.from_engine_args(engine_args) - logger.info("Created new vLLM LLMEngine") + logger.debug("Initializing LLMEngine from EngineArgs...") + self._engine = LLMEngine.from_engine_args(engine_args) + logger.debug("Created new vLLM LLMEngine") else: # Direct parameter copy into model tensors. # This bypasses vLLM's reload_weights() which uses a layerwise @@ -226,304 +261,148 @@ def update_weights(self, vllm_state: dict) -> None: from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan titan_state = vllm_to_torchtitan(vllm_state) - self._direct_weight_update(titan_state) - - def _direct_weight_update(self, titan_state: dict) -> None: - """Update model weights by copying directly into GPU parameters. + model = self._engine.model_executor.driver_worker.get_model() + params = dict(model.named_parameters()) + + for name, new_weight in titan_state.items(): + # TorchTitanVLLMModelWrapper stores the model as self.model, + # so parameters have a "model." prefix + param_name = f"model.{name}" + if param_name in params: + param = params[param_name] + new_w = new_weight.to(device=param.device, dtype=param.dtype) + param.data.copy_(new_w) + + def _compute_rewards_and_advantages( + self, completions: list[str] + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute rewards and GRPO advantages for generated completions. + TODO: Move this function out of generator for encapsulation. Args: - titan_state: TorchTitan format state dict (w1/w2/w3, wq/wk/wv/wo, etc.) - """ - - # Access model from vLLM engine - model = self.engine.model_executor.driver_worker.get_model() - params = dict(model.named_parameters()) - - for name, new_weight in titan_state.items(): - # TorchTitanVLLMModelWrapper stores the model as self.model, - # so parameters have a "model." prefix - param_name = f"model.{name}" - if param_name in params: - param = params[param_name] - new_w = new_weight.to(device=param.device, dtype=param.dtype) - param.data.copy_(new_w) - - @torch.no_grad() - def generate( - self, - prompt_texts: list[str], - max_new_tokens: int = 20, - temperature: float = 1.0, - n_samples_per_prompt: int = 4, - ) -> tuple[ - list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] - ]: - """ - Generate samples using vLLM LLMEngine. - - Args: - prompt_texts: List of prompt strings - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - n_samples_per_prompt: Number of samples per prompt + completions: List of completion strings. Returns: - completions: List of completion strings - log_probs: [batch] - Sum of log probs for each completion - token_ids: List of token ID lists for each completion (generated tokens only) - token_log_probs: List of per-token log prob lists for each completion - prompt_token_ids: List of prompt token ID lists for each completion + rewards: Raw rewards tensor. + advantages: GRPO advantage tensor. """ - logger.info( - f"Starting generation: {len(prompt_texts)} prompts, " - f"n_samples_per_prompt={n_samples_per_prompt}, " - f"max_tokens={max_new_tokens}, temp={temperature}" - ) - - sampling_params = SamplingParams( - temperature=temperature, - max_tokens=max_new_tokens, - n=n_samples_per_prompt, - seed=42, - logprobs=1, - prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs - output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs - ) - - # Add one request per prompt; vLLM handles n_samples_per_prompt via n= - for request_id, prompt in enumerate(prompt_texts): - self.engine.add_request(str(request_id), prompt, sampling_params) - - # Step through engine until all requests are finished - all_outputs = [] - while self.engine.has_unfinished_requests(): - request_outputs = self.engine.step() - all_outputs.extend(request_outputs) - - # Extract completions and log probs - completions = [] - log_probs_list = [] - token_ids_list = [] - token_log_probs_list = [] - prompt_token_ids_list = [] - - for output in all_outputs: - prompt_token_ids = output.prompt_token_ids - - for sample in output.outputs: - completions.append(sample.text) - - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) - - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) - - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) - - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) - - log_probs = torch.tensor(log_probs_list, dtype=torch.float32) - - return ( - completions, - log_probs, - token_ids_list, - token_log_probs_list, - prompt_token_ids_list, + logger.debug( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" ) + rewards = self.reward_fn(completions, self.expected_answers, self.group_size) - def __del__(self): - """Cleanup vLLM engine.""" - if hasattr(self, "engine"): - del self.engine - torch.cuda.empty_cache() - - -class GeneratorState: - """States for the Generator's state machine.""" - - READY_TO_GENERATE = "READY_TO_GENERATE" - READY_TO_UPDATE = "READY_TO_UPDATE" - - -class Generator(Actor, Configurable): - """ - Generates rollouts using vLLM engine. - - Maintains a vLLM engine that is synchronized with the Trainer - via weight sync. Generates completions for given prompts and - computes rewards/advantages. - - Args: - config: Generator-specific configuration. - model_path: Path to the HF model checkpoint. - dump_folder: Root output folder for RL artifacts. - batch_invariant_mode: Enable batch-invariant mode for deterministic ops. - policy_optimization: GRPO hyperparameters. - prompt_texts: List of prompt strings. - expected_answers: List of expected answer strings. - """ - - @dataclass(kw_only=True, slots=True) - class Config(Configurable.Config): - """Generator actor configuration.""" - - vllm_engine: VLLMEngine.Config = field(default_factory=VLLMEngine.Config) - """vLLM rollout engine configuration.""" - - vllm_attention_backend: str = "FLASH_ATTN" - """vLLM attention backend to use (e.g., FLASH_ATTN, XFORMERS).""" - - def __init__( - self, - config: Config, - *, - model_spec: ModelSpec, - model_path: str, - dump_folder: str, - batch_invariant_mode: bool, - policy_optimization: PolicyOptimizationConfig, - prompt_texts: list[str], - expected_answers: list[str], - ): - self.config = config - self.model_spec = model_spec - - # Set vLLM environment variables from config before any vLLM initialization - if batch_invariant_mode: - os.environ["VLLM_BATCH_INVARIANT"] = "1" - init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) - - os.environ["VLLM_ATTENTION_BACKEND"] = config.vllm_attention_backend - - self.prompt_texts = prompt_texts - self.expected_answers = expected_answers - - # Extract needed fields from configs - self.model_path = model_path - self.max_new_tokens = config.vllm_engine.sampling.max_tokens - self.temperature = config.vllm_engine.sampling.temperature - self.group_size = policy_optimization.group_size - self.grpo_beta = policy_optimization.beta - self.use_stable_grpo = policy_optimization.use_stable_grpo - - # Initialize distributed environment for SPMD generator - world_size = dist_utils.init_distributed(CommConfig()) - - # Build vLLM engine - self.vllm_engine = VLLMEngine( - config.vllm_engine, - model_path=self.model_path, - dump_folder=dump_folder, - ) - - # State machine - self.state = GeneratorState.READY_TO_UPDATE - self.cond = asyncio.Condition() - self.policy_version = 0 + # Normalize rewards + reward_mean = rewards.mean() + reward_std = rewards.std() + if reward_std > 1e-8: + rewards_normalized = (rewards - reward_mean) / reward_std + else: + rewards_normalized = rewards - reward_mean - # Reward function. TODO: Use a real reward function - self.reward_fn = trivial_reward_function + # Compute advantages using GRPO + if self.use_stable_grpo: + advantages = compute_grpo_advantages_stable( + rewards_normalized, self.group_size + ) + else: + advantages = compute_grpo_advantages( + rewards_normalized, self.group_size, beta=self.grpo_beta + ) - logger.info("Generator initialized with vLLM engine") + return rewards, advantages @endpoint - async def generate(self) -> None: - """Generate trajectories and compute rewards/advantages.""" - logger.info( + async def generate(self) -> TrajectoryData: + """Generate trajectories and compute rewards/advantages. + Called by the orchestrator (simple_grpo.py). + """ + logger.debug( f"{os.getpid()=} Generating start generate (policy v{self.policy_version})..." ) - async with self.cond: - # Wait until ready to generate (weights have been updated) - await self.cond.wait_for( - lambda: self.state == GeneratorState.READY_TO_GENERATE - ) + with torch.no_grad(): # Generate samples using vLLM - ( - completions, - vllm_log_probs, - vllm_token_ids, - vllm_token_log_probs, - prompt_token_ids, - ) = self.vllm_engine.generate( - self.prompt_texts, - self.max_new_tokens, - self.temperature, - n_samples_per_prompt=self.group_size, - ) - - # Compute rewards - logger.info( - f"Computing rewards: {len(completions)} completions, " - f"{len(self.expected_answers)} expected answers, " - f"group_size={self.group_size}" - ) - rewards = self.reward_fn( - completions, self.expected_answers, self.group_size - ) - - # Normalize rewards - reward_mean = rewards.mean() - reward_std = rewards.std() - if reward_std > 1e-8: - rewards_normalized = (rewards - reward_mean) / reward_std - else: - rewards_normalized = rewards - reward_mean - - # Compute advantages using GRPO - if self.use_stable_grpo: - advantages = compute_grpo_advantages_stable( - rewards_normalized, self.group_size - ) - else: - advantages = compute_grpo_advantages( - rewards_normalized, self.group_size, beta=self.grpo_beta - ) - - # Create trajectory data - trajectory = TrajectoryData( - policy_version=self.policy_version, - completions=completions, - vllm_token_ids=vllm_token_ids, - vllm_token_log_probs=vllm_token_log_probs, - prompt_token_ids=prompt_token_ids, - rewards=rewards, - advantages=advantages, + sampling_params = SamplingParams( + temperature=self.temperature, + max_tokens=self.max_new_tokens, + n=self.group_size, + seed=self.config.seed, + logprobs=1, + prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - # Signal ready for update - self.state = GeneratorState.READY_TO_UPDATE - self.cond.notify_all() + for request_id, prompt in enumerate(self.prompt_texts): + self._engine.add_request(str(request_id), prompt, sampling_params) + + # Step through engine until all requests are finished + all_outputs = [] + while self._engine.has_unfinished_requests(): + request_outputs = self._engine.step() + all_outputs.extend(request_outputs) + + # Extract completions and log probs + completions = [] + token_ids_list = [] + token_log_probs_list = [] + prompt_token_ids_list = [] + + for output in all_outputs: + prompt_token_ids = output.prompt_token_ids + + for sample in output.outputs: + completions.append(sample.text) + prompt_token_ids_list.append( + prompt_token_ids + ) # Store prompt tokens for this sample + token_ids_list.append( + sample.token_ids + ) # Extract token IDs (generated tokens only) + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] # Extract per-token log probs + token_log_probs_list.append(per_token_log_probs) + + # Compute rewards and advantages + rewards, advantages = self._compute_rewards_and_advantages(completions) + + # Create trajectory data + trajectory = TrajectoryData( + policy_version=self.policy_version, + completions=completions, + vllm_token_ids=token_ids_list, + vllm_token_log_probs=token_log_probs_list, + prompt_token_ids=prompt_token_ids_list, + rewards=rewards, + advantages=advantages, + ) - logger.info( - f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." - ) - return trajectory + logger.debug( + f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." + ) + return trajectory @endpoint async def update(self, version: int, vllm_compat_state: dict) -> None: """Update generate weights. + Called by the orchestrator (simple_grpo.py). Args: version: New policy version number vllm_compat_state: vLLM-compatible state dict """ - async with self.cond: - self.vllm_engine.update_weights(vllm_compat_state) - # Update version and state - self.policy_version = version - self.state = GeneratorState.READY_TO_GENERATE - self.cond.notify_all() - logger.info( - f"{os.getpid()=} Generator updating weights to policy v{version}..." - ) + # TODO: remove the helper function (_update_vllm_model_weights) once we clean up the weight updates + self._update_vllm_model_weights(vllm_compat_state) + self.policy_version = version + logger.debug( + f"{os.getpid()=} Generator updating weights to policy v{version}..." + ) + + def __del__(self): + """Cleanup vLLM engine.""" + if hasattr(self, "_engine"): + del self._engine + torch.cuda.empty_cache() diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index b5a978fa2a..c859faeead 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -25,11 +25,16 @@ from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_trainer_model +from torchtitan.experiments.rl.unified.models.utils import ( + replace_with_vllm_compatible_flash_attention, +) from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( + torchtitan_to_vllm, + vllm_to_torchtitan, +) from torchtitan.protocols.model_spec import ModelSpec logger = logging.getLogger(__name__) @@ -37,15 +42,13 @@ class PolicyTrainer(Actor, Configurable): """ - Updates policy based on collected trajectories. + Updates policy based on collected Episodes. - Run model forward on trajectories, computes loss, and run backward. + Run model forward on Episodes, computes loss, and run backward. Receives the top-level ``RLTrainer.Config`` and reads policy trainer settings (batch_invariant_mode, grpo) directly from it, plus model / optimizer / parallelism settings from the nested ``config.trainer``. - TODO: Use torchtitan Trainer for model init and parallelisation. - TODO: Use torchtitan PolicyTrainer for model init and parallelism. Args: @@ -55,10 +58,7 @@ class PolicyTrainer(Actor, Configurable): @dataclass(kw_only=True, slots=True) class Config(Configurable.Config): - """PolicyTrainer configuration for optimizer, training, and parallelism. - - TODO: Remove this config once the Trainer is replaced by torchtitan Trainer - """ + """PolicyTrainer configuration for optimizer, training, and parallelism.""" optimizer: OptimizersContainer.Config = field( default_factory=OptimizersContainer.Config @@ -78,8 +78,10 @@ class Config(Configurable.Config): def __init__( self, config: Config, - policy_optimization: PolicyOptimizationConfig, + *, model_spec: ModelSpec, + policy_optimization: PolicyOptimizationConfig, + batch_invariant_mode: bool, ): self.config = config self.model_spec = model_spec @@ -100,13 +102,30 @@ def __init__( device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - # load trainer model and patch to vllm.Attention() - self.model = load_trainer_model(model_path, model_spec.model) + # Step1: Load trainer model from HF/vLLM checkpoint. TODO: Use torchtitan components + model_config = model_spec.model + titan_state_dict = vllm_to_torchtitan(model_path) + + # If weight tying is enabled but output.weight is missing from the checkpoint + if model_config.enable_weight_tying and "output.weight" not in titan_state_dict: + titan_state_dict["output.weight"] = titan_state_dict[ + "tok_embeddings.weight" + ] + + self.model = model_config.build() + self.model.load_state_dict(titan_state_dict, strict=True) + + # Step2: Replace attention kernel be to vLLM's attention. + if batch_invariant_mode: + replace_with_vllm_compatible_flash_attention(self.model) + # vLLM's Attention requires bfloat16 inputs. + # TODO: Refine the dtype journey in trainer / generator + self.model.to(torch.bfloat16) self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism - # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from train_spec + # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from model_spec if self.ddp_size > 1: from torchtitan.models.llama3.parallelize import apply_ddp @@ -124,7 +143,7 @@ def __init__( self.policy_version = 0 self.generator: Optional[Any] = None - logger.info( + logger.debug( f"PolicyTrainer initialized: " f"group_size={self.group_size}, grpo_beta={self.grpo_beta}, " f"use_stable_grpo={self.use_stable_grpo}" @@ -148,7 +167,7 @@ async def step(self, trajectory: TrajectoryData) -> dict: Returns: Training metrics """ - logger.info( + logger.debug( f"{os.getpid()=} PolicyTrainer starts to train {self.policy_version} on traj:" ) # Compute loss @@ -169,8 +188,6 @@ async def step(self, trajectory: TrajectoryData) -> dict: self.policy_version += 1 - # TODO: save dcp checkpoint to file here instead of sending weight dicts - # Return metrics metrics = { "loss": loss.item(), @@ -182,5 +199,5 @@ async def step(self, trajectory: TrajectoryData) -> dict: "policy_version": self.policy_version, **loss_metrics, } - logger.info(f"{os.getpid()=} PolicyTrainer finish step {self.policy_version}") + logger.debug(f"{os.getpid()=} PolicyTrainer finish step {self.policy_version}") return metrics diff --git a/torchtitan/experiments/rl/unified/config_registry.py b/torchtitan/experiments/rl/unified/config_registry.py index 01bd1eb537..12ef11ce64 100644 --- a/torchtitan/experiments/rl/unified/config_registry.py +++ b/torchtitan/experiments/rl/unified/config_registry.py @@ -19,7 +19,7 @@ ParallelismConfig, TrainingConfig, ) -from torchtitan.experiments.rl.unified.actors.generator import Generator, VLLMEngine +from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer from torchtitan.experiments.rl.unified.configs import ( PolicyOptimizationConfig, @@ -34,6 +34,7 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: return RLTrainer.Config( model_spec=model_registry("0.6B"), num_steps=10, + batch_invariant_mode=True, trainer=PolicyTrainer.Config( optimizer=OptimizersContainer.Config(lr=1e-6), lr_scheduler=LRSchedulersContainer.Config( @@ -58,26 +59,23 @@ def rl_grpo_qwen3_0_6b() -> RLTrainer.Config: selective_ac_option="op", ), ), - batch_invariant_mode=True, policy_optimization=PolicyOptimizationConfig( beta=0.1, group_size=8, use_stable_grpo=False, ), generator=Generator.Config( - vllm_engine=VLLMEngine.Config( - dtype="bfloat16", - gpu_memory_limit=0.5, - enforce_eager=True, - seed=42, - parallelism=ParallelismConfig( - tensor_parallel_degree=1, - ), - sampling=VLLMSamplingConfig( - temperature=0.8, - top_p=0.95, - max_tokens=100, - ), + dtype="bfloat16", + gpu_memory_limit=0.5, + enforce_eager=True, + seed=42, + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=0.8, + top_p=0.95, + max_tokens=100, ), vllm_attention_backend="FLASH_ATTN", ), @@ -89,6 +87,7 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config: return RLTrainer.Config( model_spec=model_registry("debugmodel"), num_steps=5, + batch_invariant_mode=False, trainer=PolicyTrainer.Config( optimizer=OptimizersContainer.Config(lr=8e-4), lr_scheduler=LRSchedulersContainer.Config( @@ -107,23 +106,20 @@ def rl_grpo_qwen3_debug() -> RLTrainer.Config: interval=5, ), ), - batch_invariant_mode=False, policy_optimization=PolicyOptimizationConfig( beta=0.1, group_size=4, use_stable_grpo=False, ), generator=Generator.Config( - vllm_engine=VLLMEngine.Config( - gpu_memory_limit=0.3, - enforce_eager=True, - parallelism=ParallelismConfig( - tensor_parallel_degree=1, - ), - sampling=VLLMSamplingConfig( - temperature=1.0, - max_tokens=50, - ), + gpu_memory_limit=0.3, + enforce_eager=True, + parallelism=ParallelismConfig( + tensor_parallel_degree=1, + ), + sampling=VLLMSamplingConfig( + temperature=1.0, + max_tokens=50, ), ), ) diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/inference_example.py similarity index 77% rename from torchtitan/experiments/rl/unified/infer.py rename to torchtitan/experiments/rl/unified/inference_example.py index 0a54be356a..e9963bfe24 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/inference_example.py @@ -36,13 +36,13 @@ def generate(): config = rl_grpo_qwen3_0_6b() - vllm_config = config.generator.vllm_engine + gen_config = config.generator model_path = config.trainer.checkpoint.initial_load_path - logger.info("Initializing vLLM LLMEngine with TorchTitan model") - logger.info(f"Model: {model_path}") - logger.info( - f"Tensor Parallel Size: {vllm_config.parallelism.tensor_parallel_degree}" + logger.debug("Initializing vLLM LLMEngine with TorchTitan model") + logger.debug(f"Model: {model_path}") + logger.debug( + f"Tensor Parallel Size: {gen_config.parallelism.tensor_parallel_degree}" ) # Create EngineArgs from config @@ -50,50 +50,50 @@ def generate(): # Model configuration model=model_path, trust_remote_code=True, - dtype=vllm_config.dtype, + dtype=gen_config.dtype, # Parallelism configuration - tensor_parallel_size=vllm_config.parallelism.tensor_parallel_degree, + tensor_parallel_size=gen_config.parallelism.tensor_parallel_degree, # Use external_launcher only when launched via torchrun (multi-GPU); # for single-GPU, let vLLM pick the default executor. distributed_executor_backend=("external_launcher"), # Memory and performance - gpu_memory_utilization=vllm_config.gpu_memory_limit, - enforce_eager=vllm_config.enforce_eager, + gpu_memory_utilization=gen_config.gpu_memory_limit, + enforce_eager=gen_config.enforce_eager, # Seed - seed=vllm_config.seed, + seed=gen_config.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("Initializing LLMEngine from EngineArgs...") + logger.debug("Initializing LLMEngine from EngineArgs...") engine = LLMEngine.from_engine_args(engine_args) - logger.info("vLLM LLMEngine initialized successfully") + logger.debug("vLLM LLMEngine initialized successfully") # Create sampling parameters from config - sampling = vllm_config.sampling + sampling = gen_config.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, max_tokens=sampling.max_tokens, ) - logger.info( + logger.debug( f"Sampling params: temperature={sampling.temperature}, " f"top_p={sampling.top_p}, max_tokens={sampling.max_tokens}" ) # Example prompt prompt = "Hello, my name is" - logger.info(f"Prompt: {prompt}") + logger.debug(f"Prompt: {prompt}") # Add request to engine - logger.info("Adding request to engine...") + logger.debug("Adding request to engine...") request_id = "0" engine.add_request(request_id, prompt, sampling_params) # Generate text by stepping through engine - logger.info("Generating text...") + logger.debug("Generating text...") while engine.has_unfinished_requests(): request_outputs = engine.step() @@ -104,7 +104,7 @@ def generate(): generated_text = request_output.outputs[0].text # Print results - logger.info("Generation complete") + logger.debug("Generation complete") print(f"\nPrompt: {prompt}") print(f"Generated text: {generated_text!r}\n") diff --git a/torchtitan/experiments/rl/unified/infra/parallelism_utils.py b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py index 61f063f517..73b4c06ea0 100644 --- a/torchtitan/experiments/rl/unified/infra/parallelism_utils.py +++ b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py @@ -58,7 +58,7 @@ def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDi world_size=world_size, ) - logger.info( + logger.debug( f"Created ParallelDims from vLLM config: " f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index a810522729..133e33988c 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -7,13 +7,10 @@ import logging -import torch from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan -from torchtitan.protocols.model import BaseModel logger = logging.getLogger(__name__) @@ -60,7 +57,7 @@ def replace_with_vllm_attention(model, tp_degree=1): layer.attention.inner_attention = vllm_attn - logger.info( + logger.debug( f"Successfully replaced TorchTitan attention with VLLMAttention " f"({len(model.layers)} layers)" ) @@ -85,45 +82,7 @@ def replace_with_vllm_compatible_flash_attention(model): layer.attention.inner_attention = vllm_attn - logger.info( + logger.debug( f"Successfully replaced TorchTitan attention with VLLMCompatibleFlashAttention " f"({len(model.layers)} layers)" ) - - -def load_trainer_model(model_path: str, model_config: BaseModel.Config): - """ - Load TorchTitan model from checkpoint for trainer. - - Args: - model_path: Path to HuggingFace model (for weights) - model_config: Model config from model_spec (e.g., Qwen3Model.Config) - - Returns: - model: Loaded TorchTitan model for trainer. - """ - model_args = model_config - - # convert to torchtitan state_dict. TODO: Use torchtitan components - titan_state_dict = vllm_to_torchtitan(model_path) - - # If weight tying is enabled but output.weight is missing from the checkpoint - # (HF models with tie_word_embeddings=True may not store lm_head.weight), - # synthesize it from tok_embeddings.weight so load_state_dict(strict=True) works. - if model_args.enable_weight_tying and "output.weight" not in titan_state_dict: - titan_state_dict["output.weight"] = titan_state_dict["tok_embeddings.weight"] - - model = model_args.build() - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention during training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(titan_state_dict, strict=True) - - model.to(torch.bfloat16) - - return model diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py index 16fd5a1943..c9505a6b55 100644 --- a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -76,7 +76,7 @@ def __init__( # Use TorchTitan model config directly (no HF config mapping) self.config = model_config - logger.info(f"Creating model with config: {model_config}") + logger.debug(f"Creating model with config: {model_config}") self.model = model_config.build() # RoPE config from model for cache extension @@ -86,7 +86,7 @@ def __init__( # vLLM config contains the tensor_parallel_size from command-line args # and this will be consistent across all worker processes self.parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) - self.parallel_config = create_trainer_config_from_vllm_config( + self.trainer_config = create_trainer_config_from_vllm_config( vllm_config=vllm_config, ) @@ -102,7 +102,7 @@ def __init__( # NOTE: We need to apply parallelize within model.__init__ because vllm # doesn't separate model creation and parallelism application and instead # requires parallelization to be done inside model constructor. - cfg = self.parallel_config + cfg = self.trainer_config self.model = parallelize_fn( model=self.model, parallel_dims=self.parallel_dims, @@ -205,9 +205,7 @@ def forward( # Get RoPE cache (handle model-specific attribute names) # Use hasattr to avoid ambiguous boolean value error with tensors - if hasattr(self.model, "rope_cache"): - rope_attr = self.model.rope_cache - elif hasattr(self.model, "freqs_cis"): + if hasattr(self.model, "freqs_cis"): rope_attr = self.model.freqs_cis else: rope_attr = None diff --git a/torchtitan/experiments/rl/unified/simple_grpo.py b/torchtitan/experiments/rl/unified/simple_grpo.py index 88523ca40c..8c9ab2f1f4 100644 --- a/torchtitan/experiments/rl/unified/simple_grpo.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -98,7 +98,7 @@ async def main(): num_steps = config.num_steps # Use fake dataset for test. TODO: Implement real RL dataloader. - logger.info("Using default prompts") + logger.debug("Using default prompts") prompts_with_answers = [ ("The capital of France is", "paris"), ("What is 7 times 8?", "56"), @@ -109,13 +109,13 @@ async def main(): prompt_texts = [p[0] for p in prompts_with_answers] expected_answers = [p[1] for p in prompts_with_answers] - logger.info(f"Loaded {len(prompt_texts)} prompts") + logger.debug(f"Loaded {len(prompt_texts)} prompts") # Create process meshes trainer_mesh = this_host().spawn_procs( per_host={"gpus": trainer_ddp_size * trainer_tp_size} ) - gen_tp_size = config.generator.vllm_engine.parallelism.tensor_parallel_degree + gen_tp_size = config.generator.parallelism.tensor_parallel_degree gen_mesh = this_host().spawn_procs(per_host={"gpus": gen_tp_size}) # Set up distributed env vars so that actors are connected via c10d @@ -137,8 +137,9 @@ async def main(): "trainer", PolicyTrainer, config.trainer, - config.policy_optimization, - config.model_spec, + model_spec=config.model_spec, + policy_optimization=config.policy_optimization, + batch_invariant_mode=config.batch_invariant_mode, ) generator = gen_mesh.spawn( @@ -176,13 +177,13 @@ async def main(): f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " f"Reward: {metrics['reward_mean']:+.3f}" ) - logger.info(f" Sample: {metrics['sample_completion']}...") + logger.debug(f" Sample: {metrics['sample_completion']}...") # Check for divergence if not torch.isfinite(torch.tensor(metrics["loss"])): - logger.info("\n" + "!" * 80) - logger.info("ERROR: Loss is NaN/Inf! Training diverged.") - logger.info("!" * 80) + logger.debug("\n" + "!" * 80) + logger.debug("ERROR: Loss is NaN/Inf! Training diverged.") + logger.debug("!" * 80) break logger.info("\n" + "=" * 80)