From 07ffab932fc6e15dfdd814e9b6780057007bdd9d Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 3 Feb 2026 12:56:49 +0100 Subject: [PATCH 01/13] added test script for relative paths --- integration_tests/test_relative_paths.py | 156 +++++++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 integration_tests/test_relative_paths.py diff --git a/integration_tests/test_relative_paths.py b/integration_tests/test_relative_paths.py new file mode 100644 index 000000000..cc49f3efa --- /dev/null +++ b/integration_tests/test_relative_paths.py @@ -0,0 +1,156 @@ +"""Integration test to check for relative paths in the codebase.""" + +import os +import re +import pytest +from pathlib import Path + + +class TestRelativePaths: + """Test suite to detect potentially problematic relative paths in the codebase.""" + + # Root directory of the project + PROJECT_ROOT = Path(__file__).parent.parent + + # Patterns that indicate relative paths + RELATIVE_PATH_PATTERNS = [ + r'"\./[^"]+', # "./path" + r"'\./[^']+", # './path' + r'"\.\./[^"]+', # "../path" + r"'\.\./[^']+", # '../path' + r'open\(["\'][^/][^"\']+["\']', # open('relative/path') + ] + + # File extensions to check + EXTENSIONS_TO_CHECK = {'.py'} # , '.yaml', '.yml', '.json', '.toml' + + # Directories to exclude + EXCLUDE_DIRS = {'__pycache__', '.git', '.venv', 'integration_tests'} + + # Files to exclude (e.g., this test file itself) + EXCLUDE_FILES = {'test_relative_paths.py'} + + def get_all_source_files(self): + """Recursively get all source files in the project.""" + source_files = [] + for root, dirs, files in os.walk(self.PROJECT_ROOT): + # Skip excluded directories + dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] + + for file in files: + if file in self.EXCLUDE_FILES: + continue + if Path(file).suffix in self.EXTENSIONS_TO_CHECK: + source_files.append(Path(root) / file) + return source_files + + def find_relative_paths_in_file(self, filepath: Path) -> list[tuple[int, str, str]]: + """ + Find relative paths in a file. + + Returns: + List of tuples: (line_number, matched_pattern, line_content) + """ + matches = [] + try: + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + for line_num, line in enumerate(f, 1): + # Skip comments + stripped = line.strip() + if stripped.startswith('#') or stripped.startswith('//'): + continue + + for pattern in self.RELATIVE_PATH_PATTERNS: + found = re.search(pattern, line) + if found: + matches.append((line_num, found.group(), line.strip())) + except Exception as e: + pytest.skip(f"Could not read file {filepath}: {e}") + + return matches + + def test_no_hardcoded_relative_paths(self): + """Ensure no hardcoded relative paths exist in Python files.""" + source_files = self.get_all_source_files() + violations = [] + + for filepath in source_files: + print(f"Checking file: {filepath}") + matches = self.find_relative_paths_in_file(filepath) + for line_num, match, line in matches: + violations.append({ + 'file': str(filepath.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'match': match, + 'content': line + }) + + if violations: + report = "\n\nRelative path violations found:\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}\n" + report += f" Match: {v['match']}\n" + report += f" Line: {v['content']}\n" + + pytest.fail(report) + + def test_yaml_configs_use_absolute_or_variable_paths(self): + """Check that YAML config files don't use hardcoded relative paths.""" + yaml_files = [f for f in self.get_all_source_files() + if f.suffix in {'.yaml', '.yml'}] + + violations = [] + relative_path_pattern = re.compile(r':\s*["\']?\.\.?/[^"\'#\n]+') + + for yaml_file in yaml_files: + print(f"Checking YAML file: {yaml_file}") + try: + with open(yaml_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if relative_path_pattern.search(line): + violations.append({ + 'file': str(yaml_file.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'content': line.strip() + }) + except Exception: + continue + + if violations: + report = "\n\nRelative paths in YAML configs:\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}: {v['content']}\n" + pytest.fail(report) + + def test_path_construction_uses_pathlib_or_os_path(self): + """Check that path construction uses pathlib or os.path properly.""" + python_files = [f for f in self.get_all_source_files() if f.suffix == '.py'] + + # Pattern for string concatenation that looks like path building + bad_pattern = re.compile(r'["\'][^"\']*["\']\s*\+\s*["\']/|/["\']\s*\+') + + violations = [] + for py_file in python_files: + print(f"Checking Python file: {py_file}") + try: + with open(py_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if bad_pattern.search(line): + violations.append({ + 'file': str(py_file.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'content': line.strip() + }) + except Exception: + continue + + if violations: + report = "\n\nPotential unsafe path concatenation found:\n" + report += "(Consider using pathlib.Path or os.path.join)\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}: {v['content']}\n" + pytest.fail(report) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file From 7abcc1b943017fda3c4e33d4cf1dd6bed8f0c37f Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 9 Feb 2026 02:16:32 +0100 Subject: [PATCH 02/13] updated path_run --- src/weathergen/model/model_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 24962da13..8e8ff2959 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -11,7 +11,6 @@ import itertools import logging -from pathlib import Path import omegaconf import torch @@ -21,7 +20,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import get_path_model, Config, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -179,7 +178,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(cf, run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" ) From 9db92dd468892307e7773ef0ffbbc20e41923e88 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 9 Feb 2026 02:18:07 +0100 Subject: [PATCH 03/13] changed loss keys reference --- src/weathergen/utils/train_logger.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2812134a0..4e55baaa8 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -144,7 +144,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: ) run_id = cf.general.run_id - result_dir_base = config.get_path_run(cf) + result_dir_base = config.get_path_run(cf).parent result_dir = result_dir_base / run_id fname_log_train = result_dir / f"{run_id}_train_log.txt" fname_log_val = result_dir / f"{run_id}_val_log.txt" @@ -156,12 +156,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_train = ["dtime", "samples", "mse", "lr"] cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] for si in cf.streams: - for lf in cf.loss_fcts: + for lf in cf.training_config.losses.physical.loss_fcts: cols1 += [_key_loss(si["name"], lf[0])] cols_train += [ si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] ] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts] + with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts] if with_stddev: for si in cf.streams: cols1 += [_key_stddev(si["name"])] @@ -214,12 +214,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] for si in cf.streams: - for lf in cf.loss_fcts_val: + for lf in cf.training_config.losses.physical.loss_fcts: cols_val += [ si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] ] cols2 += [_key_loss(si["name"], lf[0])] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] + with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts] if with_stddev: for si in cf.streams: cols2 += [_key_stddev(si["name"])] @@ -370,6 +370,8 @@ def clean_df(df, columns: list[str] | None): idcs = [i for i in range(len(columns)) if columns[i] == "loss_avg_mean"] if len(idcs) > 0: columns[idcs[0]] = "loss_avg_0_mean" + for key in list(df.columns): + _logger.info(key) df = df.select(columns) # Remove all rows where all columns are null df = df.filter(~pl.all_horizontal(pl.col(c).is_null() for c in columns)) @@ -392,18 +394,21 @@ def clean_name(s: str) -> str: def _key_loss(st_name: str, lf_name: str) -> str: - st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_avg" + st_name = clean_name(st_name) # LossPhysical.ERA5.mse.t_600.2 + return f"LossPhysical.{st_name}.mse.avg" # LossPhysical.ERA5.mse.avg + # return f"stream.{st_name}.loss_{lf_name}.loss_avg" def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: - st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}" + st_name = clean_name(st_name) + return f"LossPhysical.{st_name}.{lf_name}.{ch_name}" # LossPhysical.ERA5.mse.t_500.1 + # return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}" def _key_stddev(st_name: str) -> str: st_name = clean_name(st_name) - return f"stream.{st_name}.stddev_avg" + return f"LossPhysical.loss_avg" # + # return f"stream.{st_name}.stddev_avg" def prepare_losses_for_logging( From af03e1f73c686184ea7f704a42e0c6dfd8c44e4d Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 9 Feb 2026 02:21:13 +0100 Subject: [PATCH 04/13] improved paths retrieval --- src/weathergen/utils/plot_training.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 35bfafe3e..1ce69a85d 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -23,7 +23,8 @@ _logger = logging.getLogger(__name__) DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") - +DEFAULT_CONFIG_FILE= Path("./config/default_config.yml") +DEFAULT_MODEL_PATH = config._get_shared_wg_path() / "models" #################################################################################################### def _ensure_list(value): @@ -150,7 +151,7 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### -def get_stream_names(run_id: str, model_path: Path | None = "./model"): +def get_stream_names(run_id: str, model_path: Path | None = DEFAULT_MODEL_PATH): """ Get the stream names from the model configuration file. @@ -492,7 +493,7 @@ def plot_loss_per_run( if errs is None: errs = ["mse"] - plot_dir = Path(plot_dir) + plot_dir = config._get_shared_wg_path() / "plots" modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -594,12 +595,14 @@ def plot_train(args=None): ) parser.add_argument( - "-o", "--output_dir", default="./plots/", type=Path, help="Directory where plots are saved" + "-o", "--output_dir", + default=config._get_shared_wg_path() / "plots", + type=Path, help="Directory where plots are saved" ) parser.add_argument( "-m", "--model_base_dir", - default=None, + default=config._get_shared_wg_path() / "models", type=Path, help="Base-directory where models are saved", ) @@ -673,7 +676,7 @@ def plot_train(args=None): clean_plot_folder(out_dir) # read logged data - + runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids] # determine which runs are still alive (as a process, though they might hang internally) From c8ff618ecdc466d582ac03d8ba3e7905b23a6c0a Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 9 Feb 2026 03:57:25 +0100 Subject: [PATCH 05/13] used _get_shared_wg_path to handle paths to models --- src/weathergen/utils/compare_run_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/utils/compare_run_configs.py b/src/weathergen/utils/compare_run_configs.py index d33789da0..13539f6c6 100755 --- a/src/weathergen/utils/compare_run_configs.py +++ b/src/weathergen/utils/compare_run_configs.py @@ -31,7 +31,7 @@ import yaml from omegaconf import OmegaConf -from weathergen.common.config import load_run_config +from weathergen.common.config import _get_shared_wg_path, load_run_config def truncate_value(value, max_length=50): @@ -152,14 +152,14 @@ def main(): "-m1", "--model_directory_1", type=Path, - default=Path("models/"), + default=_get_shared_wg_path() / "models", help="Path to model directory for -r1/--run_id_1", ) parser.add_argument( "-m2", "--model_directory_2", type=Path, - default=Path("models/"), + default=_get_shared_wg_path() / "models", help="Path to model directory for -r2/--run_id_2", ) parser.add_argument( From 1068151935e8f913face068209b6d6f6ed3e8c70 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Mon, 9 Feb 2026 05:53:25 +0100 Subject: [PATCH 06/13] fix: Use shared paths for model/results/plots output Replace relative path handling with _get_shared_wg_path() across evaluation and export modules. Ensures models, results, and plots are written to correct shared workspace directories. Fixes path resolution for inference and evaluation on shared storage. --- packages/common/src/weathergen/common/config.py | 2 +- .../src/weathergen/evaluate/export/cf_utils.py | 6 +++++- .../weathergen/evaluate/export/export_inference.py | 5 +++-- .../src/weathergen/evaluate/export/io_utils.py | 6 ++++-- .../evaluate/export/parsers/netcdf_parser.py | 3 ++- .../evaluate/export/parsers/quaver_parser.py | 3 ++- .../src/weathergen/evaluate/io/wegen_reader.py | 4 ++++ .../src/weathergen/evaluate/plotting/plotter.py | 13 ++++++------- .../evaluate/src/weathergen/evaluate/utils/utils.py | 2 +- 9 files changed, 28 insertions(+), 16 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 794f829c6..8c3b593f9 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -232,7 +232,7 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) if model_path is None: path = get_path_model(run_id=run_id) else: - path = Path(model_path) / run_id + path = _get_shared_wg_path() / "models" / run_id fname = path / _get_model_config_file_read_name(run_id, mini_epoch) assert fname.exists(), ( diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 201ffa168..3e240ceab 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -4,6 +4,8 @@ import numpy as np import xarray as xr +from weathergen.common.config import _get_shared_wg_path + _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -32,11 +34,13 @@ def __init__(self, config, **kwargs): self.fstep_hours = np.timedelta64(self.fstep_hours, "h") self.mapping = config.get("variables", {}) + self.output_dir = _get_shared_wg_path() / "results" / config.run_id + def get_output_filename(self) -> Path: """ Generate output filename based on run_id and output directory. """ - return Path(self.output_dir) / f"{self.run_id}.{self.file_extension}" + return self.output_dir / f"{self.run_id}.{self.file_extension}" def process_sample(self, fstep_iterator_results: iter, ref_time: np.datetime64): """ diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 0bf4be398..6c92e276b 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -20,7 +20,7 @@ from omegaconf import OmegaConf -from weathergen.common.config import _REPO_ROOT +from weathergen.common.config import _get_shared_wg_path, _REPO_ROOT from weathergen.evaluate.export.export_core import export_model_outputs _logger = logging.getLogger(__name__) @@ -217,7 +217,8 @@ def export_from_args(args: list) -> None: _logger.info(kwargs) # Ensure output directory exists - out_dir = Path(args.output_dir) + _logger.info(f"Path(args.output_dir) = {Path(args.output_dir)}") + out_dir = _get_shared_wg_path() / "results" / args.run_id out_dir.mkdir(parents=True, exist_ok=True) for dtype in args.type: diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py index 06f0cf25f..69801ec18 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from weathergen.common.config import get_model_results +from weathergen.common.config import _get_shared_wg_path, get_model_results from weathergen.common.io import zarrio_reader _logger = logging.getLogger(__name__) @@ -43,7 +43,9 @@ def output_filename( frt = np.datetime_as_string(forecast_ref_time, unit="h") if regrid_degree is not None: run_id += f"_regular{regrid_degree, regrid_degree}" - out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" + + output_dir = _get_shared_wg_path() / "results" / run_id + out_fname = output_dir / f"{prefix}_{frt}_{run_id}.{file_extension}" return out_fname diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index fe7655fbe..725f31cfe 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -9,6 +9,7 @@ from weathergen.evaluate.export.cf_utils import CfParser from weathergen.evaluate.export.reshape import Regridder, find_pl + _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -103,7 +104,7 @@ def get_output_filename(self, forecast_ref_time: np.datetime64) -> Path: frt = np.datetime_as_string(forecast_ref_time, unit="h") out_fname = ( - Path(self.output_dir) / f"{self.data_type}_{frt}_{self.run_id}.{self.file_extension}" + self.output_dir / f"{self.data_type}_{frt}_{self.run_id}.{self.file_extension}" ) return out_fname diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py index 95e58f87b..8f793f76b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py @@ -65,6 +65,7 @@ def __init__(self, config: OmegaConf, **kwargs): self.sf_file = ekd.create_target("file", self.get_output_filename("sfc")) self.template_cache = self.cache_templates() + def process_sample( self, @@ -190,7 +191,7 @@ def get_output_filename(self, level_type: str) -> Path: Output filename as a Path object. """ return ( - Path(self.output_dir) + self.output_dir / f"{self.data_type}_{level_type}_{self.run_id}_{self.expver}.{self.file_extension}" ) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..16eb4036f 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -21,6 +21,7 @@ # Local application / package from weathergen.common.config import ( + _get_shared_wg_path, get_path_run, load_merge_configs, load_run_config, @@ -47,6 +48,9 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.results_base_dir: self.results_base_dir = get_path_run(self.inference_cfg) _logger.info(f"Results directory obtained from private config: {self.results_base_dir}") + elif self.results_base_dir == "./results/": + self.results_base_dir = _get_shared_wg_path() / "results" / run_id + _logger.info(f"Results directory parsed: {self.results_base_dir}") else: _logger.info(f"Results directory parsed: {self.results_base_dir}") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..724e9b452 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -17,7 +17,7 @@ from PIL import Image from scipy.stats import wilcoxon -from weathergen.common.config import _load_private_conf +from weathergen.common.config import _get_shared_wg_path, _load_private_conf from weathergen.evaluate.plotting.plot_utils import ( DefaultMarkerSize, ) @@ -44,7 +44,7 @@ class Plotter: Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None): + def __init__(self, plotter_cfg: dict, run_id: str , stream: str | None = None): """ Initialize the Plotter class. @@ -57,9 +57,8 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | - dpi_val: DPI value for the saved images - fig_size: Size of the figure (width, height) in inches - tokenize_spacetime: If True, all valid times will be plotted in one plot - output_basedir: - Base directory under which the plots will be saved. - Expected scheme `/`. + run_id: + Run identifier used to organize the plots. stream: Stream identifier for which the plots will be created. It can also be set later via update_data_selection. @@ -75,9 +74,9 @@ def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False ) # True if plots are created for each valid time separately - self.run_id = output_basedir.name + self.run_id = run_id - self.out_plot_basedir = Path(output_basedir) / "plots" + self.out_plot_basedir = _get_shared_wg_path() / "results" / self.run_id / "plots" if not os.path.exists(self.out_plot_basedir): _logger.info(f"Creating dir {self.out_plot_basedir}") diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index d911a370b..4976e0881 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -356,7 +356,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: "regions": global_plotting_opts.get("regions", ["global"]), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } - plotter = Plotter(plotter_cfg, reader.runplot_dir) + plotter = Plotter(plotter_cfg, run_id) available_data = reader.check_availability(stream, mode="plotting") From a0872e286120311becf01f49f636ec1fc0bff826 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 10 Feb 2026 14:14:09 +0100 Subject: [PATCH 07/13] more shared path setting up --- .../common/src/weathergen/common/logger.py | 4 ++-- src/weathergen/utils/plot_training.py | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/packages/common/src/weathergen/common/logger.py b/packages/common/src/weathergen/common/logger.py index d857f0baa..5d59953cd 100644 --- a/packages/common/src/weathergen/common/logger.py +++ b/packages/common/src/weathergen/common/logger.py @@ -14,7 +14,7 @@ import pathlib from functools import cache -from weathergen.common.config import _load_private_conf +from weathergen.common.config import _load_private_conf, _get_shared_wg_path LOGGING_CONFIG = """ { @@ -123,7 +123,7 @@ def init_loggers(run_id=None, logging_config=None): # output_dir = f"./output/{timestamp}-{run_id}" output_dir = "" if run_id is not None: - output_dir = f"./output/{run_id}" + output_dir = _get_shared_wg_path() / "output" / run_id # load the structure for logging config if logging_config is None: diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 1ce69a85d..a5ee5e544 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -24,7 +24,7 @@ DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") DEFAULT_CONFIG_FILE= Path("./config/default_config.yml") -DEFAULT_MODEL_PATH = config._get_shared_wg_path() / "models" +DEFAULT_SHARED_PATH = config._get_shared_wg_path() #################################################################################################### def _ensure_list(value): @@ -151,7 +151,10 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### -def get_stream_names(run_id: str, model_path: Path | None = DEFAULT_MODEL_PATH): +def get_stream_names( + run_id: str, + model_path: Path | None = DEFAULT_SHARED_PATH / "models" + ) -> list[str]: """ Get the stream names from the model configuration file. @@ -493,7 +496,7 @@ def plot_loss_per_run( if errs is None: errs = ["mse"] - plot_dir = config._get_shared_wg_path() / "plots" + plot_dir = DEFAULT_SHARED_PATH / "plots" modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -596,13 +599,13 @@ def plot_train(args=None): parser.add_argument( "-o", "--output_dir", - default=config._get_shared_wg_path() / "plots", + default=DEFAULT_SHARED_PATH / "plots", type=Path, help="Directory where plots are saved" ) parser.add_argument( "-m", "--model_base_dir", - default=config._get_shared_wg_path() / "models", + default=DEFAULT_SHARED_PATH / "models", type=Path, help="Base-directory where models are saved", ) @@ -653,7 +656,12 @@ def plot_train(args=None): # parse the command line arguments args = parser.parse_args(args) - model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None + model_base_dir = DEFAULT_SHARED_PATH / "models" + if model_base_dir != Path(args.model_base_dir): + _logger.warning( + f"Model base directory specified in args ({args.model_base_dir}) is different from the default shared path ({model_base_dir}). " + f"Using the model base directory from args: {model_base_dir}" + ) out_dir = Path(args.output_dir) streams = list(args.streams) x_types_valid = ["step"] # TODO: add "reltime" support when fix available From 428f151707e132b9e929b03a9274292dde562528 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 10 Feb 2026 17:57:53 +0100 Subject: [PATCH 08/13] fix: correct attribute access in WeatherGenZarrReader error handling The code was attempting to log and raise an exception using self.fname_zarr before it was assigned, causing an AttributeError when the Zarr file doesn't exist. Changed to use the local fname_zarr variable instead. --- packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 16eb4036f..45a0d4bd3 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -324,8 +324,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ): self.fname_zarr = fname_zarr else: - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") - raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") + _logger.error(f"Zarr file {fname_zarr} does not exist.") + raise FileNotFoundError(f"Zarr file {fname_zarr} does not exist") def get_data( self, From ceefa384c7aec2c9d1f5fa011db8aa237623b113 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 10 Feb 2026 18:04:47 +0100 Subject: [PATCH 09/13] added ruff requested changes --- packages/common/src/weathergen/common/logger.py | 2 +- .../evaluate/src/weathergen/evaluate/export/export_inference.py | 2 +- .../src/weathergen/evaluate/export/parsers/netcdf_parser.py | 1 - src/weathergen/model/model_interface.py | 2 +- src/weathergen/utils/train_logger.py | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/packages/common/src/weathergen/common/logger.py b/packages/common/src/weathergen/common/logger.py index 5d59953cd..af59a0692 100644 --- a/packages/common/src/weathergen/common/logger.py +++ b/packages/common/src/weathergen/common/logger.py @@ -14,7 +14,7 @@ import pathlib from functools import cache -from weathergen.common.config import _load_private_conf, _get_shared_wg_path +from weathergen.common.config import _get_shared_wg_path, _load_private_conf LOGGING_CONFIG = """ { diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index 6c92e276b..df6285c47 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -20,7 +20,7 @@ from omegaconf import OmegaConf -from weathergen.common.config import _get_shared_wg_path, _REPO_ROOT +from weathergen.common.config import _REPO_ROOT, _get_shared_wg_path from weathergen.evaluate.export.export_core import export_model_outputs _logger = logging.getLogger(__name__) diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index 725f31cfe..80b2b928e 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -9,7 +9,6 @@ from weathergen.evaluate.export.cf_utils import CfParser from weathergen.evaluate.export.reshape import Regridder, find_pl - _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 8e8ff2959..7310bb27e 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -20,7 +20,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import get_path_model, Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 4e55baaa8..87145ef14 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -407,7 +407,7 @@ def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: def _key_stddev(st_name: str) -> str: st_name = clean_name(st_name) - return f"LossPhysical.loss_avg" # + return "LossPhysical.loss_avg" # # return f"stream.{st_name}.stddev_avg" From ba55e004e1b31c3e7d10ed770b1276520f02569d Mon Sep 17 00:00:00 2001 From: sbAsma Date: Tue, 10 Feb 2026 18:07:22 +0100 Subject: [PATCH 10/13] fixed log message too long --- src/weathergen/utils/plot_training.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index a5ee5e544..1dccd0c93 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -659,7 +659,8 @@ def plot_train(args=None): model_base_dir = DEFAULT_SHARED_PATH / "models" if model_base_dir != Path(args.model_base_dir): _logger.warning( - f"Model base directory specified in args ({args.model_base_dir}) is different from the default shared path ({model_base_dir}). " + f"Model base directory specified in args ({args.model_base_dir}) " + f"is different from the default shared path ({model_base_dir}). " f"Using the model base directory from args: {model_base_dir}" ) out_dir = Path(args.output_dir) From b45fd11178961aa7bb47d5d568207ec65fa60e23 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Wed, 11 Feb 2026 15:58:43 +0100 Subject: [PATCH 11/13] reverted changes as requested --- .../weathergen/evaluate/export/cf_utils.py | 6 +---- .../evaluate/export/export_inference.py | 5 ++-- .../weathergen/evaluate/export/io_utils.py | 6 ++--- .../evaluate/export/parsers/netcdf_parser.py | 2 +- .../evaluate/export/parsers/quaver_parser.py | 3 +-- .../weathergen/evaluate/io/wegen_reader.py | 4 ---- .../weathergen/evaluate/plotting/plotter.py | 15 ++++++------ src/weathergen/utils/compare_run_configs.py | 6 ++--- src/weathergen/utils/plot_training.py | 24 ++++++++++--------- src/weathergen/utils/train_logger.py | 10 ++++---- 10 files changed, 36 insertions(+), 45 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py index 3e240ceab..201ffa168 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/cf_utils.py @@ -4,8 +4,6 @@ import numpy as np import xarray as xr -from weathergen.common.config import _get_shared_wg_path - _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) @@ -34,13 +32,11 @@ def __init__(self, config, **kwargs): self.fstep_hours = np.timedelta64(self.fstep_hours, "h") self.mapping = config.get("variables", {}) - self.output_dir = _get_shared_wg_path() / "results" / config.run_id - def get_output_filename(self) -> Path: """ Generate output filename based on run_id and output directory. """ - return self.output_dir / f"{self.run_id}.{self.file_extension}" + return Path(self.output_dir) / f"{self.run_id}.{self.file_extension}" def process_sample(self, fstep_iterator_results: iter, ref_time: np.datetime64): """ diff --git a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py index df6285c47..0bf4be398 100755 --- a/packages/evaluate/src/weathergen/evaluate/export/export_inference.py +++ b/packages/evaluate/src/weathergen/evaluate/export/export_inference.py @@ -20,7 +20,7 @@ from omegaconf import OmegaConf -from weathergen.common.config import _REPO_ROOT, _get_shared_wg_path +from weathergen.common.config import _REPO_ROOT from weathergen.evaluate.export.export_core import export_model_outputs _logger = logging.getLogger(__name__) @@ -217,8 +217,7 @@ def export_from_args(args: list) -> None: _logger.info(kwargs) # Ensure output directory exists - _logger.info(f"Path(args.output_dir) = {Path(args.output_dir)}") - out_dir = _get_shared_wg_path() / "results" / args.run_id + out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) for dtype in args.type: diff --git a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py index 69801ec18..06f0cf25f 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/io_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/export/io_utils.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from weathergen.common.config import _get_shared_wg_path, get_model_results +from weathergen.common.config import get_model_results from weathergen.common.io import zarrio_reader _logger = logging.getLogger(__name__) @@ -43,9 +43,7 @@ def output_filename( frt = np.datetime_as_string(forecast_ref_time, unit="h") if regrid_degree is not None: run_id += f"_regular{regrid_degree, regrid_degree}" - - output_dir = _get_shared_wg_path() / "results" / run_id - out_fname = output_dir / f"{prefix}_{frt}_{run_id}.{file_extension}" + out_fname = Path(output_dir) / f"{prefix}_{frt}_{run_id}.{file_extension}" return out_fname diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py index 80b2b928e..fe7655fbe 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/netcdf_parser.py @@ -103,7 +103,7 @@ def get_output_filename(self, forecast_ref_time: np.datetime64) -> Path: frt = np.datetime_as_string(forecast_ref_time, unit="h") out_fname = ( - self.output_dir / f"{self.data_type}_{frt}_{self.run_id}.{self.file_extension}" + Path(self.output_dir) / f"{self.data_type}_{frt}_{self.run_id}.{self.file_extension}" ) return out_fname diff --git a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py index 8f793f76b..95e58f87b 100644 --- a/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py +++ b/packages/evaluate/src/weathergen/evaluate/export/parsers/quaver_parser.py @@ -65,7 +65,6 @@ def __init__(self, config: OmegaConf, **kwargs): self.sf_file = ekd.create_target("file", self.get_output_filename("sfc")) self.template_cache = self.cache_templates() - def process_sample( self, @@ -191,7 +190,7 @@ def get_output_filename(self, level_type: str) -> Path: Output filename as a Path object. """ return ( - self.output_dir + Path(self.output_dir) / f"{self.data_type}_{level_type}_{self.run_id}_{self.expver}.{self.file_extension}" ) diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 45a0d4bd3..736859771 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -21,7 +21,6 @@ # Local application / package from weathergen.common.config import ( - _get_shared_wg_path, get_path_run, load_merge_configs, load_run_config, @@ -48,9 +47,6 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non if not self.results_base_dir: self.results_base_dir = get_path_run(self.inference_cfg) _logger.info(f"Results directory obtained from private config: {self.results_base_dir}") - elif self.results_base_dir == "./results/": - self.results_base_dir = _get_shared_wg_path() / "results" / run_id - _logger.info(f"Results directory parsed: {self.results_base_dir}") else: _logger.info(f"Results directory parsed: {self.results_base_dir}") diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 724e9b452..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -17,7 +17,7 @@ from PIL import Image from scipy.stats import wilcoxon -from weathergen.common.config import _get_shared_wg_path, _load_private_conf +from weathergen.common.config import _load_private_conf from weathergen.evaluate.plotting.plot_utils import ( DefaultMarkerSize, ) @@ -44,7 +44,7 @@ class Plotter: Contains all basic plotting functions. """ - def __init__(self, plotter_cfg: dict, run_id: str , stream: str | None = None): + def __init__(self, plotter_cfg: dict, output_basedir: str | Path, stream: str | None = None): """ Initialize the Plotter class. @@ -57,8 +57,9 @@ def __init__(self, plotter_cfg: dict, run_id: str , stream: str | None = None): - dpi_val: DPI value for the saved images - fig_size: Size of the figure (width, height) in inches - tokenize_spacetime: If True, all valid times will be plotted in one plot - run_id: - Run identifier used to organize the plots. + output_basedir: + Base directory under which the plots will be saved. + Expected scheme `/`. stream: Stream identifier for which the plots will be created. It can also be set later via update_data_selection. @@ -74,9 +75,9 @@ def __init__(self, plotter_cfg: dict, run_id: str , stream: str | None = None): self.plot_subtimesteps = plotter_cfg.get( "plot_subtimesteps", False ) # True if plots are created for each valid time separately - self.run_id = run_id + self.run_id = output_basedir.name - self.out_plot_basedir = _get_shared_wg_path() / "results" / self.run_id / "plots" + self.out_plot_basedir = Path(output_basedir) / "plots" if not os.path.exists(self.out_plot_basedir): _logger.info(f"Creating dir {self.out_plot_basedir}") @@ -606,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/src/weathergen/utils/compare_run_configs.py b/src/weathergen/utils/compare_run_configs.py index 13539f6c6..d33789da0 100755 --- a/src/weathergen/utils/compare_run_configs.py +++ b/src/weathergen/utils/compare_run_configs.py @@ -31,7 +31,7 @@ import yaml from omegaconf import OmegaConf -from weathergen.common.config import _get_shared_wg_path, load_run_config +from weathergen.common.config import load_run_config def truncate_value(value, max_length=50): @@ -152,14 +152,14 @@ def main(): "-m1", "--model_directory_1", type=Path, - default=_get_shared_wg_path() / "models", + default=Path("models/"), help="Path to model directory for -r1/--run_id_1", ) parser.add_argument( "-m2", "--model_directory_2", type=Path, - default=_get_shared_wg_path() / "models", + default=Path("models/"), help="Path to model directory for -r2/--run_id_2", ) parser.add_argument( diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 1dccd0c93..641ab3ede 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -23,9 +23,10 @@ _logger = logging.getLogger(__name__) DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") -DEFAULT_CONFIG_FILE= Path("./config/default_config.yml") +DEFAULT_CONFIG_FILE = Path("./config/default_config.yml") DEFAULT_SHARED_PATH = config._get_shared_wg_path() + #################################################################################################### def _ensure_list(value): """ @@ -152,9 +153,8 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### def get_stream_names( - run_id: str, - model_path: Path | None = DEFAULT_SHARED_PATH / "models" - ) -> list[str]: + run_id: str, model_path: Path | None = DEFAULT_SHARED_PATH / "models" +) -> list[str]: """ Get the stream names from the model configuration file. @@ -496,7 +496,7 @@ def plot_loss_per_run( if errs is None: errs = ["mse"] - plot_dir = DEFAULT_SHARED_PATH / "plots" + plot_dir = DEFAULT_SHARED_PATH / "plots" modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -598,9 +598,11 @@ def plot_train(args=None): ) parser.add_argument( - "-o", "--output_dir", - default=DEFAULT_SHARED_PATH / "plots", - type=Path, help="Directory where plots are saved" + "-o", + "--output_dir", + default=DEFAULT_SHARED_PATH / "plots", + type=Path, + help="Directory where plots are saved", ) parser.add_argument( "-m", @@ -656,8 +658,8 @@ def plot_train(args=None): # parse the command line arguments args = parser.parse_args(args) - model_base_dir = DEFAULT_SHARED_PATH / "models" - if model_base_dir != Path(args.model_base_dir): + model_base_dir = DEFAULT_SHARED_PATH / "models" + if model_base_dir != Path(args.model_base_dir): _logger.warning( f"Model base directory specified in args ({args.model_base_dir}) " f"is different from the default shared path ({model_base_dir}). " @@ -685,7 +687,7 @@ def plot_train(args=None): clean_plot_folder(out_dir) # read logged data - + runs_data = [TrainLogger.read(run_id, model_path=model_base_dir) for run_id in runs_ids] # determine which runs are still alive (as a process, though they might hang internally) diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 87145ef14..25d168e13 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -394,20 +394,20 @@ def clean_name(s: str) -> str: def _key_loss(st_name: str, lf_name: str) -> str: - st_name = clean_name(st_name) # LossPhysical.ERA5.mse.t_600.2 - return f"LossPhysical.{st_name}.mse.avg" # LossPhysical.ERA5.mse.avg + st_name = clean_name(st_name) # LossPhysical.ERA5.mse.t_600.2 + return f"LossPhysical.{st_name}.mse.avg" # LossPhysical.ERA5.mse.avg # return f"stream.{st_name}.loss_{lf_name}.loss_avg" def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: - st_name = clean_name(st_name) - return f"LossPhysical.{st_name}.{lf_name}.{ch_name}" # LossPhysical.ERA5.mse.t_500.1 + st_name = clean_name(st_name) + return f"LossPhysical.{st_name}.{lf_name}.{ch_name}" # LossPhysical.ERA5.mse.t_500.1 # return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}" def _key_stddev(st_name: str) -> str: st_name = clean_name(st_name) - return "LossPhysical.loss_avg" # + return "LossPhysical.loss_avg" # # return f"stream.{st_name}.stddev_avg" From 0e9b907d6bcda41c1648fc1f503426c19b316595 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Thu, 12 Feb 2026 03:23:47 +0100 Subject: [PATCH 12/13] reverted some changes and added method for output path --- packages/common/src/weathergen/common/config.py | 8 ++++++++ packages/common/src/weathergen/common/logger.py | 4 ++-- .../evaluate/src/weathergen/evaluate/utils/utils.py | 2 +- src/weathergen/utils/plot_training.py | 10 +++------- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 8c3b593f9..7c6f9a16f 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -626,6 +626,14 @@ def get_path_model(config: Config | None = None, run_id: str | None = None) -> P raise ValueError(msg) return _get_shared_wg_path() / "models" / run_id +def get_path_output(config: Config | None = None, run_id: str | None = None) -> Path: + """Get the current runs output path for storing output files.""" + if config or run_id: + run_id = run_id if run_id else get_run_id_from_config(config) + else: + msg = f"Missing run_id and cannot infer it from config: {config}" + raise ValueError(msg) + return _get_shared_wg_path() / "output" / run_id def get_path_results(config: Config, mini_epoch: int) -> Path: """Get the path to validation results for a specific mini_epoch and rank.""" diff --git a/packages/common/src/weathergen/common/logger.py b/packages/common/src/weathergen/common/logger.py index af59a0692..fd6de1352 100644 --- a/packages/common/src/weathergen/common/logger.py +++ b/packages/common/src/weathergen/common/logger.py @@ -14,7 +14,7 @@ import pathlib from functools import cache -from weathergen.common.config import _get_shared_wg_path, _load_private_conf +from weathergen.common.config import _load_private_conf, get_path_output LOGGING_CONFIG = """ { @@ -123,7 +123,7 @@ def init_loggers(run_id=None, logging_config=None): # output_dir = f"./output/{timestamp}-{run_id}" output_dir = "" if run_id is not None: - output_dir = _get_shared_wg_path() / "output" / run_id + output_dir = get_path_output(run_id=run_id) # load the structure for logging config if logging_config is None: diff --git a/packages/evaluate/src/weathergen/evaluate/utils/utils.py b/packages/evaluate/src/weathergen/evaluate/utils/utils.py index 4976e0881..d911a370b 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/utils.py @@ -356,7 +356,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None: "regions": global_plotting_opts.get("regions", ["global"]), "plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False), } - plotter = Plotter(plotter_cfg, run_id) + plotter = Plotter(plotter_cfg, reader.runplot_dir) available_data = reader.check_availability(stream, mode="plotting") diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 641ab3ede..1c8f7b7f0 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -153,7 +153,7 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### def get_stream_names( - run_id: str, model_path: Path | None = DEFAULT_SHARED_PATH / "models" + run_id: str, model_path: Path | None = "./models" ) -> list[str]: """ Get the stream names from the model configuration file. @@ -598,16 +598,12 @@ def plot_train(args=None): ) parser.add_argument( - "-o", - "--output_dir", - default=DEFAULT_SHARED_PATH / "plots", - type=Path, - help="Directory where plots are saved", + "-o", "--output_dir", default="./plots/", type=Path, help="Directory where plots are saved" ) parser.add_argument( "-m", "--model_base_dir", - default=DEFAULT_SHARED_PATH / "models", + default= None, type=Path, help="Base-directory where models are saved", ) From f2b38e93e52ea16a017c0d98c16a04093ec21616 Mon Sep 17 00:00:00 2001 From: sbAsma Date: Thu, 12 Feb 2026 03:33:22 +0100 Subject: [PATCH 13/13] lint changes --- packages/common/src/weathergen/common/config.py | 2 ++ src/weathergen/utils/plot_training.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 7c6f9a16f..889bd77da 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -626,6 +626,7 @@ def get_path_model(config: Config | None = None, run_id: str | None = None) -> P raise ValueError(msg) return _get_shared_wg_path() / "models" / run_id + def get_path_output(config: Config | None = None, run_id: str | None = None) -> Path: """Get the current runs output path for storing output files.""" if config or run_id: @@ -635,6 +636,7 @@ def get_path_output(config: Config | None = None, run_id: str | None = None) -> raise ValueError(msg) return _get_shared_wg_path() / "output" / run_id + def get_path_results(config: Config, mini_epoch: int) -> Path: """Get the path to validation results for a specific mini_epoch and rank.""" ext = StoreType(config.zarr_store).value # validate extension diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 1c8f7b7f0..4ef8a790e 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -152,9 +152,7 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### -def get_stream_names( - run_id: str, model_path: Path | None = "./models" -) -> list[str]: +def get_stream_names(run_id: str, model_path: Path | None = "./models") -> list[str]: """ Get the stream names from the model configuration file. @@ -603,7 +601,7 @@ def plot_train(args=None): parser.add_argument( "-m", "--model_base_dir", - default= None, + default=None, type=Path, help="Base-directory where models are saved", )