From ffd835abf1b7f4c072573d6937c07831e280155e Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 18 Dec 2025 15:39:58 +0800 Subject: [PATCH 1/3] [ray] feat: resource pool create one placement group across all nodes --- .../test_high_level_scheduling_api.py | 31 +++- .../test_split_resource_pool.py | 8 +- verl/single_controller/ray/base.py | 135 +++++++----------- 3 files changed, 80 insertions(+), 94 deletions(-) diff --git a/tests/single_controller/test_high_level_scheduling_api.py b/tests/single_controller/test_high_level_scheduling_api.py index 487eb37e344..617e8c283ba 100644 --- a/tests/single_controller/test_high_level_scheduling_api.py +++ b/tests/single_controller/test_high_level_scheduling_api.py @@ -17,7 +17,7 @@ import ray from verl.single_controller.base.worker import Worker -from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool +from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, split_resource_pool from verl.utils.device import get_device_name @@ -68,15 +68,14 @@ def test(): del ref_wg gc.collect() # make sure ray actors are deleted - [ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()] + ray.util.remove_placement_group(resource_pool.get_placement_group()) print("wait 5s to remove placemeng_group") time.sleep(5) # test single-node-multi-partition print("test single-node-multi-partition") - rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm") - ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref") - total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool) + total_resource_pool = RayResourcePool([8], use_gpu=True, name_prefix="ref") + rm_resource_pool, ref_resource_pool = split_resource_pool(total_resource_pool, split_size=4) assert rm_resource_pool.world_size == 4 assert ref_resource_pool.world_size == 4 @@ -101,3 +100,25 @@ def test(): assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)] ray.shutdown() + + +def test_multi_nodes(): + ray.init() + class_with_args = RayClassWithInitArgs(cls=TestActor) + resource_pool = RayResourcePool([4, 4]) + assert resource_pool.world_size == 8 + + # actor worker group + actor_wg = RayWorkerGroup(resource_pool, class_with_args) + assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)] + + # split resource pool for rollout (world_size=2) + rollout_pools = split_resource_pool(resource_pool, split_size=2) + assert len(rollout_pools) == 4 + for idx, rollout_pool in enumerate(rollout_pools): + assert rollout_pool.world_size == 2 + assert rollout_pool.start_bundle_index == idx * 2 + rollout_wg = RayWorkerGroup(rollout_pool, class_with_args) + assert rollout_wg.execute_all_sync("get_cuda_visible_devices") == [str(idx * 2 + i) for i in range(2)] + + ray.shutdown() diff --git a/tests/single_controller/test_split_resource_pool.py b/tests/single_controller/test_split_resource_pool.py index 2cb32606cf3..d89dcdf949e 100644 --- a/tests/single_controller/test_split_resource_pool.py +++ b/tests/single_controller/test_split_resource_pool.py @@ -51,7 +51,7 @@ def test_split_resource_pool_with_split_size(): ray.init() # assume we have 2 nodes, with 4 GPUs each global_resource_pool = RayResourcePool(process_on_nodes=[4, 4]) - global_resource_pool.get_placement_groups(device_name=get_device_name()) + global_resource_pool.get_placement_group(device_name=get_device_name()) # first 4 gpus for actor_1, last 4 gpus for actor_2 actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(resource_pool=global_resource_pool, split_size=4) @@ -79,7 +79,7 @@ def test_split_resource_pool_with_split_size_list(): ray.init() # assume we have 4 nodes, with 2 GPUs each global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2]) - global_resource_pool.get_placement_groups(device_name=get_device_name()) + global_resource_pool.get_placement_group(device_name=get_device_name()) # first 2 gpus for actor_1, last 6 gpus for actor_2 actor_1_resource_pool, actor_2_resource_pool = split_resource_pool( @@ -113,7 +113,7 @@ def test_split_resource_pool_with_split_size_list_cross_nodes(): ray.init() # assume we have 4 nodes, with 2 GPUs each global_resource_pool = RayResourcePool(process_on_nodes=[4, 4]) - global_resource_pool.get_placement_groups(device_name=get_device_name()) + global_resource_pool.get_placement_group(device_name=get_device_name()) # first 2 gpus for actor_1, last 6 gpus for actor_2 actor_1_resource_pool, actor_2_resource_pool = split_resource_pool( @@ -149,7 +149,7 @@ def test_split_resource_pool_with_split_twice(): # assume we have 4 nodes, with 2 GPUs each global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2]) - global_resource_pool.get_placement_groups(device_name=get_device_name()) + global_resource_pool.get_placement_group(device_name=get_device_name()) # actors with [2, 1, 1, 1, 1, 2] (split twice) rp_1, rp_2, rp_3 = split_resource_pool( diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 2f657c08221..107bcf2b288 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -108,13 +108,22 @@ def __init__( self.use_gpu = use_gpu # print(f"in RayProcessDispatchConfiguration: name_prefix = {name_prefix}") self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix - self.pgs = None + self.pg = None self.detached = detached self.accelerator_type = accelerator_type - def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): - if self.pgs is not None: - return self.pgs + def get_placement_group(self, name=None, device_name="cuda") -> PlacementGroup: + """Create a placement group across all nodes. + + Args: + name: The name of the placement group. If None, a random name will be generated. + device_name: The device type to be used in the placement group. Defaults to "cuda". + + Returns: + The created placement group. + """ + if self.pg is not None: + return self.pg pg_name_prefix = ( name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" @@ -130,39 +139,42 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c bundle[device_name] = 1 if self.accelerator_type is not None: bundle[self.accelerator_type] = 1e-4 - pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store] + bundles = [bundle.copy() for process_count in self._store for _ in range(process_count)] lifetime = "detached" if self.detached else None + self.pg = placement_group(bundles=bundles, name=pg_name_prefix, lifetime=lifetime) + ray.get(self.pg.ready()) + return self.pg - pgs = [ - placement_group(bundles=bundles, strategy=strategy, name=pg_name_prefix + str(idx), lifetime=lifetime) - for idx, bundles in enumerate(pg_scheme) - ] - - ray.get([pg.ready() for pg in pgs]) - - self.pgs = sort_placement_group_by_node_ip(pgs) - return pgs + @property + def start_bundle_index(self): + """The start bundle index of the placement group.""" + return 0 class SubRayResourcePool(RayResourcePool): def __init__( self, - placement_groups: list[PlacementGroup], + placement_group: PlacementGroup, start_bundle_index: int, subgroup_world_size: int, **kwargs, ) -> None: super().__init__(**kwargs) - self.pgs = placement_groups - self.start_bundle_index = start_bundle_index - self.subgroup_world_size = subgroup_world_size + self.pg = placement_group + self._start_bundle_index = start_bundle_index + self._subgroup_world_size = subgroup_world_size @property def world_size(self): - return self.subgroup_world_size + return self._subgroup_world_size + + @property + def start_bundle_index(self): + return self._start_bundle_index +# TODO: not used? def extract_pg_from_exist( resource_pools: dict[str, RayResourcePool], src_role_names: list[str], resource_pool: RayResourcePool ) -> list: @@ -222,14 +234,14 @@ def split_resource_pool( start_bundle_idx_list = np.cumsum([0] + split_size_list[:-1]) # ensure resource_pool.pgs has been initialized - placement_groups = resource_pool.get_placement_groups() + placement_group = resource_pool.get_placement_group() split_resource_pools = [ SubRayResourcePool( process_on_nodes=resource_pool.store, use_gpu=resource_pool.use_gpu, name_prefix=f"{resource_pool.name_prefix}_split_{split_idx}", max_colocate_count=resource_pool.max_colocate_count, - placement_groups=placement_groups, + placement_group=placement_group, start_bundle_index=start_bundle_idx_list[split_idx], subgroup_world_size=split_size_list[split_idx], ) @@ -238,6 +250,7 @@ def split_resource_pool( return split_resource_pools +# TODO: not used? def merge_resource_pool(rp1: RayResourcePool, rp2: RayResourcePool) -> RayResourcePool: assert rp1.use_gpu == rp2.use_gpu, "Both RayResourcePool must either use_gpu or not" assert rp1.max_colocate_count == rp2.max_colocate_count, "Both RayResourcePool must has the same max_colocate_count" @@ -346,7 +359,7 @@ def __init__( self, resource_pool: RayResourcePool = None, ray_cls_with_init: RayClassWithInitArgs = None, - bin_pack: bool = True, + bin_pack: bool = False, name_prefix: str = None, detached=False, worker_names=None, @@ -390,14 +403,6 @@ def __init__( if self._is_init_with_detached_workers: self._init_with_detached_workers(worker_names=worker_names, worker_handles=worker_handles) - elif isinstance(resource_pool, SubRayResourcePool): - self._init_with_subresource_pool( - resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - bin_pack=bin_pack, - detached=detached, - worker_env=self.customized_worker_env, - ) else: self._init_with_resource_pool( resource_pool=resource_pool, @@ -434,13 +439,13 @@ def _init_with_detached_workers(self, worker_names, worker_handles): self._workers = workers self._world_size = len(worker_names) - def _get_master_addr_port(self, pg): + def _get_master_addr_port(self, pg, bundle_index): """Get master addr and port for this worker group""" if self._master_addr is None and self._master_port is None: self._master_addr, self._master_port = ray.get( get_master_addr_port.options( scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, placement_group_bundle_index=0 + placement_group=pg, placement_group_bundle_index=bundle_index ), ).remote() ) @@ -462,64 +467,21 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d detached: Whether workers should be detached """ self.resource_pool = resource_pool - - strategy = "PACK" - if bin_pack: - strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) + pg = resource_pool.get_placement_group(device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size - # cia.add_kwarg("_world_size", world_size) - - rank = -1 local_world_size = resource_pool.store[0] - for pg_idx, pg in enumerate(sort_placement_group_by_node_ip(pgs)): - assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " - if pg_idx == 0: - self._get_master_addr_port(pg) - - for local_rank in range(local_world_size): - rank += 1 - self._create_worker( - rank=rank, - pg_idx=pg_idx, - pg=pg, - local_rank=local_rank, - resource_pool=resource_pool, - ray_cls_with_init=ray_cls_with_init, - worker_env=worker_env, - detached=detached, - ) + start_bundle_index = resource_pool.start_bundle_index - def _init_with_subresource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached, worker_env=None): - """Initialize the worker group by creating new workers from a resource pool or sub resource pool. - Args: - resource_pool: Resource pool for worker allocation - ray_cls_with_init: Class with initialization arguments for workers - bin_pack: Whether to use strict bin packing for resource allocation - detached: Whether workers should be detached - """ - strategy = "PACK" - if bin_pack: - strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) - world_size = resource_pool.world_size - self._world_size = world_size + for rank in range(world_size): + if rank == 0: + self._get_master_addr_port(pg, bundle_index=start_bundle_index) - rank = -1 - local_world_size = resource_pool.store[0] - self._get_master_addr_port(pgs[0]) - for curr_rank in range(resource_pool.start_bundle_index, resource_pool.start_bundle_index + world_size): - pg_idx = curr_rank // local_world_size - pg = pgs[pg_idx] - local_rank = curr_rank % local_world_size - assert local_world_size <= pg.bundle_count, f"when generating for {self.name_prefix}, for the " - - rank += 1 + local_rank = rank % local_world_size self._create_worker( - rank=rank, - pg_idx=pg_idx, pg=pg, + bundle_index=start_bundle_index + rank, + rank=rank, local_rank=local_rank, resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, @@ -527,7 +489,9 @@ def _init_with_subresource_pool(self, resource_pool, ray_cls_with_init, bin_pack detached=detached, ) - def _create_worker(self, rank, pg_idx, pg, local_rank, resource_pool, ray_cls_with_init, worker_env, detached): + def _create_worker( + self, pg, bundle_index, rank, local_rank, resource_pool, ray_cls_with_init, worker_env, detached + ): world_size = resource_pool.world_size use_gpu = resource_pool.use_gpu local_world_size = resource_pool.store[0] @@ -558,6 +522,7 @@ def _create_worker(self, rank, pg_idx, pg, local_rank, resource_pool, ray_cls_wi cia_name = type(ray_cls_with_init.cls).__name__ match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" + pg_idx = rank // local_world_size name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 if self.profile_steps and self.device_name == "cuda": @@ -579,7 +544,7 @@ def _create_worker(self, rank, pg_idx, pg, local_rank, resource_pool, ray_cls_wi # create a worker worker = ray_cls_with_init( placement_group=pg, - placement_group_bundle_idx=local_rank, + placement_group_bundle_idx=bundle_index, use_gpu=use_gpu, num_gpus=num_gpus, device_name=self.device_name, From 4a71aabcc0df27c80697f1e8d2e0f3c851d64844 Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Mon, 12 Jan 2026 15:45:09 +0800 Subject: [PATCH 2/3] [ray] feat: implement IP-based worker sorting to fix ray placement group PACK strategy topology issue Signed-off-by: jianjunzhong --- .../vllm_rollout/vllm_async_server.py | 13 +- verl/experimental/reward_loop/reward_model.py | 16 +- verl/single_controller/base/__init__.py | 4 +- verl/single_controller/base/worker.py | 129 ++++++++++++++ verl/single_controller/ray/base.py | 167 +++++++++++++++++- verl/utils/device.py | 17 ++ verl/utils/net_utils.py | 12 ++ verl/workers/fsdp_workers.py | 1 - verl/workers/rollout/replica.py | 16 +- .../rollout/vllm_rollout/vllm_async_server.py | 18 +- 10 files changed, 360 insertions(+), 33 deletions(-) diff --git a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py index 4c3793275a7..90d50fcd984 100644 --- a/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py +++ b/verl/experimental/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -45,8 +45,19 @@ def __init__( node_rank: int, gpus_per_node: int, nnodes: int, + cuda_visible_devices: str, ): - super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes) + super().__init__( + config, + model_config, + rollout_mode, + workers, + replica_rank, + node_rank, + gpus_per_node, + nnodes, + cuda_visible_devices, + ) # for cancel LLMServer self.paused = False diff --git a/verl/experimental/reward_loop/reward_model.py b/verl/experimental/reward_loop/reward_model.py index 2bc05e1eea1..1a3477ae64a 100644 --- a/verl/experimental/reward_loop/reward_model.py +++ b/verl/experimental/reward_loop/reward_model.py @@ -16,7 +16,7 @@ import logging import os -from verl.single_controller.ray.base import RayResourcePool, split_resource_pool +from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup from verl.workers.config import HFModelConfig, RewardModelConfig from verl.workers.rollout.replica import get_rollout_replica_class @@ -75,14 +75,14 @@ def _initialize_llm_servers(self): for replica_rank in range(num_replicas) ] if self.resource_pool: - split_resource_pools = split_resource_pool(self.resource_pool, split_size=rollout_world_size) - assert len(split_resource_pools) == len(self.rollout_replicas) - self._run_all( - [ - server.init_colocated(resource_pool) - for server, resource_pool in zip(self.rollout_replicas, split_resource_pools, strict=True) - ] + ray_cls_with_init = self.rollout_replicas[0].get_ray_class_with_init_args() + worker_group = RayWorkerGroup( + resource_pool=self.resource_pool, + ray_cls_with_init=ray_cls_with_init, + bin_pack=False, + name_prefix="rollout_reward_colocate", ) + self._run_all([server.init_colocated(worker_group) for server in self.rollout_replicas]) else: self._run_all([server.init_standalone() for server in self.rollout_replicas]) self.server_handles = [server._server_handle for server in self.rollout_replicas] diff --git a/verl/single_controller/base/__init__.py b/verl/single_controller/base/__init__.py index b24bd9942b8..18189f6fe26 100644 --- a/verl/single_controller/base/__init__.py +++ b/verl/single_controller/base/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .worker import Worker +from .worker import TwoPhaseInitWorker, Worker, WorkerMeta from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup -__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] +__all__ = ["Worker", "TwoPhaseInitWorker", "WorkerMeta", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"] diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index cffaf5d30a0..659c7c7df28 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -31,6 +31,18 @@ from .decorator import Dispatch, Execute, register +@dataclass +class WorkerMeta: + """Metadata for a worker, used for sorting and environment setup.""" + + worker: "ray.actor.ActorHandle" + worker_name: str = "" + bundle_index: int = 0 + ip: str = "" + node_id: str = "" + gpu_ids: list = None + + @dataclass class DistRankInfo: tp_rank: int @@ -71,6 +83,123 @@ def get_availale_master_addr_port(self): def get_available_master_addr_port(self): return self._get_node_ip().strip("[]"), str(self._get_free_port()) + @staticmethod + def get_worker_info() -> tuple[str, str, list[int]]: + """Get the IP address, node ID and GPU IDs for this worker. + + Returns: + tuple: A tuple of (ip, node_id, gpu_ids) where: + - ip is the IP address of the node + - node_id is the Ray node ID + - gpu_ids is a list of GPU IDs assigned to this worker + """ + from verl.utils.device import get_ray_device_key + + ip = WorkerHelper._get_node_ip() + node_id = ray.get_runtime_context().get_node_id() + device_key = get_ray_device_key() + + try: + gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key] + except Exception as e: + raise RuntimeError( + f"Current platform does not support Ray accelerators. Device key: {device_key}. Error: {e}" + ) from e + + return ip, node_id, gpu_ids + + +class TwoPhaseInitWorker(WorkerHelper): + """Base class for all workers that provides two-phase initialization. + + This class provides environment setup and initialization methods that are + required by RayWorkerGroup for IP-based worker sorting and distributed training. + Workers that are needed to perform two-phase initialization should inherit from + this class. + + Two-phase initialization: + 1. __init__: Only store parameters, do NOT initialize heavy resources + 2. setup_worker_environment: Set up distributed environment (rank, devices, master info) + 3. init_worker: Called after environment setup, initialize resources here + """ + + def __init__(self) -> None: + """Initialize the worker base with environment setup flag.""" + self._environment_setup = False + + def setup_worker_environment( + self, rank: int, local_rank: int, visible_devices: str, master_addr: str, master_port: str + ): + """Setup the worker's distributed environment including rank, devices, and master info. + + This method is called after workers are sorted by IP to ensure correct + topology. It sets up all necessary environment variables for distributed training. + + Args: + rank: The new global rank for this worker. + local_rank: The new local rank for this worker. + visible_devices: The comma-separated visible device IDs for this worker. + master_addr: The IP address of the master node for distributed initialization. + master_port: The port number of the master node for distributed initialization. + """ + import os + + # Update rank environment variables + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(local_rank) + + # Update device visibility environment variable + device_keyword = get_visible_devices_keyword().upper() + os.environ[device_keyword] = visible_devices + + # Update master address environment variables + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = master_port + + # Mark that environment setup has been completed + self._environment_setup = True + + def init_worker(self): + """Initialize the worker after environment setup. + + This method should be overridden by subclasses to perform initialization + that depends on the correct rank assignment. The __init__ method should + only store passed parameters as class attributes, and actual initialization + (like torch.distributed.init_process_group) should be done in this method. + + This is called after workers are created and their environment is set up based + on IP sorting to ensure correct topology. + + Note: + Subclasses must call super().init_worker() at the beginning of their + implementation to ensure setup_worker_environment has been called. The subclass's + own initialization logic should follow after this call. + + Raises: + RuntimeError: If setup_worker_environment has not been called before this method. + """ + if not self.is_environment_setup(): + raise RuntimeError( + "init_worker() called before setup_worker_environment(). " + "setup_worker_environment() must be called first to ensure correct environment setup." + ) + + def set_environment_setup(self, setup: bool): + """Set the environment setup flag. + + Args: + setup: Whether environment setup has been completed. + """ + self._environment_setup = setup + + def is_environment_setup(self) -> bool: + """Check if environment setup has been completed. + + Returns: + bool: True if environment setup has been completed, False otherwise. + """ + return self._environment_setup + # we assume that in each WorkerGroup, there is a Master Worker class Worker(WorkerHelper): diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 107bcf2b288..d22ac069e7b 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -25,9 +25,17 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from verl.protocol import DataProto, _padding_size_key -from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base import ( + ClassWithInitArgs, + ResourcePool, + TwoPhaseInitWorker, + Worker, + WorkerGroup, + WorkerMeta, +) from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch -from verl.utils.device import get_device_name +from verl.utils.device import get_device_name, get_visible_devices_keyword +from verl.utils.net_utils import get_ip from verl.utils.py_functional import temp_env_var __all__ = ["Worker"] @@ -458,27 +466,92 @@ def _get_master_addr_port(self, pg, bundle_index): ) def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, detached, worker_env=None): - """Initialize the worker group by creating new workers from a resource pool. + """Initialize the worker group from a resource pool. + + This method wraps worker classes with TwoPhaseInitWorker if needed, creates workers, + sorts them by IP for correct topology, sets up distributed environment variables, + and calls init_worker for actual initialization. Args: resource_pool: Resource pool for worker allocation ray_cls_with_init: Class with initialization arguments for workers bin_pack: Whether to use strict bin packing for resource allocation detached: Whether workers should be detached + worker_env: Optional environment variables for workers """ + from collections import defaultdict + self.resource_pool = resource_pool pg = resource_pool.get_placement_group(device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size local_world_size = resource_pool.store[0] start_bundle_index = resource_pool.start_bundle_index + use_gpu = resource_pool.use_gpu + worker_cls = _unwrap_ray_remote(ray_cls_with_init.cls) + + # Step 1: Extracting the worker class and wrapping it with TwoPhaseInitWorker if needed + if use_gpu and not issubclass(worker_cls, TwoPhaseInitWorker): + # Create a dynamic wrapper class + original_worker_cls = worker_cls + + class TwoPhaseInitWrapper(original_worker_cls, TwoPhaseInitWorker): + def __init__(self, *args, **kwargs): + """Initialize the wrapper with deferred initialization.""" + TwoPhaseInitWorker.__init__(self) + self._init_args = args + self._init_kwargs = kwargs + + def setup_worker_environment( + self, rank: int, local_rank: int, visible_devices: str, master_addr: str, master_port: str + ): + """Setup worker environment and log environment variable changes.""" + # Store original environment variables + env_vars = ["RANK", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT"] + device_keyword = get_visible_devices_keyword().upper() + if device_keyword: + env_vars.append(device_keyword) + + original_env = {} + for var in env_vars: + original_env[var] = os.environ.get(var, "NOT_SET") + + # Call parent's setup_worker_environment + TwoPhaseInitWorker.setup_worker_environment( + self, rank, local_rank, visible_devices, master_addr, master_port + ) + # Store modified environment variables + modified_env = {} + for var in env_vars: + modified_env[var] = os.environ.get(var, "NOT_SET") + + # Log environment variable changes + log_str = f"Setting worker (rank={rank}) environment: " + for var in env_vars: + original = original_env[var] + modified = modified_env[var] + if original != modified: + log_str += f"{var}: {original} -> {modified} (CHANGED), " + else: + log_str += f"{var}: {modified} (unchanged), " + logger.info(log_str.removesuffix(", ")) + + def init_worker(self): + """Initialize the worker after rank adjustment.""" + TwoPhaseInitWorker.init_worker(self) + original_worker_cls.__init__(self, *self._init_args, **self._init_kwargs) + + ray_cls_with_init.cls = ray.remote(TwoPhaseInitWrapper) + + # Step 2: Create all workers + worker_meta_list: list[WorkerMeta] = [] for rank in range(world_size): if rank == 0: self._get_master_addr_port(pg, bundle_index=start_bundle_index) local_rank = rank % local_world_size - self._create_worker( + worker, worker_name = self._create_worker( pg=pg, bundle_index=start_bundle_index + rank, rank=rank, @@ -488,10 +561,92 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d worker_env=worker_env, detached=detached, ) + worker_meta_list.append( + WorkerMeta(worker=worker, worker_name=worker_name, bundle_index=start_bundle_index + rank) + ) + + if not use_gpu: + self._workers = [item.worker for item in worker_meta_list] + self._worker_names = [item.worker_name for item in worker_meta_list] + if issubclass(worker_cls, TwoPhaseInitWorker): + ray.get([worker.set_environment_setup.remote(True) for worker in self._workers]) + ray.get([worker.init_worker.remote() for worker in self._workers]) + return + + # Step 3: Collect worker metadata (IP, node_id, gpu_ids) + new_worker_cls = _unwrap_ray_remote(ray_cls_with_init.cls) + assert issubclass(new_worker_cls, TwoPhaseInitWorker), ( + f"Two-phase initialization only supports TwoPhaseInitWorker, but got {new_worker_cls}" + ) + + worker_infos = ray.get([meta.worker.get_worker_info.remote() for meta in worker_meta_list]) + for meta, (ip, node_id, gpu_ids) in zip(worker_meta_list, worker_infos, strict=True): + meta.ip = ip + meta.node_id = node_id + meta.gpu_ids = [int(x) for x in gpu_ids] + + # Step 4: Sort workers by IP for correct topology + ip_counts: dict[str, int] = {} + for ip, _, _ in worker_infos: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + driver_ip = get_ip() + + def sort_by_driver_then_worker_ip(item: WorkerMeta): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver, + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the worker is on a node with smaller IP address, it + should be placed first. + """ + ip = item.ip + return (0 if ip == driver_ip else 1, ip_counts[ip], ip) + + sorted_worker_meta = sorted(worker_meta_list, key=sort_by_driver_then_worker_ip) + + # Step 5: Build final workers list and regenerate master info + self._workers = [item.worker for item in sorted_worker_meta] + self._worker_names = [item.worker_name for item in sorted_worker_meta] + + self._master_addr, self._master_port = None, None + self._get_master_addr_port(pg, bundle_index=sorted_worker_meta[0].bundle_index) + + # Step 6: Build node_workers and node_gpus mappings + node_workers: dict[str, list[int]] = defaultdict(list) # node id -> list of worker ranks + node_gpus: dict[str, list[int]] = defaultdict(list) # node id -> list of gpu ids + + for rank, item in enumerate(sorted_worker_meta): + node_workers[item.node_id].append(rank) + node_gpus[item.node_id].extend(item.gpu_ids) + + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(set(gpu_ids)) + + # Step 7: Setup worker environment for each worker + setup_environment_futures = [] + for rank, item in enumerate(sorted_worker_meta): + local_rank = node_workers[item.node_id].index(rank) + visible_devices = str(node_gpus[item.node_id][local_rank]) + future = item.worker.setup_worker_environment.remote( + rank, 0, visible_devices, self._master_addr, self._master_port + ) + setup_environment_futures.append(future) + ray.get(setup_environment_futures) + + # Step 8: Initialize all workers + ray.get([item.worker.init_worker.remote() for item in sorted_worker_meta]) def _create_worker( self, pg, bundle_index, rank, local_rank, resource_pool, ray_cls_with_init, worker_env, detached ): + """Create a single worker actor. + + Returns: + tuple: (worker_handle, worker_name) + """ world_size = resource_pool.world_size use_gpu = resource_pool.use_gpu local_world_size = resource_pool.store[0] @@ -523,6 +678,7 @@ def _create_worker( match = re.search(r"ActorClass\(([^)]+)\)", cia_name) # ray.remote(Obj) -> "ActorClass(Obj)" cia_name = match.group(1) if match else cia_name # "ActorClass(Obj)" -> "Obj" pg_idx = rank // local_world_size + # TODO: worker_name may be inaccurate after rank is adjusted, need to fix name = f"{self.name_prefix}{cia_name}_{pg_idx}:{local_rank}" # e.g. Worker_2:5 if self.profile_steps and self.device_name == "cuda": @@ -549,8 +705,7 @@ def _create_worker( num_gpus=num_gpus, device_name=self.device_name, ) - self._workers.append(worker) - self._worker_names.append(name) + return worker, name @property def worker_names(self): diff --git a/verl/utils/device.py b/verl/utils/device.py index 882497df562..fbb7d2621c2 100644 --- a/verl/utils/device.py +++ b/verl/utils/device.py @@ -110,6 +110,23 @@ def get_nccl_backend() -> str: return "nccl" +def get_ray_device_key() -> str: + """Get the Ray accelerator key for the current device type. + + Returns the appropriate key for accessing accelerator IDs in Ray's + runtime context based on the detected accelerator type. + + Returns: + str: 'NPU' for Ascend NPU, 'GPU' for CUDA, empty string otherwise. + """ + if is_npu_available: + return "NPU" + elif is_cuda_available: + return "GPU" + else: + return "CPU" + + def set_expandable_segments(enable: bool) -> None: """Configure CUDA memory allocator expandable segments setting. diff --git a/verl/utils/net_utils.py b/verl/utils/net_utils.py index 1acef76a434..9d790216604 100644 --- a/verl/utils/net_utils.py +++ b/verl/utils/net_utils.py @@ -70,6 +70,18 @@ def is_valid_ipv6_address(address: str) -> bool: return False +def get_ip() -> str: + try: + # try to get ip from network interface + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception as e: # noqa: BLE001 + # fallback to get ip from hostname + print(f"fail to get ip from network interface, fallback to get ip from hostname: {e}") + return socket.gethostbyname(socket.gethostname()) + + def get_free_port(address: str) -> tuple[int, socket.socket]: family = socket.AF_INET if is_valid_ipv6_address(address): diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 3c0748b7531..631a8eb53a8 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -1217,7 +1217,6 @@ def __init__(self, config: FSDPCriticConfig): timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) - self.config: FSDPCriticConfig = config # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index 917135d9c78..cb82b551ef1 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -125,7 +125,7 @@ async def init_hybrid(self, worker_group: RayWorkerGroup): await self.launch_servers() # TODO(sgm): this should be the default solution, but need to make the RolloutMode more clear. - async def init_colocated(self, resource_pool: RayResourcePool): + async def init_colocated(self, worker_group: RayWorkerGroup): """Init colocated rollout server, rollout engine and hybrid engine colocated in same ray placement group but in separate processes. @@ -133,17 +133,9 @@ async def init_colocated(self, resource_pool: RayResourcePool): resource_pool: RayResourcePool, ray placement group where hybrid engine processes have been launched. """ self.rollout_mode = RolloutMode.COLOCATED - self.resource_pool = resource_pool - - worker_group = RayWorkerGroup( - resource_pool=self.resource_pool, - ray_cls_with_init=self.get_ray_class_with_init_args(), - bin_pack=False, - name_prefix=f"rollout_colocate_{self.replica_rank}" - if not self.is_reward_model - else f"rollout_reward_colocate_{self.replica_rank}", - ) - self.workers = worker_group.workers + self.workers = worker_group.workers[ + self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) + ] await self.launch_servers() async def init_standalone(self): diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 0437c0c23c3..ffe537f7ac3 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -46,6 +46,7 @@ from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_visible_devices_keyword from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput @@ -186,6 +187,7 @@ def __init__( node_rank: int, gpus_per_node: int, nnodes: int, + cuda_visible_devices: str, ): """ Args: @@ -196,8 +198,10 @@ def __init__( node_rank (int): node rank. gpus_per_node (int): number of gpus per node. nnodes (int): number of nodes. + cuda_visible_devices (str): cuda visible devices. """ super().__init__() + os.environ[get_visible_devices_keyword()] = cuda_visible_devices self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) @@ -695,13 +699,17 @@ async def launch_servers(self): f"worker number {len(self.workers)} not equal to world size {self.world_size}" ) - # get node_id of all workers - worker_node_ids = await asyncio.gather( + # get (node_id, CUDA_VISIBLE_DEVICES) of all workers + worker_infos = await asyncio.gather( *[ - worker.__ray_call__.remote(lambda self: ray.get_runtime_context().get_node_id()) + worker.__ray_call__.remote( + lambda self: (ray.get_runtime_context().get_node_id(), os.environ[get_visible_devices_keyword()]) + ) for worker in self.workers ] ) + worker_cuda_visible_devices = [worker_info[1] for worker_info in worker_infos] + worker_node_ids = [worker_info[0] for worker_info in worker_infos] # For non-data parallel case, there's only one server whether it's single or multi nodes. nnodes, gpus_per_node = self.nnodes, self.gpus_per_node @@ -712,6 +720,9 @@ async def launch_servers(self): # create server actor in each node with node affinity for node_rank in range(nnodes): workers = self.workers[node_rank * gpus_per_node : (node_rank + 1) * gpus_per_node] + node_cuda_visible_devices = ",".join( + worker_cuda_visible_devices[node_rank * self.gpus_per_node : (node_rank + 1) * self.gpus_per_node] + ) node_id = worker_node_ids[node_rank * gpus_per_node] name = ( f"vllm_server_{self.replica_rank}_{node_rank}" @@ -734,6 +745,7 @@ async def launch_servers(self): node_rank=node_rank, gpus_per_node=gpus_per_node, nnodes=nnodes, + cuda_visible_devices=node_cuda_visible_devices, ) self.servers.append(server) From 659ed0afee85d7e81d48e569f6a9e7fca47a88ee Mon Sep 17 00:00:00 2001 From: jianjunzhong Date: Tue, 13 Jan 2026 15:05:35 +0800 Subject: [PATCH 3/3] fix ci Signed-off-by: jianjunzhong --- verl/single_controller/ray/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index d22ac069e7b..30d94690ea4 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -542,6 +542,7 @@ def init_worker(self): TwoPhaseInitWorker.init_worker(self) original_worker_cls.__init__(self, *self._init_args, **self._init_kwargs) + ray_cls_with_init = deepcopy(ray_cls_with_init) ray_cls_with_init.cls = ray.remote(TwoPhaseInitWrapper) # Step 2: Create all workers