diff --git a/integration_tests/test_relative_paths.py b/integration_tests/test_relative_paths.py new file mode 100644 index 000000000..cc49f3efa --- /dev/null +++ b/integration_tests/test_relative_paths.py @@ -0,0 +1,156 @@ +"""Integration test to check for relative paths in the codebase.""" + +import os +import re +import pytest +from pathlib import Path + + +class TestRelativePaths: + """Test suite to detect potentially problematic relative paths in the codebase.""" + + # Root directory of the project + PROJECT_ROOT = Path(__file__).parent.parent + + # Patterns that indicate relative paths + RELATIVE_PATH_PATTERNS = [ + r'"\./[^"]+', # "./path" + r"'\./[^']+", # './path' + r'"\.\./[^"]+', # "../path" + r"'\.\./[^']+", # '../path' + r'open\(["\'][^/][^"\']+["\']', # open('relative/path') + ] + + # File extensions to check + EXTENSIONS_TO_CHECK = {'.py'} # , '.yaml', '.yml', '.json', '.toml' + + # Directories to exclude + EXCLUDE_DIRS = {'__pycache__', '.git', '.venv', 'integration_tests'} + + # Files to exclude (e.g., this test file itself) + EXCLUDE_FILES = {'test_relative_paths.py'} + + def get_all_source_files(self): + """Recursively get all source files in the project.""" + source_files = [] + for root, dirs, files in os.walk(self.PROJECT_ROOT): + # Skip excluded directories + dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] + + for file in files: + if file in self.EXCLUDE_FILES: + continue + if Path(file).suffix in self.EXTENSIONS_TO_CHECK: + source_files.append(Path(root) / file) + return source_files + + def find_relative_paths_in_file(self, filepath: Path) -> list[tuple[int, str, str]]: + """ + Find relative paths in a file. + + Returns: + List of tuples: (line_number, matched_pattern, line_content) + """ + matches = [] + try: + with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: + for line_num, line in enumerate(f, 1): + # Skip comments + stripped = line.strip() + if stripped.startswith('#') or stripped.startswith('//'): + continue + + for pattern in self.RELATIVE_PATH_PATTERNS: + found = re.search(pattern, line) + if found: + matches.append((line_num, found.group(), line.strip())) + except Exception as e: + pytest.skip(f"Could not read file {filepath}: {e}") + + return matches + + def test_no_hardcoded_relative_paths(self): + """Ensure no hardcoded relative paths exist in Python files.""" + source_files = self.get_all_source_files() + violations = [] + + for filepath in source_files: + print(f"Checking file: {filepath}") + matches = self.find_relative_paths_in_file(filepath) + for line_num, match, line in matches: + violations.append({ + 'file': str(filepath.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'match': match, + 'content': line + }) + + if violations: + report = "\n\nRelative path violations found:\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}\n" + report += f" Match: {v['match']}\n" + report += f" Line: {v['content']}\n" + + pytest.fail(report) + + def test_yaml_configs_use_absolute_or_variable_paths(self): + """Check that YAML config files don't use hardcoded relative paths.""" + yaml_files = [f for f in self.get_all_source_files() + if f.suffix in {'.yaml', '.yml'}] + + violations = [] + relative_path_pattern = re.compile(r':\s*["\']?\.\.?/[^"\'#\n]+') + + for yaml_file in yaml_files: + print(f"Checking YAML file: {yaml_file}") + try: + with open(yaml_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if relative_path_pattern.search(line): + violations.append({ + 'file': str(yaml_file.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'content': line.strip() + }) + except Exception: + continue + + if violations: + report = "\n\nRelative paths in YAML configs:\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}: {v['content']}\n" + pytest.fail(report) + + def test_path_construction_uses_pathlib_or_os_path(self): + """Check that path construction uses pathlib or os.path properly.""" + python_files = [f for f in self.get_all_source_files() if f.suffix == '.py'] + + # Pattern for string concatenation that looks like path building + bad_pattern = re.compile(r'["\'][^"\']*["\']\s*\+\s*["\']/|/["\']\s*\+') + + violations = [] + for py_file in python_files: + print(f"Checking Python file: {py_file}") + try: + with open(py_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + if bad_pattern.search(line): + violations.append({ + 'file': str(py_file.relative_to(self.PROJECT_ROOT)), + 'line': line_num, + 'content': line.strip() + }) + except Exception: + continue + + if violations: + report = "\n\nPotential unsafe path concatenation found:\n" + report += "(Consider using pathlib.Path or os.path.join)\n" + for v in violations: + report += f"\n {v['file']}:{v['line']}: {v['content']}\n" + pytest.fail(report) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file diff --git a/packages/common/src/weathergen/common/config.py b/packages/common/src/weathergen/common/config.py index 794f829c6..889bd77da 100644 --- a/packages/common/src/weathergen/common/config.py +++ b/packages/common/src/weathergen/common/config.py @@ -232,7 +232,7 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None) if model_path is None: path = get_path_model(run_id=run_id) else: - path = Path(model_path) / run_id + path = _get_shared_wg_path() / "models" / run_id fname = path / _get_model_config_file_read_name(run_id, mini_epoch) assert fname.exists(), ( @@ -627,6 +627,16 @@ def get_path_model(config: Config | None = None, run_id: str | None = None) -> P return _get_shared_wg_path() / "models" / run_id +def get_path_output(config: Config | None = None, run_id: str | None = None) -> Path: + """Get the current runs output path for storing output files.""" + if config or run_id: + run_id = run_id if run_id else get_run_id_from_config(config) + else: + msg = f"Missing run_id and cannot infer it from config: {config}" + raise ValueError(msg) + return _get_shared_wg_path() / "output" / run_id + + def get_path_results(config: Config, mini_epoch: int) -> Path: """Get the path to validation results for a specific mini_epoch and rank.""" ext = StoreType(config.zarr_store).value # validate extension diff --git a/packages/common/src/weathergen/common/logger.py b/packages/common/src/weathergen/common/logger.py index d857f0baa..fd6de1352 100644 --- a/packages/common/src/weathergen/common/logger.py +++ b/packages/common/src/weathergen/common/logger.py @@ -14,7 +14,7 @@ import pathlib from functools import cache -from weathergen.common.config import _load_private_conf +from weathergen.common.config import _load_private_conf, get_path_output LOGGING_CONFIG = """ { @@ -123,7 +123,7 @@ def init_loggers(run_id=None, logging_config=None): # output_dir = f"./output/{timestamp}-{run_id}" output_dir = "" if run_id is not None: - output_dir = f"./output/{run_id}" + output_dir = get_path_output(run_id=run_id) # load the structure for logging config if logging_config is None: diff --git a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py index 062031273..736859771 100644 --- a/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py @@ -320,8 +320,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non ): self.fname_zarr = fname_zarr else: - _logger.error(f"Zarr file {self.fname_zarr} does not exist.") - raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist") + _logger.error(f"Zarr file {fname_zarr} does not exist.") + raise FileNotFoundError(f"Zarr file {fname_zarr} does not exist") def get_data( self, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py index 44ec60a3b..e3cdfefa6 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plotter.py @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]: image_paths += names if image_paths: - image_paths=sorted(image_paths) + image_paths = sorted(image_paths) images = [Image.open(path) for path in image_paths] images[0].save( f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif", diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 24962da13..7310bb27e 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -11,7 +11,6 @@ import itertools import logging -from pathlib import Path import omegaconf import torch @@ -21,7 +20,7 @@ ) from torch.distributed.tensor import distribute_tensor -from weathergen.common.config import Config, merge_configs +from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -179,7 +178,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch """ - path_run = Path(cf.model_path) / run_id + path_run = get_path_model(cf, run_id) mini_epoch_id = ( f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest" ) diff --git a/src/weathergen/utils/plot_training.py b/src/weathergen/utils/plot_training.py index 35bfafe3e..4ef8a790e 100644 --- a/src/weathergen/utils/plot_training.py +++ b/src/weathergen/utils/plot_training.py @@ -23,6 +23,8 @@ _logger = logging.getLogger(__name__) DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml") +DEFAULT_CONFIG_FILE = Path("./config/default_config.yml") +DEFAULT_SHARED_PATH = config._get_shared_wg_path() #################################################################################################### @@ -150,7 +152,7 @@ def clean_plot_folder(plot_dir: Path): #################################################################################################### -def get_stream_names(run_id: str, model_path: Path | None = "./model"): +def get_stream_names(run_id: str, model_path: Path | None = "./models") -> list[str]: """ Get the stream names from the model configuration file. @@ -492,7 +494,7 @@ def plot_loss_per_run( if errs is None: errs = ["mse"] - plot_dir = Path(plot_dir) + plot_dir = DEFAULT_SHARED_PATH / "plots" modes = [modes] if type(modes) is not list else modes # repeat colors when train and val is plotted simultaneously @@ -650,7 +652,13 @@ def plot_train(args=None): # parse the command line arguments args = parser.parse_args(args) - model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None + model_base_dir = DEFAULT_SHARED_PATH / "models" + if model_base_dir != Path(args.model_base_dir): + _logger.warning( + f"Model base directory specified in args ({args.model_base_dir}) " + f"is different from the default shared path ({model_base_dir}). " + f"Using the model base directory from args: {model_base_dir}" + ) out_dir = Path(args.output_dir) streams = list(args.streams) x_types_valid = ["step"] # TODO: add "reltime" support when fix available diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index 2812134a0..25d168e13 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -144,7 +144,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: ) run_id = cf.general.run_id - result_dir_base = config.get_path_run(cf) + result_dir_base = config.get_path_run(cf).parent result_dir = result_dir_base / run_id fname_log_train = result_dir / f"{run_id}_train_log.txt" fname_log_val = result_dir / f"{run_id}_val_log.txt" @@ -156,12 +156,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_train = ["dtime", "samples", "mse", "lr"] cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"] for si in cf.streams: - for lf in cf.loss_fcts: + for lf in cf.training_config.losses.physical.loss_fcts: cols1 += [_key_loss(si["name"], lf[0])] cols_train += [ si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] ] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts] + with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts] if with_stddev: for si in cf.streams: cols1 += [_key_stddev(si["name"])] @@ -214,12 +214,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics: cols_val = ["dtime", "samples"] cols2 = [_weathergen_timestamp, "num_samples"] for si in cf.streams: - for lf in cf.loss_fcts_val: + for lf in cf.training_config.losses.physical.loss_fcts: cols_val += [ si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0] ] cols2 += [_key_loss(si["name"], lf[0])] - with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val] + with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts] if with_stddev: for si in cf.streams: cols2 += [_key_stddev(si["name"])] @@ -370,6 +370,8 @@ def clean_df(df, columns: list[str] | None): idcs = [i for i in range(len(columns)) if columns[i] == "loss_avg_mean"] if len(idcs) > 0: columns[idcs[0]] = "loss_avg_0_mean" + for key in list(df.columns): + _logger.info(key) df = df.select(columns) # Remove all rows where all columns are null df = df.filter(~pl.all_horizontal(pl.col(c).is_null() for c in columns)) @@ -392,18 +394,21 @@ def clean_name(s: str) -> str: def _key_loss(st_name: str, lf_name: str) -> str: - st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_avg" + st_name = clean_name(st_name) # LossPhysical.ERA5.mse.t_600.2 + return f"LossPhysical.{st_name}.mse.avg" # LossPhysical.ERA5.mse.avg + # return f"stream.{st_name}.loss_{lf_name}.loss_avg" def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str: st_name = clean_name(st_name) - return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}" + return f"LossPhysical.{st_name}.{lf_name}.{ch_name}" # LossPhysical.ERA5.mse.t_500.1 + # return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}" def _key_stddev(st_name: str) -> str: st_name = clean_name(st_name) - return f"stream.{st_name}.stddev_avg" + return "LossPhysical.loss_avg" # + # return f"stream.{st_name}.stddev_avg" def prepare_losses_for_logging(