diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 7e109775..a539277d 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional, cast +import mpi4py.MPI as MPI import torch import torch.distributed as dist @@ -122,6 +123,158 @@ def average_models(model, training_group=None): param.data = param_tensor / world_size +@dataclass +class DistributedContextmpi4py: + """Holds all distributed training/replay buffer groups and ranks.""" + + my_rank: int + world_size: int + num_training_ranks: int + agent_group_size: int + agent_groups: Optional[List[MPI.Comm]] = None + agent_group_id: Optional[int] = None + train_global_group: MPI.Comm = MPI.COMM_WORLD + assigned_buffer: Optional[int] = None + buffer_group: Optional[MPI.Comm] = None + assigned_training_ranks: Optional[List[int]] = None + + def is_buffer_rank(self) -> bool: + """Check if the current rank is part of the buffer group.""" + return self.my_rank >= self.num_training_ranks + + def is_training_rank(self) -> bool: + """Check if the current rank is part of the training group.""" + return self.my_rank < self.num_training_ranks + + +def initialize_distributed_compute_mpi4py( + num_remote_buffers: int, + num_agent_groups: int, +) -> DistributedContextmpi4py: + """Initalizes distributed compute using either ccl or mpi backends.""" + """ + Initalizes distributed compute using either ccl or mpi backends. + Args: + dist_backend: The backend to use for distributed compute. + num_remote_buffers: The number of remote buffers to use. + """ + + pmi_size = MPI.COMM_WORLD.Get_size() + print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) + + if pmi_size <= 1: + print("+ PMI_SIZE <= 1, running in single process mode.") + return DistributedContextmpi4py( + my_rank=0, world_size=1, num_training_ranks=1, agent_group_size=1 + ) + + os.environ["RANK"] = str(MPI.COMM_WORLD.Get_rank()) + os.environ["WORLD_SIZE"] = str(pmi_size) + + print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS")) + + world_size = MPI.COMM_WORLD.Get_size() + if world_size is None: + raise ValueError("WORLD_SIZE is not set") + rank = MPI.COMM_WORLD.Get_rank() + if rank is None: + raise ValueError("RANK is not set") + + dist.barrier() + print("+ Distributed compute initialized") + + my_rank = rank # dist.get_rank() # Global! + # world_size = dist.get_world_size() # Global! + + num_training_ranks = world_size - num_remote_buffers + + # make sure that we have atmost 1 remote buffer per training rank. + assert num_training_ranks >= num_remote_buffers + print("num_train = ", num_training_ranks) + print("num_remote_buffers = ", num_remote_buffers) + + # for now, let us enforce that each agent gets equal number of ranks. + # TODO: later, we can relax this condition. + assert num_training_ranks % num_agent_groups == 0 + agent_group_size = num_training_ranks // num_agent_groups + agent_group_rank_list = [ + list(range(i * agent_group_size, (i + 1) * agent_group_size)) + for i in range(num_agent_groups) + ] + print(f"Agent group ranks: {agent_group_rank_list}") + world_group = MPI.COMM_WORLD.Get_group() + agent_group_list = [] + for i in range(num_agent_groups): + grp = world_group.Incl(agent_group_rank_list[i]) + agent_group_list.append(MPI.COMM_WORLD.Create(grp)) + + # all training ranks in one global group + training_ranks = [ + r for r in range(num_training_ranks) + ] # e.g., 0..num_training_ranks-1 + + # train_global_group = dist.new_group( + # ranks=training_ranks, + # backend=dist_backend, + # timeout=datetime.timedelta(minutes=5), + # ) + grp = world_group.Incl(training_ranks) + train_global_group = MPI.COMM_WORLD.Create(grp) + # print(f"Training global group ranks: {training_ranks}, {train_global_group}") + # assert train_global_group != MPI.COMM_NULL + + buffer_group = None + assigned_buffer = None + assigned_training_ranks = {} + if num_remote_buffers > 0: + buffer_ranks = list( + range(num_training_ranks, num_training_ranks + num_remote_buffers) + ) + # buffer_group = dist.new_group( + # buffer_ranks, + # backend=dist_backend, + # timeout=datetime.timedelta(minutes=5), + # ) + grp = world_group.Incl(buffer_ranks) + buffer_group = MPI.COMM_WORLD.Create(grp) + print(f" >>>>>>>>>>>>>>>>. Buffer group ranks: {buffer_ranks}, {buffer_group}") + # assert buffer_group != MPI.COMM_NULL + + print(f"Buffer group ranks: {buffer_ranks}") + + # Each training rank gets assigned to a buffer rank + if my_rank < (num_training_ranks): + assigned_buffer = num_training_ranks + (my_rank % num_remote_buffers) + else: + assigned_training_ranks[my_rank] = [ + ranks + for ranks in range(num_training_ranks) + if (ranks % num_remote_buffers) == (my_rank - num_training_ranks) + ] + + print(f"+ My rank: {my_rank} size: {world_size}") + if my_rank < (num_training_ranks): + print(f" -> Training group, assigned buffer rank = {assigned_buffer}") + else: + print(" -> Buffer group") + + # dist.barrier() + print("+ Distributed compute initialized, rank = ", my_rank) + + return DistributedContextmpi4py( + my_rank=my_rank, + world_size=world_size, + num_training_ranks=num_training_ranks, + agent_group_size=agent_group_size, + agent_groups=agent_group_list, + agent_group_id=my_rank // agent_group_size, + train_global_group=train_global_group, + assigned_buffer=assigned_buffer, + buffer_group=buffer_group, + assigned_training_ranks=assigned_training_ranks.get(my_rank, None), + ) + + @dataclass class DistributedContext: """Holds all distributed training/replay buffer groups and ranks.""" @@ -136,6 +289,7 @@ class DistributedContext: assigned_buffer: Optional[int] = None buffer_group: Optional[dist.ProcessGroup] = None assigned_training_ranks: Optional[List[int]] = None + dc_mpi4py: Optional[DistributedContextmpi4py] = None def is_buffer_rank(self) -> bool: """Check if the current rank is part of the buffer group.""" @@ -145,6 +299,19 @@ def is_training_rank(self) -> bool: """Check if the current rank is part of the training group.""" return self.my_rank < self.num_training_ranks + def cleanup(self) -> None: + """Cleans up the distributed process group.""" + dist.destroy_process_group() + if self.dc_mpi4py is not None: + if self.dc_mpi4py.train_global_group is not None: + self.dc_mpi4py.train_global_group.Free() + if self.dc_mpi4py.buffer_group is not None: + self.dc_mpi4py.buffer_group.Free() + if self.dc_mpi4py.agent_groups is not None: + for ag in self.dc_mpi4py.agent_groups: + ag.Free() + MPI.Finalize() + def initialize_distributed_compute( dist_backend: str, @@ -302,6 +469,11 @@ def initialize_distributed_compute( dist.barrier() logger.info("Distributed compute initialized, rank = %d", my_rank) + dc = initialize_distributed_compute_mpi4py( + num_remote_buffers=num_remote_buffers, + num_agent_groups=num_agent_groups, + ) + return DistributedContext( my_rank=my_rank, world_size=world_size, @@ -313,6 +485,7 @@ def initialize_distributed_compute( assigned_buffer=assigned_buffer, buffer_group=buffer_group, assigned_training_ranks=assigned_training_ranks.get(my_rank, None), + dc_mpi4py=dc, ) diff --git a/tutorials/examples/multinode/spawn_policy.py b/tutorials/examples/multinode/spawn_policy.py index c069cd3a..bec22ffb 100644 --- a/tutorials/examples/multinode/spawn_policy.py +++ b/tutorials/examples/multinode/spawn_policy.py @@ -1,15 +1,20 @@ from __future__ import annotations +import json import logging +import random import threading import time from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple, cast +import mpi4py.MPI as MPI +import numpy as np import torch import torch.distributed as dist from gfn.gflownet.base import GFlowNet +from gfn.utils.common import Timer logger = logging.getLogger(__name__) @@ -90,6 +95,7 @@ def __init__( poll_interval_s: float = 0.01, threshold: Optional[float] = None, cooldown: int = 200, + timing: Optional[dict] = None, # timing is a dict to capture timing info ) -> None: super().__init__(average_every) self.replacement_ratio = float(replacement_ratio) @@ -603,3 +609,704 @@ def _compute_averaging_weights( ) weights = 1.0 / (contributing_metrics + 1e-8) return weights / weights.sum() + + +""" +Selective averaging based on async mpi-3 comms + Version 1 (fast-general): One buffer with all the params, hence utilizes the network BW much better than general version. + Hence different model param dtypes can be handled. + Comments: + - Need a fix where there is one buffer for all params but in byte format, + after comms, the buffer can be converted to appropriate dtypes for corr params dtypes. + - Right now each param has its own buffer and window, which may produce incorrect output. + Notes: + - Only mean averaging strategy is implemented here for now. + - Neighbors are randomly selected from all ranks except self. + - Agents are killed based on age only for now. +""" + + +class AsyncSelectiveAveragingPolicympi4pyGeneral(SpawnPolicy): + r""" + Asynchronous selective averaging version 2, uses mpi one-sided comms to get the + selectively averaged parameters from a random set of ranks. + """ + + def __init__( + self, + model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], + model: GFlowNet, + average_every: int, + threshold_metric: float = 0.0, + replacement_ratio: float = 0.2, + averaging_strategy: str = "mean", + momentum: float = 0.0, + poll_interval_s: float = 0.01, + age_range: Tuple[int, int] = (50, 150), + group: MPI.Comm = MPI.COMM_WORLD, + ) -> None: + super().__init__(average_every) + self.myrank = group.Get_rank() + self.comm_size = group.Get_size() + + self.replacement_ratio = float(replacement_ratio) + self.averaging_strategy = str(averaging_strategy) + self.momentum = float(momentum) + self._model_builder = model_builder + + self._model: Optional[GFlowNet] = None + self.threshold_metric = float(threshold_metric) + + # timers + self.timing = {} + self.stats = {} + + self._model = model + self.train_comm_group = group + self._expose = False + + # new agents' stats + self.agents_killed = 0 + self.averaging_ranks = 0 + self._count = 0 + + self.debug_mode = False + self.age = 0 + self.age_range = age_range + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + + if self.debug_mode: + self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" + with open(self.logfile, "w") as f: + f.write(f"Selective Averaging Log for Rank {self.myrank}\n") + f.write("=" * 50 + "\n") + + def shutdown(self) -> None: + for _, v in self._mpi_tensor_wins.items(): + v[0].Free() + + def print_time(self) -> None: + print("Selective Averaging timings:", flush=True) + for k, v in self.timing.items(): + # here v is a list, avg over the list + # avg = sum(v) / len(v) if len(v) > 0 else 0 + print(f"{k:<35}: {sum(v):>10.4f} seconds") + + def print_stats(self) -> None: + print("Selective Averaging comms stats:", flush=True) + avg_donors, num_calls = 0.0, 0 + for k, v in self.stats.items(): + # v is a list, print min, max ,avg, and len of it + minimum = min(v) if len(v) > 0 else 0 + maximum = max(v) if len(v) > 0 else 0 + avg = sum(v) / len(v) if len(v) > 0 else 0 + length = len(v) + print( + f"Rank {self.myrank} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, count={length}" + ) + if k == "donors": + avg_donors = avg + num_calls = length + + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] + print( + f"Rank {self.myrank:<10} - {'param elements':<15} {'iter':<10} {'total params elements commd':<25}" + ) + for name, param in named_params: + param.device + param_shape = param.data.shape + print( + f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}" + ) + + def capture_comm(self, name: str, size: int) -> None: + if name not in self.stats: + self.stats[name] = [] + self.stats[name].append(size) + + def _ensure_initialized(self, model: GFlowNet) -> None: + self._model = model + self._initialized = True + # export model parameters to mpi windows (should do that periodically to keep them fresh) + self._expose_model_parameters(model) + + def reset_age(self) -> None: + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + self.age = 0 + + def is_agent_dying( + self, local_metric: float, threshold_metric: float, check_agent=0 + ) -> bool: + if check_agent == 0: # static theshold + return local_metric < threshold_metric + + elif check_agent == 1: # dynamic threshold based on age + if self.age >= self.max_age: + print( + "+ Agent killed due to age: ", + self.age, + " max_age: ", + self.max_age, + flush=True, + ) + self.reset_age() + return True + + self.age += 1 + return False + + else: + raise ValueError(f"Unknown is_agent_dying version: {check_agent}") + + # Execute this function regularly to copy model params to mpi windows + # for recepotrs to get recent params + def _copy_model_params_to_buf( + self, + model: GFlowNet, + ) -> None: + for name, param in model.named_parameters(): + win = self._mpi_tensor_wins[name][0] + win.Lock(rank=self.myrank, lock_type=MPI.LOCK_EXCLUSIVE) + self._mpi_tensor_wins[name][1][:] = param.data.cpu().numpy().flatten() + win.Unlock(rank=self.myrank) + + def _expose_model_parameters(self, model: GFlowNet) -> None: + + # Serialize model parameters to a contiguous numpy array + param_tensors = {} + for name, param in model.named_parameters(): + param_tensors[name] = np.zeros_like(param.data.cpu().numpy().flatten()) + + # Create MPI windows for each parameter and its shape (2 separate windows set) + self._mpi_tensor_wins = {} + self._mpi_shape_wins = {} + assert isinstance(self.train_comm_group, MPI.Intracomm) + comm = cast(MPI.Intracomm, self.train_comm_group) + for name, tensor in param_tensors.items(): + buf = tensor + win = MPI.Win.Create(buf, comm=comm) + self._mpi_tensor_wins[name] = (win, buf) + + self._copy_model_params_to_buf(model) + + def _get_donors(self, n, k, d) -> List[int]: + if k > n - 1: + raise ValueError("k must be ≤ n-1 when excluding one value") + + # All values from 0..n-1 except d + candidates = [x for x in range(n) if x != d] + # Pick k distinct values + return random.sample(candidates, k) + + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + + if self._expose is False: + self._expose_model_parameters(model) + self._expose = True + + self._count += 1 + self._model = model + + check_agent = 1 # version of dying agent check + # validation info + layer_name = None + if self.debug_mode: + layer_name = "pb.module.last_layer.bias" + + if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): + with Timer(self.timing, "sa get_params_from_donors"): + if self.debug_mode: + with open(self.logfile, "a") as f: + # kill this model and rebuild model with fresh weights + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # * self.replacement_ratio)) + donors = self._get_donors( + self.comm_size, num_donors, self.myrank + ) + new_avg_params = self._get_model_params_from_donors( + donors, layer_name, f + ) + + if layer_name is not None: + json.dump( + { + self._count: [ + self._mpi_tensor_wins[layer_name][1].tolist(), + donors, + new_avg_params[layer_name].tolist(), + ] + }, + f, + ) + f.write("\n") + else: + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # self.replacement_ratio)) + donors = self._get_donors(self.comm_size, num_donors, self.myrank) + new_avg_params = self._get_model_params_from_donors( + donors, layer_name, None + ) + + with Timer(self.timing, "sa new_agent_model_rebuild"): + model, optimizer = self._model_builder() + for name, param in model.named_parameters(): + if name in new_avg_params: + param.data.copy_(new_avg_params[name]) + + if expose_params is True: + with Timer(self.timing, "sa copy_params_to_buf"): + self._copy_model_params_to_buf(model) + + return model, optimizer, {"averaged_this_iteration": True} + + def _get_model_params_from_donors( + self, donors: List[int], layer_name, f + ) -> Dict[str, torch.Tensor]: + avg_state: Dict[str, torch.Tensor] = {} + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] + tot_comm_ele = 0 + + self.capture_comm("donors", len(donors)) + for name, param in named_params: + device = param.device + param_shape = param.data.shape + acc = torch.zeros_like(param.data) + all_donors = [] + + for i, src in enumerate(donors): + tensor_win, tensor_buf = self._mpi_tensor_wins[name] + tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + flat_size = np.prod(param_shape) + assert flat_size > 0 + tot_comm_ele += flat_size + donor_tensor_flat = np.zeros( + flat_size, dtype=param.data.cpu().numpy().dtype + ) + tensor_win.Get([donor_tensor_flat, MPI.FLOAT], target_rank=src) + tensor_win.Unlock(rank=src) + + donor_tensor = torch.tensor( + donor_tensor_flat.reshape(param_shape), device=device + ) + # Adding all the donor tensors/params + acc.add_(donor_tensor) + + if self.debug_mode and name == layer_name: + all_donors.append(donor_tensor.tolist()) + # Additions: Other averaging strategies can be implemented here + + if self.debug_mode and name == layer_name: + json.dump({self._count: all_donors}, f) + f.write("\n") + + # default to mean averaging + acc = acc / len(donors) + avg_state[name] = acc + + self.capture_comm("num_param_tensors_received", tot_comm_ele) + return avg_state + + def _average_received_params( + self, + ) -> Dict[str, torch.Tensor]: + avg_state: Dict[str, torch.Tensor] = {} + for name, param in self.avg_state.items(): + # device = param.device + # param_shape = param.data.shape + acc = torch.zeros_like(param[0].data) + + for i, donor_tensor in enumerate(param): + # Adding all the donor tensors/params + acc.add_(donor_tensor) + + # default to mean averaging + if self.averaging_strategy == "mean": + acc = acc / len(param) + avg_state[name] = acc + + return avg_state + + +""" +Selective averaging based on async mpi-3 comms + Version 2 (fast): One buffer with all the params, hence utilizes the network BW much better than general version. + It assumes all the params are of same dtype (float32) +Notes: + - Only mean averaging strategy is implemented here for now. + - Neighbors are randomly selected from all ranks except self. + - Agents are killed based on age only for now. +""" + + +class AsyncSelectiveAveragingPolicympi4pyFast(SpawnPolicy): + r""" + Asynchronous selective averaging version 2, uses mpi one-sided comms to get the + selectively averaged parameters from a random set of ranks. + """ + + def __init__( + self, + model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], + model: GFlowNet, + average_every: int, + threshold_metric: float = 0.0, + replacement_ratio: float = 0.2, + averaging_strategy: str = "mean", + momentum: float = 0.0, + age_range: Tuple[int, int] = (50, 150), + group: MPI.Comm = MPI.COMM_WORLD, + ) -> None: + super().__init__(average_every) + self.myrank = group.Get_rank() + self.comm_size = group.Get_size() + + self.replacement_ratio = float(replacement_ratio) + self.averaging_strategy = str(averaging_strategy) + self.momentum = float(momentum) + self._model_builder = model_builder + + self._model: Optional[GFlowNet] = None + self.threshold_metric = float(threshold_metric) + # timers + self.timing = {} + self.stats = {} + + self._model = model + self.train_comm_group = group + # self._expose_model_parameters(model) + self._expose = False + + # **** new agents' stats **** + self.agents_killed = 0 + self.averaging_ranks = 0 + self._count = 0 + + self.total_iterations = 0 + self.num_replacements = 0 + self.debug_mode = False + + self.age = 0 + self.age_range = age_range + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + + # test code, remove it later + if self.debug_mode: + self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" + with open(self.logfile, "w") as f: + f.write(f"Selective Averaging Log for Rank {self.myrank}\n") + f.write("=" * 50 + "\n") + + def shutdown(self) -> None: + self._mpi_tensor_wins[0].Free() + + def print_time(self) -> None: + print("Selective Averaging timings:", flush=True) + for k, v in self.timing.items(): + # here v is a list, avg over the list + print(f"{k:<35}: {sum(v):>10.4f} seconds") + + def print_stats(self) -> None: + print("Selective Averaging comms stats:", flush=True) + print( + f"Rank {self.myrank} - Agent replaced for {self.num_replacements} out of {self.total_iterations} iterations." + ) + avg_donors, num_calls = 0.0, 0 + for k, v in self.stats.items(): + # v is a list, print min, max ,avg, and len of it + minimum = min(v) if len(v) > 0 else 0 + maximum = max(v) if len(v) > 0 else 0 + avg = sum(v) / len(v) if len(v) > 0 else 0 + length = len(v) + print( + f"Rank {self.myrank:<10} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, across {length} iters" + ) + if k == "donors": + avg_donors = avg + num_calls = length + + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] + print( + f"Rank {self.myrank:<10} - {'param elements':<15} {'#comm_iters':<10} {'total params elements communicated':<25}" + ) + for name, param in named_params: + param.device + param_shape = param.data.shape + print( + f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}" + ) + + def capture_comm(self, name: str, size: int) -> None: + if name not in self.stats: + self.stats[name] = [] + self.stats[name].append(size) + + def _ensure_initialized(self, model: GFlowNet) -> None: + self._model = model + self._initialized = True + # export model parameters to mpi windows (should do that periodically to keep them fresh) + self._expose_model_parameters(model) + + def reset_age(self) -> None: + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + self.age = 0 + + def is_agent_dying( + self, local_metric: float, threshold_metric: float, check_policy=0 + ) -> bool: + if check_policy == 0: # static theshold + return local_metric < threshold_metric + + elif check_policy == 1: # dynamic threshold based on age + if self.age >= self.max_age: + print( + "+ Agent killed due to age: ", + self.age, + " max_age: ", + self.max_age, + flush=True, + ) + self.reset_age() + return True + + self.age += 1 + return False + + else: + raise ValueError(f"Unknown is_agent_dying version: {check_policy}") + + # Execute this function regularly to copy model params to mpi windows + # for recepotrs to get recent params + def _copy_model_params_to_buf( + self, + model: GFlowNet, + ) -> None: + offset = 0 + for name, param in model.named_parameters(): + win = self._mpi_tensor_wins[0] + win.Lock(rank=self.myrank, lock_type=MPI.LOCK_EXCLUSIVE) + size = param.data.numel() + self._mpi_tensor_wins[1][offset : offset + size] = ( + param.data.cpu().numpy().flatten() + ) + offset += size + win.Unlock(rank=self.myrank) + + def _expose_model_parameters(self, model: GFlowNet) -> None: + print("+ Exposing model parameters via MPI windows...", flush=True) + # Serialize model parameters to a contiguous numpy array + param_size = 0 + {param.dtype for param in model.parameters()} + # todo: enable this to work with any dtypes for the model + # print('model dtypes: ', dtypes) + param_dtype = np.float32 + th_param_dtype = torch.float32 + + # self.param_shapes = {} + for _, param in model.named_parameters(): + param_size += param.data.numel() + param_dtype = param.data.cpu().numpy().dtype + + param_tensors_flat = np.zeros(param_size, dtype=param_dtype) + self.donor_tensor_flat = torch.zeros(param_size, dtype=th_param_dtype) + self.acc = torch.zeros(param_size, dtype=th_param_dtype) + + # Create MPI windows for the flat parameter tensor + buf = param_tensors_flat + assert isinstance(self.train_comm_group, MPI.Intracomm) + cast(MPI.Intracomm, self.train_comm_group) + win = MPI.Win.Create(buf, comm=self.train_comm_group) + # buffer attached to the win, used to copy data in/out of the win + self._mpi_tensor_wins = (win, buf) + + self._copy_model_params_to_buf(model) + + def _get_donors(self, n, k, d) -> List[int]: + if k > n - 1: + raise ValueError("k must be ≤ n-1 when excluding one value") + + # All values from 0..n-1 except d + candidates = [x for x in range(n) if x != d] + # Random policy Pick k distinct random values + return random.sample(candidates, k) + + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + + if self._expose is False: + self._expose_model_parameters(model) + self._expose = True + + self._count += 1 + self._model = model + named_params = list(model.named_parameters()) + for name, param in named_params: + param.data.shape + + # validation info + layer_name = None + if self.debug_mode: + layer_name = "pb.module.last_layer.bias" + + self.total_iterations += 1 + check_agent = 1 # 0: static thresholding, 1: dynamic based on age + if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): + self.num_replacements += 1 + with Timer(self.timing, "sa get_params_from_donors"): + # kill this model and rebuild model with fresh weights + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # <<<<< parameterize this one + donors = self._get_donors(self.comm_size, num_donors, self.myrank) + _new_avg_params = self._get_model_params_from_donors( + donors, layer_name, None + ) + + with Timer(self.timing, "sa param_list_to_dict_convert"): + # conver the flat tensor to model param dict + new_avg_params: Dict[str, torch.Tensor] = {} + win_buf: Dict[str, torch.Tensor] = {} + + offset = 0 + for name, param in model.named_parameters(): + device = param.device + flat_size = param.data.numel() + assert flat_size == np.prod(param.data.shape) + donor_tensor_flat = _new_avg_params[offset : offset + flat_size] + donor_tensor = torch.tensor( + donor_tensor_flat.reshape(param.data.shape), device=device + ) + if self.debug_mode: + buf_tensor_flat = self._mpi_tensor_wins[1][ + offset : offset + flat_size + ] + buf_tensor = torch.tensor( + buf_tensor_flat.reshape(param.data.shape), device=device + ) + win_buf[name] = buf_tensor + + new_avg_params[name] = donor_tensor + offset += flat_size + + if self.debug_mode: + with open(self.logfile, "a") as f: + if layer_name is not None: + json.dump( + { + self._count: [ + win_buf[layer_name].tolist(), + donors, + new_avg_params[layer_name].tolist(), + ] + }, + f, + ) + f.write("\n") + + with Timer(self.timing, "sa new_agent_model_rebuild"): + model, optimizer = self._model_builder() + for name, param in model.named_parameters(): + if name in new_avg_params: + param.data.copy_(new_avg_params[name]) + + if expose_params is True: + with Timer(self.timing, "sa copy_params_to_buf"): + self._copy_model_params_to_buf(model) + + return model, optimizer, {"averaged_this_iteration": True} + + def _get_model_params_from_donors( + self, donors: List[int], layer_name, f + ) -> torch.Tensor: + # avg_state: Dict[str, torch.Tensor] = {} + # _named_params = list(self._model.named_parameters()) + # named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + tot_comm_ele = 0 + + self.capture_comm("donors", len(donors)) + tensor_win, tensor_buf = self._mpi_tensor_wins + flat_size = tensor_buf.size + self.donor_tensor_flat.zero_() + self.acc.zero_() + + if self.averaging_strategy == "mean": + for i, src in enumerate(donors): + tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + tensor_win.Get([self.donor_tensor_flat, MPI.FLOAT], target_rank=src) + tensor_win.Unlock(rank=src) + + # Adding all the donor tensors/params + self.acc.add_(self.donor_tensor_flat) + tot_comm_ele = tot_comm_ele + flat_size + + self.acc = self.acc / len(donors) + else: + raise ValueError(f"Unknown averaging strategy: {self.averaging_strategy}") + + self.capture_comm("num_param_tensors_received", tot_comm_ele) + return self.acc + + +class AverageAllPolicympi4py(SpawnPolicy): + """Standard model averaging across all ranks every N iterations.""" + + def __init__(self, average_every: int) -> None: + super().__init__(average_every) + + @torch.no_grad() + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: Optional[float] = None, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + if not dist.is_available() or not dist.is_initialized(): + return model, optimizer, {} + if iteration % self.average_every != 0: + return model, optimizer, {"averaged_this_iteration": False} + + # print("AverageAll mpi4py model parameters across all ranks ...", flush=True) + world_size = group.Get_size() + for param in model.parameters(): + param_tensor = ( + param.detach().cpu().numpy().copy() + ) # param.data.clone().numpy() + # dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=group) + group.Allreduce(MPI.IN_PLACE, param_tensor, op=MPI.SUM) + param_tensor /= world_size + param.data.copy_(torch.from_numpy(param_tensor)) + + return model, optimizer, {"averaged_this_iteration": True} diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index e49dd2aa..d48152a2 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -12,12 +12,12 @@ import numpy as np import torch -import wandb from numpy.typing import NDArray from scipy.special import logsumexp from sklearn.neighbors import KernelDensity from tqdm import tqdm, trange +import wandb from gfn.estimators import ScalarEstimator from gfn.gflownet import ( DBGFlowNet, diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index b79b828d..133c89c7 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -15,9 +15,9 @@ from typing import cast import torch -import wandb from tqdm import tqdm, trange +import wandb from gfn.estimators import DiscretePolicyEstimator from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 2148ca4d..2b1128aa 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -36,6 +36,7 @@ from typing import Optional, Tuple, cast import matplotlib.pyplot as plt +import mpi4py.MPI as MPI import torch import torch.distributed as dist from matplotlib.gridspec import GridSpec @@ -62,7 +63,10 @@ from gfn.utils.modules import MLP, DiscreteUniform, Tabular from tutorials.examples.multinode.spawn_policy import ( AsyncSelectiveAveragingPolicy, + AsyncSelectiveAveragingPolicympi4pyFast, + AsyncSelectiveAveragingPolicympi4pyGeneral, AverageAllPolicy, + AverageAllPolicympi4py, ) logger = logging.getLogger(__name__) @@ -836,10 +840,23 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: dist.barrier(group=distributed_context.train_global_group) # Set up averaging policy (called every iteration; internal guard checks cadence/distributed) - averaging_policy = None + averaging_policy_torch = None + averaging_policy_mpi4py = None + if args.distributed: + averaging_policy_torch = AverageAllPolicy(average_every=args.average_every) + averaging_policy_mpi4py = AverageAllPolicympi4py( + average_every=args.average_every + ) + + mpi4py_train_group = ( + distributed_context.dc_mpi4py.train_global_group + if distributed_context.dc_mpi4py is not None + else MPI.COMM_WORLD + ) + if args.use_selective_averaging: - averaging_policy = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] + averaging_policy_torch = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] model_builder=_model_builder, average_every=args.average_every, replacement_ratio=args.replacement_ratio, @@ -848,8 +865,29 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: threshold=args.performance_tracker_threshold, cooldown=args.performance_tracker_cooldown, ) - else: - averaging_policy = AverageAllPolicy(average_every=args.average_every) + averaging_policy_mpi4py = AsyncSelectiveAveragingPolicympi4pyGeneral( # type: ignore[abstract] + model_builder=_model_builder, + model=gflownet, + average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + age_range=args.age_range, + group=mpi4py_train_group, + ) + if args.mpi_sa_mode == "fast": + averaging_policy_mpi4py = AsyncSelectiveAveragingPolicympi4pyFast( # type: ignore[abstract] + model_builder=_model_builder, + model=gflownet, + average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + age_range=args.age_range, + group=mpi4py_train_group, + ) # Accumulators for averaging score_dict between log intervals. score_dict_accum: dict[str, float] = {} @@ -962,17 +1000,36 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: with Timer( timing, "averaging_model", enabled=args.timing ) as model_averaging_timer: - if averaging_policy is not None: - gflownet, optimizer, averaging_info = averaging_policy( - iteration=iteration, - model=gflownet, - optimizer=optimizer, - local_metric=( - score_dict["score"] if score_dict is not None else -loss.item() - ), - group=distributed_context.train_global_group, - ) + if args.spawn_backend == "dist": + if averaging_policy_torch is not None: + gflownet, optimizer, averaging_info = averaging_policy_torch( + iteration=iteration, + model=gflownet, + optimizer=optimizer, + local_metric=( + score_dict["score"] + if score_dict is not None + else -loss.item() + ), + group=distributed_context.train_global_group, + ) + else: + if ( + averaging_policy_mpi4py is not None + and distributed_context.dc_mpi4py is not None + ): + gflownet, optimizer, averaging_info = averaging_policy_mpi4py( + iteration=iteration, + model=gflownet, + optimizer=optimizer, + local_metric=( + score_dict["score"] + if score_dict is not None + else -loss.item() + ), + group=distributed_context.dc_mpi4py.train_global_group, + ) # Calculate how long this iteration took. iteration_time = time.time() - iteration_start rest_time = iteration_time - sum( @@ -1062,8 +1119,13 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: wandb.log(to_log, step=iteration) with Timer(timing, "barrier 2", enabled=(args.timing and args.distributed)): - if args.distributed and args.timing: - dist.barrier(group=distributed_context.train_global_group) + if ( + args.distributed + and args.timing + and distributed_context.dc_mpi4py is not None + ): + t_comm = distributed_context.dc_mpi4py.train_global_group + t_comm.Barrier() logger.info("Finished all iterations") total_time = time.time() - time_start @@ -1074,9 +1136,11 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: if args.distributed: dist.barrier(group=distributed_context.train_global_group) - assert averaging_policy is not None + assert averaging_policy_torch is not None + assert averaging_policy_mpi4py is not None try: - averaging_policy.shutdown() + averaging_policy_torch.shutdown() + averaging_policy_mpi4py.shutdown() except Exception: pass @@ -1096,6 +1160,16 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: logger.info("-" * 80) for k, v in timing.items(): logger.info("%-25s %10.4fs", k, sum(v)) + try: + if ( + args.spawn_backend == "mpi4py" + and args.use_selective_averaging + and averaging_policy_mpi4py is not None + ): + averaging_policy_mpi4py.print_time() + averaging_policy_mpi4py.print_stats() + except Exception: + pass # Stop the profiler if it's active. if args.profile: @@ -1119,6 +1193,14 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: # Send a termination signal to the replay buffer manager. ReplayBufferManager.send_termination_signal(distributed_context.assigned_buffer) + if args.distributed: + dist.barrier(group=distributed_context.train_global_group) + assert distributed_context is not None + try: + distributed_context.cleanup() + except Exception: + pass + return to_log @@ -1207,6 +1289,29 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: default=0.01, help="Momentum factor for combining with previous weights (0.0 = no momentum, 1.0 = keep old weights)", ) + ## for mpi-3 code of selective averaging debug + parser.add_argument( + "--spawn_backend", + choices=["dist", "mpi4py"], + default="mpi4py", + help="Backend for spawn policy implementation: torch.distributed or mpi4py", + ) + parser.add_argument( + "--mpi_sa_mode", + choices=["general", "fast"], + default="fast", + help=( + "MPI selective averaging implementation to use. " + "'fast' uses an optimized communication path assuming all parameters " + "have the same dtype (e.g., float32)." + ), + ) + parser.add_argument( + "--age_range", + type=lambda s: tuple(map(int, s.split(","))), + default=(5, 15), + help="Age range (iterations) for selective averaging policy as tuple (min_age, max_age), e.g., '5,15'", + ) # Environment settings. parser.add_argument( @@ -1469,7 +1574,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: parser.add_argument( "--performance_tracker_threshold", type=float, - default=None, + default=0.0, help="Threshold for the performance tracker. If None, the performance tracker is not triggered.", ) parser.add_argument( @@ -1480,4 +1585,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) args = parser.parse_args() + assert ( + args.age_range[1] >= args.age_range[0] + ), "Invalid age_range: max_age must be ge min_age" main(args) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 1595d778..45353709 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -1,9 +1,9 @@ from argparse import ArgumentParser import torch -import wandb from tqdm import tqdm +import wandb from gfn.estimators import DiscretePolicyEstimator from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM