diff --git a/gns/train.py b/gns/train.py index a730e75..52e3152 100644 --- a/gns/train.py +++ b/gns/train.py @@ -202,7 +202,7 @@ def predict(device: str, cfg: DictConfig): example_rollout["loss"] = loss.mean() filename = f"{cfg.output.filename}_ex{example_i}.pkl" filename_render = f"{cfg.output.filename}_ex{example_i}" - filename = os.path.join(cfg.output.path, filename_render) + filename = os.path.join(cfg.output.path, f"{filename_render}.pkl") with open(filename, "wb") as f: pickle.dump(example_rollout, f) if cfg.rendering.mode: @@ -628,6 +628,7 @@ def train(rank, cfg, world_size, device, verbose, use_dist): cfg, rank, device_id, + use_dist, ) writer.add_scalar("Loss/valid", valid_loss.item(), step) @@ -698,9 +699,15 @@ def train(rank, cfg, world_size, device, verbose, use_dist): if cfg.training.validation_interval is not None: sampled_valid_example = next(iter(valid_dl)) epoch_valid_loss = validation( - simulator, sampled_valid_example, n_features, cfg, rank, device_id + simulator, + sampled_valid_example, + n_features, + cfg, + rank, + device_id, + use_dist, ) - if device == torch.device("cuda"): + if use_dist: torch.distributed.reduce( epoch_valid_loss, dst=0, op=torch.distributed.ReduceOp.SUM ) @@ -807,7 +814,7 @@ def _get_simulator( return simulator -def validation(simulator, example, n_features, cfg, rank, device_id): +def validation(simulator, example, n_features, cfg, rank, device_id, use_dist): ( position, particle_type, @@ -830,7 +837,7 @@ def validation(simulator, example, n_features, cfg, rank, device_id): # Select the appropriate prediction function predict_accelerations = ( simulator.module.predict_accelerations - if isinstance(device_id, int) + if use_dist else simulator.predict_accelerations ) # Get the predictions and target accelerations