Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/gfn/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, cast

import mpi4py.MPI as MPI
import torch
import torch.distributed as dist

Expand Down Expand Up @@ -122,6 +123,158 @@ def average_models(model, training_group=None):
param.data = param_tensor / world_size


@dataclass
class DistributedContextmpi4py:
"""Holds all distributed training/replay buffer groups and ranks."""

my_rank: int
world_size: int
num_training_ranks: int
agent_group_size: int
agent_groups: Optional[List[MPI.Comm]] = None
agent_group_id: Optional[int] = None
train_global_group: MPI.Comm = MPI.COMM_WORLD
assigned_buffer: Optional[int] = None
buffer_group: Optional[MPI.Comm] = None
assigned_training_ranks: Optional[List[int]] = None

def is_buffer_rank(self) -> bool:
"""Check if the current rank is part of the buffer group."""
return self.my_rank >= self.num_training_ranks

def is_training_rank(self) -> bool:
"""Check if the current rank is part of the training group."""
return self.my_rank < self.num_training_ranks


def initialize_distributed_compute_mpi4py(
num_remote_buffers: int,
num_agent_groups: int,
) -> DistributedContextmpi4py:
"""Initalizes distributed compute using either ccl or mpi backends."""
"""
Initalizes distributed compute using either ccl or mpi backends.
Args:
dist_backend: The backend to use for distributed compute.
num_remote_buffers: The number of remote buffers to use.
"""

pmi_size = MPI.COMM_WORLD.Get_size()
print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size))

if pmi_size <= 1:
print("+ PMI_SIZE <= 1, running in single process mode.")
return DistributedContextmpi4py(
my_rank=0, world_size=1, num_training_ranks=1, agent_group_size=1
)

os.environ["RANK"] = str(MPI.COMM_WORLD.Get_rank())
os.environ["WORLD_SIZE"] = str(pmi_size)

print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS"))

world_size = MPI.COMM_WORLD.Get_size()
if world_size is None:
raise ValueError("WORLD_SIZE is not set")
rank = MPI.COMM_WORLD.Get_rank()
if rank is None:
raise ValueError("RANK is not set")

dist.barrier()
print("+ Distributed compute initialized")

my_rank = rank # dist.get_rank() # Global!
# world_size = dist.get_world_size() # Global!

num_training_ranks = world_size - num_remote_buffers

# make sure that we have atmost 1 remote buffer per training rank.
assert num_training_ranks >= num_remote_buffers
print("num_train = ", num_training_ranks)
print("num_remote_buffers = ", num_remote_buffers)

# for now, let us enforce that each agent gets equal number of ranks.
# TODO: later, we can relax this condition.
assert num_training_ranks % num_agent_groups == 0
agent_group_size = num_training_ranks // num_agent_groups
agent_group_rank_list = [
list(range(i * agent_group_size, (i + 1) * agent_group_size))
for i in range(num_agent_groups)
]
print(f"Agent group ranks: {agent_group_rank_list}")
world_group = MPI.COMM_WORLD.Get_group()
agent_group_list = []
for i in range(num_agent_groups):
grp = world_group.Incl(agent_group_rank_list[i])
agent_group_list.append(MPI.COMM_WORLD.Create(grp))

# all training ranks in one global group
training_ranks = [
r for r in range(num_training_ranks)
] # e.g., 0..num_training_ranks-1

# train_global_group = dist.new_group(
# ranks=training_ranks,
# backend=dist_backend,
# timeout=datetime.timedelta(minutes=5),
# )
grp = world_group.Incl(training_ranks)
train_global_group = MPI.COMM_WORLD.Create(grp)
# print(f"Training global group ranks: {training_ranks}, {train_global_group}")
# assert train_global_group != MPI.COMM_NULL

buffer_group = None
assigned_buffer = None
assigned_training_ranks = {}
if num_remote_buffers > 0:
buffer_ranks = list(
range(num_training_ranks, num_training_ranks + num_remote_buffers)
)
# buffer_group = dist.new_group(
# buffer_ranks,
# backend=dist_backend,
# timeout=datetime.timedelta(minutes=5),
# )
grp = world_group.Incl(buffer_ranks)
buffer_group = MPI.COMM_WORLD.Create(grp)
print(f" >>>>>>>>>>>>>>>>. Buffer group ranks: {buffer_ranks}, {buffer_group}")
# assert buffer_group != MPI.COMM_NULL

print(f"Buffer group ranks: {buffer_ranks}")

# Each training rank gets assigned to a buffer rank
if my_rank < (num_training_ranks):
assigned_buffer = num_training_ranks + (my_rank % num_remote_buffers)
else:
assigned_training_ranks[my_rank] = [
ranks
for ranks in range(num_training_ranks)
if (ranks % num_remote_buffers) == (my_rank - num_training_ranks)
]

print(f"+ My rank: {my_rank} size: {world_size}")
if my_rank < (num_training_ranks):
print(f" -> Training group, assigned buffer rank = {assigned_buffer}")
else:
print(" -> Buffer group")

# dist.barrier()
print("+ Distributed compute initialized, rank = ", my_rank)

return DistributedContextmpi4py(
my_rank=my_rank,
world_size=world_size,
num_training_ranks=num_training_ranks,
agent_group_size=agent_group_size,
agent_groups=agent_group_list,
agent_group_id=my_rank // agent_group_size,
train_global_group=train_global_group,
assigned_buffer=assigned_buffer,
buffer_group=buffer_group,
assigned_training_ranks=assigned_training_ranks.get(my_rank, None),
)


@dataclass
class DistributedContext:
"""Holds all distributed training/replay buffer groups and ranks."""
Expand All @@ -136,6 +289,7 @@ class DistributedContext:
assigned_buffer: Optional[int] = None
buffer_group: Optional[dist.ProcessGroup] = None
assigned_training_ranks: Optional[List[int]] = None
dc_mpi4py: Optional[DistributedContextmpi4py] = None

def is_buffer_rank(self) -> bool:
"""Check if the current rank is part of the buffer group."""
Expand All @@ -145,6 +299,19 @@ def is_training_rank(self) -> bool:
"""Check if the current rank is part of the training group."""
return self.my_rank < self.num_training_ranks

def cleanup(self) -> None:
"""Cleans up the distributed process group."""
dist.destroy_process_group()
if self.dc_mpi4py is not None:
if self.dc_mpi4py.train_global_group is not None:
self.dc_mpi4py.train_global_group.Free()
if self.dc_mpi4py.buffer_group is not None:
self.dc_mpi4py.buffer_group.Free()
if self.dc_mpi4py.agent_groups is not None:
for ag in self.dc_mpi4py.agent_groups:
ag.Free()
MPI.Finalize()


def initialize_distributed_compute(
dist_backend: str,
Expand Down Expand Up @@ -302,6 +469,11 @@ def initialize_distributed_compute(
dist.barrier()
logger.info("Distributed compute initialized, rank = %d", my_rank)

dc = initialize_distributed_compute_mpi4py(
num_remote_buffers=num_remote_buffers,
num_agent_groups=num_agent_groups,
)

return DistributedContext(
my_rank=my_rank,
world_size=world_size,
Expand All @@ -313,6 +485,7 @@ def initialize_distributed_compute(
assigned_buffer=assigned_buffer,
buffer_group=buffer_group,
assigned_training_ranks=assigned_training_ranks.get(my_rank, None),
dc_mpi4py=dc,
)


Expand Down
Loading
Loading