Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions integration_tests/test_relative_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Integration test to check for relative paths in the codebase."""

import os
import re
import pytest
from pathlib import Path


class TestRelativePaths:
"""Test suite to detect potentially problematic relative paths in the codebase."""

# Root directory of the project
PROJECT_ROOT = Path(__file__).parent.parent

# Patterns that indicate relative paths
RELATIVE_PATH_PATTERNS = [
r'"\./[^"]+', # "./path"
r"'\./[^']+", # './path'
r'"\.\./[^"]+', # "../path"
r"'\.\./[^']+", # '../path'
r'open\(["\'][^/][^"\']+["\']', # open('relative/path')
]

# File extensions to check
EXTENSIONS_TO_CHECK = {'.py'} # , '.yaml', '.yml', '.json', '.toml'

# Directories to exclude
EXCLUDE_DIRS = {'__pycache__', '.git', '.venv', 'integration_tests'}

# Files to exclude (e.g., this test file itself)
EXCLUDE_FILES = {'test_relative_paths.py'}

def get_all_source_files(self):
"""Recursively get all source files in the project."""
source_files = []
for root, dirs, files in os.walk(self.PROJECT_ROOT):
# Skip excluded directories
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS]

for file in files:
if file in self.EXCLUDE_FILES:
continue
if Path(file).suffix in self.EXTENSIONS_TO_CHECK:
source_files.append(Path(root) / file)
return source_files

def find_relative_paths_in_file(self, filepath: Path) -> list[tuple[int, str, str]]:
"""
Find relative paths in a file.

Returns:
List of tuples: (line_number, matched_pattern, line_content)
"""
matches = []
try:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
for line_num, line in enumerate(f, 1):
# Skip comments
stripped = line.strip()
if stripped.startswith('#') or stripped.startswith('//'):
continue

for pattern in self.RELATIVE_PATH_PATTERNS:
found = re.search(pattern, line)
if found:
matches.append((line_num, found.group(), line.strip()))
except Exception as e:
pytest.skip(f"Could not read file {filepath}: {e}")

return matches

def test_no_hardcoded_relative_paths(self):
"""Ensure no hardcoded relative paths exist in Python files."""
source_files = self.get_all_source_files()
violations = []

for filepath in source_files:
print(f"Checking file: {filepath}")
matches = self.find_relative_paths_in_file(filepath)
for line_num, match, line in matches:
violations.append({
'file': str(filepath.relative_to(self.PROJECT_ROOT)),
'line': line_num,
'match': match,
'content': line
})

if violations:
report = "\n\nRelative path violations found:\n"
for v in violations:
report += f"\n {v['file']}:{v['line']}\n"
report += f" Match: {v['match']}\n"
report += f" Line: {v['content']}\n"

pytest.fail(report)

def test_yaml_configs_use_absolute_or_variable_paths(self):
"""Check that YAML config files don't use hardcoded relative paths."""
yaml_files = [f for f in self.get_all_source_files()
if f.suffix in {'.yaml', '.yml'}]

violations = []
relative_path_pattern = re.compile(r':\s*["\']?\.\.?/[^"\'#\n]+')

for yaml_file in yaml_files:
print(f"Checking YAML file: {yaml_file}")
try:
with open(yaml_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if relative_path_pattern.search(line):
violations.append({
'file': str(yaml_file.relative_to(self.PROJECT_ROOT)),
'line': line_num,
'content': line.strip()
})
except Exception:
continue

if violations:
report = "\n\nRelative paths in YAML configs:\n"
for v in violations:
report += f"\n {v['file']}:{v['line']}: {v['content']}\n"
pytest.fail(report)

def test_path_construction_uses_pathlib_or_os_path(self):
"""Check that path construction uses pathlib or os.path properly."""
python_files = [f for f in self.get_all_source_files() if f.suffix == '.py']

# Pattern for string concatenation that looks like path building
bad_pattern = re.compile(r'["\'][^"\']*["\']\s*\+\s*["\']/|/["\']\s*\+')

violations = []
for py_file in python_files:
print(f"Checking Python file: {py_file}")
try:
with open(py_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if bad_pattern.search(line):
violations.append({
'file': str(py_file.relative_to(self.PROJECT_ROOT)),
'line': line_num,
'content': line.strip()
})
except Exception:
continue

if violations:
report = "\n\nPotential unsafe path concatenation found:\n"
report += "(Consider using pathlib.Path or os.path.join)\n"
for v in violations:
report += f"\n {v['file']}:{v['line']}: {v['content']}\n"
pytest.fail(report)


if __name__ == '__main__':
pytest.main([__file__, '-v'])
12 changes: 11 additions & 1 deletion packages/common/src/weathergen/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def load_run_config(run_id: str, mini_epoch: int | None, model_path: str | None)
if model_path is None:
path = get_path_model(run_id=run_id)
else:
path = Path(model_path) / run_id
path = _get_shared_wg_path() / "models" / run_id

fname = path / _get_model_config_file_read_name(run_id, mini_epoch)
assert fname.exists(), (
Expand Down Expand Up @@ -627,6 +627,16 @@ def get_path_model(config: Config | None = None, run_id: str | None = None) -> P
return _get_shared_wg_path() / "models" / run_id


def get_path_output(config: Config | None = None, run_id: str | None = None) -> Path:
"""Get the current runs output path for storing output files."""
if config or run_id:
run_id = run_id if run_id else get_run_id_from_config(config)
else:
msg = f"Missing run_id and cannot infer it from config: {config}"
raise ValueError(msg)
return _get_shared_wg_path() / "output" / run_id


def get_path_results(config: Config, mini_epoch: int) -> Path:
"""Get the path to validation results for a specific mini_epoch and rank."""
ext = StoreType(config.zarr_store).value # validate extension
Expand Down
4 changes: 2 additions & 2 deletions packages/common/src/weathergen/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pathlib
from functools import cache

from weathergen.common.config import _load_private_conf
from weathergen.common.config import _load_private_conf, get_path_output

LOGGING_CONFIG = """
{
Expand Down Expand Up @@ -123,7 +123,7 @@ def init_loggers(run_id=None, logging_config=None):
# output_dir = f"./output/{timestamp}-{run_id}"
output_dir = ""
if run_id is not None:
output_dir = f"./output/{run_id}"
output_dir = get_path_output(run_id=run_id)

# load the structure for logging config
if logging_config is None:
Expand Down
4 changes: 2 additions & 2 deletions packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
):
self.fname_zarr = fname_zarr
else:
_logger.error(f"Zarr file {self.fname_zarr} does not exist.")
raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist")
_logger.error(f"Zarr file {fname_zarr} does not exist.")
raise FileNotFoundError(f"Zarr file {fname_zarr} does not exist")

def get_data(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]:
image_paths += names

if image_paths:
image_paths=sorted(image_paths)
image_paths = sorted(image_paths)
images = [Image.open(path) for path in image_paths]
images[0].save(
f"{map_output_dir}/animation_{self.run_id}_{tag}_{sa}_{self.stream}_{region}_{var}.gif",
Expand Down
5 changes: 2 additions & 3 deletions src/weathergen/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import itertools
import logging
from pathlib import Path

import omegaconf
import torch
Expand All @@ -21,7 +20,7 @@
)
from torch.distributed.tensor import distribute_tensor

from weathergen.common.config import Config, merge_configs
from weathergen.common.config import Config, get_path_model, merge_configs
from weathergen.model.attention import (
MultiCrossAttentionHeadVarlen,
MultiCrossAttentionHeadVarlenSlicedQ,
Expand Down Expand Up @@ -179,7 +178,7 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1):
mini_epoch : The mini_epoch to load. Default (-1) is the latest mini_epoch
"""

path_run = Path(cf.model_path) / run_id
path_run = get_path_model(cf, run_id)
mini_epoch_id = (
f"chkpt{mini_epoch:05d}" if mini_epoch != -1 and mini_epoch is not None else "latest"
)
Expand Down
14 changes: 11 additions & 3 deletions src/weathergen/utils/plot_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
_logger = logging.getLogger(__name__)

DEFAULT_RUN_FILE = Path("./config/runs_plot_train.yml")
DEFAULT_CONFIG_FILE = Path("./config/default_config.yml")
DEFAULT_SHARED_PATH = config._get_shared_wg_path()


####################################################################################################
Expand Down Expand Up @@ -150,7 +152,7 @@ def clean_plot_folder(plot_dir: Path):


####################################################################################################
def get_stream_names(run_id: str, model_path: Path | None = "./model"):
def get_stream_names(run_id: str, model_path: Path | None = "./models") -> list[str]:
"""
Get the stream names from the model configuration file.

Expand Down Expand Up @@ -492,7 +494,7 @@ def plot_loss_per_run(
if errs is None:
errs = ["mse"]

plot_dir = Path(plot_dir)
plot_dir = DEFAULT_SHARED_PATH / "plots"

modes = [modes] if type(modes) is not list else modes
# repeat colors when train and val is plotted simultaneously
Expand Down Expand Up @@ -650,7 +652,13 @@ def plot_train(args=None):
# parse the command line arguments
args = parser.parse_args(args)

model_base_dir = Path(args.model_base_dir) if args.model_base_dir else None
model_base_dir = DEFAULT_SHARED_PATH / "models"
if model_base_dir != Path(args.model_base_dir):
_logger.warning(
f"Model base directory specified in args ({args.model_base_dir}) "
f"is different from the default shared path ({model_base_dir}). "
f"Using the model base directory from args: {model_base_dir}"
)
out_dir = Path(args.output_dir)
streams = list(args.streams)
x_types_valid = ["step"] # TODO: add "reltime" support when fix available
Expand Down
23 changes: 14 additions & 9 deletions src/weathergen/utils/train_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics:
)
run_id = cf.general.run_id

result_dir_base = config.get_path_run(cf)
result_dir_base = config.get_path_run(cf).parent
result_dir = result_dir_base / run_id
fname_log_train = result_dir / f"{run_id}_train_log.txt"
fname_log_val = result_dir / f"{run_id}_val_log.txt"
Expand All @@ -156,12 +156,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics:
cols_train = ["dtime", "samples", "mse", "lr"]
cols1 = [_weathergen_timestamp, "num_samples", "loss_avg_mean", "learning_rate"]
for si in cf.streams:
for lf in cf.loss_fcts:
for lf in cf.training_config.losses.physical.loss_fcts:
cols1 += [_key_loss(si["name"], lf[0])]
cols_train += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0]
]
with_stddev = [("stats" in lf) for lf in cf.loss_fcts]
with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts]
if with_stddev:
for si in cf.streams:
cols1 += [_key_stddev(si["name"])]
Expand Down Expand Up @@ -214,12 +214,12 @@ def read(run_id: str, model_path: str = None, mini_epoch: int = -1) -> Metrics:
cols_val = ["dtime", "samples"]
cols2 = [_weathergen_timestamp, "num_samples"]
for si in cf.streams:
for lf in cf.loss_fcts_val:
for lf in cf.training_config.losses.physical.loss_fcts:
cols_val += [
si["name"].replace(",", "").replace("/", "_").replace(" ", "_") + ", " + lf[0]
]
cols2 += [_key_loss(si["name"], lf[0])]
with_stddev = [("stats" in lf) for lf in cf.loss_fcts_val]
with_stddev = [("stats" in lf) for lf in cf.training_config.losses.physical.loss_fcts]
if with_stddev:
for si in cf.streams:
cols2 += [_key_stddev(si["name"])]
Expand Down Expand Up @@ -370,6 +370,8 @@ def clean_df(df, columns: list[str] | None):
idcs = [i for i in range(len(columns)) if columns[i] == "loss_avg_mean"]
if len(idcs) > 0:
columns[idcs[0]] = "loss_avg_0_mean"
for key in list(df.columns):
_logger.info(key)
df = df.select(columns)
# Remove all rows where all columns are null
df = df.filter(~pl.all_horizontal(pl.col(c).is_null() for c in columns))
Expand All @@ -392,18 +394,21 @@ def clean_name(s: str) -> str:


def _key_loss(st_name: str, lf_name: str) -> str:
st_name = clean_name(st_name)
return f"stream.{st_name}.loss_{lf_name}.loss_avg"
st_name = clean_name(st_name) # LossPhysical.ERA5.mse.t_600.2
return f"LossPhysical.{st_name}.mse.avg" # LossPhysical.ERA5.mse.avg
# return f"stream.{st_name}.loss_{lf_name}.loss_avg"


def _key_loss_chn(st_name: str, lf_name: str, ch_name: str) -> str:
st_name = clean_name(st_name)
return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}"
return f"LossPhysical.{st_name}.{lf_name}.{ch_name}" # LossPhysical.ERA5.mse.t_500.1
# return f"stream.{st_name}.loss_{lf_name}.loss_{ch_name}"


def _key_stddev(st_name: str) -> str:
st_name = clean_name(st_name)
return f"stream.{st_name}.stddev_avg"
return "LossPhysical.loss_avg" #
# return f"stream.{st_name}.stddev_avg"


def prepare_losses_for_logging(
Expand Down
Loading