Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions tests/single_controller/test_high_level_scheduling_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import ray

from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, split_resource_pool
from verl.utils.device import get_device_name


Expand Down Expand Up @@ -68,15 +68,14 @@ def test():
del ref_wg
gc.collect() # make sure ray actors are deleted

[ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
ray.util.remove_placement_group(resource_pool.get_placement_group())
print("wait 5s to remove placemeng_group")
time.sleep(5)
# test single-node-multi-partition

print("test single-node-multi-partition")
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
total_resource_pool = RayResourcePool([8], use_gpu=True, name_prefix="ref")
rm_resource_pool, ref_resource_pool = split_resource_pool(total_resource_pool, split_size=4)

assert rm_resource_pool.world_size == 4
assert ref_resource_pool.world_size == 4
Expand All @@ -101,3 +100,25 @@ def test():
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]

ray.shutdown()


def test_multi_nodes():
ray.init()
class_with_args = RayClassWithInitArgs(cls=TestActor)
resource_pool = RayResourcePool([4, 4])
assert resource_pool.world_size == 8

# actor worker group
actor_wg = RayWorkerGroup(resource_pool, class_with_args)
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]

# split resource pool for rollout (world_size=2)
rollout_pools = split_resource_pool(resource_pool, split_size=2)
assert len(rollout_pools) == 4
for idx, rollout_pool in enumerate(rollout_pools):
assert rollout_pool.world_size == 2
assert rollout_pool.start_bundle_index == idx * 2
rollout_wg = RayWorkerGroup(rollout_pool, class_with_args)
assert rollout_wg.execute_all_sync("get_cuda_visible_devices") == [str(idx * 2 + i) for i in range(2)]

ray.shutdown()
8 changes: 4 additions & 4 deletions tests/single_controller/test_split_resource_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_split_resource_pool_with_split_size():
ray.init()
# assume we have 2 nodes, with 4 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
global_resource_pool.get_placement_groups(device_name=get_device_name())
global_resource_pool.get_placement_group(device_name=get_device_name())

# first 4 gpus for actor_1, last 4 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(resource_pool=global_resource_pool, split_size=4)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_split_resource_pool_with_split_size_list():
ray.init()
# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
global_resource_pool.get_placement_groups(device_name=get_device_name())
global_resource_pool.get_placement_group(device_name=get_device_name())

# first 2 gpus for actor_1, last 6 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_split_resource_pool_with_split_size_list_cross_nodes():
ray.init()
# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
global_resource_pool.get_placement_groups(device_name=get_device_name())
global_resource_pool.get_placement_group(device_name=get_device_name())

# first 2 gpus for actor_1, last 6 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_split_resource_pool_with_split_twice():

# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
global_resource_pool.get_placement_groups(device_name=get_device_name())
global_resource_pool.get_placement_group(device_name=get_device_name())

# actors with [2, 1, 1, 1, 1, 2] (split twice)
rp_1, rp_2, rp_3 = split_resource_pool(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,19 @@ def __init__(
node_rank: int,
gpus_per_node: int,
nnodes: int,
cuda_visible_devices: str,
):
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
super().__init__(
config,
model_config,
rollout_mode,
workers,
replica_rank,
node_rank,
gpus_per_node,
nnodes,
cuda_visible_devices,
)

# for cancel LLMServer
self.paused = False
Expand Down
16 changes: 8 additions & 8 deletions verl/experimental/reward_loop/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import os

from verl.single_controller.ray.base import RayResourcePool, split_resource_pool
from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup
from verl.workers.config import HFModelConfig, RewardModelConfig
from verl.workers.rollout.replica import get_rollout_replica_class

Expand Down Expand Up @@ -75,14 +75,14 @@ def _initialize_llm_servers(self):
for replica_rank in range(num_replicas)
]
if self.resource_pool:
split_resource_pools = split_resource_pool(self.resource_pool, split_size=rollout_world_size)
assert len(split_resource_pools) == len(self.rollout_replicas)
self._run_all(
[
server.init_colocated(resource_pool)
for server, resource_pool in zip(self.rollout_replicas, split_resource_pools, strict=True)
]
ray_cls_with_init = self.rollout_replicas[0].get_ray_class_with_init_args()
worker_group = RayWorkerGroup(
resource_pool=self.resource_pool,
ray_cls_with_init=ray_cls_with_init,
bin_pack=False,
name_prefix="rollout_reward_colocate",
)
self._run_all([server.init_colocated(worker_group) for server in self.rollout_replicas])
else:
self._run_all([server.init_standalone() for server in self.rollout_replicas])
self.server_handles = [server._server_handle for server in self.rollout_replicas]
Expand Down
4 changes: 2 additions & 2 deletions verl/single_controller/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .worker import Worker
from .worker import TwoPhaseInitWorker, Worker, WorkerMeta
from .worker_group import ClassWithInitArgs, ResourcePool, WorkerGroup

__all__ = ["Worker", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"]
__all__ = ["Worker", "TwoPhaseInitWorker", "WorkerMeta", "WorkerGroup", "ClassWithInitArgs", "ResourcePool"]
129 changes: 129 additions & 0 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@
from .decorator import Dispatch, Execute, register


@dataclass
class WorkerMeta:
"""Metadata for a worker, used for sorting and environment setup."""

worker: "ray.actor.ActorHandle"
worker_name: str = ""
bundle_index: int = 0
ip: str = ""
node_id: str = ""
gpu_ids: list = None


@dataclass
class DistRankInfo:
tp_rank: int
Expand Down Expand Up @@ -71,6 +83,123 @@ def get_availale_master_addr_port(self):
def get_available_master_addr_port(self):
return self._get_node_ip().strip("[]"), str(self._get_free_port())

@staticmethod
def get_worker_info() -> tuple[str, str, list[int]]:
"""Get the IP address, node ID and GPU IDs for this worker.

Returns:
tuple: A tuple of (ip, node_id, gpu_ids) where:
- ip is the IP address of the node
- node_id is the Ray node ID
- gpu_ids is a list of GPU IDs assigned to this worker
"""
from verl.utils.device import get_ray_device_key

ip = WorkerHelper._get_node_ip()
node_id = ray.get_runtime_context().get_node_id()
device_key = get_ray_device_key()

try:
gpu_ids = ray.get_runtime_context().get_accelerator_ids()[device_key]
except Exception as e:
raise RuntimeError(
f"Current platform does not support Ray accelerators. Device key: {device_key}. Error: {e}"
) from e

return ip, node_id, gpu_ids


class TwoPhaseInitWorker(WorkerHelper):
"""Base class for all workers that provides two-phase initialization.

This class provides environment setup and initialization methods that are
required by RayWorkerGroup for IP-based worker sorting and distributed training.
Workers that are needed to perform two-phase initialization should inherit from
this class.

Two-phase initialization:
1. __init__: Only store parameters, do NOT initialize heavy resources
2. setup_worker_environment: Set up distributed environment (rank, devices, master info)
3. init_worker: Called after environment setup, initialize resources here
"""

def __init__(self) -> None:
"""Initialize the worker base with environment setup flag."""
self._environment_setup = False

def setup_worker_environment(
self, rank: int, local_rank: int, visible_devices: str, master_addr: str, master_port: str
):
"""Setup the worker's distributed environment including rank, devices, and master info.

This method is called after workers are sorted by IP to ensure correct
topology. It sets up all necessary environment variables for distributed training.

Args:
rank: The new global rank for this worker.
local_rank: The new local rank for this worker.
visible_devices: The comma-separated visible device IDs for this worker.
master_addr: The IP address of the master node for distributed initialization.
master_port: The port number of the master node for distributed initialization.
"""
import os

# Update rank environment variables
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(local_rank)

# Update device visibility environment variable
device_keyword = get_visible_devices_keyword().upper()
os.environ[device_keyword] = visible_devices

# Update master address environment variables
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

# Mark that environment setup has been completed
self._environment_setup = True

def init_worker(self):
"""Initialize the worker after environment setup.

This method should be overridden by subclasses to perform initialization
that depends on the correct rank assignment. The __init__ method should
only store passed parameters as class attributes, and actual initialization
(like torch.distributed.init_process_group) should be done in this method.

This is called after workers are created and their environment is set up based
on IP sorting to ensure correct topology.

Note:
Subclasses must call super().init_worker() at the beginning of their
implementation to ensure setup_worker_environment has been called. The subclass's
own initialization logic should follow after this call.

Raises:
RuntimeError: If setup_worker_environment has not been called before this method.
"""
if not self.is_environment_setup():
raise RuntimeError(
"init_worker() called before setup_worker_environment(). "
"setup_worker_environment() must be called first to ensure correct environment setup."
)

def set_environment_setup(self, setup: bool):
"""Set the environment setup flag.

Args:
setup: Whether environment setup has been completed.
"""
self._environment_setup = setup

def is_environment_setup(self) -> bool:
"""Check if environment setup has been completed.

Returns:
bool: True if environment setup has been completed, False otherwise.
"""
return self._environment_setup


# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
Expand Down
Loading