Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions torchtitan/experiments/rl/unified/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""
Unified approach for running TorchTitan models with vLLM inference.

To manually load the plugin:
from torchtitan.experiments.rl.unified import plugin
plugin.register()
To register TorchTitan models with vLLM:
from torchtitan.experiments.rl.unified.plugin import register
register(model_spec)
"""

from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
Expand Down
175 changes: 35 additions & 140 deletions torchtitan/experiments/rl/unified/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import glob
import logging
import os
import shutil

from dataclasses import dataclass, field
from typing import List

import torch
from monarch.actor import Actor, endpoint
from safetensors.torch import save_file
from torchtitan.config import CommConfig, Configurable, ParallelismConfig
from torchtitan.distributed import utils as dist_utils

Expand Down Expand Up @@ -75,7 +72,6 @@ class Generator(Actor, Configurable):
Args:
config: Generator-specific configuration.
model_path: Path to the HF model checkpoint.
dump_folder: Root output folder for RL artifacts.
batch_invariant_mode: Enable batch-invariant mode for deterministic ops.
policy_optimization: GRPO hyperparameters.
prompt_texts: List of prompt strings.
Expand All @@ -95,8 +91,8 @@ class Config(Configurable.Config):
enforce_eager: bool = True
"""Disable CUDA graphs in vLLM (use eager execution)."""

seed: int | None = None
"""Random seed for reproducible generation. None means no fixed seed."""
seed: int = 42
"""Random seed for reproducible generation."""

parallelism: ParallelismConfig = field(default_factory=ParallelismConfig)
"""Parallelism configuration for the vLLM engine."""
Expand All @@ -113,7 +109,6 @@ def __init__(
*,
model_spec: ModelSpec,
model_path: str,
dump_folder: str,
batch_invariant_mode: bool,
policy_optimization: PolicyOptimizationConfig,
prompt_texts: list[str],
Expand All @@ -129,7 +124,6 @@ def __init__(
)

register_model_to_vllm_model_registry(model_spec)
self._vllm_model_name = VLLM_MODEL_NAME

# Set vLLM environment variables from config before any vLLM initialization
if batch_invariant_mode:
Expand All @@ -148,135 +142,39 @@ def __init__(
self.group_size = policy_optimization.group_size
self.grpo_beta = policy_optimization.beta
self.use_stable_grpo = policy_optimization.use_stable_grpo

# Initialize distributed environment for SPMD generator
world_size = dist_utils.init_distributed(CommConfig())

# Set up temp model directory for vLLM weight loading
self._base_model_path = model_path
self._temp_model_dir = os.path.abspath(
os.path.join(dump_folder, "vllm_temp_model")

# Build vLLM engine
engine_args = EngineArgs(
model=model_path,
trust_remote_code=True,
dtype=config.dtype,
tensor_parallel_size=config.parallelism.tensor_parallel_degree,
distributed_executor_backend="external_launcher",
gpu_memory_utilization=config.gpu_memory_limit,
enforce_eager=config.enforce_eager,
seed=config.seed,
hf_overrides={"architectures": [VLLM_MODEL_NAME]},
attention_config=AttentionConfig(
backend=AttentionBackendEnum.FLASH_ATTN,
),
)
os.makedirs(self._temp_model_dir, exist_ok=True)

# Copy config/tokenizer files from base model to temp dir
for file in [
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"merges.txt",
"vocab.json",
]:
src = os.path.join(model_path, file)
if os.path.exists(src):
shutil.copy2(src, self._temp_model_dir)

# Copy the original model shard files if they exist
# We'll overwrite these with our single model.safetensors later
for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")):
dst = os.path.join(self._temp_model_dir, os.path.basename(shard_file))
shutil.copy2(shard_file, dst)

# Copy index file if it exists
index_file = os.path.join(model_path, "model.safetensors.index.json")
if os.path.exists(index_file):
shutil.copy2(index_file, self._temp_model_dir)

self._engine: LLMEngine | None = None

logger.info("Initializing LLMEngine from EngineArgs...")
self._engine = LLMEngine.from_engine_args(engine_args)
logger.info("vLLM rollout engine initialized")

self.policy_version = 0

# Reward function. TODO: Move reward calculation out of generator
self.reward_fn = trivial_reward_function

logger.debug("Generator initialized (vLLM engine will load on first use)")
logger.info("Generator initialized with vLLM engine")

def _update_vllm_model_weights(self, vllm_state: dict) -> None:
def _get_model(self):
"""Access the model from the vLLM engine.
Returns a TorchTitanVLLMModelWrapper instance.
"""
Update vLLM model weights from vLLM model state dict. This function is used
when updating vLLM model's weights from trainer's updated weights.

Args:
vllm_state: vLLM model state dict, a map from vLLM model's fqn names to weights
"""
# Save to temp model directory
checkpoint_path = os.path.join(self._temp_model_dir, "model.safetensors")

# Update the shard files that vLLM will actually load
# We need to split our weights to match the original 2-shard structure
shard_files = sorted(
glob.glob(os.path.join(self._temp_model_dir, "model-*.safetensors"))
)
index_file = os.path.join(self._temp_model_dir, "model.safetensors.index.json")

# TODO: need to replace this with Torchtitan's checkpoint save and load
# right now we hardcoded to work with 1 safetensor files which we only
# tested on Qwen3 0.6B model. In the longer term, need to use TorchStore
# to achieve the weight communication.
if torch.distributed.get_rank() == 0:
logger.debug(f"Saving weights to {checkpoint_path}")

# TODO: Check the detail of vLLM's dtype conversion journey
# Currently converting float32 to bfloat16 to match vLLM's attention and kv cache dtype
vllm_state = {
k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v
for k, v in vllm_state.items()
}
save_file(vllm_state, checkpoint_path)

# Synchronize all ranks before reloading to ensure rank 0 finished writing
torch.distributed.barrier()
logger.debug(
f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save"
)

# First time: create the engine using LLMEngine and EngineArgs
if self._engine is None:
cfg = self.config

engine_args = EngineArgs(
# Model configuration
model=self._temp_model_dir,
trust_remote_code=True,
dtype=cfg.dtype,
# Parallelism configuration
tensor_parallel_size=cfg.parallelism.tensor_parallel_degree,
# Use external_launcher because Monarch already spawns the worker processes
distributed_executor_backend="external_launcher",
# Memory and performance
gpu_memory_utilization=cfg.gpu_memory_limit,
enforce_eager=cfg.enforce_eager,
# Seed
seed=cfg.seed,
# HuggingFace overrides to use registered TorchTitan model
hf_overrides={"architectures": [self._vllm_model_name]},
attention_config=AttentionConfig(
backend=AttentionBackendEnum.FLASH_ATTN,
),
)

logger.debug("Initializing LLMEngine from EngineArgs...")
self._engine = LLMEngine.from_engine_args(engine_args)
logger.debug("Created new vLLM LLMEngine")
else:
# Direct parameter copy into model tensors.
# This bypasses vLLM's reload_weights() which uses a layerwise
# reload mechanism that moves params to meta device
from torchtitan.experiments.rl.vllm_compat.weights import vllm_to_torchtitan

titan_state = vllm_to_torchtitan(vllm_state)
model = self._engine.model_executor.driver_worker.get_model()
params = dict(model.named_parameters())

for name, new_weight in titan_state.items():
# TorchTitanVLLMModelWrapper stores the model as self.model,
# so parameters have a "model." prefix
param_name = f"model.{name}"
if param_name in params:
param = params[param_name]
new_w = new_weight.to(device=param.device, dtype=param.dtype)
param.data.copy_(new_w)
return self._engine.model_executor.driver_worker.get_model()

def _compute_rewards_and_advantages(
self, completions: list[str]
Expand Down Expand Up @@ -359,16 +257,12 @@ async def generate(self) -> TrajectoryData:

for sample in output.outputs:
completions.append(sample.text)
prompt_token_ids_list.append(
prompt_token_ids
) # Store prompt tokens for this sample
token_ids_list.append(
sample.token_ids
) # Extract token IDs (generated tokens only)
prompt_token_ids_list.append(prompt_token_ids)
token_ids_list.append(sample.token_ids)
per_token_log_probs = [
list(logprob_dict.values())[0].logprob
for logprob_dict in sample.logprobs
] # Extract per-token log probs
]
token_log_probs_list.append(per_token_log_probs)

# Compute rewards and advantages
Expand All @@ -391,18 +285,19 @@ async def generate(self) -> TrajectoryData:
return trajectory

@endpoint
async def update(self, version: int, vllm_compat_state: dict) -> None:
"""Update generate weights.
async def update(self, version: int, state_dict: dict) -> None:
"""Update generator weights.
Called by the orchestrator (simple_grpo.py).

Args:
version: New policy version number
vllm_compat_state: vLLM-compatible state dict
state_dict: Model state dict
"""
# TODO: remove the helper function (_update_vllm_model_weights) once we clean up the weight updates
self._update_vllm_model_weights(vllm_compat_state)
load_weights = self._get_model().load_weights_from_state_dict(state_dict)
self.policy_version = version
logger.debug(
f"Updated weights into vLLM engine model. "
f"Number of parameters: {len(load_weights)}"
f"{os.getpid()=} Generator updating weights to policy v{version}..."
)

Expand Down
7 changes: 3 additions & 4 deletions torchtitan/experiments/rl/unified/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,13 @@ def __init__(

@endpoint
async def get_weights(self) -> dict:
"""Get vLLM weights for generator.
"""Get model weights for generator.

Returns:
vLLM state dict
model state dict
"""
titan_state = self.model.state_dict()
vllm_state = torchtitan_to_vllm(titan_state)
return vllm_state
return titan_state

@endpoint
async def step(self, trajectory: TrajectoryData) -> dict:
Expand Down
Loading
Loading