From 1d72d5f9e2dcf89e3574e02353f9484642cfe0c0 Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Thu, 5 Feb 2026 09:28:09 -0800 Subject: [PATCH 1/5] selective_averging_mpi4py code added --- src/gfn/utils/distributed.py | 173 +++++ tutorials/examples/multinode/spawn_policy.py | 694 +++++++++++++++++++ tutorials/examples/train_hypergrid.py | 119 +++- 3 files changed, 965 insertions(+), 21 deletions(-) diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 7e109775..0e04055c 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -7,6 +7,9 @@ import torch import torch.distributed as dist +import mpi4py +import mpi4py.MPI as MPI + logger = logging.getLogger(__name__) @@ -121,6 +124,156 @@ def average_models(model, training_group=None): dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=training_group) param.data = param_tensor / world_size +@dataclass +class DistributedContextmpi4py: + """Holds all distributed training/replay buffer groups and ranks.""" + + my_rank: int + world_size: int + num_training_ranks: int + agent_group_size: int + agent_groups: Optional[List[dist.ProcessGroup]] = None + agent_group_id: Optional[int] = None + train_global_group: Optional[dist.ProcessGroup] = None + assigned_buffer: Optional[int] = None + buffer_group: Optional[dist.ProcessGroup] = None + assigned_training_ranks: Optional[List[int]] = None + + def is_buffer_rank(self) -> bool: + """Check if the current rank is part of the buffer group.""" + return self.my_rank >= self.num_training_ranks + + def is_training_rank(self) -> bool: + """Check if the current rank is part of the training group.""" + return self.my_rank < self.num_training_ranks + +def initialize_distributed_compute_mpi4py( + num_remote_buffers: int, + num_agent_groups: int, +) -> DistributedContextmpi4py: + """Initalizes distributed compute using either ccl or mpi backends.""" + """ + Initalizes distributed compute using either ccl or mpi backends. + Args: + dist_backend: The backend to use for distributed compute. + num_remote_buffers: The number of remote buffers to use. + """ + + pmi_size = MPI.COMM_WORLD.Get_size() + print("+ Initalizing distributed compute, PMI_SIZE={}".format(pmi_size)) + + if pmi_size <= 1: + print("+ PMI_SIZE <= 1, running in single process mode.") + return DistributedContext( + my_rank=0, world_size=1, num_training_ranks=1, agent_group_size=1 + ) + + os.environ["RANK"] = str(MPI.COMM_WORLD.Get_rank() ) + os.environ["WORLD_SIZE"] = str(pmi_size) + + print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS")) + + world_size = MPI.COMM_WORLD.Get_size() + if world_size is None: + raise ValueError("WORLD_SIZE is not set") + rank = MPI.COMM_WORLD.Get_rank() + if rank is None: + raise ValueError("RANK is not set") + + dist.barrier() + print("+ Distributed compute initialized") + + my_rank = rank ##dist.get_rank() # Global! + #world_size = dist.get_world_size() # Global! + + num_training_ranks = world_size - num_remote_buffers + + # make sure that we have atmost 1 remote buffer per training rank. + assert num_training_ranks >= num_remote_buffers + print("num_train = ", num_training_ranks) + print("num_remote_buffers = ", num_remote_buffers) + + # for now, let us enforce that each agent gets equal number of ranks. + # TODO: later, we can relax this condition. + assert num_training_ranks % num_agent_groups == 0 + agent_group_size = num_training_ranks // num_agent_groups + agent_group_rank_list = [ + list(range(i * agent_group_size, (i + 1) * agent_group_size)) + for i in range(num_agent_groups) + ] + print(f"Agent group ranks: {agent_group_rank_list}") + world_group = MPI.COMM_WORLD.Get_group() + agent_group_list = [ ] + for i in range(num_agent_groups): + grp = world_group.Incl(agent_group_rank_list[i]) + agent_group_list.append( MPI.COMM_WORLD.Create(grp)) + + # all training ranks in one global group + training_ranks = [ + r for r in range(num_training_ranks) + ] # e.g., 0..num_training_ranks-1 + + #train_global_group = dist.new_group( + # ranks=training_ranks, + # backend=dist_backend, + # timeout=datetime.timedelta(minutes=5), + #) + grp = world_group.Incl(training_ranks) + train_global_group = MPI.COMM_WORLD.Create(grp) + #print(f"Training global group ranks: {training_ranks}, {train_global_group}") + #assert train_global_group != MPI.COMM_NULL + + buffer_group = None + assigned_buffer = None + assigned_training_ranks = {} + if num_remote_buffers > 0: + buffer_ranks = list( + range(num_training_ranks, num_training_ranks + num_remote_buffers) + ) + #buffer_group = dist.new_group( + # buffer_ranks, + # backend=dist_backend, + # timeout=datetime.timedelta(minutes=5), + #) + grp = world_group.Incl(buffer_ranks) + buffer_group = MPI.COMM_WORLD.Create(grp) + print(f" >>>>>>>>>>>>>>>>. Buffer group ranks: {buffer_ranks}, {buffer_group}") + #assert buffer_group != MPI.COMM_NULL + + print(f"Buffer group ranks: {buffer_ranks}") + + # Each training rank gets assigned to a buffer rank + if my_rank < (num_training_ranks): + assigned_buffer = num_training_ranks + (my_rank % num_remote_buffers) + else: + assigned_training_ranks[my_rank] = [ + ranks + for ranks in range(num_training_ranks) + if (ranks % num_remote_buffers) == (my_rank - num_training_ranks) + ] + + print(f"+ My rank: {my_rank} size: {world_size}") + if my_rank < (num_training_ranks): + print(f" -> Training group, assigned buffer rank = {assigned_buffer}") + else: + print(" -> Buffer group") + + #dist.barrier() + print("+ Distributed compute initialized, rank = ", my_rank) + + return DistributedContextmpi4py( + my_rank=my_rank, + world_size=world_size, + num_training_ranks=num_training_ranks, + agent_group_size=agent_group_size, + agent_groups=agent_group_list, + agent_group_id=my_rank // agent_group_size, + train_global_group=train_global_group, + assigned_buffer=assigned_buffer, + buffer_group=buffer_group, + assigned_training_ranks=assigned_training_ranks.get(my_rank, None), + ) + @dataclass class DistributedContext: @@ -136,6 +289,7 @@ class DistributedContext: assigned_buffer: Optional[int] = None buffer_group: Optional[dist.ProcessGroup] = None assigned_training_ranks: Optional[List[int]] = None + dc_mpi4py: Optional[DistributedContextmpi4py] = None def is_buffer_rank(self) -> bool: """Check if the current rank is part of the buffer group.""" @@ -145,6 +299,19 @@ def is_training_rank(self) -> bool: """Check if the current rank is part of the training group.""" return self.my_rank < self.num_training_ranks + def cleanup(self) -> None: + """Cleans up the distributed process group.""" + dist.destroy_process_group() + if self.dc_mpi4py is not None: + if self.dc_mpi4py.train_global_group is not None: + self.dc_mpi4py.train_global_group.Free() + if self.dc_mpi4py.buffer_group is not None: + self.dc_mpi4py.buffer_group.Free() + if self.dc_mpi4py.agent_groups is not None: + for ag in self.dc_mpi4py.agent_groups: + ag.Free() + MPI.Finalize() + def initialize_distributed_compute( dist_backend: str, @@ -302,6 +469,11 @@ def initialize_distributed_compute( dist.barrier() logger.info("Distributed compute initialized, rank = %d", my_rank) + dc = initialize_distributed_compute_mpi4py( + num_remote_buffers=num_remote_buffers, + num_agent_groups=num_agent_groups, + ) + return DistributedContext( my_rank=my_rank, world_size=world_size, @@ -313,6 +485,7 @@ def initialize_distributed_compute( assigned_buffer=assigned_buffer, buffer_group=buffer_group, assigned_training_ranks=assigned_training_ranks.get(my_rank, None), + dc_mpi4py=dc ) diff --git a/tutorials/examples/multinode/spawn_policy.py b/tutorials/examples/multinode/spawn_policy.py index c069cd3a..99a4a7a9 100644 --- a/tutorials/examples/multinode/spawn_policy.py +++ b/tutorials/examples/multinode/spawn_policy.py @@ -11,6 +11,13 @@ from gfn.gflownet.base import GFlowNet +import os +import random +import numpy as np +import mpi4py.MPI as MPI +from gfn.utils.common import Timer +import json + logger = logging.getLogger(__name__) @@ -90,6 +97,7 @@ def __init__( poll_interval_s: float = 0.01, threshold: Optional[float] = None, cooldown: int = 200, + timing: Optional[dict] = None, ## timing is a dict to capture timing info ) -> None: super().__init__(average_every) self.replacement_ratio = float(replacement_ratio) @@ -603,3 +611,689 @@ def _compute_averaging_weights( ) weights = 1.0 / (contributing_metrics + 1e-8) return weights / weights.sum() + +###########################################################################################3 +## Selective averaging based on async mpi-3 comms +# Version 1 (fast-general): One buffer with all the params, hence utilizes the network BW much better than general version. +# Hence different model param dtypes can be handled. +# Comments: +# - Need a fix where there is one buffer for all params but in byte format, +# after comms, the buffer can be converted to appropriate dtypes for corr params dtypes. +# - Right now each param has its own buffer and window, which may produce incorrect output. +# Notes: +# - Only mean averaging strategy is implemented here for now. +# - Neighbors are randomly selected from all ranks except self. +# - Agents are killed based on age only for now. +###########################################################################################3 + +class AsyncSelectiveAveragingPolicympi4pyGeneral(SpawnPolicy): + r""" + Asynchronous selective averaging version 2, uses mpi one-sided comms to get the + selectively averaged parameters from a random set of ranks. + """ + def __init__( + self, + model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], + model: GFlowNet, + average_every: int, + threshold_metric: float = 0.0, + replacement_ratio: float = 0.2, + averaging_strategy: str = "mean", + momentum: float = 0.0, + poll_interval_s: float = 0.01, + age_range: Tuple[int, int] = (50, 150), + group: MPI.Comm = MPI.COMM_WORLD, + ) -> None: + super().__init__(average_every) + self.myrank = group.Get_rank() + self.comm_size = group.Get_size() + + self.replacement_ratio = float(replacement_ratio) + self.averaging_strategy = str(averaging_strategy) + self.momentum = float(momentum) + self._model_builder = model_builder + + self._model: Optional[GFlowNet] = None + self.threshold_metric = float(threshold_metric) + ## timers + self.timing = {} + self.stats = {} + + self._model = model + self.train_comm_group = group + self._expose = False + + ###### new agents' stats #### + self.agents_killed = 0 + self.averaging_ranks = 0 + self._count = 0 + + self.debug_mode = False + self.age = 0 + self.age_range = age_range + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + + if self.debug_mode: + self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" + with open(self.logfile, 'w') as f: + f.write(f"Selective Averaging Log for Rank {self.myrank}\n") + f.write("=" * 50 + "\n") + + + ############################### + def shutdown(self) -> None: + for _,v in self._mpi_tensor_wins.items(): + v[0].Free() + + ############################## + def print_time(self) -> None: + print("Selective Averaging timings:", flush=True) + for k, v in self.timing.items(): + ## here v is a list, avg over the list + #avg = sum(v) / len(v) if len(v) > 0 else 0 + print(f"{k:<35}: {sum(v):>10.4f} seconds") + + def print_stats(self) -> None: + print("Selective Averaging comms stats:", flush=True) + avg_donors, num_calls = 0.0, 0 + for k, v in self.stats.items(): + # v is a list, print min, max ,avg, and len of it + minimum = min(v) if len(v) > 0 else 0 + maximum = max(v) if len(v) > 0 else 0 + avg = sum(v) / len(v) if len(v) > 0 else 0 + length = len(v) + print(f"Rank {self.myrank} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, count={length}") + if k == "donors": + avg_donors = avg + num_calls = length + + _named_params = list(self._model.named_parameters()) + named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + print(f"Rank {self.myrank:<10} - {'param elements':<15} {'iter':<10} {'total params elements commd':<25}") + for name, param in named_params: + device = param.device + param_shape = param.data.shape + print(f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}") + + def capture_comm(self, name: str, size: int) -> None: + if name not in self.stats: + self.stats[name] = [] + self.stats[name].append(size) + + ################################ + def _ensure_initialized(self, model: GFlowNet) -> None: + self._model = model + self._initialized = True + ## export model parameters to mpi windows (should do that periodically to keep them fresh) + self._expose_model_parameters(model) + + ################################ + def reset_age(self) -> None: + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + self.age = 0 + + ################################ + def is_agent_dying(self, local_metric: float, threshold_metric: float, check_agent = 0) -> bool: + if check_agent == 0: ## static theshold + return local_metric < threshold_metric + + elif check_agent == 1: ## dynamic threshold based on age + if self.age >= self.max_age: + print('+ Agent killed due to age: ', self.age, ' max_age: ', self.max_age, flush=True) + self.reset_age() + return True + + self.age += 1 + return False + + else: + raise ValueError(f"Unknown is_agent_dying version: {check_agent}") + + + ################################ + # Execute this function regularly to copy model params to mpi windows + # for recepotrs to get recent params + def _copy_model_params_to_buf( + self, + model: GFlowNet, + ) -> None: + for name, param in model.named_parameters(): + win = self._mpi_tensor_wins[name][0] + win.Lock(rank=self.myrank, lock_type=MPI.LOCK_EXCLUSIVE) + self._mpi_tensor_wins[name][1][:] = param.data.cpu().numpy().flatten() + win.Unlock(rank=self.myrank) + + + ################################ + def _expose_model_parameters(self, model: GFlowNet) -> None: + rank = self.myrank + size = self.comm_size + + # Serialize model parameters to a contiguous numpy array + param_tensors = {} + for name, param in model.named_parameters(): + param_tensors[name] = np.zeros_like(param.data.cpu().numpy().flatten()) + + # Create MPI windows for each parameter and its shape (2 separate windows set) + self._mpi_tensor_wins = {} + self._mpi_shape_wins = {} + for name, tensor in param_tensors.items(): + buf = tensor + win = MPI.Win.Create(buf, comm=self.train_comm_group) + self._mpi_tensor_wins[name] = (win, buf) + + self._copy_model_params_to_buf(model) + + ################################ + def _get_donors(self, n, k, d) -> List[int]: + if k > n - 1: + raise ValueError("k must be ≤ n-1 when excluding one value") + + # All values from 0..n-1 except d + candidates = [x for x in range(n) if x != d] + # Pick k distinct values + return random.sample(candidates, k) + + + ################################ + def __call__(self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + + if self._expose == False: + self._expose_model_parameters(model) + self._expose = True + + self._count += 1 + self._model = model + + check_agent = 1 ## version of dying agent check + ## validation info + layer_name = None + if self.debug_mode: + layer_name = "pb.module.last_layer.bias" + + if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): + with Timer( + self.timing, "sa get_params_from_donors" + ): + if self.debug_mode: + with open(self.logfile, 'a') as f: + # kill this model and rebuild model with fresh weights + num_donors = max(1, int(self.comm_size * 0.5)) ## * self.replacement_ratio)) + donors = self._get_donors(self.comm_size, num_donors, self.myrank, self.averaging_strategy) + new_avg_params = self._get_model_params_from_donors(donors, layer_name, f) + json.dump({self._count: [self._mpi_tensor_wins[layer_name][1].tolist(), donors, new_avg_params[layer_name].tolist()]}, f) + f.write("\n") + else: + num_donors = max(1, int(self.comm_size * 0.5)) # self.replacement_ratio)) + donors = self._get_donors(self.comm_size, num_donors, self.myrank) + new_avg_params = self._get_model_params_from_donors(donors, layer_name, None) + + with Timer( + self.timing, "sa new_agent_model_rebuild" + ): + tic = self.get_time() + model, optimizer = self._model_builder() + for name, param in model.named_parameters(): + if name in new_avg_params: + param.data.copy_(new_avg_params[name]) + + if expose_params == True: + with Timer( + self.timing, "sa copy_params_to_buf" + ): + self._copy_model_params_to_buf(model) + + return model, optimizer, True + + + ################################ + def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dict[str, torch.Tensor]: + avg_state: Dict[str, torch.Tensor] = {} + _named_params = list(self._model.named_parameters()) + named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + tot_comm_ele = 0 + + self.capture_comm("donors", len(donors)) + for name, param in named_params: + device = param.device + param_shape = param.data.shape + acc = torch.zeros_like(param.data) + all_donors = [] + + for i, src in enumerate(donors): + tensor_win, tensor_buf = self._mpi_tensor_wins[name] + tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + flat_size = np.prod(param_shape) + assert flat_size > 0 + tot_comm_ele += flat_size + donor_tensor_flat = np.zeros(flat_size, dtype=param.data.cpu().numpy().dtype) + tensor_win.Get([donor_tensor_flat, MPI.FLOAT], target_rank=src) + tensor_win.Unlock(rank=src) + + donor_tensor = torch.tensor(donor_tensor_flat.reshape(param_shape), device=device) + ## Adding all the donor tensors/params + acc.add_(donor_tensor) + + if self.debug_mode and name == layer_name: + all_donors.append(donor_tensor.tolist()) + ## Additions: Other averaging strategies can be implemented here + + if self.debug_mode and name == layer_name: + json.dump({self._count: all_donors}, f) + f.write("\n") + + # default to mean averaging + acc = acc / len(donors) + avg_state[name] = acc + + self.capture_comm("num_param_tensors_received", tot_comm_ele) + return avg_state + + + ################################ + def _get_model_params_from_donors_general(self, donors: List[int]) -> Dict[str, torch.Tensor]: + self.avg_state: Dict[str, torch.Tensor] = {} + named_params = list(self._model.named_parameters()) + + for name, param in named_params: + device = param.device + param_shape = param.data.shape + #acc = torch.zeros_like(param.data) + acc = [] + + for i, src in enumerate(donors): + # Get shape of parameter from donor + #shape_win, shape_buf = self._mpi_shape_wins[name] + #shape_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + #donor_shape = tuple(shape_buf.tolist()) + #shape_win.Unlock(rank=src, lock_type=MPI.LOCK_SHARED) + + # Get parameter tensor from donor + tensor_win, tensor_buf = self._mpi_tensor_wins[name] + tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + flat_size = np.prod(param_shape) + donor_tensor_flat = np.zeros(flat_size, dtype=param.data.cpu().numpy().dtype) + tensor_win.Get([donor_tensor_flat, MPI.FLOAT], target_rank=src) + tensor_win.Unlock(rank=src) + + donor_tensor = torch.tensor(donor_tensor_flat.reshape(param_shape), device=device) + ## Adding all the donor tensors/params + #acc.add_(donor_tensor) + acc.append(donor_tensor) + self.capture_comm("num_param_tensors_received", flat_size); + + ## Additions: Other averaging strategies can be implemented here + + #if self.averaging_strategy == "mean" and len(donors) > 0: + # default to mean averaging + #acc = acc / len(donors) + self.avg_state[name] = acc + + return self.avg_state + + + ################################ + def _average_received_params( + self, + ) -> Dict[str, torch.Tensor]: + avg_state: Dict[str, torch.Tensor] = {} + for name, param in self.avg_state.items(): + #device = param.device + #param_shape = param.data.shape + acc = torch.zeros_like(param[0].data) + + for i, donor_tensor in enumerate(param): + # Adding all the donor tensors/params + acc.add_(donor_tensor) + + # default to mean averaging + if self.averaging_strategy == "mean": + acc = acc / len(param) + avg_state[name] = acc + + return avg_state + + + + +###########################################################################################3 +## Selective averaging based on async mpi-3 comms +# Version 2 (fast): One buffer with all the params, hence utilizes the network BW much better than general version. +# It assumes all the params are of same dtype (float32) +# Notes: +# - Only mean averaging strategy is implemented here for now. +# - Neighbors are randomly selected from all ranks except self. +# - Agents are killed based on age only for now. +###########################################################################################3 +class AsyncSelectiveAveragingPolicympi4pyFast(SpawnPolicy): + r""" + Asynchronous selective averaging version 2, uses mpi one-sided comms to get the + selectively averaged parameters from a random set of ranks. + """ + def __init__( + self, + model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], + model: GFlowNet, + average_every: int, + threshold_metric: float = 0.0, + replacement_ratio: float = 0.2, + averaging_strategy: str = "mean", + momentum: float = 0.0, + age_range: Tuple[int, int] = (50, 150), + group: MPI.Comm = MPI.COMM_WORLD, + ) -> None: + super().__init__(average_every) + self.myrank = group.Get_rank() + self.comm_size = group.Get_size() + + self.replacement_ratio = float(replacement_ratio) + self.averaging_strategy = str(averaging_strategy) + self.momentum = float(momentum) + self._model_builder = model_builder + + self._model: Optional[GFlowNet] = None + self.threshold_metric = float(threshold_metric) + ## timers + self.timing = {} + self.stats = {} + + self._model = model + self.train_comm_group = group + #self._expose_model_parameters(model) + self._expose = False + + ###### new agents' stats #### + self.agents_killed = 0 + self.averaging_ranks = 0 + self._count = 0 + + self.total_iterations = 0 + self.num_replacements = 0 + self.debug_mode = False + + self.age = 0 + self.age_range = age_range + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + + ## test code, remove it later + if self.debug_mode: + self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" + with open(self.logfile, 'w') as f: + f.write(f"Selective Averaging Log for Rank {self.myrank}\n") + f.write("=" * 50 + "\n") + + ############################### + def shutdown(self) -> None: + self._mpi_tensor_wins[0].Free() + + ############################## + def print_time(self) -> None: + print("Selective Averaging timings:", flush=True) + for k, v in self.timing.items(): + ## here v is a list, avg over the list + print(f"{k:<35}: {sum(v):>10.4f} seconds") + + def print_stats(self) -> None: + print("Selective Averaging comms stats:", flush=True) + print(f"Rank {self.myrank} - Agent replaced for {self.num_replacements} out of {self.total_iterations} iterations.") + avg_donors, num_calls = 0.0, 0 + for k, v in self.stats.items(): + # v is a list, print min, max ,avg, and len of it + minimum = min(v) if len(v) > 0 else 0 + maximum = max(v) if len(v) > 0 else 0 + avg = sum(v) / len(v) if len(v) > 0 else 0 + length = len(v) + print(f"Rank {self.myrank:<10} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, across {length} iters") + if k == "donors": + avg_donors = avg + num_calls = length + + _named_params = list(self._model.named_parameters()) + named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + print(f"Rank {self.myrank:<10} - {'param elements':<15} {'#comm_iters':<10} {'total params elements communicated':<25}") + for name, param in named_params: + device = param.device + param_shape = param.data.shape + print(f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}") + + def capture_comm(self, name: str, size: int) -> None: + if name not in self.stats: + self.stats[name] = [] + self.stats[name].append(size) + + ################################ + def _ensure_initialized(self, model: GFlowNet) -> None: + self._model = model + self._initialized = True + ## export model parameters to mpi windows (should do that periodically to keep them fresh) + self._expose_model_parameters(model) + + ################################ + def reset_age(self) -> None: + self.max_age = random.randint(self.age_range[0], self.age_range[1]) + self.age = 0 + + ################################ + def is_agent_dying(self, local_metric: float, threshold_metric: float, check_policy = 0) -> bool: + if check_policy == 0: ## static theshold + return local_metric < threshold_metric + + elif check_policy == 1: ## dynamic threshold based on age + if self.age >= self.max_age: + print('+ Agent killed due to age: ', self.age, ' max_age: ', self.max_age, flush=True) + self.reset_age() + return True + + self.age += 1 + return False + + else: + raise ValueError(f"Unknown is_agent_dying version: {check_policy}") + + + ################################ + # Execute this function regularly to copy model params to mpi windows + # for recepotrs to get recent params + def _copy_model_params_to_buf( + self, + model: GFlowNet, + ) -> None: + offset = 0 + for name, param in model.named_parameters(): + win = self._mpi_tensor_wins[0] + win.Lock(rank=self.myrank, lock_type=MPI.LOCK_EXCLUSIVE) + size = param.data.numel() + self._mpi_tensor_wins[1][offset:offset+size] = param.data.cpu().numpy().flatten() + offset += size + win.Unlock(rank=self.myrank) + + + ################################ + def _expose_model_parameters(self, model: GFlowNet) -> None: + print("+ Exposing model parameters via MPI windows...", flush=True) + # Serialize model parameters to a contiguous numpy array + param_size = 0 + dtypes = {param.dtype for param in model.parameters()} + ## todo: enable this to work with any dtypes for the model + # print('model dtypes: ', dtypes) + param_dtype = np.float32 + th_param_dtype = torch.float32 + + #self.param_shapes = {} + for _, param in model.named_parameters(): + param_size += param.data.numel() + param_dtype = param.data.cpu().numpy().dtype + + param_tensors_flat = np.zeros(param_size, dtype=param_dtype) + self.donor_tensor_flat = torch.zeros(param_size, dtype=th_param_dtype) + self.acc = torch.zeros(param_size, dtype=th_param_dtype) + + # Create MPI windows for the flat parameter tensor + buf = param_tensors_flat + win = MPI.Win.Create(buf, comm=self.train_comm_group) + ## buffer attached to the win, used to copy data in/out of the win + self._mpi_tensor_wins = (win, buf) + + self._copy_model_params_to_buf(model) + + ################################ + def _get_donors(self, n, k, d) -> List[int]: + if k > n - 1: + raise ValueError("k must be ≤ n-1 when excluding one value") + + # All values from 0..n-1 except d + candidates = [x for x in range(n) if x != d] + # Random policy Pick k distinct random values + return random.sample(candidates, k) + + ################################ + def __call__(self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + + if self._expose == False: + self._expose_model_parameters(model) + self._expose = True + + self._count += 1 + self._model = model + named_params = list(model.named_parameters()) + for name, param in named_params: + param_shape = param.data.shape + + ## validation info + layer_name = None + if self.debug_mode: + layer_name = "pb.module.last_layer.bias" + + self.total_iterations += 1 + check_agent = 1 ## 0: static thresholding, 1: dynamic based on age + if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): + self.num_replacements += 1 + with Timer( + self.timing, "sa get_params_from_donors" + ): + # kill this model and rebuild model with fresh weights + num_donors = max(1, int(self.comm_size * 0.5)) ## <<<<< parameterize this one + donors = self._get_donors(self.comm_size, num_donors, self.myrank) + _new_avg_params = self._get_model_params_from_donors(donors, layer_name, None) + + + with Timer( + self.timing, "sa param_list_to_dict_convert" + ): + ## conver the flat tensor to model param dict + new_avg_params: Dict[str, torch.Tensor] = {} + win_buf: Dict[str, torch.Tensor] = {} + + offset = 0 + for name, param in model.named_parameters(): + device = param.device + flat_size = param.data.numel() + assert flat_size == np.prod(param.data.shape) + donor_tensor_flat = _new_avg_params[offset:offset + flat_size] + donor_tensor = torch.tensor(donor_tensor_flat.reshape(param.data.shape), device=device) + if self.debug_mode: + buf_tensor_flat = self._mpi_tensor_wins[1][offset:offset + flat_size] + buf_tensor = torch.tensor(buf_tensor_flat.reshape(param.data.shape), device=device) + win_buf[name] = buf_tensor + + new_avg_params[name] = donor_tensor + offset += flat_size + + if self.debug_mode: + with open(self.logfile, 'a') as f: + json.dump({self._count: [win_buf[layer_name].tolist(), donors, new_avg_params[layer_name].tolist()]}, f) + f.write("\n") + + with Timer( + self.timing, "sa new_agent_model_rebuild" + ): + model, optimizer = self._model_builder() + for name, param in model.named_parameters(): + if name in new_avg_params: + param.data.copy_(new_avg_params[name]) + + + if expose_params == True: + with Timer( + self.timing, "sa copy_params_to_buf" + ): + self._copy_model_params_to_buf(model) + + return model, optimizer, True + + + ################################ + def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dict[str, torch.Tensor]: + #avg_state: Dict[str, torch.Tensor] = {} + #_named_params = list(self._model.named_parameters()) + #named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + tot_comm_ele = 0 + + self.capture_comm("donors", len(donors)) + tensor_win, tensor_buf = self._mpi_tensor_wins + flat_size = tensor_buf.size + self.donor_tensor_flat.zero_() + self.acc.zero_() + + if self.averaging_strategy == "mean": + for i, src in enumerate(donors): + tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) + tensor_win.Get([self.donor_tensor_flat, MPI.FLOAT], target_rank=src) + tensor_win.Unlock(rank=src) + + ## Adding all the donor tensors/params + self.acc.add_(self.donor_tensor_flat) + tot_comm_ele = tot_comm_ele + flat_size + + self.acc = self.acc / len(donors) + else: + raise ValueError(f"Unknown averaging strategy: {self.averaging_strategy}") + + self.capture_comm("num_param_tensors_received", tot_comm_ele) + return self.acc + + +class AverageAllPolicympi4py(SpawnPolicy): + """Standard model averaging across all ranks every N iterations.""" + + def __init__(self, average_every: int) -> None: + super().__init__(average_every) + + @torch.no_grad() + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: Optional[float] = None, + group=MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + if not dist.is_available() or not dist.is_initialized(): + return model, optimizer, {} + if iteration % self.average_every != 0: + return model, optimizer, {"averaged_this_iteration": False} + + #print("AverageAll mpi4py model parameters across all ranks ...", flush=True) + world_size = group.Get_size() + for param in model.parameters(): + param_tensor = param.detach().cpu().numpy().copy() ##param.data.clone().numpy() + #dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=group) + group.Allreduce(MPI.IN_PLACE, param_tensor, op=MPI.SUM) + param_tensor /= world_size + param.data.copy_(torch.from_numpy(param_tensor)) + + return model, optimizer, {"averaged_this_iteration": True} diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 2148ca4d..3c7544ab 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -62,7 +62,10 @@ from gfn.utils.modules import MLP, DiscreteUniform, Tabular from tutorials.examples.multinode.spawn_policy import ( AsyncSelectiveAveragingPolicy, + AsyncSelectiveAveragingPolicympi4pyFast, + AsyncSelectiveAveragingPolicympi4py, AverageAllPolicy, + AverageAllPolicympi4py, ) logger = logging.getLogger(__name__) @@ -822,6 +825,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: # Initialize some variables before the training loop. timing = {} + oldcode = args.oldcode time_start = time.time() l1_distances, validation_steps = [], [] @@ -839,17 +843,49 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: averaging_policy = None if args.distributed: if args.use_selective_averaging: - averaging_policy = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] - model_builder=_model_builder, - average_every=args.average_every, - replacement_ratio=args.replacement_ratio, - averaging_strategy=args.averaging_strategy, - momentum=args.momentum, - threshold=args.performance_tracker_threshold, - cooldown=args.performance_tracker_cooldown, - ) + if oldcode: + averaging_policy = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] + model_builder=_model_builder, + average_every=args.average_every, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + threshold=args.performance_tracker_threshold, + cooldown=args.performance_tracker_cooldown, + ) + else: + if args.fast_sa: + ## fast -- assumes all the params are of same precision + averaging_policy = AsyncSelectiveAveragingPolicympi4pyFast( # type: ignore[abstract] + model_builder=_model_builder, + model=gflownet, + average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + age_range=args.age_range, + group=distributed_context.dc_mpi4py.train_global_group, + ) + else: + # general -- more general where diffeernt params can be of different precision + averaging_policy = AsyncSelectiveAveragingPolicympi4py( # type: ignore[abstract] + model_builder=_model_builder, + model=gflownet, + average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + age_range=args.age_range, + group=distributed_context.dc_mpi4py.train_global_group, + ) else: - averaging_policy = AverageAllPolicy(average_every=args.average_every) + if oldcode: + averaging_policy = AverageAllPolicy(average_every=args.average_every) + else: + print("+ Using AverageAllPolicympi4py", flush=True) + averaging_policy = AverageAllPolicympi4py(average_every=args.average_every) # Accumulators for averaging score_dict between log intervals. score_dict_accum: dict[str, float] = {} @@ -963,16 +999,25 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: timing, "averaging_model", enabled=args.timing ) as model_averaging_timer: if averaging_policy is not None: - gflownet, optimizer, averaging_info = averaging_policy( - iteration=iteration, - model=gflownet, - optimizer=optimizer, - local_metric=( - score_dict["score"] if score_dict is not None else -loss.item() - ), - group=distributed_context.train_global_group, - ) - + if oldcode: + gflownet, optimizer, averaging_info = averaging_policy( + iteration=iteration, + model=gflownet, + optimizer=optimizer, + local_metric=( + score_dict["score"] if score_dict is not None else -loss.item() + ), + group=distributed_context.train_global_group, + ) + else: + gflownet, optimizer, averaging_info = averaging_policy( + iteration=iteration, + model=gflownet, + optimizer=optimizer, + local_metric=(score_dict["score"] if score_dict is not None else -loss.item() + ), + group=distributed_context.dc_mpi4py.train_global_group, + ) # Calculate how long this iteration took. iteration_time = time.time() - iteration_start rest_time = iteration_time - sum( @@ -1063,7 +1108,8 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: with Timer(timing, "barrier 2", enabled=(args.timing and args.distributed)): if args.distributed and args.timing: - dist.barrier(group=distributed_context.train_global_group) + t_comm = distributed_context.dc_mpi4py.train_global_group + t_comm.Barrier() logger.info("Finished all iterations") total_time = time.time() - time_start @@ -1096,6 +1142,11 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: logger.info("-" * 80) for k, v in timing.items(): logger.info("%-25s %10.4fs", k, sum(v)) + try: + averaging_policy.print_time() + averaging_policy.print_stats() + except Exception: + pass # Stop the profiler if it's active. if args.profile: @@ -1119,6 +1170,14 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: # Send a termination signal to the replay buffer manager. ReplayBufferManager.send_termination_signal(distributed_context.assigned_buffer) + if args.distributed: + dist.barrier(group=distributed_context.train_global_group) + assert averaging_policy is not None + try: + distributed_context.cleanup() + except Exception: + pass + return to_log @@ -1207,6 +1266,23 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: default=0.01, help="Momentum factor for combining with previous weights (0.0 = no momentum, 1.0 = keep old weights)", ) + ## for mpi-3 code of selective averaging debug + parser.add_argument( + "--oldcode", + action="store_true", + help="Temp switch between old and new selective averaging code (dist mpi or mpi4py)", + ) + parser.add_argument( + "--fast_sa", + action="store_true", + help="Use fast (comms) selective averaging mpi4py code assuming all params are of same precision (e.g., float32)", + ) + parser.add_argument( + "--age_range", + type=lambda s: tuple(map(int, s.split(','))), + default=(5, 15), + help="Age range (iterations) for selective averaging policy as tuple (min_age, max_age), e.g., '5,15'", + ) # Environment settings. parser.add_argument( @@ -1480,4 +1556,5 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) args = parser.parse_args() + assert args.age_range[1] >= args.age_range[0], "Invalid age_range: max_age must be ge min_age" main(args) From d84c1cb189831ec073b05c302efe42131e806b86 Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Fri, 6 Feb 2026 06:27:58 -0800 Subject: [PATCH 2/5] fixed bugs and precommit errors --- src/gfn/utils/distributed.py | 42 +- tutorials/examples/multinode/spawn_policy.py | 477 ++++++++++--------- tutorials/examples/train_box.py | 2 +- tutorials/examples/train_discreteebm.py | 2 +- tutorials/examples/train_hypergrid.py | 136 +++--- tutorials/examples/train_ising.py | 2 +- 6 files changed, 350 insertions(+), 311 deletions(-) diff --git a/src/gfn/utils/distributed.py b/src/gfn/utils/distributed.py index 0e04055c..a539277d 100644 --- a/src/gfn/utils/distributed.py +++ b/src/gfn/utils/distributed.py @@ -4,12 +4,10 @@ from dataclasses import dataclass from typing import Dict, List, Optional, cast +import mpi4py.MPI as MPI import torch import torch.distributed as dist -import mpi4py -import mpi4py.MPI as MPI - logger = logging.getLogger(__name__) @@ -124,6 +122,7 @@ def average_models(model, training_group=None): dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=training_group) param.data = param_tensor / world_size + @dataclass class DistributedContextmpi4py: """Holds all distributed training/replay buffer groups and ranks.""" @@ -132,11 +131,11 @@ class DistributedContextmpi4py: world_size: int num_training_ranks: int agent_group_size: int - agent_groups: Optional[List[dist.ProcessGroup]] = None + agent_groups: Optional[List[MPI.Comm]] = None agent_group_id: Optional[int] = None - train_global_group: Optional[dist.ProcessGroup] = None + train_global_group: MPI.Comm = MPI.COMM_WORLD assigned_buffer: Optional[int] = None - buffer_group: Optional[dist.ProcessGroup] = None + buffer_group: Optional[MPI.Comm] = None assigned_training_ranks: Optional[List[int]] = None def is_buffer_rank(self) -> bool: @@ -147,6 +146,7 @@ def is_training_rank(self) -> bool: """Check if the current rank is part of the training group.""" return self.my_rank < self.num_training_ranks + def initialize_distributed_compute_mpi4py( num_remote_buffers: int, num_agent_groups: int, @@ -164,11 +164,11 @@ def initialize_distributed_compute_mpi4py( if pmi_size <= 1: print("+ PMI_SIZE <= 1, running in single process mode.") - return DistributedContext( + return DistributedContextmpi4py( my_rank=0, world_size=1, num_training_ranks=1, agent_group_size=1 ) - os.environ["RANK"] = str(MPI.COMM_WORLD.Get_rank() ) + os.environ["RANK"] = str(MPI.COMM_WORLD.Get_rank()) os.environ["WORLD_SIZE"] = str(pmi_size) print("+ OMP_NUM_THREADS = ", os.getenv("OMP_NUM_THREADS")) @@ -183,8 +183,8 @@ def initialize_distributed_compute_mpi4py( dist.barrier() print("+ Distributed compute initialized") - my_rank = rank ##dist.get_rank() # Global! - #world_size = dist.get_world_size() # Global! + my_rank = rank # dist.get_rank() # Global! + # world_size = dist.get_world_size() # Global! num_training_ranks = world_size - num_remote_buffers @@ -203,25 +203,25 @@ def initialize_distributed_compute_mpi4py( ] print(f"Agent group ranks: {agent_group_rank_list}") world_group = MPI.COMM_WORLD.Get_group() - agent_group_list = [ ] + agent_group_list = [] for i in range(num_agent_groups): grp = world_group.Incl(agent_group_rank_list[i]) - agent_group_list.append( MPI.COMM_WORLD.Create(grp)) + agent_group_list.append(MPI.COMM_WORLD.Create(grp)) # all training ranks in one global group training_ranks = [ r for r in range(num_training_ranks) ] # e.g., 0..num_training_ranks-1 - #train_global_group = dist.new_group( + # train_global_group = dist.new_group( # ranks=training_ranks, # backend=dist_backend, # timeout=datetime.timedelta(minutes=5), - #) + # ) grp = world_group.Incl(training_ranks) train_global_group = MPI.COMM_WORLD.Create(grp) - #print(f"Training global group ranks: {training_ranks}, {train_global_group}") - #assert train_global_group != MPI.COMM_NULL + # print(f"Training global group ranks: {training_ranks}, {train_global_group}") + # assert train_global_group != MPI.COMM_NULL buffer_group = None assigned_buffer = None @@ -230,15 +230,15 @@ def initialize_distributed_compute_mpi4py( buffer_ranks = list( range(num_training_ranks, num_training_ranks + num_remote_buffers) ) - #buffer_group = dist.new_group( + # buffer_group = dist.new_group( # buffer_ranks, # backend=dist_backend, # timeout=datetime.timedelta(minutes=5), - #) + # ) grp = world_group.Incl(buffer_ranks) buffer_group = MPI.COMM_WORLD.Create(grp) print(f" >>>>>>>>>>>>>>>>. Buffer group ranks: {buffer_ranks}, {buffer_group}") - #assert buffer_group != MPI.COMM_NULL + # assert buffer_group != MPI.COMM_NULL print(f"Buffer group ranks: {buffer_ranks}") @@ -258,7 +258,7 @@ def initialize_distributed_compute_mpi4py( else: print(" -> Buffer group") - #dist.barrier() + # dist.barrier() print("+ Distributed compute initialized, rank = ", my_rank) return DistributedContextmpi4py( @@ -485,7 +485,7 @@ def initialize_distributed_compute( assigned_buffer=assigned_buffer, buffer_group=buffer_group, assigned_training_ranks=assigned_training_ranks.get(my_rank, None), - dc_mpi4py=dc + dc_mpi4py=dc, ) diff --git a/tutorials/examples/multinode/spawn_policy.py b/tutorials/examples/multinode/spawn_policy.py index 99a4a7a9..bec22ffb 100644 --- a/tutorials/examples/multinode/spawn_policy.py +++ b/tutorials/examples/multinode/spawn_policy.py @@ -1,22 +1,20 @@ from __future__ import annotations +import json import logging +import random import threading import time from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple, cast +import mpi4py.MPI as MPI +import numpy as np import torch import torch.distributed as dist from gfn.gflownet.base import GFlowNet - -import os -import random -import numpy as np -import mpi4py.MPI as MPI from gfn.utils.common import Timer -import json logger = logging.getLogger(__name__) @@ -97,7 +95,7 @@ def __init__( poll_interval_s: float = 0.01, threshold: Optional[float] = None, cooldown: int = 200, - timing: Optional[dict] = None, ## timing is a dict to capture timing info + timing: Optional[dict] = None, # timing is a dict to capture timing info ) -> None: super().__init__(average_every) self.replacement_ratio = float(replacement_ratio) @@ -612,25 +610,28 @@ def _compute_averaging_weights( weights = 1.0 / (contributing_metrics + 1e-8) return weights / weights.sum() -###########################################################################################3 -## Selective averaging based on async mpi-3 comms -# Version 1 (fast-general): One buffer with all the params, hence utilizes the network BW much better than general version. -# Hence different model param dtypes can be handled. -# Comments: -# - Need a fix where there is one buffer for all params but in byte format, -# after comms, the buffer can be converted to appropriate dtypes for corr params dtypes. -# - Right now each param has its own buffer and window, which may produce incorrect output. -# Notes: -# - Only mean averaging strategy is implemented here for now. -# - Neighbors are randomly selected from all ranks except self. -# - Agents are killed based on age only for now. -###########################################################################################3 + +""" +Selective averaging based on async mpi-3 comms + Version 1 (fast-general): One buffer with all the params, hence utilizes the network BW much better than general version. + Hence different model param dtypes can be handled. + Comments: + - Need a fix where there is one buffer for all params but in byte format, + after comms, the buffer can be converted to appropriate dtypes for corr params dtypes. + - Right now each param has its own buffer and window, which may produce incorrect output. + Notes: + - Only mean averaging strategy is implemented here for now. + - Neighbors are randomly selected from all ranks except self. + - Agents are killed based on age only for now. +""" + class AsyncSelectiveAveragingPolicympi4pyGeneral(SpawnPolicy): r""" Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks. """ + def __init__( self, model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], @@ -655,7 +656,8 @@ def __init__( self._model: Optional[GFlowNet] = None self.threshold_metric = float(threshold_metric) - ## timers + + # timers self.timing = {} self.stats = {} @@ -663,7 +665,7 @@ def __init__( self.train_comm_group = group self._expose = False - ###### new agents' stats #### + # new agents' stats self.agents_killed = 0 self.averaging_ranks = 0 self._count = 0 @@ -675,22 +677,19 @@ def __init__( if self.debug_mode: self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" - with open(self.logfile, 'w') as f: + with open(self.logfile, "w") as f: f.write(f"Selective Averaging Log for Rank {self.myrank}\n") f.write("=" * 50 + "\n") - - ############################### def shutdown(self) -> None: - for _,v in self._mpi_tensor_wins.items(): + for _, v in self._mpi_tensor_wins.items(): v[0].Free() - ############################## def print_time(self) -> None: print("Selective Averaging timings:", flush=True) for k, v in self.timing.items(): - ## here v is a list, avg over the list - #avg = sum(v) / len(v) if len(v) > 0 else 0 + # here v is a list, avg over the list + # avg = sum(v) / len(v) if len(v) > 0 else 0 print(f"{k:<35}: {sum(v):>10.4f} seconds") def print_stats(self) -> None: @@ -702,44 +701,59 @@ def print_stats(self) -> None: maximum = max(v) if len(v) > 0 else 0 avg = sum(v) / len(v) if len(v) > 0 else 0 length = len(v) - print(f"Rank {self.myrank} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, count={length}") + print( + f"Rank {self.myrank} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, count={length}" + ) if k == "donors": avg_donors = avg num_calls = length - _named_params = list(self._model.named_parameters()) - named_params = [(name, param) for name, param in _named_params if param.dim() != 0] - print(f"Rank {self.myrank:<10} - {'param elements':<15} {'iter':<10} {'total params elements commd':<25}") + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] + print( + f"Rank {self.myrank:<10} - {'param elements':<15} {'iter':<10} {'total params elements commd':<25}" + ) for name, param in named_params: - device = param.device + param.device param_shape = param.data.shape - print(f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}") + print( + f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}" + ) def capture_comm(self, name: str, size: int) -> None: if name not in self.stats: self.stats[name] = [] self.stats[name].append(size) - ################################ def _ensure_initialized(self, model: GFlowNet) -> None: self._model = model self._initialized = True - ## export model parameters to mpi windows (should do that periodically to keep them fresh) + # export model parameters to mpi windows (should do that periodically to keep them fresh) self._expose_model_parameters(model) - ################################ def reset_age(self) -> None: self.max_age = random.randint(self.age_range[0], self.age_range[1]) self.age = 0 - ################################ - def is_agent_dying(self, local_metric: float, threshold_metric: float, check_agent = 0) -> bool: - if check_agent == 0: ## static theshold + def is_agent_dying( + self, local_metric: float, threshold_metric: float, check_agent=0 + ) -> bool: + if check_agent == 0: # static theshold return local_metric < threshold_metric - elif check_agent == 1: ## dynamic threshold based on age + elif check_agent == 1: # dynamic threshold based on age if self.age >= self.max_age: - print('+ Agent killed due to age: ', self.age, ' max_age: ', self.max_age, flush=True) + print( + "+ Agent killed due to age: ", + self.age, + " max_age: ", + self.max_age, + flush=True, + ) self.reset_age() return True @@ -749,8 +763,6 @@ def is_agent_dying(self, local_metric: float, threshold_metric: float, check_age else: raise ValueError(f"Unknown is_agent_dying version: {check_agent}") - - ################################ # Execute this function regularly to copy model params to mpi windows # for recepotrs to get recent params def _copy_model_params_to_buf( @@ -763,11 +775,7 @@ def _copy_model_params_to_buf( self._mpi_tensor_wins[name][1][:] = param.data.cpu().numpy().flatten() win.Unlock(rank=self.myrank) - - ################################ def _expose_model_parameters(self, model: GFlowNet) -> None: - rank = self.myrank - size = self.comm_size # Serialize model parameters to a contiguous numpy array param_tensors = {} @@ -777,14 +785,15 @@ def _expose_model_parameters(self, model: GFlowNet) -> None: # Create MPI windows for each parameter and its shape (2 separate windows set) self._mpi_tensor_wins = {} self._mpi_shape_wins = {} + assert isinstance(self.train_comm_group, MPI.Intracomm) + comm = cast(MPI.Intracomm, self.train_comm_group) for name, tensor in param_tensors.items(): buf = tensor - win = MPI.Win.Create(buf, comm=self.train_comm_group) + win = MPI.Win.Create(buf, comm=comm) self._mpi_tensor_wins[name] = (win, buf) self._copy_model_params_to_buf(model) - ################################ def _get_donors(self, n, k, d) -> List[int]: if k > n - 1: raise ValueError("k must be ≤ n-1 when excluding one value") @@ -794,70 +803,87 @@ def _get_donors(self, n, k, d) -> List[int]: # Pick k distinct values return random.sample(candidates, k) + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: - ################################ - def __call__(self, - iteration: int, - model: GFlowNet, - optimizer: torch.optim.Optimizer, - local_metric: float, - expose_params: bool = True, - group: MPI.Comm = MPI.COMM_WORLD, - ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: - - if self._expose == False: + if self._expose is False: self._expose_model_parameters(model) self._expose = True self._count += 1 self._model = model - check_agent = 1 ## version of dying agent check - ## validation info + check_agent = 1 # version of dying agent check + # validation info layer_name = None if self.debug_mode: layer_name = "pb.module.last_layer.bias" if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): - with Timer( - self.timing, "sa get_params_from_donors" - ): + with Timer(self.timing, "sa get_params_from_donors"): if self.debug_mode: - with open(self.logfile, 'a') as f: + with open(self.logfile, "a") as f: # kill this model and rebuild model with fresh weights - num_donors = max(1, int(self.comm_size * 0.5)) ## * self.replacement_ratio)) - donors = self._get_donors(self.comm_size, num_donors, self.myrank, self.averaging_strategy) - new_avg_params = self._get_model_params_from_donors(donors, layer_name, f) - json.dump({self._count: [self._mpi_tensor_wins[layer_name][1].tolist(), donors, new_avg_params[layer_name].tolist()]}, f) + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # * self.replacement_ratio)) + donors = self._get_donors( + self.comm_size, num_donors, self.myrank + ) + new_avg_params = self._get_model_params_from_donors( + donors, layer_name, f + ) + + if layer_name is not None: + json.dump( + { + self._count: [ + self._mpi_tensor_wins[layer_name][1].tolist(), + donors, + new_avg_params[layer_name].tolist(), + ] + }, + f, + ) f.write("\n") else: - num_donors = max(1, int(self.comm_size * 0.5)) # self.replacement_ratio)) + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # self.replacement_ratio)) donors = self._get_donors(self.comm_size, num_donors, self.myrank) - new_avg_params = self._get_model_params_from_donors(donors, layer_name, None) + new_avg_params = self._get_model_params_from_donors( + donors, layer_name, None + ) - with Timer( - self.timing, "sa new_agent_model_rebuild" - ): - tic = self.get_time() + with Timer(self.timing, "sa new_agent_model_rebuild"): model, optimizer = self._model_builder() for name, param in model.named_parameters(): if name in new_avg_params: param.data.copy_(new_avg_params[name]) - if expose_params == True: - with Timer( - self.timing, "sa copy_params_to_buf" - ): + if expose_params is True: + with Timer(self.timing, "sa copy_params_to_buf"): self._copy_model_params_to_buf(model) - return model, optimizer, True - + return model, optimizer, {"averaged_this_iteration": True} - ################################ - def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dict[str, torch.Tensor]: + def _get_model_params_from_donors( + self, donors: List[int], layer_name, f + ) -> Dict[str, torch.Tensor]: avg_state: Dict[str, torch.Tensor] = {} - _named_params = list(self._model.named_parameters()) - named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] tot_comm_ele = 0 self.capture_comm("donors", len(donors)) @@ -873,17 +899,21 @@ def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dic flat_size = np.prod(param_shape) assert flat_size > 0 tot_comm_ele += flat_size - donor_tensor_flat = np.zeros(flat_size, dtype=param.data.cpu().numpy().dtype) + donor_tensor_flat = np.zeros( + flat_size, dtype=param.data.cpu().numpy().dtype + ) tensor_win.Get([donor_tensor_flat, MPI.FLOAT], target_rank=src) tensor_win.Unlock(rank=src) - donor_tensor = torch.tensor(donor_tensor_flat.reshape(param_shape), device=device) - ## Adding all the donor tensors/params + donor_tensor = torch.tensor( + donor_tensor_flat.reshape(param_shape), device=device + ) + # Adding all the donor tensors/params acc.add_(donor_tensor) if self.debug_mode and name == layer_name: all_donors.append(donor_tensor.tolist()) - ## Additions: Other averaging strategies can be implemented here + # Additions: Other averaging strategies can be implemented here if self.debug_mode and name == layer_name: json.dump({self._count: all_donors}, f) @@ -896,57 +926,13 @@ def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dic self.capture_comm("num_param_tensors_received", tot_comm_ele) return avg_state - - ################################ - def _get_model_params_from_donors_general(self, donors: List[int]) -> Dict[str, torch.Tensor]: - self.avg_state: Dict[str, torch.Tensor] = {} - named_params = list(self._model.named_parameters()) - - for name, param in named_params: - device = param.device - param_shape = param.data.shape - #acc = torch.zeros_like(param.data) - acc = [] - - for i, src in enumerate(donors): - # Get shape of parameter from donor - #shape_win, shape_buf = self._mpi_shape_wins[name] - #shape_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) - #donor_shape = tuple(shape_buf.tolist()) - #shape_win.Unlock(rank=src, lock_type=MPI.LOCK_SHARED) - - # Get parameter tensor from donor - tensor_win, tensor_buf = self._mpi_tensor_wins[name] - tensor_win.Lock(rank=src, lock_type=MPI.LOCK_SHARED) - flat_size = np.prod(param_shape) - donor_tensor_flat = np.zeros(flat_size, dtype=param.data.cpu().numpy().dtype) - tensor_win.Get([donor_tensor_flat, MPI.FLOAT], target_rank=src) - tensor_win.Unlock(rank=src) - - donor_tensor = torch.tensor(donor_tensor_flat.reshape(param_shape), device=device) - ## Adding all the donor tensors/params - #acc.add_(donor_tensor) - acc.append(donor_tensor) - self.capture_comm("num_param_tensors_received", flat_size); - - ## Additions: Other averaging strategies can be implemented here - - #if self.averaging_strategy == "mean" and len(donors) > 0: - # default to mean averaging - #acc = acc / len(donors) - self.avg_state[name] = acc - - return self.avg_state - - - ################################ def _average_received_params( self, ) -> Dict[str, torch.Tensor]: avg_state: Dict[str, torch.Tensor] = {} for name, param in self.avg_state.items(): - #device = param.device - #param_shape = param.data.shape + # device = param.device + # param_shape = param.data.shape acc = torch.zeros_like(param[0].data) for i, donor_tensor in enumerate(param): @@ -961,22 +947,23 @@ def _average_received_params( return avg_state +""" +Selective averaging based on async mpi-3 comms + Version 2 (fast): One buffer with all the params, hence utilizes the network BW much better than general version. + It assumes all the params are of same dtype (float32) +Notes: + - Only mean averaging strategy is implemented here for now. + - Neighbors are randomly selected from all ranks except self. + - Agents are killed based on age only for now. +""" -###########################################################################################3 -## Selective averaging based on async mpi-3 comms -# Version 2 (fast): One buffer with all the params, hence utilizes the network BW much better than general version. -# It assumes all the params are of same dtype (float32) -# Notes: -# - Only mean averaging strategy is implemented here for now. -# - Neighbors are randomly selected from all ranks except self. -# - Agents are killed based on age only for now. -###########################################################################################3 class AsyncSelectiveAveragingPolicympi4pyFast(SpawnPolicy): r""" Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks. """ + def __init__( self, model_builder: Callable[[], Tuple[GFlowNet, torch.optim.Optimizer]], @@ -1000,16 +987,16 @@ def __init__( self._model: Optional[GFlowNet] = None self.threshold_metric = float(threshold_metric) - ## timers + # timers self.timing = {} self.stats = {} self._model = model self.train_comm_group = group - #self._expose_model_parameters(model) + # self._expose_model_parameters(model) self._expose = False - ###### new agents' stats #### + # **** new agents' stats **** self.agents_killed = 0 self.averaging_ranks = 0 self._count = 0 @@ -1022,27 +1009,27 @@ def __init__( self.age_range = age_range self.max_age = random.randint(self.age_range[0], self.age_range[1]) - ## test code, remove it later + # test code, remove it later if self.debug_mode: self.logfile = f"debug/selective_averaging_rank_{self.myrank}.log" - with open(self.logfile, 'w') as f: + with open(self.logfile, "w") as f: f.write(f"Selective Averaging Log for Rank {self.myrank}\n") f.write("=" * 50 + "\n") - ############################### def shutdown(self) -> None: self._mpi_tensor_wins[0].Free() - ############################## def print_time(self) -> None: print("Selective Averaging timings:", flush=True) for k, v in self.timing.items(): - ## here v is a list, avg over the list + # here v is a list, avg over the list print(f"{k:<35}: {sum(v):>10.4f} seconds") def print_stats(self) -> None: print("Selective Averaging comms stats:", flush=True) - print(f"Rank {self.myrank} - Agent replaced for {self.num_replacements} out of {self.total_iterations} iterations.") + print( + f"Rank {self.myrank} - Agent replaced for {self.num_replacements} out of {self.total_iterations} iterations." + ) avg_donors, num_calls = 0.0, 0 for k, v in self.stats.items(): # v is a list, print min, max ,avg, and len of it @@ -1050,44 +1037,59 @@ def print_stats(self) -> None: maximum = max(v) if len(v) > 0 else 0 avg = sum(v) / len(v) if len(v) > 0 else 0 length = len(v) - print(f"Rank {self.myrank:<10} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, across {length} iters") + print( + f"Rank {self.myrank:<10} - Stat {k:30}: min={minimum}, max={maximum}, avg={avg:.6f}, across {length} iters" + ) if k == "donors": avg_donors = avg num_calls = length - _named_params = list(self._model.named_parameters()) - named_params = [(name, param) for name, param in _named_params if param.dim() != 0] - print(f"Rank {self.myrank:<10} - {'param elements':<15} {'#comm_iters':<10} {'total params elements communicated':<25}") + _named_params = ( + list(self._model.named_parameters()) if self._model is not None else [] + ) + named_params = [ + (name, param) for name, param in _named_params if param.dim() != 0 + ] + print( + f"Rank {self.myrank:<10} - {'param elements':<15} {'#comm_iters':<10} {'total params elements communicated':<25}" + ) for name, param in named_params: - device = param.device + param.device param_shape = param.data.shape - print(f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}") + print( + f"Rank {self.myrank:<10} - {np.prod(param_shape):<15} {avg_donors*num_calls:<10} {np.prod(param_shape)*avg_donors*num_calls:<15}" + ) def capture_comm(self, name: str, size: int) -> None: if name not in self.stats: self.stats[name] = [] self.stats[name].append(size) - ################################ def _ensure_initialized(self, model: GFlowNet) -> None: self._model = model self._initialized = True - ## export model parameters to mpi windows (should do that periodically to keep them fresh) + # export model parameters to mpi windows (should do that periodically to keep them fresh) self._expose_model_parameters(model) - ################################ def reset_age(self) -> None: self.max_age = random.randint(self.age_range[0], self.age_range[1]) self.age = 0 - ################################ - def is_agent_dying(self, local_metric: float, threshold_metric: float, check_policy = 0) -> bool: - if check_policy == 0: ## static theshold + def is_agent_dying( + self, local_metric: float, threshold_metric: float, check_policy=0 + ) -> bool: + if check_policy == 0: # static theshold return local_metric < threshold_metric - elif check_policy == 1: ## dynamic threshold based on age + elif check_policy == 1: # dynamic threshold based on age if self.age >= self.max_age: - print('+ Agent killed due to age: ', self.age, ' max_age: ', self.max_age, flush=True) + print( + "+ Agent killed due to age: ", + self.age, + " max_age: ", + self.max_age, + flush=True, + ) self.reset_age() return True @@ -1097,8 +1099,6 @@ def is_agent_dying(self, local_metric: float, threshold_metric: float, check_pol else: raise ValueError(f"Unknown is_agent_dying version: {check_policy}") - - ################################ # Execute this function regularly to copy model params to mpi windows # for recepotrs to get recent params def _copy_model_params_to_buf( @@ -1110,23 +1110,23 @@ def _copy_model_params_to_buf( win = self._mpi_tensor_wins[0] win.Lock(rank=self.myrank, lock_type=MPI.LOCK_EXCLUSIVE) size = param.data.numel() - self._mpi_tensor_wins[1][offset:offset+size] = param.data.cpu().numpy().flatten() + self._mpi_tensor_wins[1][offset : offset + size] = ( + param.data.cpu().numpy().flatten() + ) offset += size win.Unlock(rank=self.myrank) - - ################################ def _expose_model_parameters(self, model: GFlowNet) -> None: print("+ Exposing model parameters via MPI windows...", flush=True) # Serialize model parameters to a contiguous numpy array param_size = 0 - dtypes = {param.dtype for param in model.parameters()} - ## todo: enable this to work with any dtypes for the model + {param.dtype for param in model.parameters()} + # todo: enable this to work with any dtypes for the model # print('model dtypes: ', dtypes) param_dtype = np.float32 th_param_dtype = torch.float32 - #self.param_shapes = {} + # self.param_shapes = {} for _, param in model.named_parameters(): param_size += param.data.numel() param_dtype = param.data.cpu().numpy().dtype @@ -1137,13 +1137,14 @@ def _expose_model_parameters(self, model: GFlowNet) -> None: # Create MPI windows for the flat parameter tensor buf = param_tensors_flat + assert isinstance(self.train_comm_group, MPI.Intracomm) + cast(MPI.Intracomm, self.train_comm_group) win = MPI.Win.Create(buf, comm=self.train_comm_group) - ## buffer attached to the win, used to copy data in/out of the win + # buffer attached to the win, used to copy data in/out of the win self._mpi_tensor_wins = (win, buf) self._copy_model_params_to_buf(model) - ################################ def _get_donors(self, n, k, d) -> List[int]: if k > n - 1: raise ValueError("k must be ≤ n-1 when excluding one value") @@ -1153,17 +1154,17 @@ def _get_donors(self, n, k, d) -> List[int]: # Random policy Pick k distinct random values return random.sample(candidates, k) - ################################ - def __call__(self, - iteration: int, - model: GFlowNet, - optimizer: torch.optim.Optimizer, - local_metric: float, - expose_params: bool = True, - group: MPI.Comm = MPI.COMM_WORLD, - ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: - - if self._expose == False: + def __call__( + self, + iteration: int, + model: GFlowNet, + optimizer: torch.optim.Optimizer, + local_metric: float, + expose_params: bool = True, + group: MPI.Comm = MPI.COMM_WORLD, + ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: + + if self._expose is False: self._expose_model_parameters(model) self._expose = True @@ -1171,30 +1172,29 @@ def __call__(self, self._model = model named_params = list(model.named_parameters()) for name, param in named_params: - param_shape = param.data.shape + param.data.shape - ## validation info + # validation info layer_name = None if self.debug_mode: layer_name = "pb.module.last_layer.bias" self.total_iterations += 1 - check_agent = 1 ## 0: static thresholding, 1: dynamic based on age + check_agent = 1 # 0: static thresholding, 1: dynamic based on age if self.is_agent_dying(local_metric, self.threshold_metric, check_agent): self.num_replacements += 1 - with Timer( - self.timing, "sa get_params_from_donors" - ): + with Timer(self.timing, "sa get_params_from_donors"): # kill this model and rebuild model with fresh weights - num_donors = max(1, int(self.comm_size * 0.5)) ## <<<<< parameterize this one + num_donors = max( + 1, int(self.comm_size * 0.5) + ) # <<<<< parameterize this one donors = self._get_donors(self.comm_size, num_donors, self.myrank) - _new_avg_params = self._get_model_params_from_donors(donors, layer_name, None) - + _new_avg_params = self._get_model_params_from_donors( + donors, layer_name, None + ) - with Timer( - self.timing, "sa param_list_to_dict_convert" - ): - ## conver the flat tensor to model param dict + with Timer(self.timing, "sa param_list_to_dict_convert"): + # conver the flat tensor to model param dict new_avg_params: Dict[str, torch.Tensor] = {} win_buf: Dict[str, torch.Tensor] = {} @@ -1203,44 +1203,55 @@ def __call__(self, device = param.device flat_size = param.data.numel() assert flat_size == np.prod(param.data.shape) - donor_tensor_flat = _new_avg_params[offset:offset + flat_size] - donor_tensor = torch.tensor(donor_tensor_flat.reshape(param.data.shape), device=device) + donor_tensor_flat = _new_avg_params[offset : offset + flat_size] + donor_tensor = torch.tensor( + donor_tensor_flat.reshape(param.data.shape), device=device + ) if self.debug_mode: - buf_tensor_flat = self._mpi_tensor_wins[1][offset:offset + flat_size] - buf_tensor = torch.tensor(buf_tensor_flat.reshape(param.data.shape), device=device) + buf_tensor_flat = self._mpi_tensor_wins[1][ + offset : offset + flat_size + ] + buf_tensor = torch.tensor( + buf_tensor_flat.reshape(param.data.shape), device=device + ) win_buf[name] = buf_tensor new_avg_params[name] = donor_tensor offset += flat_size if self.debug_mode: - with open(self.logfile, 'a') as f: - json.dump({self._count: [win_buf[layer_name].tolist(), donors, new_avg_params[layer_name].tolist()]}, f) - f.write("\n") + with open(self.logfile, "a") as f: + if layer_name is not None: + json.dump( + { + self._count: [ + win_buf[layer_name].tolist(), + donors, + new_avg_params[layer_name].tolist(), + ] + }, + f, + ) + f.write("\n") - with Timer( - self.timing, "sa new_agent_model_rebuild" - ): + with Timer(self.timing, "sa new_agent_model_rebuild"): model, optimizer = self._model_builder() for name, param in model.named_parameters(): if name in new_avg_params: param.data.copy_(new_avg_params[name]) - - if expose_params == True: - with Timer( - self.timing, "sa copy_params_to_buf" - ): + if expose_params is True: + with Timer(self.timing, "sa copy_params_to_buf"): self._copy_model_params_to_buf(model) - return model, optimizer, True - + return model, optimizer, {"averaged_this_iteration": True} - ################################ - def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dict[str, torch.Tensor]: - #avg_state: Dict[str, torch.Tensor] = {} - #_named_params = list(self._model.named_parameters()) - #named_params = [(name, param) for name, param in _named_params if param.dim() != 0] + def _get_model_params_from_donors( + self, donors: List[int], layer_name, f + ) -> torch.Tensor: + # avg_state: Dict[str, torch.Tensor] = {} + # _named_params = list(self._model.named_parameters()) + # named_params = [(name, param) for name, param in _named_params if param.dim() != 0] tot_comm_ele = 0 self.capture_comm("donors", len(donors)) @@ -1255,7 +1266,7 @@ def _get_model_params_from_donors(self, donors: List[int], layer_name, f) -> Dic tensor_win.Get([self.donor_tensor_flat, MPI.FLOAT], target_rank=src) tensor_win.Unlock(rank=src) - ## Adding all the donor tensors/params + # Adding all the donor tensors/params self.acc.add_(self.donor_tensor_flat) tot_comm_ele = tot_comm_ele + flat_size @@ -1280,18 +1291,20 @@ def __call__( model: GFlowNet, optimizer: torch.optim.Optimizer, local_metric: Optional[float] = None, - group=MPI.COMM_WORLD, + group: MPI.Comm = MPI.COMM_WORLD, ) -> Tuple[GFlowNet, torch.optim.Optimizer, dict]: if not dist.is_available() or not dist.is_initialized(): return model, optimizer, {} if iteration % self.average_every != 0: return model, optimizer, {"averaged_this_iteration": False} - #print("AverageAll mpi4py model parameters across all ranks ...", flush=True) + # print("AverageAll mpi4py model parameters across all ranks ...", flush=True) world_size = group.Get_size() for param in model.parameters(): - param_tensor = param.detach().cpu().numpy().copy() ##param.data.clone().numpy() - #dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=group) + param_tensor = ( + param.detach().cpu().numpy().copy() + ) # param.data.clone().numpy() + # dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=group) group.Allreduce(MPI.IN_PLACE, param_tensor, op=MPI.SUM) param_tensor /= world_size param.data.copy_(torch.from_numpy(param_tensor)) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index e49dd2aa..d48152a2 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -12,12 +12,12 @@ import numpy as np import torch -import wandb from numpy.typing import NDArray from scipy.special import logsumexp from sklearn.neighbors import KernelDensity from tqdm import tqdm, trange +import wandb from gfn.estimators import ScalarEstimator from gfn.gflownet import ( DBGFlowNet, diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index b79b828d..133c89c7 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -15,9 +15,9 @@ from typing import cast import torch -import wandb from tqdm import tqdm, trange +import wandb from gfn.estimators import DiscretePolicyEstimator from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 3c7544ab..3388c036 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -36,6 +36,7 @@ from typing import Optional, Tuple, cast import matplotlib.pyplot as plt +import mpi4py.MPI as MPI import torch import torch.distributed as dist from matplotlib.gridspec import GridSpec @@ -63,7 +64,7 @@ from tutorials.examples.multinode.spawn_policy import ( AsyncSelectiveAveragingPolicy, AsyncSelectiveAveragingPolicympi4pyFast, - AsyncSelectiveAveragingPolicympi4py, + AsyncSelectiveAveragingPolicympi4pyGeneral, AverageAllPolicy, AverageAllPolicympi4py, ) @@ -840,52 +841,54 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: dist.barrier(group=distributed_context.train_global_group) # Set up averaging policy (called every iteration; internal guard checks cadence/distributed) - averaging_policy = None + averaging_policy_torch = None + averaging_policy_mpi4py = None + if args.distributed: + averaging_policy_torch = AverageAllPolicy(average_every=args.average_every) + averaging_policy_mpi4py = AverageAllPolicympi4py( + average_every=args.average_every + ) + + mpi4py_train_group = ( + distributed_context.dc_mpi4py.train_global_group + if distributed_context.dc_mpi4py is not None + else MPI.COMM_WORLD + ) + if args.use_selective_averaging: - if oldcode: - averaging_policy = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] + averaging_policy_torch = AsyncSelectiveAveragingPolicy( # type: ignore[abstract] + model_builder=_model_builder, + average_every=args.average_every, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + threshold=args.performance_tracker_threshold, + cooldown=args.performance_tracker_cooldown, + ) + averaging_policy_mpi4py = AsyncSelectiveAveragingPolicympi4pyGeneral( # type: ignore[abstract] + model_builder=_model_builder, + model=gflownet, + average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, + replacement_ratio=args.replacement_ratio, + averaging_strategy=args.averaging_strategy, + momentum=args.momentum, + age_range=args.age_range, + group=mpi4py_train_group, + ) + if args.fast_sa: + averaging_policy_mpi4py = AsyncSelectiveAveragingPolicympi4pyFast( # type: ignore[abstract] model_builder=_model_builder, + model=gflownet, average_every=args.average_every, + threshold_metric=args.performance_tracker_threshold, replacement_ratio=args.replacement_ratio, averaging_strategy=args.averaging_strategy, momentum=args.momentum, - threshold=args.performance_tracker_threshold, - cooldown=args.performance_tracker_cooldown, + age_range=args.age_range, + group=mpi4py_train_group, ) - else: - if args.fast_sa: - ## fast -- assumes all the params are of same precision - averaging_policy = AsyncSelectiveAveragingPolicympi4pyFast( # type: ignore[abstract] - model_builder=_model_builder, - model=gflownet, - average_every=args.average_every, - threshold_metric=args.performance_tracker_threshold, - replacement_ratio=args.replacement_ratio, - averaging_strategy=args.averaging_strategy, - momentum=args.momentum, - age_range=args.age_range, - group=distributed_context.dc_mpi4py.train_global_group, - ) - else: - # general -- more general where diffeernt params can be of different precision - averaging_policy = AsyncSelectiveAveragingPolicympi4py( # type: ignore[abstract] - model_builder=_model_builder, - model=gflownet, - average_every=args.average_every, - threshold_metric=args.performance_tracker_threshold, - replacement_ratio=args.replacement_ratio, - averaging_strategy=args.averaging_strategy, - momentum=args.momentum, - age_range=args.age_range, - group=distributed_context.dc_mpi4py.train_global_group, - ) - else: - if oldcode: - averaging_policy = AverageAllPolicy(average_every=args.average_every) - else: - print("+ Using AverageAllPolicympi4py", flush=True) - averaging_policy = AverageAllPolicympi4py(average_every=args.average_every) # Accumulators for averaging score_dict between log intervals. score_dict_accum: dict[str, float] = {} @@ -998,23 +1001,33 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: with Timer( timing, "averaging_model", enabled=args.timing ) as model_averaging_timer: - if averaging_policy is not None: - if oldcode: - gflownet, optimizer, averaging_info = averaging_policy( + + if oldcode: + if averaging_policy_torch is not None: + gflownet, optimizer, averaging_info = averaging_policy_torch( iteration=iteration, model=gflownet, optimizer=optimizer, local_metric=( - score_dict["score"] if score_dict is not None else -loss.item() + score_dict["score"] + if score_dict is not None + else -loss.item() ), group=distributed_context.train_global_group, ) - else: - gflownet, optimizer, averaging_info = averaging_policy( + else: + if ( + averaging_policy_mpi4py is not None + and distributed_context.dc_mpi4py is not None + ): + gflownet, optimizer, averaging_info = averaging_policy_mpi4py( iteration=iteration, model=gflownet, optimizer=optimizer, - local_metric=(score_dict["score"] if score_dict is not None else -loss.item() + local_metric=( + score_dict["score"] + if score_dict is not None + else -loss.item() ), group=distributed_context.dc_mpi4py.train_global_group, ) @@ -1107,7 +1120,11 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: wandb.log(to_log, step=iteration) with Timer(timing, "barrier 2", enabled=(args.timing and args.distributed)): - if args.distributed and args.timing: + if ( + args.distributed + and args.timing + and distributed_context.dc_mpi4py is not None + ): t_comm = distributed_context.dc_mpi4py.train_global_group t_comm.Barrier() @@ -1120,9 +1137,11 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: if args.distributed: dist.barrier(group=distributed_context.train_global_group) - assert averaging_policy is not None + assert averaging_policy_torch is not None + assert averaging_policy_mpi4py is not None try: - averaging_policy.shutdown() + averaging_policy_torch.shutdown() + averaging_policy_mpi4py.shutdown() except Exception: pass @@ -1143,8 +1162,13 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: for k, v in timing.items(): logger.info("%-25s %10.4fs", k, sum(v)) try: - averaging_policy.print_time() - averaging_policy.print_stats() + if ( + not oldcode + and args.use_selective_averaging + and averaging_policy_mpi4py is not None + ): + averaging_policy_mpi4py.print_time() + averaging_policy_mpi4py.print_stats() except Exception: pass @@ -1172,7 +1196,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: if args.distributed: dist.barrier(group=distributed_context.train_global_group) - assert averaging_policy is not None + assert distributed_context is not None try: distributed_context.cleanup() except Exception: @@ -1279,7 +1303,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) parser.add_argument( "--age_range", - type=lambda s: tuple(map(int, s.split(','))), + type=lambda s: tuple(map(int, s.split(","))), default=(5, 15), help="Age range (iterations) for selective averaging policy as tuple (min_age, max_age), e.g., '5,15'", ) @@ -1545,7 +1569,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: parser.add_argument( "--performance_tracker_threshold", type=float, - default=None, + default=0.0, help="Threshold for the performance tracker. If None, the performance tracker is not triggered.", ) parser.add_argument( @@ -1556,5 +1580,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) args = parser.parse_args() - assert args.age_range[1] >= args.age_range[0], "Invalid age_range: max_age must be ge min_age" + assert ( + args.age_range[1] >= args.age_range[0] + ), "Invalid age_range: max_age must be ge min_age" main(args) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 1595d778..45353709 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -1,9 +1,9 @@ from argparse import ArgumentParser import torch -import wandb from tqdm import tqdm +import wandb from gfn.estimators import DiscretePolicyEstimator from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM From e9ca600942d2d73e245ac976da455e9bc1b83e47 Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Wed, 11 Feb 2026 03:48:51 -0800 Subject: [PATCH 3/5] refactored code --- tutorials/examples/train_hypergrid.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 3388c036..83c944d3 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -826,7 +826,6 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: # Initialize some variables before the training loop. timing = {} - oldcode = args.oldcode time_start = time.time() l1_distances, validation_steps = [], [] @@ -877,7 +876,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: age_range=args.age_range, group=mpi4py_train_group, ) - if args.fast_sa: + if args.mpi_sa_mode == "fast": averaging_policy_mpi4py = AsyncSelectiveAveragingPolicympi4pyFast( # type: ignore[abstract] model_builder=_model_builder, model=gflownet, @@ -1002,7 +1001,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: timing, "averaging_model", enabled=args.timing ) as model_averaging_timer: - if oldcode: + if args.spawn_backend == "dist": if averaging_policy_torch is not None: gflownet, optimizer, averaging_info = averaging_policy_torch( iteration=iteration, @@ -1163,7 +1162,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: logger.info("%-25s %10.4fs", k, sum(v)) try: if ( - not oldcode + args.spawn_backend == "mpi" and args.use_selective_averaging and averaging_policy_mpi4py is not None ): @@ -1292,14 +1291,20 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ) ## for mpi-3 code of selective averaging debug parser.add_argument( - "--oldcode", - action="store_true", - help="Temp switch between old and new selective averaging code (dist mpi or mpi4py)", + "--spawn_backend", + choices=["dist", "mpi"], + default="mpi", + help="Backend for spawn policy implementation: torch.distributed or mpi4py", ) parser.add_argument( - "--fast_sa", - action="store_true", - help="Use fast (comms) selective averaging mpi4py code assuming all params are of same precision (e.g., float32)", + "--mpi_sa_mode", + choices=["general", "fast"], + default="general", + help=( + "MPI selective averaging implementation to use. " + "'fast' uses an optimized communication path assuming all parameters " + "have the same dtype (e.g., float32)." + ), ) parser.add_argument( "--age_range", From fbe4fe8875dbc5f568a5c0a9dc60cfe1fde2cf93 Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Wed, 11 Feb 2026 04:10:46 -0800 Subject: [PATCH 4/5] spawn_backend choice renamed --- tutorials/examples/train_hypergrid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 83c944d3..9849ffea 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -1162,7 +1162,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: logger.info("%-25s %10.4fs", k, sum(v)) try: if ( - args.spawn_backend == "mpi" + args.spawn_backend == "mpi4py" and args.use_selective_averaging and averaging_policy_mpi4py is not None ): @@ -1292,8 +1292,8 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: ## for mpi-3 code of selective averaging debug parser.add_argument( "--spawn_backend", - choices=["dist", "mpi"], - default="mpi", + choices=["dist", "mpi4py"], + default="mpi4py", help="Backend for spawn policy implementation: torch.distributed or mpi4py", ) parser.add_argument( From b8bb472ce6ce1c2557aba97a5a31d8ba799b8353 Mon Sep 17 00:00:00 2001 From: chirayuharyan Date: Wed, 11 Feb 2026 05:51:59 -0800 Subject: [PATCH 5/5] changed default mpi4py selective average method --- tutorials/examples/train_hypergrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 9849ffea..2b1128aa 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -1299,7 +1299,7 @@ def _model_builder() -> Tuple[GFlowNet, torch.optim.Optimizer]: parser.add_argument( "--mpi_sa_mode", choices=["general", "fast"], - default="general", + default="fast", help=( "MPI selective averaging implementation to use. " "'fast' uses an optimized communication path assuming all parameters "