From 114a87163c411d2967833f81a47cb8faebc32c29 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 27 Jan 2026 10:17:46 +0100 Subject: [PATCH 1/2] added tests for jepa and dinov2 --- integration_tests/dinov21.yaml | 351 ++++++++++++++++++++++++++++++ integration_tests/dinov21_test.py | 102 +++++++++ integration_tests/jepa1.yaml | 326 ++++++++++++++++++++++----- integration_tests/jepa1_test.py | 2 +- 4 files changed, 728 insertions(+), 53 deletions(-) create mode 100644 integration_tests/dinov21.yaml create mode 100644 integration_tests/dinov21_test.py diff --git a/integration_tests/dinov21.yaml b/integration_tests/dinov21.yaml new file mode 100644 index 000000000..c235d6a32 --- /dev/null +++ b/integration_tests/dinov21.yaml @@ -0,0 +1,351 @@ + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 2 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 8 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 2 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 4 +num_register_tokens: 8 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 0 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +forecast_att_dense_rate: 1.0 + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +ddp_find_unused_parameters: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + + +freeze_modules: "" + +norm_type: "LayerNorm" + +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore +##################################### + +# streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" +streams_directory: "./integration_tests/streams/" +streams: ??? + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + + +# config for training +training_config: + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["student_teacher"] + + num_mini_epochs: 1 + samples_per_mini_epoch: 128 + shuffle: True + + start_date: 1999-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction : 0 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 4 + num_steps_cooldown: 2 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "student-teacher": { + enabled: True, + type: LossLatentSSLStudentTeacher, + weight: 1.0, + loss_fcts : { + "iBOT": { + weight: 0.75, + target_source_correspondence: {0 : {0 : "subset"}, 1 : {3 : "subset"},}, + loss_extra_args: { "student_temp": 0.1,}, + out_dim: 4096, # 16384, + teacher_temp: 0.1, + teacher_style: "softmax_center", + center_momentum: 0.9, + head: "mlp", + num_layers: 2, + hidden_factor: 2, + }, + "DINO": { + weight: 0.25, + target_source_correspondence: {0 : {1: "subset", 2: "identity"}, 1 : {4: "subset", 5: "identity"},}, + loss_extra_args: { "student_temp": 0.1,}, + out_dim: 4096, # 16384, + teacher_temp: 0.1, + teacher_style: "softmax_center", + center_momentum: 0.9, + head: "mlp", + num_layers: 2, + hidden_factor: 2, + }, + }, + target_and_aux_calc: "EMATeacher", + } + } + + model_input: { + "strategy1" : { + # masking strategy: "random", "healpix", "forecast" + masking_strategy: "random", + enabled: True, + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.5, + rate_sampling: False + }, + }, + "strategy2" : { + masking_strategy: "healpix", + enabled: True, + num_samples: 2, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.5, + hl_mask: 1, + rate_sampling: False + }, + }, + "strategy3" : { + masking_strategy: "healpix", + enabled: True, + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + hl_mask: 0, + # randomly sample the rate + rate_sampling: False, + rate: 0.0 + }, + }, + "strategy4" : { + masking_strategy: "random", + enabled: True, + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.5, + # randomly sample the rate + rate_sampling: False + }, + }, + "strategy5" : { + masking_strategy: "healpix", + enabled: True, + num_samples: 2, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.5, + hl_mask: 1, + # randomly sample the rate + rate_sampling: False + }, + }, + "strategy6" : { + masking_strategy: "healpix", + enabled: True, + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + hl_mask: 0, + rate: 0.0, + # randomly sample the rate + rate_sampling: False + }, + }, + } + + target_input: { + "strategy1" : { + masking_strategy: "healpix", + num_samples: 1, + masking_strategy_config : { rate : 0.4, hl_mask: 0, rate_sampling: False }, + }, + "strategy2" : { + masking_strategy: "healpix", + num_samples: 1, + masking_strategy_config : { rate : 0.4, hl_mask: 0, rate_sampling: False }, + }, + } + + forecast : + time_step: 06:00:00 + num_steps: 0 + policy: null + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 32 #256 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction : 0 + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : True + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # number of validation samples that are written to disk + write_num_samples: 0 #8 + # output streams to write; default all + output_streams: null + + # run validation before training starts (mainly for model development) + validate_before_training: False + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/integration_tests/dinov21_test.py b/integration_tests/dinov21_test.py new file mode 100644 index 000000000..d0c951988 --- /dev/null +++ b/integration_tests/dinov21_test.py @@ -0,0 +1,102 @@ +""" +Small test for the Weather Generator. +This test must run on a GPU machine. +It performs a training and inference of the Weather Generator model. + +Command: +uv run pytest ./integration_tests/dino1.py +""" + +import json +import logging +import os +import shutil +from pathlib import Path +import omegaconf +import pytest +import numpy as np + +from weathergen.evaluate.run_evaluation import evaluate_from_config +from weathergen.run_train import inference_from_args, train_with_args +from weathergen.utils.metrics import get_train_metrics_path + +logger = logging.getLogger(__name__) + +# Read from git the current commit hash and take the first 5 characters: +try: + from git import Repo + + repo = Repo(search_parent_directories=False) + commit_hash = repo.head.object.hexsha[:5] + logger.info(f"Current commit hash: {commit_hash}") +except Exception as e: + commit_hash = "unknown" + logger.warning(f"Could not get commit hash: {e}") + +WEATHERGEN_HOME = Path(__file__).parent.parent + + +@pytest.fixture() +def setup(test_run_id): + logger.info(f"setup fixture with {test_run_id}") + shutil.rmtree(WEATHERGEN_HOME / "results" / test_run_id, ignore_errors=True) + shutil.rmtree(WEATHERGEN_HOME / "models" / test_run_id, ignore_errors=True) + yield + logger.info("end fixture") + + +@pytest.mark.parametrize("test_run_id", ["test_dinov21_" + commit_hash]) +def test_train(setup, test_run_id): + logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") + + train_with_args( + [ f"--base-config={WEATHERGEN_HOME}/integration_tests/dinov21.yaml" ] + + [ + "--run_id", + test_run_id, + ], + f"{WEATHERGEN_HOME}/config/streams/streams_test/", + ) + + assert_missing_metrics_file(test_run_id) + assert_nans_in_metrics_file(test_run_id) + logger.info("end test_train") + + + +def load_metrics(run_id): + """Helper function to load metrics""" + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + if not os.path.exists(file_path): + raise FileNotFoundError(f"Metrics file not found for run_id: {run_id}") + with open(file_path) as f: + json_str = f.readlines() + return json.loads("[" + r"".join([s.replace("\n", ",") for s in json_str])[:-1] + "]") + + +def assert_missing_metrics_file(run_id): + """Test that a missing metrics file raises FileNotFoundError.""" + file_path = get_train_metrics_path(base_path=WEATHERGEN_HOME / "results", run_id=run_id) + assert os.path.exists(file_path), f"Metrics file does not exist for run_id: {run_id}" + metrics = load_metrics(run_id) + logger.info(f"Loaded metrics for run_id: {run_id}: {metrics}") + assert metrics is not None, f"Failed to load metrics for run_id: {run_id}" + +def assert_nans_in_metrics_file(run_id): + """Test that there are no NaNs in the metrics file.""" + metrics = load_metrics(run_id) + loss_values_train = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'train']) + loss_values_val = np.array([entry.get('LossLatentSSLStudentTeacher.loss_avg') for entry in metrics if entry.get("stage") == 'val']) + + #remove nans if applicable + loss_values_train = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_train]) + loss_values_val = np.array([float(value) if value != 'nan' else np.nan for value in loss_values_val]) + + assert not np.isnan(loss_values_train).any(), ( + "NaN values found in training loss metrics!" + ) + + assert not np.isnan(loss_values_val).any(), ( + "NaN values found in validation loss metrics!" + ) + diff --git a/integration_tests/jepa1.yaml b/integration_tests/jepa1.yaml index 1cd69f734..888e6c7ad 100644 --- a/integration_tests/jepa1.yaml +++ b/integration_tests/jepa1.yaml @@ -1,54 +1,276 @@ +# (C) Copyright 2025 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +embed_orientation: "channels" +embed_unembed_mode: "block" +embed_dropout_rate: 0.1 + +ae_local_dim_embed: 1024 +ae_local_num_blocks: 4 +ae_local_num_heads: 16 +ae_local_dropout_rate: 0.1 +ae_local_with_qk_lnorm: True + +ae_local_num_queries: 1 +ae_local_queries_per_cell: False +ae_adapter_num_heads: 16 +ae_adapter_embed: 128 +ae_adapter_with_qk_lnorm: True +ae_adapter_with_residual: True +ae_adapter_dropout_rate: 0.1 + +ae_global_dim_embed: 2048 +ae_global_num_blocks: 2 +ae_global_num_heads: 32 +ae_global_dropout_rate: 0.1 +ae_global_with_qk_lnorm: True +# TODO: switching to < 1 triggers triton-related issues. +# See https://github.com/ecmwf/WeatherGenerator/issues/1050 +ae_global_att_dense_rate: 1.0 +ae_global_block_factor: 64 +ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False + +ae_aggregation_num_blocks: 8 +ae_aggregation_num_heads: 32 +ae_aggregation_dropout_rate: 0.1 +ae_aggregation_with_qk_lnorm: True +ae_aggregation_att_dense_rate: 1.0 +ae_aggregation_block_factor: 64 +ae_aggregation_mlp_hidden_factor: 2 + +decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning +pred_adapter_kv: False +pred_self_attention: True +pred_dyadic_dims: False +pred_mlp_adaln: True +num_class_tokens: 1 +num_register_tokens: 7 + +# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then +# one is training an auto-encoder +fe_num_blocks: 0 +fe_num_heads: 16 +fe_dropout_rate: 0.1 +fe_with_qk_lnorm: True +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_impute_latent_noise_std: 0.0 # 1e-4 +# currently fixed to 1.0 (due to limitations with flex_attention and triton) +forecast_att_dense_rate: 1.0 +with_step_conditioning: True # False + +healpix_level: 5 + +with_mixed_precision: True +with_flash_attention: True +compile_model: False +with_fsdp: False +ddp_find_unused_parameters: False +attention_dtype: bf16 +mixed_precision_dtype: bf16 +mlp_norm_eps: 1e-5 +norm_eps: 1e-4 + +latent_noise_kl_weight: 0.0 # 1e-5 +latent_noise_gamma: 2.0 +latent_noise_saturate_encodings: 5 +latent_noise_use_additive_noise: False +latent_noise_deterministic_latents: True + +freeze_modules: "" + +norm_type: "LayerNorm" +# type of zarr_store +zarr_store: "zip" # "zarr" for LocalStore, "zip" for ZipStore + +##################################### + +# streams_directory: "./config/streams/era5_1deg/" +# streams_directory: "./config/streams/era5_nppatms_synop/" streams_directory: "./integration_tests/streams/" -run_path: "./results" -model_path: "./models" -loss_fcts: [["mse", 1.0]] -loss_fcts_val: [["mse", 1.0]] -num_mini_epochs: 1 -samples_per_mini_epoch: 100 -samples_per_validation: 5 -lr_steps: 4 -lr_steps_warmup: 2 -lr_steps_cooldown: 2 -loader_num_workers: 8 - -sslpred_num_blocks: 12 -sslpred_num_heads: 12 -sslpred_dropout_rate: 0.1 -sslpred_with_qk_lnorm: True -sslpred_intermediate_dim: 384 - -train_log: - log_interval: 1 -### Example validation and training config for student-teacher with JEPA -validation_config: - losses: - # null - LossPhysical: {'weight': 0.0} - LossLatentSSLStudentTeacher: { - "weight": 1.0, - "JEPA": {'weight': 5, "loss_extra_args": {}, "out_dim": 2048} } -### Student-teacher configuration (only used when training_mode == "student_teacher") +streams: ??? + +general: + + # mutable parameters + istep: 0 + rank: ??? + world_size: ??? + + # local_rank, + # with_ddp, + # data_path_*, + # model_path, + # run_path, + # path_shared_ + + multiprocessing_method: "fork" + + desc: "" + run_id: ??? + run_history: [] + +# logging frequency in the training loop (in number of batches) +train_log_freq: + terminal: 10 + metrics: 20 + checkpoint: 250 + +# parameters for data loading +data_loading : + + num_workers: 12 + rng_seed: ??? + + +# config for training training_config: - # null - # when this is "masking", we are basically only using the model_input subconfig - training_mode: "student_teacher" # "masking", "student_teacher", "forecast" - target_and_aux_calc: "EMATeacher" - losses : - LossPhysical: {weight: 0.0} - LossLatentSSLStudentTeacher: { - "weight": 1.0, - "JEPA": {'weight': 5, "loss_extra_args": {}, "out_dim": 2048} } - model_input: - - masking_strategy: "random" # "random", "healpix". Masking strategy to use for model input for masking, and local (student) views when doing student-teacher - num_samples: 1 # if student-teacher, the number of local (student) views to generate - masking_strategy_config : { diffusion_rn : False, rate : 0.4 } - # relationship: "independent" #, "subset", "disjoint". Relationship of student views to teacher view. - relationship: "subset" # "independent", "subset", "disjoint". Relationship of student views to teacher view. - loss : jepa - rate_sampling: False # randomly sample the rate per batch - - target_input: - - masking_strategy: "healpix" # Strategy for teacher (global) view: "random", "healpix" - masking_strategy_config : { diffusion_rn : False, rate : 0.4, hl_mask: 0 } - num_samples: 1 # number of teacher views to generate - rate_sampling: False # randomly sample the rate per batch + + # training_mode: "masking", "student_teacher", "latent_loss" + training_mode: ["student_teacher"] + + num_mini_epochs: 1 + samples_per_mini_epoch: 64 + shuffle: True + + start_date: 1979-01-01T00:00 + end_date: 2022-12-31T00:00 + + time_window_step: 06:00:00 + time_window_len: 06:00:00 + + window_offset_prediction : 0 + + learning_rate_scheduling : + lr_start: 1e-6 + lr_max: 5e-5 + lr_final_decay: 1e-6 + lr_final: 0.0 + num_steps_warmup: 4 + num_steps_cooldown: 2 + policy_warmup: "cosine" + policy_decay: "constant" + policy_cooldown: "linear" + parallel_scaling_policy: "sqrt" + + optimizer: + grad_clip: 1.0 + weight_decay: 0.1 + log_grad_norms: False + adamw : + # parameters are scaled by number of DDP workers + beta1 : 0.975 + beta2 : 0.9875 + eps : 2e-08 + + losses : { + "student-teacher": { + enabled: True, + type: LossLatentSSLStudentTeacher, + weight: 1.0, + loss_fcts : { + "JEPA": { + 'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer, + "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, + "dropout_rate": 0.1, + target_source_correspondence: {0 : {0 : "subset"} }, + }, + }, + target_and_aux_calc: { "EMATeacher" : + { ema_ramp_up_ratio : 0.09, + ema_halflife_in_thousands: 1e-3, + model_param_overrides : { + training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}} + }, + } + } + } + } + + model_input: { + "random_easy" : { + # masking strategy: "random", "forecast" + masking_strategy: "random", + num_samples: 1, + num_steps_input: 1, + masking_strategy_config : { + diffusion_rn : True, + rate : 0.6, + rate_sampling: False + }, + }, + } + + target_input: { + "random_easy_target" : { + masking_strategy: "healpix", + num_samples: 1, + masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False }, + }, + } + + forecast : + time_step: 00:00:00 + num_steps: 0 + policy: null + + +# validation config; full validation config is merge of training and validation config +validation_config: + + samples_per_mini_epoch: 32 + shuffle: False + + start_date: 2023-10-01T00:00 + end_date: 2023-12-31T00:00 + + # whether to track the exponential moving average of weights for validation + validate_with_ema: + enabled : False + ema_ramp_up_ratio: 0.09 + ema_halflife_in_thousands: 1e-3 + + # number of validation samples that are written to disk + write_num_samples: 0 + # output streams to write; default all + output_streams: null + + # run validation before training starts (mainly for model development) + validate_before_training: False + +# test config; full test config is merge of validation and test config +# test config is used by default when running inference + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: null + # The Github issue corresponding to this run (number such as 1234) + # Github issues are the central point when running experiment and contain + # links to hedgedocs, code branches, pull requests etc. + # It is recommended to associate a run with a Github issue. + issue: null + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow, along with the + # issue number. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: null + # *** Experiment-specific tags *** + # All extra tags (including lists, dictionaries, etc.) are treated + # as strings by mlflow, so treat all extra tags as simple string key: value pairs. + grid: null diff --git a/integration_tests/jepa1_test.py b/integration_tests/jepa1_test.py index f2959f3c9..6380e2c8f 100644 --- a/integration_tests/jepa1_test.py +++ b/integration_tests/jepa1_test.py @@ -50,7 +50,7 @@ def test_train(setup, test_run_id): logger.info(f"test_train with run_id {test_run_id} {WEATHERGEN_HOME}") train_with_args( - [ f"--config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml" ] + [ f"--base-config={WEATHERGEN_HOME}/integration_tests/jepa1.yaml" ] + [ "--run-id", test_run_id, From b02eac4ea2011824774830df940df771b284b3d5 Mon Sep 17 00:00:00 2001 From: Kerem Tezcan Date: Tue, 27 Jan 2026 10:24:34 +0100 Subject: [PATCH 2/2] update actions.sh --- scripts/actions.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/actions.sh b/scripts/actions.sh index f0f4852ec..4eba85243 100755 --- a/scripts/actions.sh +++ b/scripts/actions.sh @@ -108,6 +108,13 @@ case "$1" in uv run --offline pytest ./integration_tests/jepa1_test.py --verbose -s ) ;; + integration-test-dinov2) + ( + cd "$SCRIPT_DIR" || exit 1 + uv sync --offline --all-packages --extra gpu + uv run --offline pytest ./integration_tests/dinov21_test.py --verbose -s + ) + ;; integration-test) ( cd "$SCRIPT_DIR" || exit 1