diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 794f829c6..1a57725db 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -231,8 +231,10 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) # Load model config here. In case model_path is not provided, get it from private conf if model_path is None: path = get_path_model(run_id=run_id) + _logger.info(f"Loading config from default model_path: {path}") else: path = Path(model_path) / run_id + _logger.info(f"Loading config from provided model_path: {path}") fname = path / _get_model_config_file_read_name(run_id, mini_epoch) assert fname.exists(), ( diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 24962da13..4a0339984 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -11,7 +11,6 @@ import itertools import logging -from pathlib import Path import omegaconf import torch @@ -21,7 +20,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -179,7 +178,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(run_id=run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" )