From 7f3d05f12a6e504789894dc665e5150c6704168b Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Dec 2025 10:29:57 -0800 Subject: [PATCH 1/3] config sys v1 --- torchtitan/experiments/rl/unified/infer.py | 159 +++++++++--------- .../experiments/rl/unified/job_config.py | 88 ++++++++++ .../rl/unified/run_configs/qwen3_0.6b.toml | 62 +++++++ .../rl/unified/simple_rl_multiprocess.py | 2 + 4 files changed, 227 insertions(+), 84 deletions(-) create mode 100644 torchtitan/experiments/rl/unified/job_config.py create mode 100644 torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 43153fb70e..136b264a72 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -5,113 +5,104 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import argparse +""" +Example inference script using TorchTitan models with vLLM LLMEngine. + +This script uses JobConfig loaded from a TOML file to configure both +the vLLM engine and sampling parameters. + +Run: torchrun --nproc_per_node=2 \ + torchtitan/experiments/rl/unified/infer.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +""" + + +from torchtitan.config import ConfigManager # Import unified module - this automatically registers TorchTitan models with vLLM from torchtitan.experiments.rl import unified # noqa: F401 -from vllm import LLM, SamplingParams +from vllm import EngineArgs, LLMEngine, SamplingParams from vllm.logger import init_logger logger = init_logger(__name__) -def parse_args(): - parser = argparse.ArgumentParser( - description="Run TorchTitan model inference with vLLM Engine", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--model-ckpt-path", - type=str, - default="torchtitan/experiments/rl/example_checkpoint", - help="Path to TorchTitan checkpoint directory", - ) - parser.add_argument( - "--prompt", - type=str, - default="Hello, my name is", - help="Prompt text for generation", - ) - parser.add_argument( - "--max-tokens", - type=int, - default=100, - help="Maximum number of tokens to generate", - ) - parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Sampling temperature", - ) - parser.add_argument( - "--tensor-parallel-size", - type=int, - default=1, - help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", +def infer(): + + config_manager = ConfigManager() + job_config = config_manager.parse_args() + + logger.info("Initializing vLLM LLMEngine with TorchTitan model") + logger.info(f"Model: {job_config.checkpoint.initial_load_path}") + logger.info( + f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" ) - return parser.parse_args() + # Create EngineArgs from JobConfig + # Map TorchTitan parallelism to vLLM parallelism + inference = job_config.inference -def infer(): - args = parse_args() - - logger.info("Initializing vLLM with TorchTitan model") - logger.info(f"Model: {args.model_ckpt_path}") - logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") - - # Initialize vLLM with custom TorchTitan model - # The LLM initialization will internally: - # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) - # 2. Create TorchTitanVLLMModel instance - # 3. Create JobConfig and ParallelDims from vLLM config - # 4. Apply parallelization using parallelize_qwen3 - # 5. Load model weights and prepare for inference - # The tensor_parallel_size will be used by vLLM to configure parallelization - # and will be available in vllm_config in worker processes - logger.info("Creating vLLM LLM engine...") - - llm = LLM( - model=args.model_ckpt_path, # Model checkpoint path - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, - dtype="bfloat16", + engine_args = EngineArgs( + # Model configuration + model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - enforce_eager=True, # Use eager mode - tensor_parallel_size=args.tensor_parallel_size, - gpu_memory_utilization=0.5, + dtype=inference.dtype, + # Parallelism configuration + tensor_parallel_size=inference.parallelism.tensor_parallel_degree, + distributed_executor_backend=inference.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=inference.gpu_memory_utilization, + enforce_eager=inference.enforce_eager, + # Seed + seed=inference.seed, + # HuggingFace overrides + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) - logger.info("vLLM engine initialized successfully") - logger.info(f"Prompt: {args.prompt}") + logger.info("Initializing LLMEngine from EngineArgs...") + engine = LLMEngine.from_engine_args(engine_args) - # Prepare prompt and sampling parameters - prompts = [args.prompt] + logger.info("vLLM LLMEngine initialized successfully") + + # Create sampling parameters from JobConfig + sampling = job_config.inference.sampling sampling_params = SamplingParams( - temperature=args.temperature, - top_p=0.95, - max_tokens=args.max_tokens, + temperature=sampling.temperature, + top_p=sampling.top_p, + max_tokens=sampling.max_tokens, ) - # Generate text - logger.info("Generating text...") - outputs = llm.generate( - prompts=prompts, - sampling_params=sampling_params, + logger.info( + f"Sampling params: temperature={sampling.temperature}, " + f"top_p={sampling.top_p}, max_tokens={sampling.max_tokens}" ) - # Print results - logger.info("Generation complete") - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text + # Example prompt + prompt = "Hello, my name is" + logger.info(f"Prompt: {prompt}") - print(f"\nPrompt: {prompt}") - print(f"Generated text: {generated_text!r}\n") + # Add request to engine + logger.info("Adding request to engine...") + request_id = "0" + engine.add_request(request_id, prompt, sampling_params) + + # Generate text by stepping through engine + logger.info("Generating text...") + while engine.has_unfinished_requests(): + request_outputs = engine.step() + + # Process finished requests + for request_output in request_outputs: + if request_output.finished: + prompt = request_output.prompt + generated_text = request_output.outputs[0].text + + # Print results + logger.info("Generation complete") + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") if __name__ == "__main__": diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py new file mode 100644 index 0000000000..47e41bc831 --- /dev/null +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Extended JobConfig for RL/Inference workloads. + +This module extends TorchTitan's base JobConfig with inference-specific +configurations needed for vLLM integration. +""" + +from dataclasses import dataclass, field + +from torchtitan.config.job_config import JobConfig as BaseJobConfig, Parallelism + + +@dataclass +class Sampling: + """ + Sampling configuration for vLLM generation. + + This dataclass contains sampling parameters used during text generation. + These map directly to vLLM's SamplingParams. + """ + + temperature: float = 0.8 + """ + Temperature for sampling. Controls randomness in generation. + - 0.0: Deterministic (greedy decoding) + - Higher values: More random outputs + """ + + top_p: float = 0.95 + """ + Top-p (nucleus) sampling threshold. + Only tokens with cumulative probability up to top_p are considered. + """ + + max_tokens: int = 100 + """Maximum number of tokens to generate""" + + +@dataclass +class Inference: + """ + Inference configuration for vLLM engine. + + This dataclass contains essential vLLM-specific settings for inference. + """ + + dtype: str = "bfloat16" + """Data type for model weights (auto, float16, bfloat16, float32)""" + + gpu_memory_utilization: float = 0.5 + """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" + + distributed_executor_backend: str = "external_launcher" + """ + Backend for distributed execution. + 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) + """ + + seed: int = 42 + """Random seed for sampling""" + + enforce_eager: bool = True + """Whether to enforce eager execution (disable CUDA graphs)""" + + parallelism: Parallelism = field(default_factory=Parallelism) + """Parallelism configuration for inference""" + + sampling: Sampling = field(default_factory=Sampling) + """Sampling configuration for inference""" + + +@dataclass +class JobConfig(BaseJobConfig): + """ + Extended JobConfig with inference support. + + This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields + for vLLM-specific inference and generation configurations. + """ + + inference: Inference = field(default_factory=Inference) + """Inference configuration for vLLM engine""" diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml new file mode 100644 index 0000000000..6166c50a91 --- /dev/null +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -0,0 +1,62 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3 0.6B training" +custom_config_module = "torchtitan.experiments.rl.unified.job_config" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3" +flavor = "0.6B" +hf_assets_path = "./assets/hf/Qwen3-0.6B" + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, 20% total steps + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" + +[parallelism] # training parallelism plan +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 + +[checkpoint] +initial_load_path = "/data/users/jianiw/model/qwen3-0.6b" + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[inference] +dtype = "bfloat16" +gpu_memory_utilization = 0.5 +distributed_executor_backend = "external_launcher" +seed = 42 +enforce_eager = true + +[inference.parallelism] +data_parallel_replicate_degree = 1 +tensor_parallel_degree = 2 + +[inference.sampling] +temperature = 0.8 +top_p = 0.95 +max_tokens = 100 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 3e914f3778..28581e2747 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -163,6 +163,8 @@ async def main(): for step in range(num_steps): # Fully sync RL loop + # NOTE: This is only getting Trajectory generated from trainer 0, and trainer 1's data is ignored. + # .get() is a blocking method which makes the loop fully sync batch = generator.generate.call().get().item(gpus=0) metrics = trainer.step.call(batch).get().item(gpus=0) weights = trainer.get_weights.call().get().item(gpus=0) From 1092b4d772e8906bc49886fa9908a157ad6c0d1e Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Dec 2025 14:25:21 -0800 Subject: [PATCH 2/3] config sys for simple_grpo v1 --- torchtitan/experiments/rl/unified/README.md | 6 +- .../rl/unified/actors/generator.py | 178 ++++++++---------- .../experiments/rl/unified/actors/trainer.py | 39 ++-- .../experiments/rl/unified/job_config.py | 16 +- .../experiments/rl/unified/models/utils.py | 81 ++------ .../rl/unified/run_configs/qwen3_0.6b.toml | 6 +- ...mple_rl_multiprocess.py => simple_grpo.py} | 94 +++------ 7 files changed, 160 insertions(+), 260 deletions(-) rename torchtitan/experiments/rl/unified/{simple_rl_multiprocess.py => simple_grpo.py} (64%) diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..ae604b19aa 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -66,10 +66,10 @@ python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path None: @@ -134,7 +136,6 @@ def update_weights(self, vllm_compat_state: dict) -> None: # Update the shard files that vLLM will actually load # We need to split our weights to match the original 2-shard structure import glob - import json shard_files = sorted( glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) @@ -142,51 +143,20 @@ def update_weights(self, vllm_compat_state: dict) -> None: index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") # TODO: need to replace this with Torchtitan's checkpoint save and load - # right now we hardcoded to work with 2 safe tensor files which we only + # right now we hardcoded to work with 2 safetensor files which we only # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore # to achieve the weight communication. # only generator rank 0 saves the weight if torch.distributed.get_rank() == 0: logger.info(f"Saving weights to {checkpoint_path}") - if len(shard_files) == 2 and os.path.exists(index_file): - # Load the index to see which weights go in which shard - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data["weight_map"] - - # Split weights according to the index - shard1_weights = {} - shard2_weights = {} - - for key, value in vllm_state.items(): - shard_file = weight_map.get(key, shard_files[0]) - if "model-00001-of-00002" in shard_file: - shard1_weights[key] = value - else: - shard2_weights[key] = value - - # Ensure weights stay in bfloat16 - shard1_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard1_weights.items() - } - shard2_weights = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in shard2_weights.items() - } - - # Save to the shard files - save_file(shard1_weights, shard_files[0]) - save_file(shard2_weights, shard_files[1]) - else: - # Ensure weights stay in bfloat16 - vllm_state = { - k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v - for k, v in vllm_state.items() - } - # Fallback: save as single file - save_file(vllm_state, checkpoint_path) + + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) # Synchronize all ranks before reloading to ensure rank 0 finished writing torch.distributed.barrier() @@ -194,28 +164,37 @@ def update_weights(self, vllm_compat_state: dict) -> None: f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" ) - # First time: create the engine - if self.llm is None: - self.llm = LLM( + # First time: create the engine using LLMEngine and EngineArgs + if self.engine is None: + inference = self.job_config.inference + + # + engine_args = EngineArgs( + # Model configuration model=self.temp_model_dir, - hf_overrides={ - # Override architectures to use our registered TorchTitan model class - "architectures": ["Qwen3TorchTitanForCausalLM"], - }, trust_remote_code=True, - max_model_len=2048, - dtype="bfloat16", - gpu_memory_utilization=0.1, # Reduced from 0.5 - distributed_executor_backend="external_launcher", # vllm do not spawn processes - seed=42, # Fixed seed for determinism - enforce_eager=True, - tensor_parallel_size=self.tp_size, # Explicitly single GPU + dtype=inference.dtype, + # Parallelism configuration + tensor_parallel_size=inference.parallelism.tensor_parallel_degree, + distributed_executor_backend=inference.distributed_executor_backend, + # Memory and performance + gpu_memory_utilization=inference.gpu_memory_utilization, + enforce_eager=inference.enforce_eager, + # Seed + seed=inference.seed, + # HuggingFace overrides to use TorchTitan model. + # TODO: make this field configurable and align with model registration + hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, + max_model_len=2048, # TODO: make this configurable ) - logger.info("Created new vLLM engine") + + logger.info("Initializing LLMEngine from EngineArgs...") + self.engine = LLMEngine.from_engine_args(engine_args) + logger.info("Created new vLLM LLMEngine") else: # Use collective_rpc to call reload_weights on all workers # This reloads weights from temp_model_dir without recreating the engine - self.llm.collective_rpc("reload_weights") + self.engine.collective_rpc("reload_weights") @torch.no_grad() def generate( @@ -228,7 +207,7 @@ def generate( list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] ]: """ - Generate samples using vLLM. + Generate samples using vLLM LLMEngine. Args: prompt_texts: List of prompt strings @@ -252,7 +231,16 @@ def generate( prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs ) - outputs = self.llm.generate(prompt_texts, sampling_params) + # Add requests to the engine + for i, prompt in enumerate(prompt_texts): + request_id = str(i) + self.engine.add_request(request_id, prompt, sampling_params) + + # Step through engine until all requests are finished + all_outputs = [] + while self.engine.has_unfinished_requests(): + request_outputs = self.engine.step() + all_outputs.extend(request_outputs) # Extract completions and log probs completions = [] @@ -261,7 +249,7 @@ def generate( token_log_probs_list = [] prompt_token_ids_list = [] - for output in outputs: + for output in all_outputs: # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids @@ -298,8 +286,8 @@ def generate( def __del__(self): """Cleanup vLLM engine.""" - if hasattr(self, "llm"): - del self.llm + if hasattr(self, "engine"): + del self.engine torch.cuda.empty_cache() @@ -319,58 +307,46 @@ class Generator(Actor): computes rewards/advantages. Args: - model_path: Path to HuggingFace model + job_config: JobConfig dataclass containing all configuration prompt_texts: List of prompt strings expected_answers: List of expected answers - group_size: Number of samples per prompt - max_new_tokens: Max tokens to generate - temperature: Sampling temperature - use_real_dataset: Whether using real dataset (GSM8K) - grpo_beta: Beta for GRPO advantages - use_stable_grpo: Whether to use stable GRPO - tp_size: Tensor Parallel size """ def __init__( self, - model_path: str, - prompt_texts: List[str], + job_config: JobConfig, + prompt_texts: List[ + str + ], # TODO: This field need to be removed once dataloader is implemented expected_answers: List[str], - group_size: int = 8, - max_new_tokens: int = 20, - temperature: float = 1.0, - use_real_dataset: bool = False, - grpo_beta: float = 0.1, - use_stable_grpo: bool = False, - tp_size: int = 1, ): - self.model_path = model_path + # Store job_config for accessing configuration + self.job_config = job_config self.prompt_texts = prompt_texts self.expected_answers = expected_answers - self.group_size = group_size - self.max_new_tokens = max_new_tokens - self.temperature = temperature - self.use_real_dataset = use_real_dataset - self.grpo_beta = grpo_beta - self.use_stable_grpo = use_stable_grpo - self.tp_size = tp_size + + # Extract needed fields from job_config + self.model_path = os.path.join(job_config.job.dump_folder, "models") + self.max_new_tokens = job_config.inference.sampling.max_tokens + self.temperature = job_config.inference.sampling.temperature + self.group_size = job_config.rl.grpo_group_size + self.grpo_beta = job_config.rl.grpo_beta + self.use_stable_grpo = job_config.rl.use_stable_grpo # Initialize distributed environment for SPMD generator world_size = dist_utils.init_distributed( Comm(), ) - # Initialize vLLM engine - self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + # Initialize vLLM engine with job_config + self.vllm_engine = VLLMRolloutEngine(job_config, self.model_path) # State machine self.state = GeneratorState.READY_TO_UPDATE self.cond = asyncio.Condition() self.policy_version = 0 - # Reward function - self.reward_fn = ( - math_reward_function if use_real_dataset else trivial_reward_function - ) + # Reward function. TODO: Use a real reward function + self.reward_fn = trivial_reward_function logger.info("Generator initialized with vLLM engine") diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index 9ffb9f0f0a..ac5addf10e 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -11,16 +11,15 @@ import torch from monarch.actor import Actor, endpoint from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData -from torchtitan.experiments.rl.unified.models.parallelism_utils import ( +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( create_trainer_parallel_dims, ) -from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.unified.job_config import JobConfig +from torchtitan.experiments.rl.unified.models.utils import load_trainer_model from torchtitan.experiments.rl.vllm_compat.simple_rl import ( compute_policy_gradient_loss_vllm, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm logger = logging.getLogger(__name__) @@ -30,34 +29,30 @@ class Trainer(Actor): Updates policy based on collected trajectories. Run model forward on trajectories, computes loss, and run backward. + TODO: Use torchtitan Trainer Args: - titan_checkpoint_path: Path to TorchTitan checkpoint - model_path: Path to HuggingFace model - learning_rate: Learning rate for optimizer - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model + job_config: JobConfig dataclass containing all configuration """ def __init__( self, - titan_checkpoint_path: str, - model_path: str, - learning_rate: float = 1e-5, - model_mode: str = ModelMode.VLLM_COMPAT, - ddp_size: int = 1, - tp_size: int = 1, + job_config: JobConfig, ): + # Extract needed fields from job_config + model_path = job_config.checkpoint.initial_load_path # path to HF checkpoint + learning_rate = job_config.optimizer.lr + self.ddp_size = job_config.parallelism.data_parallel_replicate_degree + self.tp_size = job_config.parallelism.tensor_parallel_degree + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(local_rank) - self.model = load_model( - titan_checkpoint_path, model_path, model_mode=model_mode - ) - self.ddp_size = ddp_size - self.tp_size = tp_size + # load trainer model and patch to vllm.Attention() + self.model = load_trainer_model(model_path) + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) # apply PT-D Parallelism @@ -88,7 +83,7 @@ async def get_weights(self) -> dict: vLLM-compatible state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + vllm_compat_state = torchtitan_to_vllm(titan_state) return vllm_compat_state @endpoint diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py index 47e41bc831..b9e028c023 100644 --- a/torchtitan/experiments/rl/unified/job_config.py +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -61,7 +61,6 @@ class Inference: Backend for distributed execution. 'external_launcher' means vLLM does not spawn processes (use torchrun/external launcher) """ - seed: int = 42 """Random seed for sampling""" @@ -75,6 +74,20 @@ class Inference: """Sampling configuration for inference""" +@dataclass +class RL: + """Reinforcement Learning configuration for GRPO training.""" + + grpo_beta: int = 0.1 + """Beta parameter for GRPO (Group Relative Policy Optimization) algorithm""" + + use_stable_grpo: bool = False + """Whether to use stable version of GRPO algorithm""" + + grpo_group_size: int = 8 + """Number of samples in each GRPO group for policy optimization""" + + @dataclass class JobConfig(BaseJobConfig): """ @@ -86,3 +99,4 @@ class JobConfig(BaseJobConfig): inference: Inference = field(default_factory=Inference) """Inference configuration for vLLM engine""" + rl: RL = field(default_factory=RL) diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py index 0e5d6cde52..2738a95728 100644 --- a/torchtitan/experiments/rl/unified/models/utils.py +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -6,42 +6,22 @@ import logging -from enum import Enum import torch -from safetensors.torch import load_file from torchtitan.experiments.rl.unified.models.attention import VLLMAttention from torchtitan.experiments.rl.vllm_compat.models.attention import ( VLLMCompatibleFlashAttention, ) -from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( - torchtitan_to_vllm_compat, -) +from torchtitan.experiments.rl.vllm_compat.weights.converter import vllm_to_torchtitan + from torchtitan.models.qwen3.model.args import Qwen3ModelArgs from transformers import AutoConfig logger = logging.getLogger(__name__) -class ModelMode(str, Enum): - """ - Enum defining which TorchTitan model to use. - - Attributes: - UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified - training and inference. - VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, - ensuring bitwise determinism between training and inference. - STANDARD: Plain TorchTitan model without any modifications. - """ - - UNIFIED = "unified" - VLLM_COMPAT = "vllm_compat" - STANDARD = "standard" - - def replace_with_vllm_attention(model, tp_degree=1): """ Replace TorchTitan attention with vLLM's Attention. @@ -102,26 +82,20 @@ def replace_with_vllm_compatible_flash_attention(model): ) -def load_model( - checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT -): +def load_trainer_model(model_path: str): """ Load TorchTitan model from checkpoint for trainer. Args: - checkpoint_path: Path to TorchTitan checkpoint model_path: Path to HuggingFace model (for config) - model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, - or plain Torchtitan model Returns: model: Loaded TorchTitan model for trainer. """ - # Load HuggingFace config # TODO: do not depend on transformers.AutoConfig, use qwen_args directly hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) - # Create model args + # Create model args. model_args = Qwen3ModelArgs( dim=hf_config.hidden_size, n_layers=hf_config.num_hidden_layers, @@ -135,46 +109,27 @@ def load_model( ), hidden_dim=hf_config.intermediate_size, norm_eps=hf_config.rms_norm_eps, - rope_theta=hf_config.rope_theta, + rope_theta=getattr(hf_config.rope_parameters, "rope_theta", 1000000), max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), qk_norm=True, depth_init=True, eos_id=getattr(hf_config, "eos_token_id", 151645), ) + # convert to torchtitan state_dict. TODO: Use torchtitan components + titan_state_dict = vllm_to_torchtitan(model_path) - # state_dict is in standard TorchTitan format (w1, w2, w3) - state_dict = load_file(checkpoint_path) - - if model_mode == ModelMode.UNIFIED: - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Set global default dtype to bfloat16. This is needed because vLLM's Attention - # layer uses torch.get_default_dtype() and it doesn't support float32 - torch.set_default_dtype(torch.bfloat16) - # NOTE: Override attention to vllm compatible attention for backward capability. - # Only patch to vllm compatible attention for training. - replace_with_vllm_compatible_flash_attention(model) - - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=True) - elif model_mode == ModelMode.VLLM_COMPAT: - # Create and load model that has bitwise determinism between training and inference - from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( - Qwen3VLLMCompatModel, - ) + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention during training. + replace_with_vllm_compatible_flash_attention(model) - model = Qwen3VLLMCompatModel(model_args) - # Convert to vLLM-compat format (merged gate_up_proj, down_proj) - vllm_compat_state = torchtitan_to_vllm_compat(state_dict) - model.load_state_dict(vllm_compat_state, strict=False) - else: - # Use standard TorchTitan model - from torchtitan.models.qwen3 import Qwen3Model - - model = Qwen3Model(model_args) - # Load standard TorchTitan format directly - model.load_state_dict(state_dict, strict=False) + # Load standard TorchTitan format directly + model.load_state_dict(titan_state_dict, strict=True) model.to(torch.bfloat16) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml index 6166c50a91..a004545a55 100644 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -29,8 +29,8 @@ steps = 10 dataset = "c4" [parallelism] # training parallelism plan -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_replicate_degree = 2 +data_parallel_shard_degree = 1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 @@ -54,7 +54,7 @@ enforce_eager = true [inference.parallelism] data_parallel_replicate_degree = 1 -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 [inference.sampling] temperature = 0.8 diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_grpo.py similarity index 64% rename from torchtitan/experiments/rl/unified/simple_rl_multiprocess.py rename to torchtitan/experiments/rl/unified/simple_grpo.py index 28581e2747..c874feadc4 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_grpo.py @@ -14,7 +14,8 @@ The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. Command to run: -VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_grpo.py \ + --job.config_file torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml """ import asyncio import logging @@ -22,13 +23,9 @@ import torch from monarch.actor import this_host from monarch.utils import setup_env_for_distributed +from torchtitan.config.manager import ConfigManager from torchtitan.experiments.rl.unified.actors.generator import Generator from torchtitan.experiments.rl.unified.actors.trainer import Trainer -from torchtitan.experiments.rl.unified.models.utils import ModelMode -from torchtitan.experiments.rl.vllm_compat.simple_rl import ( - download_and_convert_model, - load_gsm8k_dataset, -) from vllm.model_executor.layers.batch_invariant import ( init_batch_invariance, vllm_is_batch_invariant, @@ -39,33 +36,21 @@ async def main(): """Run the distributed RL training loop using Monarch.""" - # Model Config - model_name = "Qwen/Qwen3-0.6B" - cache_dir = "./models" - output_dir = "./converted" - - # Training config - group_size = 8 - num_steps = 10 - learning_rate = 1e-5 - max_new_tokens = 20 - # GRPO config - use_stable_grpo = False - grpo_beta = 0.1 + # Step 1: Load job config using config manager + config_manager = ConfigManager() + job_config = config_manager.parse_args() - # Dataset config - use_real_dataset = False - num_dataset_samples = 5 + # RL Training config + num_steps = job_config.training.steps # Parallelism sizes - trainer_ddp_size = 2 - trainer_tp_size = 1 - generator_tp_size = 1 + trainer_ddp_size = job_config.parallelism.data_parallel_replicate_degree + trainer_tp_size = job_config.parallelism.tensor_parallel_degree + # TODO: add a flag to enable/disable batch_invariant init_batch_invariance() batch_invariant = vllm_is_batch_invariant() - mode = ModelMode.UNIFIED # Set up batch invariant if batch_invariant: @@ -78,37 +63,24 @@ async def main(): else: raise RuntimeError("Batch invariance NOT detected - using standard model") - # Download and convert model - titan_checkpoint_path, model_path = download_and_convert_model( - model_name, cache_dir, output_dir - ) - - # Load dataset - if use_real_dataset: - logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") - # TODO: Refactor into loading torchtitan dataset - prompt_texts, expected_answers = load_gsm8k_dataset( - split="train", num_samples=num_dataset_samples - ) - if prompt_texts is None or len(prompt_texts) == 0: - use_real_dataset = False - - if not use_real_dataset: - logger.info("Using default prompts") - prompts_with_answers = [ - ("The capital of France is", "paris"), - ("What is 7 times 8?", "56"), - ("The first president of the United States was", "washington"), - ("The chemical symbol for water is", "h2o"), - ("The largest planet in our solar system is", "jupiter"), - ] - prompt_texts = [p[0] for p in prompts_with_answers] - expected_answers = [p[1] for p in prompts_with_answers] + # Use fake dataset for test. TODO: Implement real RL dataloader. + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] logger.info(f"Loaded {len(prompt_texts)} prompts") # Create process meshes - trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + trainer_mesh = this_host().spawn_procs( + per_host={"gpus": trainer_ddp_size * trainer_tp_size} + ) gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) # Set up distributed env vars so that actors are connected via c10d @@ -129,27 +101,15 @@ async def main(): trainer = trainer_mesh.spawn( "trainer", Trainer, - titan_checkpoint_path, - model_path, - learning_rate, - mode, - trainer_ddp_size, - trainer_tp_size, + job_config, # Pass full job_config ) generator = gen_mesh.spawn( "generator", Generator, - model_path, + job_config, # Pass full job_config prompt_texts, expected_answers, - group_size, - max_new_tokens, - 1.0, # temperature - use_real_dataset, - grpo_beta, - use_stable_grpo, - generator_tp_size, ) # Initialize generator with trainer weights From ac1447113d5cec788d03a02abc76420da31be162 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Dec 2025 19:28:26 -0800 Subject: [PATCH 3/3] config sys for simple_grpo v2 --- .../rl/unified/actors/generator.py | 94 +++++++++++-------- .../experiments/rl/unified/actors/trainer.py | 8 +- torchtitan/experiments/rl/unified/infer.py | 18 ++-- .../experiments/rl/unified/job_config.py | 26 ++--- .../rl/unified/run_configs/qwen3_0.6b.toml | 6 +- 5 files changed, 84 insertions(+), 68 deletions(-) diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 56a0ac41b5..39002d5f2f 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -27,8 +27,8 @@ compute_grpo_advantages_stable, trivial_reward_function, ) -from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind logger = logging.getLogger(__name__) @@ -72,7 +72,7 @@ class VLLMRolloutEngine: def __init__( self, - job_config: "JobConfig", + job_config: JobConfig, model_path: str, ): # Store job_config for accessing configuration @@ -115,7 +115,7 @@ def __init__( self.engine = None logger.info("vLLM rollout engine initialized (will load on first use)") - def update_weights(self, vllm_compat_state: dict) -> None: + def update_weights(self, vllm_state: dict) -> None: """ Update vLLM model weights from vLLM-compat state dict. @@ -123,11 +123,8 @@ def update_weights(self, vllm_compat_state: dict) -> None: vLLM's reload_weights() API after updating the model path config. Args: - vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + vllm_state: vLLM model state dict """ - # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) - vllm_state = torchtitan_to_vllm(vllm_compat_state) - # Save to temp model directory import os @@ -166,26 +163,24 @@ def update_weights(self, vllm_compat_state: dict) -> None: # First time: create the engine using LLMEngine and EngineArgs if self.engine is None: - inference = self.job_config.inference + generation = self.job_config.generation - # engine_args = EngineArgs( # Model configuration model=self.temp_model_dir, trust_remote_code=True, - dtype=inference.dtype, + dtype=generation.dtype, # Parallelism configuration - tensor_parallel_size=inference.parallelism.tensor_parallel_degree, - distributed_executor_backend=inference.distributed_executor_backend, + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, # Memory and performance - gpu_memory_utilization=inference.gpu_memory_utilization, - enforce_eager=inference.enforce_eager, + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, # Seed - seed=inference.seed, + seed=generation.seed, # HuggingFace overrides to use TorchTitan model. # TODO: make this field configurable and align with model registration hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, - max_model_len=2048, # TODO: make this configurable ) logger.info("Initializing LLMEngine from EngineArgs...") @@ -222,19 +217,30 @@ def generate( token_log_probs: List of per-token log prob lists for each completion prompt_token_ids: List of prompt token ID lists for each completion """ + logger.info( + f"Starting generation: {len(prompt_texts)} prompts, " + f"n_samples_per_prompt={n_samples_per_prompt}, " + f"max_tokens={max_new_tokens}, temp={temperature}" + ) + sampling_params = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, - n=n_samples_per_prompt, seed=42, logprobs=1, prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + output_kind=RequestOutputKind.FINAL_ONLY, # Only return completed outputs ) # Add requests to the engine - for i, prompt in enumerate(prompt_texts): - request_id = str(i) - self.engine.add_request(request_id, prompt, sampling_params) + # For n_samples_per_prompt > 1, submit each prompt multiple times with different request id + request_id = 0 + prompt_indices = [] # Track which prompt each request corresponds to + for prompt_idx, prompt in enumerate(prompt_texts): + for sample_idx in range(n_samples_per_prompt): + self.engine.add_request(str(request_id), prompt, sampling_params) + prompt_indices.append(prompt_idx) + request_id += 1 # Step through engine until all requests are finished all_outputs = [] @@ -253,26 +259,31 @@ def generate( # Extract prompt token IDs from the output prompt_token_ids = output.prompt_token_ids - for sample in output.outputs: - completions.append(sample.text) + # Each output now has exactly 1 sample (we submitted multiple requests) + assert ( + len(output.outputs) == 1 + ), f"Expected 1 output, got {len(output.outputs)}" + sample = output.outputs[0] + + completions.append(sample.text) - # Store prompt tokens for this sample - prompt_token_ids_list.append(prompt_token_ids) + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) - # Extract token IDs (generated tokens only) - token_ids = sample.token_ids - token_ids_list.append(token_ids) + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) - # Extract per-token log probs - per_token_log_probs = [ - list(logprob_dict.values())[0].logprob - for logprob_dict in sample.logprobs - ] - token_log_probs_list.append(per_token_log_probs) + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) - # Sum log probs across generated tokens - total_log_prob = sum(per_token_log_probs) - log_probs_list.append(total_log_prob) + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) log_probs = torch.tensor(log_probs_list, dtype=torch.float32) @@ -326,9 +337,9 @@ def __init__( self.expected_answers = expected_answers # Extract needed fields from job_config - self.model_path = os.path.join(job_config.job.dump_folder, "models") - self.max_new_tokens = job_config.inference.sampling.max_tokens - self.temperature = job_config.inference.sampling.temperature + self.model_path = job_config.checkpoint.initial_load_path + self.max_new_tokens = job_config.generation.sampling.max_tokens + self.temperature = job_config.generation.sampling.temperature self.group_size = job_config.rl.grpo_group_size self.grpo_beta = job_config.rl.grpo_beta self.use_stable_grpo = job_config.rl.use_stable_grpo @@ -377,6 +388,11 @@ async def generate(self) -> None: ) # Compute rewards + logger.info( + f"Computing rewards: {len(completions)} completions, " + f"{len(self.expected_answers)} expected answers, " + f"group_size={self.group_size}" + ) rewards = self.reward_fn( completions, self.expected_answers, self.group_size ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py index ac5addf10e..ccfa10a708 100644 --- a/torchtitan/experiments/rl/unified/actors/trainer.py +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -77,14 +77,14 @@ def __init__( @endpoint async def get_weights(self) -> dict: - """Get vLLM-compatible weights for generator. + """Get vLLM weights for generator. Returns: - vLLM-compatible state dict + vLLM state dict """ titan_state = self.model.state_dict() - vllm_compat_state = torchtitan_to_vllm(titan_state) - return vllm_compat_state + vllm_state = torchtitan_to_vllm(titan_state) + return vllm_state @endpoint async def step(self, trajectory: TrajectoryData) -> dict: diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py index 136b264a72..14e2542d16 100755 --- a/torchtitan/experiments/rl/unified/infer.py +++ b/torchtitan/experiments/rl/unified/infer.py @@ -37,26 +37,26 @@ def infer(): logger.info("Initializing vLLM LLMEngine with TorchTitan model") logger.info(f"Model: {job_config.checkpoint.initial_load_path}") logger.info( - f"Tensor Parallel Size: {job_config.inference.parallelism.tensor_parallel_degree}" + f"Tensor Parallel Size: {job_config.generation.parallelism.tensor_parallel_degree}" ) # Create EngineArgs from JobConfig # Map TorchTitan parallelism to vLLM parallelism - inference = job_config.inference + generation = job_config.generation engine_args = EngineArgs( # Model configuration model=job_config.checkpoint.initial_load_path, trust_remote_code=True, - dtype=inference.dtype, + dtype=generation.dtype, # Parallelism configuration - tensor_parallel_size=inference.parallelism.tensor_parallel_degree, - distributed_executor_backend=inference.distributed_executor_backend, + tensor_parallel_size=generation.parallelism.tensor_parallel_degree, + distributed_executor_backend=generation.distributed_executor_backend, # Memory and performance - gpu_memory_utilization=inference.gpu_memory_utilization, - enforce_eager=inference.enforce_eager, + gpu_memory_utilization=generation.gpu_memory_utilization, + enforce_eager=generation.enforce_eager, # Seed - seed=inference.seed, + seed=generation.seed, # HuggingFace overrides hf_overrides={"architectures": ["Qwen3TorchTitanForCausalLM"]}, ) @@ -67,7 +67,7 @@ def infer(): logger.info("vLLM LLMEngine initialized successfully") # Create sampling parameters from JobConfig - sampling = job_config.inference.sampling + sampling = job_config.generation.sampling sampling_params = SamplingParams( temperature=sampling.temperature, top_p=sampling.top_p, diff --git a/torchtitan/experiments/rl/unified/job_config.py b/torchtitan/experiments/rl/unified/job_config.py index b9e028c023..fdbe086352 100644 --- a/torchtitan/experiments/rl/unified/job_config.py +++ b/torchtitan/experiments/rl/unified/job_config.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. """ -Extended JobConfig for RL/Inference workloads. +Extended JobConfig for RL/Generation workloads. -This module extends TorchTitan's base JobConfig with inference-specific +This module extends TorchTitan's base JobConfig with generation-specific configurations needed for vLLM integration. """ @@ -43,18 +43,18 @@ class Sampling: @dataclass -class Inference: +class Generate: """ - Inference configuration for vLLM engine. + Generation configuration for vLLM engine. - This dataclass contains essential vLLM-specific settings for inference. + This dataclass contains essential vLLM-specific settings for generation. """ dtype: str = "bfloat16" """Data type for model weights (auto, float16, bfloat16, float32)""" gpu_memory_utilization: float = 0.5 - """Fraction of GPU memory to use for Inference engine (0.0 to 1.0)""" + """Fraction of GPU memory to use for generation engine (0.0 to 1.0)""" distributed_executor_backend: str = "external_launcher" """ @@ -68,10 +68,10 @@ class Inference: """Whether to enforce eager execution (disable CUDA graphs)""" parallelism: Parallelism = field(default_factory=Parallelism) - """Parallelism configuration for inference""" + """Parallelism configuration for generation""" sampling: Sampling = field(default_factory=Sampling) - """Sampling configuration for inference""" + """Sampling configuration for generation""" @dataclass @@ -91,12 +91,12 @@ class RL: @dataclass class JobConfig(BaseJobConfig): """ - Extended JobConfig with inference support. + Extended JobConfig with generation support. - This extends TorchTitan's base JobConfig by adding `inference` and `sampling` fields - for vLLM-specific inference and generation configurations. + This extends TorchTitan's base JobConfig by adding `generation` field + for vLLM-specific generation configurations. """ - inference: Inference = field(default_factory=Inference) - """Inference configuration for vLLM engine""" + generation: Generate = field(default_factory=Generate) + """Generation configuration for vLLM engine""" rl: RL = field(default_factory=RL) diff --git a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml index a004545a55..2080e7a39f 100644 --- a/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml +++ b/torchtitan/experiments/rl/unified/run_configs/qwen3_0.6b.toml @@ -45,18 +45,18 @@ selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac ba enable=false components = ["model", "loss"] -[inference] +[generation] dtype = "bfloat16" gpu_memory_utilization = 0.5 distributed_executor_backend = "external_launcher" seed = 42 enforce_eager = true -[inference.parallelism] +[generation.parallelism] data_parallel_replicate_degree = 1 tensor_parallel_degree = 1 -[inference.sampling] +[generation.sampling] temperature = 0.8 top_p = 0.95 max_tokens = 100