diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..ae604b19aa 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,10 +66,10 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path None: + def update_weights(self, vllm_state: dict) -> None: """ Update vLLM model weights from vLLM-compat state dict. @@ -121,11 +123,8 @@ def update_weights(self, vllm_compat_state: dict) -> None: vLLM's reload_weights() API after updating the model path config. Args: - vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + vllm_state: vLLM model state dict """ - # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) - vllm_state = torchtitan_to_vllm(vllm_compat_state) - # Save to temp model directory import os @@ -134,7 +133,6 @@ def update_weights(self, vllm_compat_state: dict) -> None: # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure import glob - import json shard_files = sorted( glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) @@ -142,51 +140,20 @@ def update_weights(self, vllm_compat_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safe tensor files which we only + # right now we hardcoded to work with 2 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: logger.info(f"Saving weights to {checkpoint_path}") - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() @@ -194,28 +161,35 @@ def update_weights(self, vllm_compat_state: dict) -> None: f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine - if self.llm is None: - self.llm = LLM( + # First time: create the engine using LLMEngine and EngineArgs + if self.engine is None: + generation = self.job_config.generation + + engine_args = EngineArgs( + # Model configuration model=self.temp_model_dir, - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, trust_remote_code=True, - max_model_len=2048, - dtype="bfloat16", - gpu_memory_utilization=0.1, # Reduced from 0.5 - distributed_executor_backend="external_launcher", # vllm do not spawn processes - seed=42, # Fixed seed for determinism - enforce_eager=True, - tensor_parallel_size=self.tp_size, # Explicitly single GPU + dtype=generation.dtype, + # Parallelism configuration + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, + # Seed + seed=generation.seed, + # HuggingFace overrides to use TorchTitan model. + # TODO: make this field configurable and align with model registration + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("Created new vLLM engine") + + logger.info("Initializing LLMEngine from EngineArgs...") + self.engine = LLMEngine.from_engine_args(engine_args) + logger.info("Created new vLLM LLMEngine") else: # Use collective_rpc to call reload_weights on all workers # This reloads weights from temp_model_dir without recreating the engine - self.llm.collective_rpc("reload_weights") + self.engine.collective_rpc("reload_weights") @torch.no_grad() def generate( @@ -228,7 +202,7 @@ def generate( list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] ]: """ - Generate samples using vLLM. + Generate samples using vLLM LLMEngine. Args: prompt_texts: List of prompt strings @@ -243,16 +217,36 @@ def generate( token_log_probs: List of per-token log prob lists for each completion prompt_token_ids: List of prompt token ID lists for each completion """ + logger.info( + f"Starting generation: {len(prompt_texts)} prompts, " + f"n_samples_per_prompt={n_samples_per_prompt}, " + f"max_tokens={max_new_tokens}, temp={temperature}" + ) + sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, - n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) - outputs = self.llm.generate(prompt_texts, sampling_params) + # Add requests to the engine + # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id + request_id = 0 + prompt_indices = [] # Track which prompt each request corresponds to + for prompt_idx, prompt in enumerate(prompt_texts): + for sample_idx in range(n_samples_per_prompt): + self.engine.add_request(str(request_id), prompt, sampling_params) + prompt_indices.append(prompt_idx) + request_id += 1 + + # Step through engine until all requests are finished + all_outputs = [] + while self.engine.has_unfinished_requests(): + request_outputs = self.engine.step() + all_outputs.extend(request_outputs) # Extract completions and log probs completions = [] @@ -261,30 +255,35 @@ def generate( token_log_probs_list = [] prompt_token_ids_list = [] - for output in outputs: + for output in all_outputs: # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids - for sample in output.outputs: - completions.append(sample.text) + # Each output now has exactly 1 sample (we submitted multiple requests) + assert ( + len(output.outputs) == 1 + ), f"Expected 1 output, got {len(output.outputs)}" + sample = output.outputs[0] + + completions.append(sample.text) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) @@ -298,8 +297,8 @@ def generate( def __del__(self): """Cleanup vLLM engine.""" - if hasattr(self, "llm"): - del self.llm + if hasattr(self, "engine"): + del self.engine torch.cuda.empty_cache() @@ -319,58 +318,46 @@ class Generator(Actor): computes rewards/advantages. Args: - model_path: Path to HuggingFace model + job_config: JobConfig dataclass containing all configuration prompt_texts: List of prompt strings expected_answers: List of expected answers - group_size: Number of samples per prompt - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - use_real_dataset: Whether using real dataset (GSM8K) - grpo_beta: Beta for GRPO advantages - use_stable_grpo: Whether to use stable GRPO - tp_size: Tensor Parallel size """ def __init__( self, - model_path: str, - prompt_texts: List[str], + job_config: JobConfig, + prompt_texts: List[ + str + ], # TODO: This field need to be removed once dataloader is implemented expected_answers: List[str], - group_size: int = 8, - max_new_tokens: int = 20, - temperature: float = 1.0, - use_real_dataset: bool = False, - grpo_beta: float = 0.1, - use_stable_grpo: bool = False, - tp_size: int = 1, ): - self.model_path = model_path + # Store job_config for accessing configuration + self.job_config = job_config self.prompt_texts = prompt_texts self.expected_answers = expected_answers - self.group_size = group_size - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.use_real_dataset = use_real_dataset - self.grpo_beta = grpo_beta - self.use_stable_grpo = use_stable_grpo - self.tp_size = tp_size + + # Extract needed fields from job_config + self.model_path = job_config.checkpoint.initial_load_path + self.max_new_tokens = job_config.generation.sampling.max_tokens + self.temperature = job_config.generation.sampling.temperature + self.group_size = job_config.rl.grpo_group_size + self.grpo_beta = job_config.rl.grpo_beta + self.use_stable_grpo = job_config.rl.use_stable_grpo # Initialize distributed environment for SPMD generator world_size = dist_utils.init_distributed( Comm(), ) - # Initialize vLLM engine - self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + # Initialize vLLM engine with job_config + self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE self.cond = asyncio.Condition() self.policy_version = 0 - # Reward function - self.reward_fn = ( - math_reward_function if use_real_dataset else trivial_reward_function - ) + # Reward function. TODO: Use a real reward function + self.reward_fn = trivial_reward_function logger.info("Generator initialized with vLLM engine") @@ -401,6 +388,11 @@ async def generate(self) -> None: ) # Compute rewards + logger.info( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" + ) rewards = self.reward_fn( completions, self.expected_answers, self.group_size ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 9ffb9f0f0a..ccfa10a708 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,16 +11,15 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData -from torchtitan.experiments.rl.unified.models.parallelism_utils import ( +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.unified.job_config import JobConfig +from torchtitan.experiments.rl.unified.models.utils import load_trainer_model from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm logger = logging.getLogger(__name__) @@ -30,34 +29,30 @@ class Trainer(Actor): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. + TODO: Use torchtitan Trainer Args: - titan_checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model - learning_rate: Learning rate for optimizer - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model + job_config: JobConfig dataclass containing all configuration """ def __init__( self, - titan_checkpoint_path: str, - model_path: str, - learning_rate: float = 1e-5, - model_mode: str = ModelMode.VLLM_COMPAT, - ddp_size: int = 1, - tp_size: int = 1, + job_config: JobConfig, ): + # Extract needed fields from job_config + model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = job_config.optimizer.lr + self.ddp_size = job_config.parallelism.data_parallel_replicate_degree + self.tp_size = job_config.parallelism.tensor_parallel_degree + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - self.model = load_model( - titan_checkpoint_path, model_path, model_mode=model_mode - ) - self.ddp_size = ddp_size - self.tp_size = tp_size + # load trainer model and patch to vllm.Attention() + self.model = load_trainer_model(model_path) + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -82,14 +77,14 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM-compatible weights for generator. + """Get vLLM weights for generator. Returns: - vLLM-compatible state dict + vLLM state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm_compat(titan_state) - return vllm_compat_state + vllm_state = torchtitan_to_vllm(titan_state) + return vllm_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 43153fb70e..14e2542d16 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -5,113 +5,104 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse +""" +Example inference script using TorchTitan models with vLLM LLMEngine. + +This script uses JobConfig loaded from a TOML file to configure both +the vLLM engine and sampling parameters. + +Run: torchrun --nproc_per_node=2 \ + torchtitan/experiments/rl/unified/infer.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +""" + + +from torchtitan.config import ConfigManager # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 -from vllm import LLM, SamplingParams +from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.logger import init_logger logger = init_logger(__name__) -def parse_args(): - parser = argparse.ArgumentParser( - description="Run TorchTitan model inference with vLLM Engine", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--model-ckpt-path", - type=str, - default="torchtitan/experiments/rl/example_checkpoint", - help="Path to TorchTitan checkpoint directory", - ) - parser.add_argument( - "--prompt", - type=str, - default="Hello, my name is", - help="Prompt text for generation", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Sampling temperature", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=1, - help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", +def infer(): + + config_manager = ConfigManager() + job_config = config_manager.parse_args() + + logger.info("Initializing vLLM LLMEngine with TorchTitan model") + logger.info(f"Model: {job_config.checkpoint.initial_load_path}") + logger.info( + f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" ) - return parser.parse_args() + # Create EngineArgs from JobConfig + # Map TorchTitan parallelism to vLLM parallelism + generation = job_config.generation -def infer(): - args = parse_args() - - logger.info("Initializing vLLM with TorchTitan model") - logger.info(f"Model: {args.model_ckpt_path}") - logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") - - # Initialize vLLM with custom TorchTitan model - # The LLM initialization will internally: - # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) - # 2. Create TorchTitanVLLMModel instance - # 3. Create JobConfig and ParallelDims from vLLM config - # 4. Apply parallelization using parallelize_qwen3 - # 5. Load model weights and prepare for inference - # The tensor_parallel_size will be used by vLLM to configure parallelization - # and will be available in vllm_config in worker processes - logger.info("Creating vLLM LLM engine...") - - llm = LLM( - model=args.model_ckpt_path, # Model checkpoint path - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, - dtype="bfloat16", + engine_args = EngineArgs( + # Model configuration + model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - enforce_eager=True, # Use eager mode - tensor_parallel_size=args.tensor_parallel_size, - gpu_memory_utilization=0.5, + dtype=generation.dtype, + # Parallelism configuration + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, + # Seed + seed=generation.seed, + # HuggingFace overrides + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("vLLM engine initialized successfully") - logger.info(f"Prompt: {args.prompt}") + logger.info("Initializing LLMEngine from EngineArgs...") + engine = LLMEngine.from_engine_args(engine_args) - # Prepare prompt and sampling parameters - prompts = [args.prompt] + logger.info("vLLM LLMEngine initialized successfully") + + # Create sampling parameters from JobConfig + sampling = job_config.generation.sampling sampling_params = SamplingParams( - temperature=args.temperature, - top_p=0.95, - max_tokens=args.max_tokens, + temperature=sampling.temperature, + top_p=sampling.top_p, + max_tokens=sampling.max_tokens, ) - # Generate text - logger.info("Generating text...") - outputs = llm.generate( - prompts=prompts, - sampling_params=sampling_params, + logger.info( + f"Sampling params: temperature={sampling.temperature}, " + f"top_p={sampling.top_p}, max_tokens={sampling.max_tokens}" ) - # Print results - logger.info("Generation complete") - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text + # Example prompt + prompt = "Hello, my name is" + logger.info(f"Prompt: {prompt}") - print(f"\nPrompt: {prompt}") - print(f"Generated text: {generated_text!r}\n") + # Add request to engine + logger.info("Adding request to engine...") + request_id = "0" + engine.add_request(request_id, prompt, sampling_params) + + # Generate text by stepping through engine + logger.info("Generating text...") + while engine.has_unfinished_requests(): + request_outputs = engine.step() + + # Process finished requests + for request_output in request_outputs: + if request_output.finished: + prompt = request_output.prompt + generated_text = request_output.outputs[0].text + + # Print results + logger.info("Generation complete") + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") if __name__ == "__main__": diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py new file mode 100644 index 0000000000..fdbe086352 --- /dev/null +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Extended JobConfig for RL/Generation workloads. + +This module extends TorchTitan's base JobConfig with generation-specific +configurations needed for vLLM integration. +""" + +from dataclasses import dataclass, field + +from torchtitan.config.job_config import JobConfig as BaseJobConfig, Parallelism + + +@dataclass +class Sampling: + """ + Sampling configuration for vLLM generation. + + This dataclass contains sampling parameters used during text generation. + These map directly to vLLM's SamplingParams. + """ + + temperature: float = 0.8 + """ + Temperature for sampling. Controls randomness in generation. + - 0.0: Deterministic (greedy decoding) + - Higher values: More random outputs + """ + + top_p: float = 0.95 + """ + Top-p (nucleus) sampling threshold. + Only tokens with cumulative probability up to top_p are considered. + """ + + max_tokens: int = 100 + """Maximum number of tokens to generate""" + + +@dataclass +class Generate: + """ + Generation configuration for vLLM engine. + + This dataclass contains essential vLLM-specific settings for generation. + """ + + dtype: str = "bfloat16" + """Data type for model weights (auto, float16, bfloat16, float32)""" + + gpu_memory_utilization: float = 0.5 + """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" + + distributed_executor_backend: str = "external_launcher" + """ + Backend for distributed execution. + 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) + """ + seed: int = 42 + """Random seed for sampling""" + + enforce_eager: bool = True + """Whether to enforce eager execution (disable CUDA graphs)""" + + parallelism: Parallelism = field(default_factory=Parallelism) + """Parallelism configuration for generation""" + + sampling: Sampling = field(default_factory=Sampling) + """Sampling configuration for generation""" + + +@dataclass +class RL: + """Reinforcement Learning configuration for GRPO training.""" + + grpo_beta: int = 0.1 + """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" + + use_stable_grpo: bool = False + """Whether to use stable version of GRPO algorithm""" + + grpo_group_size: int = 8 + """Number of samples in each GRPO group for policy optimization""" + + +@dataclass +class JobConfig(BaseJobConfig): + """ + Extended JobConfig with generation support. + + This extends TorchTitan's base JobConfig by adding `generation` field + for vLLM-specific generation configurations. + """ + + generation: Generate = field(default_factory=Generate) + """Generation configuration for vLLM engine""" + rl: RL = field(default_factory=RL) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 0e5d6cde52..2738a95728 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,42 +6,22 @@ import logging -from enum import Enum import torch -from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig logger = logging.getLogger(__name__) -class ModelMode(str, Enum): - """ - Enum defining which TorchTitan model to use. - - Attributes: - UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified - training and inference. - VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, - ensuring bitwise determinism between training and inference. - STANDARD: Plain TorchTitan model without any modifications. - """ - - UNIFIED = "unified" - VLLM_COMPAT = "vllm_compat" - STANDARD = "standard" - - def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -102,26 +82,20 @@ def replace_with_vllm_compatible_flash_attention(model): ) -def load_model( - checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT -): +def load_trainer_model(model_path: str): """ Load TorchTitan model from checkpoint for trainer. Args: - checkpoint_path: Path to TorchTitan checkpoint model_path: Path to HuggingFace model (for config) - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model Returns: model: Loaded TorchTitan model for trainer. """ - # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # Create model args + # Create model args. model_args = Qwen3ModelArgs( dim=hf_config.hidden_size, n_layers=hf_config.num_hidden_layers, @@ -135,46 +109,27 @@ def load_model( ), hidden_dim=hf_config.intermediate_size, norm_eps=hf_config.rms_norm_eps, - rope_theta=hf_config.rope_theta, + rope_theta=getattr(hf_config.rope_parameters, "rope_theta", 1000000), max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, eos_id=getattr(hf_config, "eos_token_id", 151645), ) + # convert to torchtitan state_dict. TODO: Use torchtitan components + titan_state_dict = vllm_to_torchtitan(model_path) - # state_dict is in standard TorchTitan format (w1, w2, w3) - state_dict = load_file(checkpoint_path) - - if model_mode == ModelMode.UNIFIED: - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention for training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=True) - elif model_mode == ModelMode.VLLM_COMPAT: - # Create and load model that has bitwise determinism between training and inference - from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( - Qwen3VLLMCompatModel, - ) + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention during training. + replace_with_vllm_compatible_flash_attention(model) - model = Qwen3VLLMCompatModel(model_args) - # Convert to vLLM-compat format (merged gate_up_proj, down_proj) - vllm_compat_state = torchtitan_to_vllm_compat(state_dict) - model.load_state_dict(vllm_compat_state, strict=False) - else: - # Use standard TorchTitan model - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=False) + # Load standard TorchTitan format directly + model.load_state_dict(titan_state_dict, strict=True) model.to(torch.bfloat16) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml new file mode 100644 index 0000000000..2080e7a39f --- /dev/null +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -0,0 +1,62 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3 0.6B training" +custom_config_module = "torchtitan.experiments.rl.unified.job_config" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3" +flavor = "0.6B" +hf_assets_path = "./assets/hf/Qwen3-0.6B" + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, 20% total steps + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" + +[parallelism] # training parallelism plan +data_parallel_replicate_degree = 2 +data_parallel_shard_degree = 1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 + +[checkpoint] +initial_load_path = "/data/users/jianiw/model/qwen3-0.6b" + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[generation] +dtype = "bfloat16" +gpu_memory_utilization = 0.5 +distributed_executor_backend = "external_launcher" +seed = 42 +enforce_eager = true + +[generation.parallelism] +data_parallel_replicate_degree = 1 +tensor_parallel_degree = 1 + +[generation.sampling] +temperature = 0.8 +top_p = 0.95 +max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_grpo.py similarity index 63% rename from torchtitan/experiments/rl/unified/simple_rl_multiprocess.py rename to torchtitan/experiments/rl/unified/simple_grpo.py index 3e914f3778..c874feadc4 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -14,7 +14,8 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml """ import asyncio import logging @@ -22,13 +23,9 @@ import torch from monarch.actor import this_host from monarch.utils import setup_env_for_distributed +from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import Trainer -from torchtitan.experiments.rl.unified.models.utils import ModelMode -from torchtitan.experiments.rl.vllm_compat.simple_rl import ( - download_and_convert_model, - load_gsm8k_dataset, -) from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, vllm_is_batch_invariant, @@ -39,33 +36,21 @@ async def main(): """Run the distributed RL training loop using Monarch.""" - # Model Config - model_name = "Qwen/Qwen3-0.6B" - cache_dir = "./models" - output_dir = "./converted" - - # Training config - group_size = 8 - num_steps = 10 - learning_rate = 1e-5 - max_new_tokens = 20 - # GRPO config - use_stable_grpo = False - grpo_beta = 0.1 + # Step 1: Load job config using config manager + config_manager = ConfigManager() + job_config = config_manager.parse_args() - # Dataset config - use_real_dataset = False - num_dataset_samples = 5 + # RL Training config + num_steps = job_config.training.steps # Parallelism sizes - trainer_ddp_size = 2 - trainer_tp_size = 1 - generator_tp_size = 1 + trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree + trainer_tp_size = job_config.parallelism.tensor_parallel_degree + # TODO: add a flag to enable/disable batch_invariant init_batch_invariance() batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -78,37 +63,24 @@ async def main(): else: raise RuntimeError("Batch invariance NOT detected - using standard model") - # Download and convert model - titan_checkpoint_path, model_path = download_and_convert_model( - model_name, cache_dir, output_dir - ) - - # Load dataset - if use_real_dataset: - logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") - # TODO: Refactor into loading torchtitan dataset - prompt_texts, expected_answers = load_gsm8k_dataset( - split="train", num_samples=num_dataset_samples - ) - if prompt_texts is None or len(prompt_texts) == 0: - use_real_dataset = False - - if not use_real_dataset: - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] + # Use fake dataset for test. TODO: Implement real RL dataloader. + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] logger.info(f"Loaded {len(prompt_texts)} prompts") # Create process meshes - trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + trainer_mesh = this_host().spawn_procs( + per_host={"gpus": trainer_ddp_size * trainer_tp_size} + ) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) # Set up distributed env vars so that actors are connected via c10d @@ -129,27 +101,15 @@ async def main(): trainer = trainer_mesh.spawn( "trainer", Trainer, - titan_checkpoint_path, - model_path, - learning_rate, - mode, - trainer_ddp_size, - trainer_tp_size, + job_config, # Pass full job_config ) generator = gen_mesh.spawn( "generator", Generator, - model_path, + job_config, # Pass full job_config prompt_texts, expected_answers, - group_size, - max_new_tokens, - 1.0, # temperature - use_real_dataset, - grpo_beta, - use_stable_grpo, - generator_tp_size, ) # Initialize generator with trainer weights @@ -163,6 +123,8 @@ async def main(): for step in range(num_steps): # Fully sync RL loop + # NOTE: This is only getting Trajectory generated from trainer 0, and trainer 1's data is ignored. + # .get() is a blocking method which makes the loop fully sync batch = generator.generate.call().get().item(gpus=0) metrics = trainer.step.call(batch).get().item(gpus=0) weights = trainer.get_weights.call().get().item(gpus=0)