From a0039ecfebd8634588a3f5afa44e66f712fa365f Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 6 Aug 2025 12:24:36 +0000 Subject: [PATCH 01/40] Log gradient norms --- src/weathergen/train/trainer.py | 18 ++++++++++++++++-- src/weathergen/utils/train_logger.py | 2 ++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 56a28d089..46653b9e9 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -76,6 +76,7 @@ def init( self.init_perf_monitoring() self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) + self.last_grad_norm = 0.0 def inference(self, cf, run_id_trained, epoch): # general initalization @@ -482,7 +483,19 @@ def train(self, epoch): # gradient clipping self.grad_scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.ddp_model.parameters(), max_norm=cf.grad_clip) + total_norm = torch.nn.utils.clip_grad_norm_( + self.ddp_model.parameters(), max_norm=cf.grad_clip + ) + + # log gradient norms + if bidx % log_interval == 0: + grad_norms = { "total_grad_norm" : total_norm.item() } + self.last_grad_norm = total_norm.item() + for name, param in self.ddp_model.named_parameters(): + if param.grad is not None: + grad_norms[name] = param.grad.norm().item() + self.train_logger.log_metrics(TRAIN, grad_norms) + # optimizer step self.grad_scaler.step(self.optimizer) @@ -718,7 +731,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): # samples per sec dt = time.time() - self.t_start pstr = "{:03d} : {:05d}/{:05d} : {:06d} : loss = {:.4E} " - pstr += "(lr={:.2E}, s/sec={:.3f})" + pstr += "(lr={:.2E}, gradient norm={:.3f}, s/sec={:.3f})" len_dataset = len(self.data_loader) // self.cf.batch_size_per_gpu print( pstr.format( @@ -728,6 +741,7 @@ def _log_terminal(self, bidx: int, epoch: int, stage: Stage): self.cf.istep, avg_loss.nanmean().item(), self.lr_scheduler.get_lr(), + self.last_grad_norm, (self.print_freq * self.cf.batch_size_per_gpu) / dt, ), flush=True, diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index be70a243b..c4db39172 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -146,6 +146,8 @@ def add_train( metrics[_performance_gpu] = perf_gpu if perf_mem > 0.0: metrics[_performance_memory] = perf_mem + + self.log_metrics("train", metrics) with open(self.path_run / (self.cf.run_id + "_perf_log.txt"), "ab") as f: np.savetxt(f, log_vals) From e83903b5f6799854933550dbe3ef4b0ac36b227c Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Wed, 6 Aug 2025 15:24:05 +0000 Subject: [PATCH 02/40] Prototype for recording grad norms --- pyproject.toml | 1 + src/weathergen/train/trainer.py | 2 +- src/weathergen/utils/plot_grad_norms.py | 483 ++++++++++++++++++++++++ uv.lock | 16 + 4 files changed, 501 insertions(+), 1 deletion(-) create mode 100644 src/weathergen/utils/plot_grad_norms.py diff --git a/pyproject.toml b/pyproject.toml index aa6232bb8..7511b0327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "dask~=2025.5.1", "hatchling", "weathergen-common", + "seaborn>=0.13.2", ] [project.urls] diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 46653b9e9..4430211ac 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -493,7 +493,7 @@ def train(self, epoch): self.last_grad_norm = total_norm.item() for name, param in self.ddp_model.named_parameters(): if param.grad is not None: - grad_norms[name] = param.grad.norm().item() + grad_norms["grad_norm_" + name] = param.grad.norm().item() self.train_logger.log_metrics(TRAIN, grad_norms) diff --git a/src/weathergen/utils/plot_grad_norms.py b/src/weathergen/utils/plot_grad_norms.py new file mode 100644 index 000000000..8a6ded4ac --- /dev/null +++ b/src/weathergen/utils/plot_grad_norms.py @@ -0,0 +1,483 @@ +import json +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from pathlib import Path +import seaborn as sns +from collections import defaultdict +import re + +class GradientNormsAnalyzer: + def __init__(self, json_file_path): + """ + Initialize the analyzer with path to JSON file containing gradient norms. + Expected format: one JSON object per line with step info and gradient norms. + """ + self.json_file_path = Path(json_file_path) + self.data = [] + self.df = None + self.load_data() + + def load_data(self): + """Load and parse the JSON data from file.""" + print(f"Loading data from {self.json_file_path}...") + + with open(self.json_file_path, 'r') as f: + for line_num, line in enumerate(f, 1): + try: + data_point = json.loads(line.strip()) + self.data.append(data_point) + except json.JSONDecodeError as e: + print(f"Warning: Could not parse line {line_num}: {e}") + + print(f"Loaded {len(self.data)} data points") + self.create_dataframe() + + def create_dataframe(self): + """Convert loaded data into a pandas DataFrame for easier analysis.""" + rows = [] + + for ith, entry in enumerate(self.data): + # step = entry.get('num_samples', entry.get('epoch', 0)) + step = ith * 5 + + # Handle different possible data structures + if 'gradients' in entry: + grad_data = entry['gradients'] + elif 'grad_norms' in entry: + grad_data = entry['grad_norms'] + else: + # Assume all keys except step/epoch are gradient data + grad_data = {k: v for k, v in entry.items() + if 'stream' not in k and ('q_cells' in k or '0' in k)} + + for param_name, norm_value in grad_data.items(): + rows.append({ + 'num_samples': step, + 'parameter': param_name, + 'grad_norm': float(norm_value), + 'layer_type': self.extract_layer_type(param_name), + 'layer_depth': self.extract_layer_depth(param_name) + }) + + self.df = pd.DataFrame(rows) + print(f"Created DataFrame with {len(self.df)} gradient norm records") + + def extract_layer_type(self, param_name): + """Extract layer type from parameter name.""" + param_name_lower = param_name.lower() + + # Handle your specific naming patterns + if param_name_lower.startswith('embeds.'): + if '.embed.' in param_name_lower: + return 'embedding' + elif '.unembed.' in param_name_lower: + return 'unembedding' + elif '.ln_final.' in param_name_lower: + return 'layer_norm_final' + elif 'proj_heads_q' in param_name_lower: + return 'attention_q' + elif 'proj_heads_k' in param_name_lower: + return 'attention_k' + elif 'proj_heads_v' in param_name_lower: + return 'attention_v' + elif 'proj_out' in param_name_lower: + return 'attention_out' + elif '.layers.' in param_name_lower and ('weight' in param_name_lower or 'bias' in param_name_lower): + return 'ffn' + else: + return 'embeds_other' + + elif param_name_lower.startswith('ae_local_blocks.'): + if 'proj_heads_q' in param_name_lower: + return 'ae_local_attention_q' + elif 'proj_heads_k' in param_name_lower: + return 'ae_local_attention_k' + elif 'proj_heads_v' in param_name_lower: + return 'ae_local_attention_v' + elif 'proj_out' in param_name_lower: + return 'ae_local_attention_out' + elif '.layers.' in param_name_lower: + return 'ae_local_ffn' + else: + return 'ae_local_other' + + elif param_name_lower.startswith('ae_global_blocks.'): + if 'proj_heads_q' in param_name_lower: + return 'ae_global_attention_q' + elif 'proj_heads_k' in param_name_lower: + return 'ae_global_attention_k' + elif 'proj_heads_v' in param_name_lower: + return 'ae_global_attention_v' + elif 'proj_out' in param_name_lower: + return 'ae_global_attention_out' + elif '.layers.' in param_name_lower: + return 'ae_global_ffn' + else: + return 'ae_global_other' + + elif param_name_lower.startswith('ae_adapter.'): + if 'proj_heads_q' in param_name_lower: + return 'ae_adapter_attention_q' + elif 'proj_heads_k' in param_name_lower: + return 'ae_adapter_attention_k' + elif 'proj_heads_v' in param_name_lower: + return 'ae_adapter_attention_v' + elif 'proj_out' in param_name_lower: + return 'ae_adapter_attention_out' + elif '.layers.' in param_name_lower: + return 'ae_adapter_ffn' + else: + return 'ae_adapter_other' + + elif param_name_lower.startswith('target_token_engines.'): + if 'proj_heads_q' in param_name_lower: + return 'tte_attention_q' + elif 'proj_heads_k' in param_name_lower: + return 'tte_attention_k' + elif 'proj_heads_v' in param_name_lower: + return 'tte_attention_v' + elif 'proj_out' in param_name_lower: + return 'tte_attention_out' + elif 'embed_aux' in param_name_lower: + return 'tte_embed_aux' + elif 'lnorm' in param_name_lower: + return 'tte_layer_norm' + elif '.layers.' in param_name_lower: + return 'tte_ffn' + else: + return 'tte_other' + + elif param_name_lower.startswith('embed_target_coords.'): + return 'target_coords_embedding' + + elif param_name_lower.startswith('pred_heads.'): + return 'prediction_head' + + # Fallback for standard patterns (if any) + elif 'embed' in param_name_lower: + return 'embedding' + elif 'attention' in param_name_lower or 'attn' in param_name_lower: + if 'q_proj' in param_name_lower or 'query' in param_name_lower: + return 'attention_q' + elif 'k_proj' in param_name_lower or 'key' in param_name_lower: + return 'attention_k' + elif 'v_proj' in param_name_lower or 'value' in param_name_lower: + return 'attention_v' + elif 'o_proj' in param_name_lower or 'out' in param_name_lower: + return 'attention_out' + else: + return 'attention' + elif 'layernorm' in param_name_lower or 'layer_norm' in param_name_lower or 'ln' in param_name_lower: + return 'layernorm' + else: + return 'other' + + def extract_layer_depth(self, param_name): + """Extract layer depth/index from parameter name.""" + param_name_lower = param_name.lower() + + # Look for patterns specific to your architecture + patterns = [ + # embeds.0.layers.N.* (transformer layers within embeds) + r'embeds\.\d+\.layers\.(\d+)\.', + # embeds.0.unembed.N.* (unembedding layers) + r'embeds\.\d+\.unembed\.(\d+)\.', + # embeds.0.ln_final.N.* (final layer norms) + r'embeds\.\d+\.ln_final\.(\d+)\.', + # ae_local_blocks.N.* (autoencoder local blocks) + r'ae_local_blocks\.(\d+)\.', + # ae_global_blocks.N.* (autoencoder global blocks) + r'ae_global_blocks\.(\d+)\.', + # ae_adapter.N.* (autoencoder adapter blocks) + r'ae_adapter\.(\d+)\.', + # target_token_engines.0.tte.N.* (target token engine blocks) + r'target_token_engines\.\d+\.tte\.(\d+)\.', + # target_token_engines.0.tte.N.block.M.* (nested blocks) + r'target_token_engines\.\d+\.tte\.(\d+)\.block\.(\d+)\.', + # pred_heads.0.pred_heads.0.N.* (prediction head layers) + r'pred_heads\.\d+\.pred_heads\.\d+\.(\d+)\.', + # Generic patterns for any numbered layers + r'layer[s]?\.(\d+)', + r'h\.(\d+)', + r'transformer\.(\d+)', + r'blocks\.(\d+)', + ] + + for pattern in patterns: + match = re.search(pattern, param_name_lower) + if match: + # For nested patterns (like tte blocks), combine indices + if len(match.groups()) > 1: + # Combine indices: e.g., tte.1.block.2 -> 12 (or 1*10+2) + return int(match.group(1)) * 10 + int(match.group(2)) + else: + return int(match.group(1)) + + # Special handling for components without clear depth + if param_name_lower.startswith('embed_target_coords.'): + return 0 # Coordinate embeddings at the start + elif 'total_grad_norm' in param_name_lower: + return -2 # Special marker for total norm + elif any(x in param_name_lower for x in ['weathergen', 'stage', 'q_cells']): + return -3 # Special marker for metadata + + return -1 # Unknown depth + + def plot_total_gradient_norms(self, figsize=(12, 6)): + """Plot total gradient norm over training steps.""" + # Calculate total norm per step + total_norms = [] + steps = [] + + for ith, entry in enumerate(self.data): + # step = entry.get('num_samples', entry.get('epoch', 0)) + step = ith * 5 + + if 'gradients' in entry: + grad_data = entry['gradients'] + elif 'grad_norms' in entry: + grad_data = entry['grad_norms'] + else: + grad_data = {k: v for k, v in entry.items() + if 'q_cells' in k or '0' in k} + + if len(grad_data) == 0: + continue + + # Calculate total norm (L2 norm of all gradients) + total_norm = np.sqrt(sum(float(v)**2 for v in grad_data.values())) + total_norms.append(total_norm) + steps.append(step) + + plt.figure(figsize=figsize) + plt.plot(steps, total_norms, linewidth=1.5, alpha=0.8) + plt.xlabel('Training Step') + plt.ylabel('Total Gradient Norm') + plt.title('Total Gradient Norm vs Training Steps') + plt.yscale('log') + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig("plots/total_grad_norm.png") + + return steps, total_norms + + def plot_layer_type_norms(self, figsize=(14, 8)): + """Plot gradient norms grouped by layer type.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + plt.figure(figsize=figsize) + + # Get unique layer types + layer_types = self.df['layer_type'].unique() + print(layer_types) + colors = plt.cm.tab10(np.linspace(0, 1, len(layer_types))) + + for i, layer_type in enumerate(layer_types): + layer_data = self.df[self.df['layer_type'] == layer_type] + + # Calculate mean gradient norm per step for this layer type + mean_norms = layer_data.groupby('num_samples')['grad_norm'].mean() + + plt.plot(mean_norms.index, mean_norms.values, + label=layer_type, color=colors[i], alpha=0.8) + + plt.xlabel('Training Step') + plt.ylabel('Mean Gradient Norm') + plt.title('Gradient Norms by Layer Type') + plt.yscale('log') + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig("plots/grad_norm_by_layer_type.png") + + def plot_layer_depth_analysis(self, figsize=(12, 8)): + """Plot gradient norms by layer depth.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + # Filter out unknown depths + depth_data = self.df[self.df['layer_depth'] >= 0] + + if len(depth_data) == 0: + print("No layer depth information found in parameter names.") + return + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) + + # Plot 1: Mean gradient norm by depth over time + depths = sorted(depth_data['layer_depth'].unique()) + colors = plt.cm.viridis(np.linspace(0, 1, len(depths))) + + for i, depth in enumerate(depths): + layer_data = depth_data[depth_data['layer_depth'] == depth] + mean_norms = layer_data.groupby('num_samples')['grad_norm'].mean() + + ax1.plot(mean_norms.index, mean_norms.values, + label=f'Layer {depth}', color=colors[i], alpha=0.8) + + ax1.set_xlabel('Training Step') + ax1.set_ylabel('Mean Gradient Norm') + ax1.set_title('Gradient Norms by Layer Depth') + ax1.set_yscale('log') + ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + ax1.grid(True, alpha=0.3) + + # Plot 2: Heatmap of gradient norms by depth and step + pivot_data = depth_data.groupby(['num_samples', 'layer_depth'])['grad_norm'].mean().unstack() + + # Sample data if too many steps for readability + if len(pivot_data) > 100: + sample_idx = np.linspace(0, len(pivot_data)-1, 100, dtype=int) + pivot_data = pivot_data.iloc[sample_idx] + + im = ax2.imshow(pivot_data.T, aspect='auto', cmap='viridis', + extent=[pivot_data.index.min(), pivot_data.index.max(), + pivot_data.columns.min(), pivot_data.columns.max()]) + ax2.set_xlabel('Training Step') + ax2.set_ylabel('Layer Depth') + ax2.set_title('Gradient Norm Heatmap (Layer Depth vs Step)') + + cbar = plt.colorbar(im, ax=ax2) + cbar.set_label('Gradient Norm') + + plt.tight_layout() + plt.savefig("plots/grad_norm_heatmap.png") + + def plot_gradient_distribution(self, figsize=(15, 10)): + """Plot distribution of gradient norms.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + fig, axes = plt.subplots(2, 2, figsize=figsize) + + # Plot 1: Histogram of all gradient norms + axes[0, 0].hist(np.log10(self.df['grad_norm'].values), bins=50, alpha=0.7) + axes[0, 0].set_xlabel('Log10(Gradient Norm)') + axes[0, 0].set_ylabel('Frequency') + axes[0, 0].set_title('Distribution of Gradient Norms (Log Scale)') + axes[0, 0].grid(True, alpha=0.3) + + # Plot 2: Box plot by layer type + layer_types = self.df['layer_type'].unique()[:10] # Limit to 10 for readability + plot_data = [np.log10(self.df[self.df['layer_type'] == lt]['grad_norm'].values) + for lt in layer_types] + + axes[0, 1].boxplot(plot_data, labels=layer_types) + axes[0, 1].set_xlabel('Layer Type') + axes[0, 1].set_ylabel('Log10(Gradient Norm)') + axes[0, 1].set_title('Gradient Norm Distribution by Layer Type') + axes[0, 1].tick_params(axis='x', rotation=45) + axes[0, 1].grid(True, alpha=0.3) + + # Plot 3: Gradient norms over time (sample of parameters) + sample_params = self.df['parameter'].unique()[:20] # Sample 20 parameters + for param in sample_params: + param_data = self.df[self.df['parameter'] == param] + axes[1, 0].plot(param_data['num_samples'], param_data['grad_norm'], + alpha=0.6, linewidth=0.8) + + axes[1, 0].set_xlabel('Training Step') + axes[1, 0].set_ylabel('Gradient Norm') + axes[1, 0].set_title('Individual Parameter Gradient Norms (Sample)') + axes[1, 0].set_yscale('log') + axes[1, 0].grid(True, alpha=0.3) + + # Plot 4: Statistics over time + stats_by_step = self.df.groupby('num_samples')['grad_norm'].agg(['mean', 'std', 'min', 'max']) + + axes[1, 1].fill_between(stats_by_step.index, + stats_by_step['mean'] - stats_by_step['std'], + stats_by_step['mean'] + stats_by_step['std'], + alpha=0.3, label='±1 std') + axes[1, 1].plot(stats_by_step.index, stats_by_step['mean'], + label='Mean', linewidth=2) + axes[1, 1].plot(stats_by_step.index, stats_by_step['max'], + label='Max', linewidth=1, alpha=0.8) + axes[1, 1].plot(stats_by_step.index, stats_by_step['min'], + label='Min', linewidth=1, alpha=0.8) + + axes[1, 1].set_xlabel('Training Step') + axes[1, 1].set_ylabel('Gradient Norm') + axes[1, 1].set_title('Gradient Norm Statistics Over Time') + axes[1, 1].set_yscale('log') + axes[1, 1].legend() + axes[1, 1].grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig("plots/grad_norm_over_time.png") + + def generate_summary_report(self): + """Generate a summary report of gradient norm statistics.""" + if self.df is None: + print("No DataFrame available. Load data first.") + return + + print("=== GRADIENT NORMS ANALYSIS REPORT ===") + print(f"Total data points: {len(self.df)}") + print(f"Training steps: {self.df['num_samples'].nunique()}") + print(f"Unique parameters: {self.df['parameter'].nunique()}") + print() + + print("Overall Statistics:") + print(f"Mean gradient norm: {self.df['grad_norm'].mean():.6f}") + print(f"Median gradient norm: {self.df['grad_norm'].median():.6f}") + print(f"Min gradient norm: {self.df['grad_norm'].min():.6f}") + print(f"Max gradient norm: {self.df['grad_norm'].max():.6f}") + print() + + print("Statistics by Layer Type:") + layer_stats = self.df.groupby('layer_type')['grad_norm'].agg(['count', 'mean', 'std', 'min', 'max']) + print(layer_stats) + print() + + # Check for potential issues + print("Potential Issues:") + very_small = (self.df['grad_norm'] < 1e-6).sum() + very_large = (self.df['grad_norm'] > 10.0).sum() + + if very_small > 0: + print(f"⚠️ {very_small} gradient norms < 1e-6 (possible vanishing gradients)") + if very_large > 0: + print(f"⚠️ {very_large} gradient norms > 10.0 (possible exploding gradients)") + + if very_small == 0 and very_large == 0: + print("✅ No obvious gradient issues detected") + +# Usage example +def analyze_gradient_file(json_file_path): + """ + Main function to analyze gradient norms from a JSON file. + + Usage: + analyze_gradient_file('gradient_norms.jsonl') + """ + + analyzer = GradientNormsAnalyzer(json_file_path) + + # Generate summary report + analyzer.generate_summary_report() + + # Create all plots + print("\n=== GENERATING PLOTS ===") + + print("1. Total gradient norms over time...") + analyzer.plot_total_gradient_norms() + + print("2. Gradient norms by layer type...") + analyzer.plot_layer_type_norms() + + print("3. Layer depth analysis...") + analyzer.plot_layer_depth_analysis() + + print("4. Gradient distribution analysis...") + analyzer.plot_gradient_distribution() + + return analyzer + +# Example usage: +analyzer = analyze_gradient_file('results/yvhxm2jc/yvhxm2jc_train_metrics.json') diff --git a/uv.lock b/uv.lock index 51d6a0485..253e7171a 100644 --- a/uv.lock +++ b/uv.lock @@ -1614,6 +1614,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, ] +[[package]] +name = "seaborn" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, +] + [[package]] name = "semantic-version" version = "2.10.0" @@ -1897,6 +1911,7 @@ dependencies = [ { name = "polars" }, { name = "psutil" }, { name = "pynvml" }, + { name = "seaborn" }, { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux' and sys_platform != 'macosx' and sys_platform != 'win32'" }, { name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "sys_platform == 'macosx'" }, { name = "torch", version = "2.6.0+cu124", source = { registry = "https://download.pytorch.org/whl/cu124" }, marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -1928,6 +1943,7 @@ requires-dist = [ { name = "polars", specifier = "~=1.25.2" }, { name = "psutil" }, { name = "pynvml" }, + { name = "seaborn", specifier = ">=0.13.2" }, { name = "torch", marker = "sys_platform != 'linux' and sys_platform != 'macosx' and sys_platform != 'win32'", specifier = "==2.6.0" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'win32'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cu124" }, { name = "torch", marker = "sys_platform == 'macosx'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu" }, From d2995b4b6d2b3a7b2ac71c4312eb670f082f1298 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 7 Aug 2025 10:17:48 +0000 Subject: [PATCH 03/40] Address review changes + hide behind feature flag --- config/default_config.yml | 1 + src/weathergen/train/trainer.py | 29 +++++++++++++++++++------ src/weathergen/utils/plot_grad_norms.py | 5 ++++- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index e8f21204a..403b1c20d 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -105,6 +105,7 @@ grad_clip: 1.0 weight_decay: 0.1 norm_type: "LayerNorm" nn_module: "te" +log_grad_norms: True start_date: 197901010000 end_date: 202012310000 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 4430211ac..b8bf07ea4 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -13,6 +13,8 @@ import time from typing import Any +from omegaconf import OmegaConf + import numpy as np import torch import tqdm @@ -54,6 +56,10 @@ def init( ): self.cf = cf + self.cf = OmegaConf.merge( + OmegaConf.create({"log_grad_norms": False}), self.cf + ) + assert cf.samples_per_epoch % cf.batch_size_per_gpu == 0 assert cf.samples_per_validation % cf.batch_size_validation_per_gpu == 0 assert cf.forecast_policy if cf.forecast_steps > 0 else True @@ -76,7 +82,6 @@ def init( self.init_perf_monitoring() self.train_logger = TrainLogger(cf, config.get_path_run(self.cf)) - self.last_grad_norm = 0.0 def inference(self, cf, run_id_trained, epoch): # general initalization @@ -459,6 +464,7 @@ def train(self, epoch): # Unweighted loss, real weighted loss, std for losses that need it self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + self.last_grad_norm = 0.0 # training loop self.t_start = time.time() @@ -489,12 +495,7 @@ def train(self, epoch): # log gradient norms if bidx % log_interval == 0: - grad_norms = { "total_grad_norm" : total_norm.item() } - self.last_grad_norm = total_norm.item() - for name, param in self.ddp_model.named_parameters(): - if param.grad is not None: - grad_norms["grad_norm_" + name] = param.grad.norm().item() - self.train_logger.log_metrics(TRAIN, grad_norms) + self._log_instant_grad_norms(TRAIN, total_norm) # optimizer step @@ -709,6 +710,20 @@ def _log(self, stage: Stage): self.loss_unweighted_hist, self.loss_model_hist, self.stdev_unweighted_hist = [], [], [] + def _log_instant_grad_norms(self, stage: Stage, total_norm): + """ + Log instantaneous grad norms, we do not average because of the cost and because we want to + measure the actual values + + TODO test DDP case + """ + grad_norms = { "total_grad_norm" : total_norm.item() } + self.last_grad_norm = total_norm.item() + for name, param in self.ddp_model.named_parameters(): + if param.grad is not None: + grad_norms["grad_norm_" + name] = param.grad.norm().item() + self.train_logger.log_metrics(TRAIN, grad_norms) + def _log_terminal(self, bidx: int, epoch: int, stage: Stage): if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL: # compute from last iteration diff --git a/src/weathergen/utils/plot_grad_norms.py b/src/weathergen/utils/plot_grad_norms.py index 8a6ded4ac..0ff1a1f5c 100644 --- a/src/weathergen/utils/plot_grad_norms.py +++ b/src/weathergen/utils/plot_grad_norms.py @@ -480,4 +480,7 @@ def analyze_gradient_file(json_file_path): return analyzer # Example usage: -analyzer = analyze_gradient_file('results/yvhxm2jc/yvhxm2jc_train_metrics.json') +# uv run python src/weathergen/utils/plot_grad_norms.py results/yvhxm2jc/yvhxm2jc_train_metrics.json +if __name__ == '__main__': + import sys + analyzer = analyze_gradient_file(sys.argv[1]) From 26c6869eccfc595173db9f11e94ad3f62b1ad210 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 7 Aug 2025 10:49:05 +0000 Subject: [PATCH 04/40] Final fixes including backward compatibility --- config/default_config.yml | 2 +- src/weathergen/train/trainer.py | 13 +++++-------- src/weathergen/utils/plot_grad_norms.py | 14 +++++++------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 403b1c20d..9fa9d359e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -105,7 +105,7 @@ grad_clip: 1.0 weight_decay: 0.1 norm_type: "LayerNorm" nn_module: "te" -log_grad_norms: True +log_grad_norms: False start_date: 197901010000 end_date: 202012310000 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index b8bf07ea4..9619c93d2 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -56,10 +56,6 @@ def init( ): self.cf = cf - self.cf = OmegaConf.merge( - OmegaConf.create({"log_grad_norms": False}), self.cf - ) - assert cf.samples_per_epoch % cf.batch_size_per_gpu == 0 assert cf.samples_per_validation % cf.batch_size_validation_per_gpu == 0 assert cf.forecast_policy if cf.forecast_steps > 0 else True @@ -72,6 +68,8 @@ def init( # num_ranks gets overwritten by current setting during init_ddp() self.num_ranks_original = cf.get("num_ranks", None) + self.log_grad_norms = cf.get("log_grad_norms", False) + # TODO remove num_ranks, rank, with_with ddp from config self.init_ddp(cf) @@ -494,10 +492,9 @@ def train(self, epoch): ) # log gradient norms - if bidx % log_interval == 0: + if bidx % log_interval == 0 and self.log_grad_norms: self._log_instant_grad_norms(TRAIN, total_norm) - # optimizer step self.grad_scaler.step(self.optimizer) self.grad_scaler.update() @@ -712,12 +709,12 @@ def _log(self, stage: Stage): def _log_instant_grad_norms(self, stage: Stage, total_norm): """ - Log instantaneous grad norms, we do not average because of the cost and because we want to + Log instantaneous grad norms, we do not average because of the cost and because we want to measure the actual values TODO test DDP case """ - grad_norms = { "total_grad_norm" : total_norm.item() } + grad_norms = {"total_grad_norm": total_norm.item()} self.last_grad_norm = total_norm.item() for name, param in self.ddp_model.named_parameters(): if param.grad is not None: diff --git a/src/weathergen/utils/plot_grad_norms.py b/src/weathergen/utils/plot_grad_norms.py index 0ff1a1f5c..de50ad8f5 100644 --- a/src/weathergen/utils/plot_grad_norms.py +++ b/src/weathergen/utils/plot_grad_norms.py @@ -49,7 +49,7 @@ def create_dataframe(self): else: # Assume all keys except step/epoch are gradient data grad_data = {k: v for k, v in entry.items() - if 'stream' not in k and ('q_cells' in k or '0' in k)} + if 'stream' not in k and ('grad_norm' in k)} for param_name, norm_value in grad_data.items(): rows.append({ @@ -65,7 +65,7 @@ def create_dataframe(self): def extract_layer_type(self, param_name): """Extract layer type from parameter name.""" - param_name_lower = param_name.lower() + param_name_lower = param_name.lower()[10:] # Handle your specific naming patterns if param_name_lower.startswith('embeds.'): @@ -180,13 +180,13 @@ def extract_layer_depth(self, param_name): # Look for patterns specific to your architecture patterns = [ # embeds.0.layers.N.* (transformer layers within embeds) - r'embeds\.\d+\.layers\.(\d+)\.', + r'grad_norm_embeds\.\d+\.layers\.(\d+)\.', # embeds.0.unembed.N.* (unembedding layers) - r'embeds\.\d+\.unembed\.(\d+)\.', + r'grad_norm_embeds\.\d+\.unembed\.(\d+)\.', # embeds.0.ln_final.N.* (final layer norms) - r'embeds\.\d+\.ln_final\.(\d+)\.', + r'grad_norm_embeds\.\d+\.ln_final\.(\d+)\.', # ae_local_blocks.N.* (autoencoder local blocks) - r'ae_local_blocks\.(\d+)\.', + r'grad_norm_ae_local_blocks\.(\d+)\.', # ae_global_blocks.N.* (autoencoder global blocks) r'ae_global_blocks\.(\d+)\.', # ae_adapter.N.* (autoencoder adapter blocks) @@ -240,7 +240,7 @@ def plot_total_gradient_norms(self, figsize=(12, 6)): grad_data = entry['grad_norms'] else: grad_data = {k: v for k, v in entry.items() - if 'q_cells' in k or '0' in k} + if 'grad_norm' in k} if len(grad_data) == 0: continue From 9a66f7217d79a44700fa6d4280ae9b0f2eccc714 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 7 Aug 2025 10:51:40 +0000 Subject: [PATCH 05/40] Ruff --- src/weathergen/train/trainer.py | 2 -- src/weathergen/utils/train_logger.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 41a9aab68..b65987484 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -13,8 +13,6 @@ import time from typing import Any -from omegaconf import OmegaConf - import numpy as np import torch import tqdm diff --git a/src/weathergen/utils/train_logger.py b/src/weathergen/utils/train_logger.py index b6840df31..f60e748f7 100644 --- a/src/weathergen/utils/train_logger.py +++ b/src/weathergen/utils/train_logger.py @@ -149,7 +149,6 @@ def add_train( if perf_mem > 0.0: metrics[_performance_memory] = perf_mem - self.log_metrics("train", metrics) with open(self.path_run / (self.cf.run_id + "_perf_log.txt"), "ab") as f: np.savetxt(f, log_vals) From 22a6fd72d9903dfd9b804d463a93c9d06df782f8 Mon Sep 17 00:00:00 2001 From: sophiex <24638638+sophie-xhonneux@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:01:38 +0000 Subject: [PATCH 06/40] More ruff stuff --- src/weathergen/utils/plot_grad_norms.py | 593 +++++++++++++----------- 1 file changed, 316 insertions(+), 277 deletions(-) diff --git a/src/weathergen/utils/plot_grad_norms.py b/src/weathergen/utils/plot_grad_norms.py index de50ad8f5..ec310c0fc 100644 --- a/src/weathergen/utils/plot_grad_norms.py +++ b/src/weathergen/utils/plot_grad_norms.py @@ -1,11 +1,13 @@ import json +import re +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np import pandas as pd -from pathlib import Path -import seaborn as sns -from collections import defaultdict -import re + +# ruff: noqa: T201 + class GradientNormsAnalyzer: def __init__(self, json_file_path): @@ -17,193 +19,202 @@ def __init__(self, json_file_path): self.data = [] self.df = None self.load_data() - + def load_data(self): """Load and parse the JSON data from file.""" print(f"Loading data from {self.json_file_path}...") - - with open(self.json_file_path, 'r') as f: + + with open(self.json_file_path) as f: for line_num, line in enumerate(f, 1): try: data_point = json.loads(line.strip()) self.data.append(data_point) except json.JSONDecodeError as e: print(f"Warning: Could not parse line {line_num}: {e}") - + print(f"Loaded {len(self.data)} data points") self.create_dataframe() - + def create_dataframe(self): """Convert loaded data into a pandas DataFrame for easier analysis.""" rows = [] - + for ith, entry in enumerate(self.data): # step = entry.get('num_samples', entry.get('epoch', 0)) step = ith * 5 - + # Handle different possible data structures - if 'gradients' in entry: - grad_data = entry['gradients'] - elif 'grad_norms' in entry: - grad_data = entry['grad_norms'] + if "gradients" in entry: + grad_data = entry["gradients"] + elif "grad_norms" in entry: + grad_data = entry["grad_norms"] else: # Assume all keys except step/epoch are gradient data - grad_data = {k: v for k, v in entry.items() - if 'stream' not in k and ('grad_norm' in k)} - + grad_data = { + k: v for k, v in entry.items() if "stream" not in k and ("grad_norm" in k) + } + for param_name, norm_value in grad_data.items(): - rows.append({ - 'num_samples': step, - 'parameter': param_name, - 'grad_norm': float(norm_value), - 'layer_type': self.extract_layer_type(param_name), - 'layer_depth': self.extract_layer_depth(param_name) - }) - + rows.append( + { + "num_samples": step, + "parameter": param_name, + "grad_norm": float(norm_value), + "layer_type": self.extract_layer_type(param_name), + "layer_depth": self.extract_layer_depth(param_name), + } + ) + self.df = pd.DataFrame(rows) print(f"Created DataFrame with {len(self.df)} gradient norm records") - + def extract_layer_type(self, param_name): """Extract layer type from parameter name.""" param_name_lower = param_name.lower()[10:] - + # Handle your specific naming patterns - if param_name_lower.startswith('embeds.'): - if '.embed.' in param_name_lower: - return 'embedding' - elif '.unembed.' in param_name_lower: - return 'unembedding' - elif '.ln_final.' in param_name_lower: - return 'layer_norm_final' - elif 'proj_heads_q' in param_name_lower: - return 'attention_q' - elif 'proj_heads_k' in param_name_lower: - return 'attention_k' - elif 'proj_heads_v' in param_name_lower: - return 'attention_v' - elif 'proj_out' in param_name_lower: - return 'attention_out' - elif '.layers.' in param_name_lower and ('weight' in param_name_lower or 'bias' in param_name_lower): - return 'ffn' + if param_name_lower.startswith("embeds."): + if ".embed." in param_name_lower: + return "embedding" + elif ".unembed." in param_name_lower: + return "unembedding" + elif ".ln_final." in param_name_lower: + return "layer_norm_final" + elif "proj_heads_q" in param_name_lower: + return "attention_q" + elif "proj_heads_k" in param_name_lower: + return "attention_k" + elif "proj_heads_v" in param_name_lower: + return "attention_v" + elif "proj_out" in param_name_lower: + return "attention_out" + elif ".layers." in param_name_lower and ( + "weight" in param_name_lower or "bias" in param_name_lower + ): + return "ffn" else: - return 'embeds_other' - - elif param_name_lower.startswith('ae_local_blocks.'): - if 'proj_heads_q' in param_name_lower: - return 'ae_local_attention_q' - elif 'proj_heads_k' in param_name_lower: - return 'ae_local_attention_k' - elif 'proj_heads_v' in param_name_lower: - return 'ae_local_attention_v' - elif 'proj_out' in param_name_lower: - return 'ae_local_attention_out' - elif '.layers.' in param_name_lower: - return 'ae_local_ffn' + return "embeds_other" + + elif param_name_lower.startswith("ae_local_blocks."): + if "proj_heads_q" in param_name_lower: + return "ae_local_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_local_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_local_attention_v" + elif "proj_out" in param_name_lower: + return "ae_local_attention_out" + elif ".layers." in param_name_lower: + return "ae_local_ffn" else: - return 'ae_local_other' - - elif param_name_lower.startswith('ae_global_blocks.'): - if 'proj_heads_q' in param_name_lower: - return 'ae_global_attention_q' - elif 'proj_heads_k' in param_name_lower: - return 'ae_global_attention_k' - elif 'proj_heads_v' in param_name_lower: - return 'ae_global_attention_v' - elif 'proj_out' in param_name_lower: - return 'ae_global_attention_out' - elif '.layers.' in param_name_lower: - return 'ae_global_ffn' + return "ae_local_other" + + elif param_name_lower.startswith("ae_global_blocks."): + if "proj_heads_q" in param_name_lower: + return "ae_global_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_global_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_global_attention_v" + elif "proj_out" in param_name_lower: + return "ae_global_attention_out" + elif ".layers." in param_name_lower: + return "ae_global_ffn" else: - return 'ae_global_other' - - elif param_name_lower.startswith('ae_adapter.'): - if 'proj_heads_q' in param_name_lower: - return 'ae_adapter_attention_q' - elif 'proj_heads_k' in param_name_lower: - return 'ae_adapter_attention_k' - elif 'proj_heads_v' in param_name_lower: - return 'ae_adapter_attention_v' - elif 'proj_out' in param_name_lower: - return 'ae_adapter_attention_out' - elif '.layers.' in param_name_lower: - return 'ae_adapter_ffn' + return "ae_global_other" + + elif param_name_lower.startswith("ae_adapter."): + if "proj_heads_q" in param_name_lower: + return "ae_adapter_attention_q" + elif "proj_heads_k" in param_name_lower: + return "ae_adapter_attention_k" + elif "proj_heads_v" in param_name_lower: + return "ae_adapter_attention_v" + elif "proj_out" in param_name_lower: + return "ae_adapter_attention_out" + elif ".layers." in param_name_lower: + return "ae_adapter_ffn" else: - return 'ae_adapter_other' - - elif param_name_lower.startswith('target_token_engines.'): - if 'proj_heads_q' in param_name_lower: - return 'tte_attention_q' - elif 'proj_heads_k' in param_name_lower: - return 'tte_attention_k' - elif 'proj_heads_v' in param_name_lower: - return 'tte_attention_v' - elif 'proj_out' in param_name_lower: - return 'tte_attention_out' - elif 'embed_aux' in param_name_lower: - return 'tte_embed_aux' - elif 'lnorm' in param_name_lower: - return 'tte_layer_norm' - elif '.layers.' in param_name_lower: - return 'tte_ffn' + return "ae_adapter_other" + + elif param_name_lower.startswith("target_token_engines."): + if "proj_heads_q" in param_name_lower: + return "tte_attention_q" + elif "proj_heads_k" in param_name_lower: + return "tte_attention_k" + elif "proj_heads_v" in param_name_lower: + return "tte_attention_v" + elif "proj_out" in param_name_lower: + return "tte_attention_out" + elif "embed_aux" in param_name_lower: + return "tte_embed_aux" + elif "lnorm" in param_name_lower: + return "tte_layer_norm" + elif ".layers." in param_name_lower: + return "tte_ffn" else: - return 'tte_other' - - elif param_name_lower.startswith('embed_target_coords.'): - return 'target_coords_embedding' - - elif param_name_lower.startswith('pred_heads.'): - return 'prediction_head' - + return "tte_other" + + elif param_name_lower.startswith("embed_target_coords."): + return "target_coords_embedding" + + elif param_name_lower.startswith("pred_heads."): + return "prediction_head" + # Fallback for standard patterns (if any) - elif 'embed' in param_name_lower: - return 'embedding' - elif 'attention' in param_name_lower or 'attn' in param_name_lower: - if 'q_proj' in param_name_lower or 'query' in param_name_lower: - return 'attention_q' - elif 'k_proj' in param_name_lower or 'key' in param_name_lower: - return 'attention_k' - elif 'v_proj' in param_name_lower or 'value' in param_name_lower: - return 'attention_v' - elif 'o_proj' in param_name_lower or 'out' in param_name_lower: - return 'attention_out' + elif "embed" in param_name_lower: + return "embedding" + elif "attention" in param_name_lower or "attn" in param_name_lower: + if "q_proj" in param_name_lower or "query" in param_name_lower: + return "attention_q" + elif "k_proj" in param_name_lower or "key" in param_name_lower: + return "attention_k" + elif "v_proj" in param_name_lower or "value" in param_name_lower: + return "attention_v" + elif "o_proj" in param_name_lower or "out" in param_name_lower: + return "attention_out" else: - return 'attention' - elif 'layernorm' in param_name_lower or 'layer_norm' in param_name_lower or 'ln' in param_name_lower: - return 'layernorm' + return "attention" + elif ( + "layernorm" in param_name_lower + or "layer_norm" in param_name_lower + or "ln" in param_name_lower + ): + return "layernorm" else: - return 'other' - + return "other" + def extract_layer_depth(self, param_name): """Extract layer depth/index from parameter name.""" param_name_lower = param_name.lower() - + # Look for patterns specific to your architecture patterns = [ # embeds.0.layers.N.* (transformer layers within embeds) - r'grad_norm_embeds\.\d+\.layers\.(\d+)\.', + r"grad_norm_embeds\.\d+\.layers\.(\d+)\.", # embeds.0.unembed.N.* (unembedding layers) - r'grad_norm_embeds\.\d+\.unembed\.(\d+)\.', + r"grad_norm_embeds\.\d+\.unembed\.(\d+)\.", # embeds.0.ln_final.N.* (final layer norms) - r'grad_norm_embeds\.\d+\.ln_final\.(\d+)\.', + r"grad_norm_embeds\.\d+\.ln_final\.(\d+)\.", # ae_local_blocks.N.* (autoencoder local blocks) - r'grad_norm_ae_local_blocks\.(\d+)\.', + r"grad_norm_ae_local_blocks\.(\d+)\.", # ae_global_blocks.N.* (autoencoder global blocks) - r'ae_global_blocks\.(\d+)\.', + r"ae_global_blocks\.(\d+)\.", # ae_adapter.N.* (autoencoder adapter blocks) - r'ae_adapter\.(\d+)\.', + r"ae_adapter\.(\d+)\.", # target_token_engines.0.tte.N.* (target token engine blocks) - r'target_token_engines\.\d+\.tte\.(\d+)\.', + r"target_token_engines\.\d+\.tte\.(\d+)\.", # target_token_engines.0.tte.N.block.M.* (nested blocks) - r'target_token_engines\.\d+\.tte\.(\d+)\.block\.(\d+)\.', + r"target_token_engines\.\d+\.tte\.(\d+)\.block\.(\d+)\.", # pred_heads.0.pred_heads.0.N.* (prediction head layers) - r'pred_heads\.\d+\.pred_heads\.\d+\.(\d+)\.', + r"pred_heads\.\d+\.pred_heads\.\d+\.(\d+)\.", # Generic patterns for any numbered layers - r'layer[s]?\.(\d+)', - r'h\.(\d+)', - r'transformer\.(\d+)', - r'blocks\.(\d+)', + r"layer[s]?\.(\d+)", + r"h\.(\d+)", + r"transformer\.(\d+)", + r"blocks\.(\d+)", ] - + for pattern in patterns: match = re.search(pattern, param_name_lower) if match: @@ -213,274 +224,302 @@ def extract_layer_depth(self, param_name): return int(match.group(1)) * 10 + int(match.group(2)) else: return int(match.group(1)) - + # Special handling for components without clear depth - if param_name_lower.startswith('embed_target_coords.'): + if param_name_lower.startswith("embed_target_coords."): return 0 # Coordinate embeddings at the start - elif 'total_grad_norm' in param_name_lower: + elif "total_grad_norm" in param_name_lower: return -2 # Special marker for total norm - elif any(x in param_name_lower for x in ['weathergen', 'stage', 'q_cells']): + elif any(x in param_name_lower for x in ["weathergen", "stage", "q_cells"]): return -3 # Special marker for metadata - + return -1 # Unknown depth - + def plot_total_gradient_norms(self, figsize=(12, 6)): """Plot total gradient norm over training steps.""" # Calculate total norm per step total_norms = [] steps = [] - + for ith, entry in enumerate(self.data): # step = entry.get('num_samples', entry.get('epoch', 0)) step = ith * 5 - - if 'gradients' in entry: - grad_data = entry['gradients'] - elif 'grad_norms' in entry: - grad_data = entry['grad_norms'] + + if "gradients" in entry: + grad_data = entry["gradients"] + elif "grad_norms" in entry: + grad_data = entry["grad_norms"] else: - grad_data = {k: v for k, v in entry.items() - if 'grad_norm' in k} + grad_data = {k: v for k, v in entry.items() if "grad_norm" in k} if len(grad_data) == 0: continue - + # Calculate total norm (L2 norm of all gradients) - total_norm = np.sqrt(sum(float(v)**2 for v in grad_data.values())) + total_norm = np.sqrt(sum(float(v) ** 2 for v in grad_data.values())) total_norms.append(total_norm) steps.append(step) - + plt.figure(figsize=figsize) plt.plot(steps, total_norms, linewidth=1.5, alpha=0.8) - plt.xlabel('Training Step') - plt.ylabel('Total Gradient Norm') - plt.title('Total Gradient Norm vs Training Steps') - plt.yscale('log') + plt.xlabel("Training Step") + plt.ylabel("Total Gradient Norm") + plt.title("Total Gradient Norm vs Training Steps") + plt.yscale("log") plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig("plots/total_grad_norm.png") - + return steps, total_norms - + def plot_layer_type_norms(self, figsize=(14, 8)): """Plot gradient norms grouped by layer type.""" if self.df is None: print("No DataFrame available. Load data first.") return - + plt.figure(figsize=figsize) - + # Get unique layer types - layer_types = self.df['layer_type'].unique() + layer_types = self.df["layer_type"].unique() print(layer_types) colors = plt.cm.tab10(np.linspace(0, 1, len(layer_types))) - + for i, layer_type in enumerate(layer_types): - layer_data = self.df[self.df['layer_type'] == layer_type] - + layer_data = self.df[self.df["layer_type"] == layer_type] + # Calculate mean gradient norm per step for this layer type - mean_norms = layer_data.groupby('num_samples')['grad_norm'].mean() - - plt.plot(mean_norms.index, mean_norms.values, - label=layer_type, color=colors[i], alpha=0.8) - - plt.xlabel('Training Step') - plt.ylabel('Mean Gradient Norm') - plt.title('Gradient Norms by Layer Type') - plt.yscale('log') - plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + mean_norms = layer_data.groupby("num_samples")["grad_norm"].mean() + + plt.plot( + mean_norms.index, mean_norms.values, label=layer_type, color=colors[i], alpha=0.8 + ) + + plt.xlabel("Training Step") + plt.ylabel("Mean Gradient Norm") + plt.title("Gradient Norms by Layer Type") + plt.yscale("log") + plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig("plots/grad_norm_by_layer_type.png") - + def plot_layer_depth_analysis(self, figsize=(12, 8)): """Plot gradient norms by layer depth.""" if self.df is None: print("No DataFrame available. Load data first.") return - + # Filter out unknown depths - depth_data = self.df[self.df['layer_depth'] >= 0] - + depth_data = self.df[self.df["layer_depth"] >= 0] + if len(depth_data) == 0: print("No layer depth information found in parameter names.") return - + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize) - + # Plot 1: Mean gradient norm by depth over time - depths = sorted(depth_data['layer_depth'].unique()) + depths = sorted(depth_data["layer_depth"].unique()) colors = plt.cm.viridis(np.linspace(0, 1, len(depths))) - + for i, depth in enumerate(depths): - layer_data = depth_data[depth_data['layer_depth'] == depth] - mean_norms = layer_data.groupby('num_samples')['grad_norm'].mean() - - ax1.plot(mean_norms.index, mean_norms.values, - label=f'Layer {depth}', color=colors[i], alpha=0.8) - - ax1.set_xlabel('Training Step') - ax1.set_ylabel('Mean Gradient Norm') - ax1.set_title('Gradient Norms by Layer Depth') - ax1.set_yscale('log') - ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + layer_data = depth_data[depth_data["layer_depth"] == depth] + mean_norms = layer_data.groupby("num_samples")["grad_norm"].mean() + + ax1.plot( + mean_norms.index, + mean_norms.values, + label=f"Layer {depth}", + color=colors[i], + alpha=0.8, + ) + + ax1.set_xlabel("Training Step") + ax1.set_ylabel("Mean Gradient Norm") + ax1.set_title("Gradient Norms by Layer Depth") + ax1.set_yscale("log") + ax1.legend(bbox_to_anchor=(1.05, 1), loc="upper left") ax1.grid(True, alpha=0.3) - + # Plot 2: Heatmap of gradient norms by depth and step - pivot_data = depth_data.groupby(['num_samples', 'layer_depth'])['grad_norm'].mean().unstack() - + pivot_data = ( + depth_data.groupby(["num_samples", "layer_depth"])["grad_norm"].mean().unstack() + ) + # Sample data if too many steps for readability if len(pivot_data) > 100: - sample_idx = np.linspace(0, len(pivot_data)-1, 100, dtype=int) + sample_idx = np.linspace(0, len(pivot_data) - 1, 100, dtype=int) pivot_data = pivot_data.iloc[sample_idx] - - im = ax2.imshow(pivot_data.T, aspect='auto', cmap='viridis', - extent=[pivot_data.index.min(), pivot_data.index.max(), - pivot_data.columns.min(), pivot_data.columns.max()]) - ax2.set_xlabel('Training Step') - ax2.set_ylabel('Layer Depth') - ax2.set_title('Gradient Norm Heatmap (Layer Depth vs Step)') - + + im = ax2.imshow( + pivot_data.T, + aspect="auto", + cmap="viridis", + extent=[ + pivot_data.index.min(), + pivot_data.index.max(), + pivot_data.columns.min(), + pivot_data.columns.max(), + ], + ) + ax2.set_xlabel("Training Step") + ax2.set_ylabel("Layer Depth") + ax2.set_title("Gradient Norm Heatmap (Layer Depth vs Step)") + cbar = plt.colorbar(im, ax=ax2) - cbar.set_label('Gradient Norm') - + cbar.set_label("Gradient Norm") + plt.tight_layout() plt.savefig("plots/grad_norm_heatmap.png") - + def plot_gradient_distribution(self, figsize=(15, 10)): """Plot distribution of gradient norms.""" if self.df is None: print("No DataFrame available. Load data first.") return - + fig, axes = plt.subplots(2, 2, figsize=figsize) - + # Plot 1: Histogram of all gradient norms - axes[0, 0].hist(np.log10(self.df['grad_norm'].values), bins=50, alpha=0.7) - axes[0, 0].set_xlabel('Log10(Gradient Norm)') - axes[0, 0].set_ylabel('Frequency') - axes[0, 0].set_title('Distribution of Gradient Norms (Log Scale)') + axes[0, 0].hist(np.log10(self.df["grad_norm"].values), bins=50, alpha=0.7) + axes[0, 0].set_xlabel("Log10(Gradient Norm)") + axes[0, 0].set_ylabel("Frequency") + axes[0, 0].set_title("Distribution of Gradient Norms (Log Scale)") axes[0, 0].grid(True, alpha=0.3) - + # Plot 2: Box plot by layer type - layer_types = self.df['layer_type'].unique()[:10] # Limit to 10 for readability - plot_data = [np.log10(self.df[self.df['layer_type'] == lt]['grad_norm'].values) - for lt in layer_types] - + layer_types = self.df["layer_type"].unique()[:10] # Limit to 10 for readability + plot_data = [ + np.log10(self.df[self.df["layer_type"] == lt]["grad_norm"].values) for lt in layer_types + ] + axes[0, 1].boxplot(plot_data, labels=layer_types) - axes[0, 1].set_xlabel('Layer Type') - axes[0, 1].set_ylabel('Log10(Gradient Norm)') - axes[0, 1].set_title('Gradient Norm Distribution by Layer Type') - axes[0, 1].tick_params(axis='x', rotation=45) + axes[0, 1].set_xlabel("Layer Type") + axes[0, 1].set_ylabel("Log10(Gradient Norm)") + axes[0, 1].set_title("Gradient Norm Distribution by Layer Type") + axes[0, 1].tick_params(axis="x", rotation=45) axes[0, 1].grid(True, alpha=0.3) - + # Plot 3: Gradient norms over time (sample of parameters) - sample_params = self.df['parameter'].unique()[:20] # Sample 20 parameters + sample_params = self.df["parameter"].unique()[:20] # Sample 20 parameters for param in sample_params: - param_data = self.df[self.df['parameter'] == param] - axes[1, 0].plot(param_data['num_samples'], param_data['grad_norm'], - alpha=0.6, linewidth=0.8) - - axes[1, 0].set_xlabel('Training Step') - axes[1, 0].set_ylabel('Gradient Norm') - axes[1, 0].set_title('Individual Parameter Gradient Norms (Sample)') - axes[1, 0].set_yscale('log') + param_data = self.df[self.df["parameter"] == param] + axes[1, 0].plot( + param_data["num_samples"], param_data["grad_norm"], alpha=0.6, linewidth=0.8 + ) + + axes[1, 0].set_xlabel("Training Step") + axes[1, 0].set_ylabel("Gradient Norm") + axes[1, 0].set_title("Individual Parameter Gradient Norms (Sample)") + axes[1, 0].set_yscale("log") axes[1, 0].grid(True, alpha=0.3) - + # Plot 4: Statistics over time - stats_by_step = self.df.groupby('num_samples')['grad_norm'].agg(['mean', 'std', 'min', 'max']) - - axes[1, 1].fill_between(stats_by_step.index, - stats_by_step['mean'] - stats_by_step['std'], - stats_by_step['mean'] + stats_by_step['std'], - alpha=0.3, label='±1 std') - axes[1, 1].plot(stats_by_step.index, stats_by_step['mean'], - label='Mean', linewidth=2) - axes[1, 1].plot(stats_by_step.index, stats_by_step['max'], - label='Max', linewidth=1, alpha=0.8) - axes[1, 1].plot(stats_by_step.index, stats_by_step['min'], - label='Min', linewidth=1, alpha=0.8) - - axes[1, 1].set_xlabel('Training Step') - axes[1, 1].set_ylabel('Gradient Norm') - axes[1, 1].set_title('Gradient Norm Statistics Over Time') - axes[1, 1].set_yscale('log') + stats_by_step = self.df.groupby("num_samples")["grad_norm"].agg( + ["mean", "std", "min", "max"] + ) + + axes[1, 1].fill_between( + stats_by_step.index, + stats_by_step["mean"] - stats_by_step["std"], + stats_by_step["mean"] + stats_by_step["std"], + alpha=0.3, + label="±1 std", + ) + axes[1, 1].plot(stats_by_step.index, stats_by_step["mean"], label="Mean", linewidth=2) + axes[1, 1].plot( + stats_by_step.index, stats_by_step["max"], label="Max", linewidth=1, alpha=0.8 + ) + axes[1, 1].plot( + stats_by_step.index, stats_by_step["min"], label="Min", linewidth=1, alpha=0.8 + ) + + axes[1, 1].set_xlabel("Training Step") + axes[1, 1].set_ylabel("Gradient Norm") + axes[1, 1].set_title("Gradient Norm Statistics Over Time") + axes[1, 1].set_yscale("log") axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) - + plt.tight_layout() plt.savefig("plots/grad_norm_over_time.png") - + def generate_summary_report(self): """Generate a summary report of gradient norm statistics.""" if self.df is None: print("No DataFrame available. Load data first.") return - + print("=== GRADIENT NORMS ANALYSIS REPORT ===") print(f"Total data points: {len(self.df)}") print(f"Training steps: {self.df['num_samples'].nunique()}") print(f"Unique parameters: {self.df['parameter'].nunique()}") print() - + print("Overall Statistics:") print(f"Mean gradient norm: {self.df['grad_norm'].mean():.6f}") print(f"Median gradient norm: {self.df['grad_norm'].median():.6f}") print(f"Min gradient norm: {self.df['grad_norm'].min():.6f}") print(f"Max gradient norm: {self.df['grad_norm'].max():.6f}") print() - + print("Statistics by Layer Type:") - layer_stats = self.df.groupby('layer_type')['grad_norm'].agg(['count', 'mean', 'std', 'min', 'max']) + layer_stats = self.df.groupby("layer_type")["grad_norm"].agg( + ["count", "mean", "std", "min", "max"] + ) print(layer_stats) print() - + # Check for potential issues print("Potential Issues:") - very_small = (self.df['grad_norm'] < 1e-6).sum() - very_large = (self.df['grad_norm'] > 10.0).sum() - + very_small = (self.df["grad_norm"] < 1e-6).sum() + very_large = (self.df["grad_norm"] > 10.0).sum() + if very_small > 0: print(f"⚠️ {very_small} gradient norms < 1e-6 (possible vanishing gradients)") if very_large > 0: print(f"⚠️ {very_large} gradient norms > 10.0 (possible exploding gradients)") - + if very_small == 0 and very_large == 0: print("✅ No obvious gradient issues detected") + # Usage example def analyze_gradient_file(json_file_path): """ Main function to analyze gradient norms from a JSON file. - + Usage: analyze_gradient_file('gradient_norms.jsonl') """ - + analyzer = GradientNormsAnalyzer(json_file_path) - + # Generate summary report analyzer.generate_summary_report() - + # Create all plots print("\n=== GENERATING PLOTS ===") - + print("1. Total gradient norms over time...") analyzer.plot_total_gradient_norms() - + print("2. Gradient norms by layer type...") analyzer.plot_layer_type_norms() - + print("3. Layer depth analysis...") analyzer.plot_layer_depth_analysis() - + print("4. Gradient distribution analysis...") analyzer.plot_gradient_distribution() - + return analyzer + # Example usage: # uv run python src/weathergen/utils/plot_grad_norms.py results/yvhxm2jc/yvhxm2jc_train_metrics.json -if __name__ == '__main__': +if __name__ == "__main__": import sys + analyzer = analyze_gradient_file(sys.argv[1]) From a1d7a2797e34dbd9be073c853fdc205f496c067a Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Tue, 12 Aug 2025 14:24:28 +0000 Subject: [PATCH 07/40] Update to develop, prepare for new experiment series --- config/default_config.yml | 24 ++++++++++++------------ config/eval_config.yml | 28 ++++++++++++++++++++++++++++ config/runs_plot_train.yml | 6 ++++++ 3 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 config/eval_config.yml create mode 100644 config/runs_plot_train.yml diff --git a/config/default_config.yml b/config/default_config.yml index 76bdd2694..e3772e842 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -10,7 +10,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -40,13 +40,13 @@ pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: "fixed" forecast_freeze_model: False -forecast_att_dense_rate: 0.25 -fe_num_blocks: 0 +forecast_att_dense_rate: 1.0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -75,7 +75,7 @@ batch_size_validation_per_gpu: 1 # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -91,17 +91,17 @@ masking_strategy: "random" # "channel": requires "mode" to be specified, "per_cell" or "global", masking_strategy_config: {"hl_mask": 3} -num_epochs: 32 +num_epochs: 64 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True lr_scaling_policy: "sqrt" lr_start: 1e-6 -lr_max: 5e-5 -lr_final_decay: 1e-6 +lr_max: 0.0001 +lr_final_decay: 2e-6 lr_final: 0.0 -lr_steps_warmup: 512 +lr_steps_warmup: 256 lr_steps_cooldown: 512 lr_policy_warmup: "cosine" lr_policy_decay: "linear" diff --git a/config/eval_config.yml b/config/eval_config.yml new file mode 100644 index 000000000..937bc59be --- /dev/null +++ b/config/eval_config.yml @@ -0,0 +1,28 @@ +verbose: true +image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. +dpi_val : 300 +summary_plots : true +print_summary: false + +evaluation: + metrics : ["rmse"] + regions: ["global"] + +run_ids : + + ptluswdo: + label: "ptluswdo: 64ep 2fs (naoj54ch) + 32ep 8fs 2e-5" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + #channels: ["2t", "q_850", ] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0] + forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] + plot_maps: true + plot_histograms: false \ No newline at end of file diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml new file mode 100644 index 000000000..49924b524 --- /dev/null +++ b/config/runs_plot_train.yml @@ -0,0 +1,6 @@ +train : + plot : + lnjzhore : + slurm_id: 0 + description: "Christian's naoj54ch with new code" + eval: vgbndhco \ No newline at end of file From 754d31c660d2fb6f40e285b59f5630971c519d73 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Wed, 8 Oct 2025 14:22:12 +0000 Subject: [PATCH 08/40] forecast config with small decoder --- config/default_config.yml | 29 +++++++++++++++-------------- config/streams/era5_1deg/era5.yml | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 2ecf4f6b8..dde6fafbc 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -10,7 +10,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True ae_local_dim_embed: 1024 -ae_local_num_blocks: 2 +ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 ae_local_with_qk_lnorm: True @@ -24,7 +24,7 @@ ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 ae_global_dim_embed: 2048 -ae_global_num_blocks: 8 +ae_global_num_blocks: 4 ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True @@ -34,18 +34,19 @@ ae_global_mlp_hidden_factor: 2 decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False -pred_self_attention: True +pred_self_attention: False pred_dyadic_dims: False pred_mlp_adaln: True # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder -forecast_offset : 0 +forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 0 -forecast_policy: null +forecast_steps: 2 +forecast_policy: "fixed" +forecast_freeze_model: False forecast_att_dense_rate: 1.0 -fe_num_blocks: 0 +fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True @@ -85,7 +86,7 @@ freeze_modules: "" # training mode: "forecast" or "masking" (masked token modeling) # for "masking" to train with auto-encoder mode, forecast_offset should be 0 -training_mode: "masking" +training_mode: "forecast" # masking rate when training mode is "masking"; ignored in foreacast mode masking_rate: 0.6 # sample the masking rate (with normal distribution centered at masking_rate) @@ -93,7 +94,7 @@ masking_rate: 0.6 masking_rate_sampling: True # sample a subset of all target points, useful e.g. to reduce memory requirements (also can specify per-stream) sampling_rate_target: 1.0 -# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "causal" and "combination" +# include a masking strategy here, currently only supporting "random", "block", "healpix", "channel", "combination" masking_strategy: "random" # masking_strategy_config is a dictionary of additional parameters for the masking strategy # required for "healpix" and "channel" masking strategies @@ -105,17 +106,17 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 32 +num_epochs: 64 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True lr_scaling_policy: "sqrt" lr_start: 1e-6 -lr_max: 5e-5 -lr_final_decay: 1e-6 +lr_max: 0.0001 +lr_final_decay: 2e-6 lr_final: 0.0 -lr_steps_warmup: 512 +lr_steps_warmup: 256 lr_steps_cooldown: 512 lr_policy_warmup: "cosine" lr_policy_decay: "linear" @@ -151,4 +152,4 @@ run_id: ??? # Parameters for logging/printing in the training loop train_log: # The period to log metrics (in number of batch steps) - log_interval: 20 + log_interval: 20 \ No newline at end of file diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index a03bb3b40..aaf1bbf53 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -29,7 +29,7 @@ ERA5 : dim_embed : 256 target_readout : type : 'obs_value' # token or obs_value - num_layers : 2 + num_layers : 1 num_heads : 4 # sampling_rate : 0.2 pred_head : From 7c756a3544c91e4f59962f8e1ae6290cf20a45ba Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Thu, 9 Oct 2025 08:38:14 +0000 Subject: [PATCH 09/40] fixed uv.lock --- uv.lock | 292 +++++++++----------------------------------------------- 1 file changed, 44 insertions(+), 248 deletions(-) diff --git a/uv.lock b/uv.lock index 56e875859..79a5b2e2f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.12.*" resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", @@ -874,7 +874,7 @@ name = "jinja2" version = "3.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "markupsafe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } wheels = [ @@ -1251,52 +1251,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/1c/6d343e030815c7c97a1f9fbad00211b47717c7fe446834c224bd5311e6f1/numpy-2.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:bd8df082b6c4695753ad6193018c05aac465d634834dca47a3ae06d4bb22d9ea", size = 9891498, upload-time = "2025-06-07T14:43:36.332Z" }, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771, upload-time = "2024-06-18T19:28:09.881Z" }, - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, - { url = "https://files.pythonhosted.org/packages/e2/2a/4f27ca96232e8b5269074a72e03b4e0d43aa68c9b965058b1684d07c6ff8/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc", size = 396895858, upload-time = "2024-04-03T21:03:31.996Z" }, -] - [[package]] name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615, upload-time = "2024-11-20T17:39:52.715Z" }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301, upload-time = "2024-11-20T17:50:41.681Z" }, ] -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556, upload-time = "2024-06-18T19:30:40.546Z" }, - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, - { url = "https://files.pythonhosted.org/packages/f3/79/8cf313ec17c58ccebc965568e5bcb265cdab0a1df99c4e674bb7a3b99bfe/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922", size = 9938035, upload-time = "2024-04-03T21:01:01.109Z" }, -] - [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764, upload-time = "2024-11-20T17:35:41.03Z" }, { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756, upload-time = "2024-10-01T16:57:45.507Z" }, @@ -1305,52 +1273,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175, upload-time = "2024-10-01T17:09:47.955Z" }, ] -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372, upload-time = "2024-06-18T19:32:00.576Z" }, - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, - { url = "https://files.pythonhosted.org/packages/7c/30/8c844bfb770f045bcd8b2c83455c5afb45983e1a8abf0c4e5297b481b6a5/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec", size = 19751955, upload-time = "2024-04-03T21:01:51.133Z" }, -] - [[package]] name = "nvidia-cuda-nvrtc-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/f4/2f/72df534873235983cc0a5371c3661bebef7c4682760c275590b972c7b0f9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13", size = 23162955, upload-time = "2024-10-01T16:59:50.922Z" }, { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, { url = "https://files.pythonhosted.org/packages/f5/46/d3a1cdda8bb113c80f43a0a6f3a853356d487b830f3483f92d49ce87fa55/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a", size = 39026742, upload-time = "2024-10-01T17:10:49.058Z" }, ] -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177, upload-time = "2024-06-18T19:32:52.877Z" }, - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, - { url = "https://files.pythonhosted.org/packages/a8/8b/450e93fab75d85a69b50ea2d5fdd4ff44541e0138db16f9cd90123ef4de4/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e", size = 878808, upload-time = "2024-04-03T21:00:49.77Z" }, -] - [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052, upload-time = "2024-11-20T17:35:19.905Z" }, { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040, upload-time = "2024-10-01T16:57:22.221Z" }, @@ -1359,30 +1295,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773, upload-time = "2024-10-01T17:09:26.362Z" }, ] -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892, upload-time = "2024-04-22T15:24:53.333Z" }, -] - [[package]] name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509, upload-time = "2024-10-25T19:53:03.148Z" }, @@ -1390,31 +1308,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743, upload-time = "2024-10-25T19:55:49.74Z" }, ] -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548, upload-time = "2024-06-18T19:33:39.396Z" }, - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, - { url = "https://files.pythonhosted.org/packages/f6/ee/3f3f8e9874f0be5bbba8fb4b62b3de050156d159f8b6edc42d6f1074113b/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b", size = 210576476, upload-time = "2024-04-03T21:04:06.422Z" }, -] - [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144, upload-time = "2024-11-20T17:40:58.288Z" }, @@ -1424,26 +1323,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881, upload-time = "2024-10-01T17:13:01.861Z" }, ] -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811, upload-time = "2024-06-18T19:34:48.575Z" }, - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, - { url = "https://files.pythonhosted.org/packages/1c/22/2573503d0d4e45673c263a313f79410e110eb562636b0617856fdb2ff5f6/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771", size = 55799918, upload-time = "2024-04-03T21:04:34.45Z" }, -] - [[package]] name = "nvidia-curand-cu12" version = "10.3.7.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/42/ac/36543605358a355632f1a6faa3e2d5dfb91eab1e4bc7d552040e0383c335/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8", size = 56289881, upload-time = "2024-10-01T17:04:18.981Z" }, { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, @@ -1452,35 +1335,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/a8/0cd0cec757bd4b4b4ef150fca62ec064db7d08a291dced835a0be7d2c147/nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905", size = 55783873, upload-time = "2024-10-01T17:13:30.377Z" }, ] -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111, upload-time = "2024-06-18T19:35:01.793Z" }, - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, - { url = "https://files.pythonhosted.org/packages/f2/be/d435b7b020e854d5d5a682eb5de4328fd62f6182507406f2818280e206e2/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c", size = 125224015, upload-time = "2024-04-03T21:04:53.339Z" }, -] - [[package]] name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628, upload-time = "2024-10-01T17:05:05.591Z" }, @@ -1490,31 +1352,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877, upload-time = "2024-10-01T17:13:49.804Z" }, ] -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987, upload-time = "2024-06-18T19:35:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e0/3155ca539760a8118ec94cc279b34293309bcd14011fc724f87f31988843/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f", size = 204684315, upload-time = "2024-04-03T21:05:26.031Z" }, -] - [[package]] name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147, upload-time = "2024-11-20T17:44:18.055Z" }, @@ -1524,26 +1367,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630, upload-time = "2024-10-01T17:14:23.779Z" }, ] -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/8e/675498726c605c9441cf46653bd29cb1b8666da1fb1469ffa25f67f20c58/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8", size = 149422781, upload-time = "2024-07-23T17:35:27.203Z" }, - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751, upload-time = "2024-07-23T02:35:53.074Z" }, - { url = "https://files.pythonhosted.org/packages/56/8f/2c33082238b6c5e783a877dc8786ab62619e3e6171c083bd3bba6e3fe75e/nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70", size = 148755794, upload-time = "2024-07-23T02:35:00.261Z" }, -] - [[package]] name = "nvidia-cusparselt-cu12" version = "0.6.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/62/da/4de092c61c6dea1fc9c936e69308a02531d122e12f1f649825934ad651b5/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1", size = 156402859, upload-time = "2024-10-16T02:23:17.184Z" }, { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, @@ -1567,52 +1394,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, ] -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510, upload-time = "2024-06-18T20:20:13.871Z" }, - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, - { url = "https://files.pythonhosted.org/packages/81/19/0babc919031bee42620257b9a911c528f05fb2688520dcd9ca59159ffea8/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1", size = 95336325, upload-time = "2024-04-03T21:06:25.073Z" }, -] - [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338, upload-time = "2024-11-20T17:46:29.758Z" }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572, upload-time = "2024-11-20T17:52:40.124Z" }, ] -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417, upload-time = "2024-06-18T20:16:22.484Z" }, - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, - { url = "https://files.pythonhosted.org/packages/54/1b/f77674fbb73af98843be25803bbd3b9a4f0a96c75b8d33a2854a5c7d2d77/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485", size = 66307, upload-time = "2024-04-03T21:02:01.959Z" }, -] - [[package]] name = "nvidia-nvtx-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/b9/93/80f8a520375af9d7ee44571a6544653a176e53c2b8ccce85b97b83c2491b/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b", size = 90549, upload-time = "2024-11-20T17:38:17.387Z" }, { url = "https://files.pythonhosted.org/packages/2b/53/36e2fd6c7068997169b49ffc8c12d5af5e5ff209df6e1a2c4d373b3a638f/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059", size = 90539, upload-time = "2024-10-01T17:00:27.179Z" }, @@ -2283,6 +2078,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, ] +[[package]] +name = "seaborn" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pandas", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, +] + [[package]] name = "semantic-version" version = "2.10.0" @@ -2426,8 +2235,8 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0" -source = { registry = "https://pypi.org/simple" } +version = "2.6.0+cpu" +source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", "platform_machine == 'x86_64' and sys_platform == 'linux'", @@ -2437,29 +2246,14 @@ dependencies = [ { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.1.0.70", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.2.1.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.5.147", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.6.1.9", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563, upload-time = "2025-01-29T16:23:19.084Z" }, - { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867, upload-time = "2025-01-29T16:25:55.649Z" }, - { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469, upload-time = "2025-01-29T16:24:01.821Z" }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538, upload-time = "2025-01-29T16:24:18.976Z" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:59e78aa0c690f70734e42670036d6b541930b8eabbaa18d94e090abf14cc4d91" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:318290e8924353c61b125cdc8768d15208704e279e7757c113b9620740deca98" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:4027d982eb2781c93825ab9527f17fbbb12dbabf422298e4b954be60016f87d8" }, ] [[package]] @@ -2508,19 +2302,19 @@ dependencies = [ { name = "fsspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "jinja2", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "networkx", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.7.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -2687,6 +2481,7 @@ dependencies = [ { name = "polars", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "psutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "pynvml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "seaborn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2696,7 +2491,7 @@ dependencies = [ [package.optional-dependencies] cpu = [ - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] gpu = [ { name = "flash-attn", version = "2.7.3", source = { url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-10-weathergen-gpu') or (platform_machine != 'aarch64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2735,11 +2530,12 @@ requires-dist = [ { name = "polars", specifier = "~=1.25.2" }, { name = "psutil" }, { name = "pynvml" }, + { name = "seaborn", specifier = ">=0.13.2" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, + { name = "torch", marker = "sys_platform != 'linux' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'gpu') or (sys_platform != 'linux' and extra == 'gpu')", specifier = "==2.6.0+cu126" }, - { name = "torch", marker = "sys_platform == 'macosx' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, - { name = "torch", marker = "sys_platform != 'macosx' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "tqdm" }, { name = "weathergen-common", editable = "packages/common" }, { name = "weathergen-evaluate", editable = "packages/evaluate" }, From 41716a670c0fbddbe96a3433210ff9d3cd717236 Mon Sep 17 00:00:00 2001 From: Julian Kuehnert Date: Thu, 9 Oct 2025 16:02:02 +0000 Subject: [PATCH 10/40] test gradient logging on mutli gpus --- config/default_config.yml | 2 +- src/weathergen/train/trainer.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index b14fddcba..d67d5359e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -128,7 +128,7 @@ grad_clip: 1.0 weight_decay: 0.1 norm_type: "LayerNorm" nn_module: "te" -log_grad_norms: False +log_grad_norms: True start_date: 197901010000 end_date: 202012310000 diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 17f7e4433..83515d317 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -930,12 +930,23 @@ def _log_instant_grad_norms(self, stage: Stage, total_norm): TODO test DDP case """ - grad_norms = {"total_grad_norm": total_norm.item()} - self.last_grad_norm = total_norm.item() - for name, param in self.ddp_model.named_parameters(): + self.last_grad_norm = ( + total_norm.full_tensor().item() if self.cf.world_size > 1 else total_norm.item() + ) + grad_norms = {"total_grad_norm": self.last_grad_norm} + for name, param in self.model.named_parameters(): if param.grad is not None: - grad_norms["grad_norm_" + name] = param.grad.norm().item() - self.train_logger.log_metrics(TRAIN, grad_norms) + # grad_norms["grad_norm_" + name] = param.grad.norm().item() + grad_norms["grad_norm_" + name] = ( + param.grad.norm().full_tensor().item() + if self.cf.world_size > 1 + else param.grad.norm().item() + ) + + # print(".item():", param.grad.norm().item()) + # print(".full_tensor().item()", param.grad.norm().full_tensor().item()) + if is_root(): + self.train_logger.log_metrics(TRAIN, grad_norms) def _log_terminal(self, bidx: int, epoch: int, stage: Stage): if bidx % self.print_freq == 0 and bidx > 0 or stage == VAL: From c12e1905a1fa50390f626f76bd3e64c5c9b6f3a8 Mon Sep 17 00:00:00 2001 From: Matthias Karlbauer Date: Fri, 10 Oct 2025 21:25:36 +0200 Subject: [PATCH 11/40] Setting o48 as default in era5 config Committer: Matthias Karlbauer On branch mk/develop/fe_experiments Your branch is ahead of 'origin/mk/develop/fe_experiments' by 57 commits. (use "git push" to publish your local commits) Changes to be committed: modified: config/streams/era5_1deg/era5.yml --- config/streams/era5_1deg/era5.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index bb2234c4e..e9cc9a6b8 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,7 +9,8 @@ ERA5 : type : anemoi - filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + #filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. From d95277e33754969e0652005e08f7d9ab4c8c1785 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 10 Oct 2025 21:28:38 +0200 Subject: [PATCH 12/40] Updated default config to 256 dim latent size On branch mk/develop/fe_experiments Your branch is ahead of 'origin/mk/develop/fe_experiments' by 58 commits. (use "git push" to publish your local commits) Changes to be committed: modified: config/default_config.yml --- config/default_config.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 140d04892..3bb87c950 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -9,7 +9,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True -ae_local_dim_embed: 1024 +ae_local_dim_embed: 256 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -23,9 +23,9 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 2048 +ae_global_dim_embed: 256 ae_global_num_blocks: 4 -ae_global_num_heads: 32 +ae_global_num_heads: 16 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. From a73447178f00993efcad4c7e1058dc2e47cf3b8e Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 13 Oct 2025 12:24:48 +0200 Subject: [PATCH 13/40] Update branch to latest develop --- uv.lock | 272 ++++++-------------------------------------------------- 1 file changed, 26 insertions(+), 246 deletions(-) diff --git a/uv.lock b/uv.lock index 56e875859..469c6a41f 100644 --- a/uv.lock +++ b/uv.lock @@ -1251,52 +1251,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/1c/6d343e030815c7c97a1f9fbad00211b47717c7fe446834c224bd5311e6f1/numpy-2.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:bd8df082b6c4695753ad6193018c05aac465d634834dca47a3ae06d4bb22d9ea", size = 9891498, upload-time = "2025-06-07T14:43:36.332Z" }, ] -[[package]] -name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771, upload-time = "2024-06-18T19:28:09.881Z" }, - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, - { url = "https://files.pythonhosted.org/packages/e2/2a/4f27ca96232e8b5269074a72e03b4e0d43aa68c9b965058b1684d07c6ff8/nvidia_cublas_cu12-12.4.5.8-py3-none-win_amd64.whl", hash = "sha256:5a796786da89203a0657eda402bcdcec6180254a8ac22d72213abc42069522dc", size = 396895858, upload-time = "2024-04-03T21:03:31.996Z" }, -] - [[package]] name = "nvidia-cublas-cu12" version = "12.6.4.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, { url = "https://files.pythonhosted.org/packages/97/0d/f1f0cadbf69d5b9ef2e4f744c9466cb0a850741d08350736dfdb4aa89569/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668", size = 390794615, upload-time = "2024-11-20T17:39:52.715Z" }, { url = "https://files.pythonhosted.org/packages/84/f7/985e9bdbe3e0ac9298fcc8cfa51a392862a46a0ffaccbbd56939b62a9c83/nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8", size = 434535301, upload-time = "2024-11-20T17:50:41.681Z" }, ] -[[package]] -name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556, upload-time = "2024-06-18T19:30:40.546Z" }, - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, - { url = "https://files.pythonhosted.org/packages/f3/79/8cf313ec17c58ccebc965568e5bcb265cdab0a1df99c4e674bb7a3b99bfe/nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922", size = 9938035, upload-time = "2024-04-03T21:01:01.109Z" }, -] - [[package]] name = "nvidia-cuda-cupti-cu12" version = "12.6.80" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/e6/8b/2f6230cb715646c3a9425636e513227ce5c93c4d65823a734f4bb86d43c3/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc", size = 8236764, upload-time = "2024-11-20T17:35:41.03Z" }, { url = "https://files.pythonhosted.org/packages/25/0f/acb326ac8fd26e13c799e0b4f3b2751543e1834f04d62e729485872198d4/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4", size = 8236756, upload-time = "2024-10-01T16:57:45.507Z" }, @@ -1305,52 +1273,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/81/7796f096afaf726796b1b648f3bc80cafc61fe7f77f44a483c89e6c5ef34/nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a", size = 5724175, upload-time = "2024-10-01T17:09:47.955Z" }, ] -[[package]] -name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372, upload-time = "2024-06-18T19:32:00.576Z" }, - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, - { url = "https://files.pythonhosted.org/packages/7c/30/8c844bfb770f045bcd8b2c83455c5afb45983e1a8abf0c4e5297b481b6a5/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:a961b2f1d5f17b14867c619ceb99ef6fcec12e46612711bcec78eb05068a60ec", size = 19751955, upload-time = "2024-04-03T21:01:51.133Z" }, -] - [[package]] name = "nvidia-cuda-nvrtc-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/f4/2f/72df534873235983cc0a5371c3661bebef7c4682760c275590b972c7b0f9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13", size = 23162955, upload-time = "2024-10-01T16:59:50.922Z" }, { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, { url = "https://files.pythonhosted.org/packages/f5/46/d3a1cdda8bb113c80f43a0a6f3a853356d487b830f3483f92d49ce87fa55/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a", size = 39026742, upload-time = "2024-10-01T17:10:49.058Z" }, ] -[[package]] -name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177, upload-time = "2024-06-18T19:32:52.877Z" }, - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, - { url = "https://files.pythonhosted.org/packages/a8/8b/450e93fab75d85a69b50ea2d5fdd4ff44541e0138db16f9cd90123ef4de4/nvidia_cuda_runtime_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:09c2e35f48359752dfa822c09918211844a3d93c100a715d79b59591130c5e1e", size = 878808, upload-time = "2024-04-03T21:00:49.77Z" }, -] - [[package]] name = "nvidia-cuda-runtime-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/8f/ea/590b2ac00d772a8abd1c387a92b46486d2679ca6622fd25c18ff76265663/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd", size = 908052, upload-time = "2024-11-20T17:35:19.905Z" }, { url = "https://files.pythonhosted.org/packages/b7/3d/159023799677126e20c8fd580cca09eeb28d5c5a624adc7f793b9aa8bbfa/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e", size = 908040, upload-time = "2024-10-01T16:57:22.221Z" }, @@ -1359,30 +1295,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/76/4c80fa138333cc975743fd0687a745fccb30d167f906f13c1c7f9a85e5ea/nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f", size = 891773, upload-time = "2024-10-01T17:09:26.362Z" }, ] -[[package]] -name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892, upload-time = "2024-04-22T15:24:53.333Z" }, -] - [[package]] name = "nvidia-cudnn-cu12" version = "9.5.1.17" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/99/93/a201a12d3ec1caa8c6ac34c1c2f9eeb696b886f0c36ff23c638b46603bd0/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def", size = 570523509, upload-time = "2024-10-25T19:53:03.148Z" }, @@ -1390,31 +1308,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/b2/3f60d15f037fa5419d9d7f788b100ef33ea913ae5315c87ca6d6fa606c35/nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8", size = 565440743, upload-time = "2024-10-25T19:55:49.74Z" }, ] -[[package]] -name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548, upload-time = "2024-06-18T19:33:39.396Z" }, - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, - { url = "https://files.pythonhosted.org/packages/f6/ee/3f3f8e9874f0be5bbba8fb4b62b3de050156d159f8b6edc42d6f1074113b/nvidia_cufft_cu12-11.2.1.3-py3-none-win_amd64.whl", hash = "sha256:d802f4954291101186078ccbe22fc285a902136f974d369540fd4a5333d1440b", size = 210576476, upload-time = "2024-04-03T21:04:06.422Z" }, -] - [[package]] name = "nvidia-cufft-cu12" version = "11.3.0.4" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/37/c50d2b2f2c07e146776389e3080f4faf70bcc4fa6e19d65bb54ca174ebc3/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6", size = 200164144, upload-time = "2024-11-20T17:40:58.288Z" }, @@ -1424,26 +1323,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/38/36fd800cec8f6e89b7c1576edaaf8076e69ec631644cdbc1b5f2e2b5a9df/nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464", size = 199356881, upload-time = "2024-10-01T17:13:01.861Z" }, ] -[[package]] -name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811, upload-time = "2024-06-18T19:34:48.575Z" }, - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, - { url = "https://files.pythonhosted.org/packages/1c/22/2573503d0d4e45673c263a313f79410e110eb562636b0617856fdb2ff5f6/nvidia_curand_cu12-10.3.5.147-py3-none-win_amd64.whl", hash = "sha256:f307cc191f96efe9e8f05a87096abc20d08845a841889ef78cb06924437f6771", size = 55799918, upload-time = "2024-04-03T21:04:34.45Z" }, -] - [[package]] name = "nvidia-curand-cu12" version = "10.3.7.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/42/ac/36543605358a355632f1a6faa3e2d5dfb91eab1e4bc7d552040e0383c335/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8", size = 56289881, upload-time = "2024-10-01T17:04:18.981Z" }, { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, @@ -1452,35 +1335,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/a8/0cd0cec757bd4b4b4ef150fca62ec064db7d08a291dced835a0be7d2c147/nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905", size = 55783873, upload-time = "2024-10-01T17:13:30.377Z" }, ] -[[package]] -name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111, upload-time = "2024-06-18T19:35:01.793Z" }, - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, - { url = "https://files.pythonhosted.org/packages/f2/be/d435b7b020e854d5d5a682eb5de4328fd62f6182507406f2818280e206e2/nvidia_cusolver_cu12-11.6.1.9-py3-none-win_amd64.whl", hash = "sha256:e77314c9d7b694fcebc84f58989f3aa4fb4cb442f12ca1a9bde50f5e8f6d1b9c", size = 125224015, upload-time = "2024-04-03T21:04:53.339Z" }, -] - [[package]] name = "nvidia-cusolver-cu12" version = "11.7.1.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/93/17/dbe1aa865e4fdc7b6d4d0dd308fdd5aaab60f939abfc0ea1954eac4fb113/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0", size = 157833628, upload-time = "2024-10-01T17:05:05.591Z" }, @@ -1490,31 +1352,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/53/fff50a0808df7113d77e3bbc7c2b7eaed6f57d5eb80fbe93ead2aea1e09a/nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7", size = 149287877, upload-time = "2024-10-01T17:13:49.804Z" }, ] -[[package]] -name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987, upload-time = "2024-06-18T19:35:32.989Z" }, - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e0/3155ca539760a8118ec94cc279b34293309bcd14011fc724f87f31988843/nvidia_cusparse_cu12-12.3.1.170-py3-none-win_amd64.whl", hash = "sha256:9bc90fb087bc7b4c15641521f31c0371e9a612fc2ba12c338d3ae032e6b6797f", size = 204684315, upload-time = "2024-04-03T21:05:26.031Z" }, -] - [[package]] name = "nvidia-cusparse-cu12" version = "12.5.4.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] dependencies = [ - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/eb/eb/6681efd0aa7df96b4f8067b3ce7246833dd36830bb4cec8896182773db7d/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887", size = 216451147, upload-time = "2024-11-20T17:44:18.055Z" }, @@ -1524,26 +1367,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/ef/876ad8e4260e1128e6d4aac803d9d51baf3791ebdb4a9b8d9b8db032b4b0/nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20", size = 213712630, upload-time = "2024-10-01T17:14:23.779Z" }, ] -[[package]] -name = "nvidia-cusparselt-cu12" -version = "0.6.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/8e/675498726c605c9441cf46653bd29cb1b8666da1fb1469ffa25f67f20c58/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8", size = 149422781, upload-time = "2024-07-23T17:35:27.203Z" }, - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751, upload-time = "2024-07-23T02:35:53.074Z" }, - { url = "https://files.pythonhosted.org/packages/56/8f/2c33082238b6c5e783a877dc8786ab62619e3e6171c083bd3bba6e3fe75e/nvidia_cusparselt_cu12-0.6.2-py3-none-win_amd64.whl", hash = "sha256:0057c91d230703924c0422feabe4ce768841f9b4b44d28586b6f6d2eb86fbe70", size = 148755794, upload-time = "2024-07-23T02:35:00.261Z" }, -] - [[package]] name = "nvidia-cusparselt-cu12" version = "0.6.3" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/62/da/4de092c61c6dea1fc9c936e69308a02531d122e12f1f649825934ad651b5/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1", size = 156402859, upload-time = "2024-10-16T02:23:17.184Z" }, { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, @@ -1567,52 +1394,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, ] -[[package]] -name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510, upload-time = "2024-06-18T20:20:13.871Z" }, - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, - { url = "https://files.pythonhosted.org/packages/81/19/0babc919031bee42620257b9a911c528f05fb2688520dcd9ca59159ffea8/nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1", size = 95336325, upload-time = "2024-04-03T21:06:25.073Z" }, -] - [[package]] name = "nvidia-nvjitlink-cu12" version = "12.6.85" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, { url = "https://files.pythonhosted.org/packages/31/db/dc71113d441f208cdfe7ae10d4983884e13f464a6252450693365e166dcf/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41", size = 19270338, upload-time = "2024-11-20T17:46:29.758Z" }, { url = "https://files.pythonhosted.org/packages/89/76/93c1467b1387387440a4d25102d86b7794535449b689f8e2dc22c1c8ff7f/nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c", size = 161908572, upload-time = "2024-11-20T17:52:40.124Z" }, ] -[[package]] -name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417, upload-time = "2024-06-18T20:16:22.484Z" }, - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, - { url = "https://files.pythonhosted.org/packages/54/1b/f77674fbb73af98843be25803bbd3b9a4f0a96c75b8d33a2854a5c7d2d77/nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485", size = 66307, upload-time = "2024-04-03T21:02:01.959Z" }, -] - [[package]] name = "nvidia-nvtx-cu12" version = "12.6.77" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/b9/93/80f8a520375af9d7ee44571a6544653a176e53c2b8ccce85b97b83c2491b/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b", size = 90549, upload-time = "2024-11-20T17:38:17.387Z" }, { url = "https://files.pythonhosted.org/packages/2b/53/36e2fd6c7068997169b49ffc8c12d5af5e5ff209df6e1a2c4d373b3a638f/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059", size = 90539, upload-time = "2024-10-01T17:00:27.179Z" }, @@ -2426,8 +2221,8 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0" -source = { registry = "https://pypi.org/simple" } +version = "2.6.0+cpu" +source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "platform_machine == 'aarch64' and sys_platform == 'linux'", "platform_machine == 'x86_64' and sys_platform == 'linux'", @@ -2437,29 +2232,14 @@ dependencies = [ { name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "nvidia-cublas-cu12", version = "12.4.5.8", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.1.0.70", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.2.1.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.5.147", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.6.1.9", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.3.1.170", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.4.127", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563, upload-time = "2025-01-29T16:23:19.084Z" }, - { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867, upload-time = "2025-01-29T16:25:55.649Z" }, - { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469, upload-time = "2025-01-29T16:24:01.821Z" }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538, upload-time = "2025-01-29T16:24:18.976Z" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:59e78aa0c690f70734e42670036d6b541930b8eabbaa18d94e090abf14cc4d91" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:318290e8924353c61b125cdc8768d15208704e279e7757c113b9620740deca98" }, + { url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp312-cp312-win_amd64.whl", hash = "sha256:4027d982eb2781c93825ab9527f17fbbb12dbabf422298e4b954be60016f87d8" }, ] [[package]] @@ -2508,19 +2288,19 @@ dependencies = [ { name = "fsspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "jinja2", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "networkx", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cublas-cu12", version = "12.6.4.1", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-cupti-cu12", version = "12.6.80", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cuda-runtime-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-cu12", version = "9.5.1.17", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cufft-cu12", version = "11.3.0.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-curand-cu12", version = "10.3.7.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusolver-cu12", version = "11.7.1.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", version = "12.5.4.2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cusparselt-cu12", version = "0.6.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", version = "12.6.85", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-nvtx-cu12", version = "12.6.77", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -2696,7 +2476,7 @@ dependencies = [ [package.optional-dependencies] cpu = [ - { name = "torch", version = "2.6.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "torch", version = "2.6.0+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] gpu = [ { name = "flash-attn", version = "2.7.3", source = { url = "https://object-store.os-api.cci1.ecmwf.int/weathergenerator-dev/wheels/flash_attn-2.7.3-cp312-cp312-linux_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'extra-10-weathergen-gpu') or (platform_machine != 'aarch64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2737,9 +2517,9 @@ requires-dist = [ { name = "pynvml" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, + { name = "torch", marker = "sys_platform != 'linux' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'gpu') or (sys_platform != 'linux' and extra == 'gpu')", specifier = "==2.6.0+cu126" }, - { name = "torch", marker = "sys_platform == 'macosx' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, - { name = "torch", marker = "sys_platform != 'macosx' and extra == 'cpu'", specifier = "==2.6.0" }, { name = "tqdm" }, { name = "weathergen-common", editable = "packages/common" }, { name = "weathergen-evaluate", editable = "packages/evaluate" }, From eba89a6a8181ae3905fc64157cf247e5e3ce2fe2 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 13 Oct 2025 17:01:52 +0200 Subject: [PATCH 14/40] Change epochs from 64 to 32 --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index abbcb47f2..efb6e95b3 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -109,7 +109,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 64 +num_epochs: 32 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True From 56156346daa15a664d1ec0d147463e0aeb64f11f Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 10 Nov 2025 16:22:20 +0100 Subject: [PATCH 15/40] LayerNorm replication and analysis tools --- config/default_config.yml | 2 + scripts/model_weight_progression.py | 82 +++++++++++++++++++++++++++++ src/weathergen/model/engines.py | 9 ++++ src/weathergen/model/model.py | 30 ++++++++++- 4 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 scripts/model_weight_progression.py diff --git a/config/default_config.yml b/config/default_config.yml index efb6e95b3..be3a5eb03 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -33,6 +33,7 @@ ae_global_with_qk_lnorm: True ae_global_att_dense_rate: 1.0 ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 +ae_global_trailing_layer_norm: False decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False @@ -52,6 +53,7 @@ fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True +fe_layer_norm_at_layers: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer impute_latent_noise_std: 0.0 # 1e-4 healpix_level: 4 diff --git a/scripts/model_weight_progression.py b/scripts/model_weight_progression.py new file mode 100644 index 000000000..1af1aaaad --- /dev/null +++ b/scripts/model_weight_progression.py @@ -0,0 +1,82 @@ +import os +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +import tqdm + + +def load_checkpoint(run_id: str, epoch: int) -> dict[str, torch.Tensor]: + chkpt = torch.load(f"models/{run_id}/{run_id}_epoch{str(epoch).zfill(5)}.chkpt") + fe_keys = [key for key in list(chkpt.keys()) if "fe" in key] + fe_chkpt = {fe_key: chkpt[fe_key] for fe_key in fe_keys} + return fe_keys, fe_chkpt + + +def get_layer_and_name(key: str) -> [int, str]: + key_split = key.split(".") + layer = int(key_split[1]) + name = ".".join(key_split[2:]) + return layer, name + + +def prepare_weights_and_eigenvalues( + w_dict: dict[str, torch.Tensor], +) -> [dict[str, list], dict[str, list]]: + # Compute eigenvectors of each layer. Set to [0, 0] if no matrix. + e_dict = { + key: (w_dict[key].svd().S.cpu().numpy()) if len(w_dict[key].shape) > 1 else [0, 0] + for key in w_dict + } + # Flatten all weights + w_dict = {key: w_dict[key].flatten().cpu().numpy() for key in w_dict} + return w_dict, e_dict + + +def plot_results( + w_dict: dict[str, torch.Tensor], epoch: int, layers: int, run_id: str, plot_dir: str +): + w_dict, e_dict = prepare_weights_and_eigenvalues(w_dict=w_dict) + fig, axs = plt.subplots(2, 1, figsize=(len(w_dict.keys()), 5), sharex=True) + axs[0].boxplot(w_dict.values(), tick_labels=w_dict.keys()) + axs[1].violinplot(e_dict.values()) + axs[0].grid() + axs[1].grid() + axs[0].set_title("Weight distribution") + axs[1].set_title("Singular value distribution") + plt.xticks(rotation=45, ha="right") + os.makedirs(plot_dir, exist_ok=True) + plot_path = plot_dir / Path(f"w-dist_{run_id}_epoch{str(epoch).zfill(3)}.png") + fig.savefig(plot_path, bbox_inches="tight", pad_inches=0) + + +if __name__ == "__main__": + run_id = "vso7p6dt" + epochs = [2, 4, 8, 16, 32, 63] + plot_dir = Path("plots", "w_dist", run_id) + + for epoch in tqdm.tqdm(epochs, desc="Processing epoch"): + fe_keys, fe_w_dict = load_checkpoint(run_id=run_id, epoch=epoch) + + # + # Option 1: All layers in one plot + plot_results(w_dict=fe_w_dict, epoch=epoch, layers=15, run_id=run_id, plot_dir=plot_dir) + + # + # Option 2: One plot per layer + # layer = -1 # init + # for fe_key in fe_keys: + # l, name = get_layer_and_name(key=fe_key) + + # if layer != l: + # if layer != -1: + # plot_results(w_dict=w_per_layer, epoch=epoch, layers=15, run_id=run_id, plot_dir=plot_dir) + # # Reset layer dict for new layer + # layer = l + # w_per_layer = dict() + + # w_per_layer[name] = fe_w_dict[fe_key] + + # print(fe_key) + + # plot_results(w_dict=w_per_layer, epoch=epoch, layer=l) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..68d6809cb 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -260,6 +260,10 @@ def create(self) -> torch.nn.ModuleList: norm_eps=self.cf.mlp_norm_eps, ) ) + if getattr(self.cf, "ae_global_trailing_layer_norm", False): + self.ae_global_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) return self.ae_global_blocks @@ -329,6 +333,11 @@ def create(self) -> torch.nn.ModuleList: norm_eps=self.cf.mlp_norm_eps, ) ) + # Optionally, add LayerNorm after i-th layer + if i in getattr(self.cf, "fe_layer_norm_at_layers", []): + self.fe_blocks.append( + torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) + ) def init_weights_final(m): if isinstance(m, torch.nn.Linear): diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 803c0312b..b83ec7173 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -547,6 +547,25 @@ def forward_jac(self, *args): return tuple(preds_all[0]) + ######################################### + def plot_token_distribution(self, tokens, fstep): + plot_path = Path(self.cf.run_path, self.cf.run_id, "plots", "ERA5", "latent_hists") + import os + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.hist(tokens.flatten().to("cpu").numpy(), bins=30) + if not hasattr(self, "xlim"): + self.xlim = ax.get_xlim() + self.ylim = ax.get_ylim() + ax.set_xlim(self.xlim) + ax.set_ylim(self.ylim) + ax.set_title(f"Forecast step {fstep}") + os.makedirs(plot_path, exist_ok=True) + fig.savefig(plot_path / f"fstep_{str(fstep).zfill(2)}.png") + plt.close() + ######################################### def forward(self, model_params: ModelParams, batch, forecast_offset: int, forecast_steps: int): """Performs the forward pass of the model to generate forecasts @@ -576,6 +595,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.assimilate_global(model_params, tokens) + if not self.training: + self.plot_token_distribution(tokens=tokens, fstep=0) + # roll-out in latent space preds_all = [] for fstep in range(forecast_offset, forecast_offset + forecast_steps): @@ -598,6 +620,9 @@ def forward(self, model_params: ModelParams, batch, forecast_offset: int, foreca tokens = self.forecast(model_params, tokens) + if not self.training: + self.plot_token_distribution(tokens=tokens, fstep=fstep) + # prediction for final step preds_all += [ self.predict( @@ -807,7 +832,10 @@ def forecast(self, model_params: ModelParams, tokens: torch.Tensor) -> torch.Ten for it, block in enumerate(self.fe_blocks): aux_info = torch.tensor([it], dtype=torch.float32, device="cuda") - tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) + if isinstance(block, torch.nn.modules.normalization.LayerNorm): + tokens = block(tokens) + else: + tokens = checkpoint(block, tokens, aux_info, use_reentrant=False) return tokens From 9ccc95eb5e0822773a5bca3855910288d1643339 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 10 Nov 2025 16:53:42 +0100 Subject: [PATCH 16/40] Rename fe_layer_norm_at_layers to fe_layer_norm_after_blocks --- config/default_config.yml | 2 +- src/weathergen/model/engines.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index be3a5eb03..d3f1eaa83 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -53,7 +53,7 @@ fe_num_blocks: 8 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_layer_norm_at_layers: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer impute_latent_noise_std: 0.0 # 1e-4 healpix_level: 4 diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 68d6809cb..d3f3c3b86 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -334,7 +334,7 @@ def create(self) -> torch.nn.ModuleList: ) ) # Optionally, add LayerNorm after i-th layer - if i in getattr(self.cf, "fe_layer_norm_at_layers", []): + if i in getattr(self.cf, "fe_layer_norm_after_blocks", []): self.fe_blocks.append( torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) From 240031dfb50c518a346af5f3b6e986838e7b0ecd Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 13 Nov 2025 08:24:32 +0100 Subject: [PATCH 17/40] Increase epochs from 32 to 64 and resolve minor bug --- config/default_config.yml | 2 +- src/weathergen/model/model.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index d3f1eaa83..6b2327833 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -111,7 +111,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 32 +num_epochs: 64 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index b83ec7173..2a55c7808 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -549,9 +549,12 @@ def forward_jac(self, *args): ######################################### def plot_token_distribution(self, tokens, fstep): + # When validating (distributed setup), don't plot the token distribution + if tokens.dtype == torch.bfloat16: + return + plot_path = Path(self.cf.run_path, self.cf.run_id, "plots", "ERA5", "latent_hists") import os - import matplotlib.pyplot as plt fig, ax = plt.subplots() From 20ae505d67af4249081b6482575c24d92bfc7e40 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 17 Nov 2025 13:45:49 +0100 Subject: [PATCH 18/40] Update default_config back to d2048 on the O96 grid --- config/default_config.yml | 6 +++--- uv.lock | 16 ---------------- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 566bc98c5..d0b4d7ef0 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -9,7 +9,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True -ae_local_dim_embed: 256 +ae_local_dim_embed: 1024 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -23,9 +23,9 @@ ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 -ae_global_dim_embed: 256 +ae_global_dim_embed: 2048 ae_global_num_blocks: 4 -ae_global_num_heads: 16 +ae_global_num_heads: 32 ae_global_dropout_rate: 0.1 ae_global_with_qk_lnorm: True # TODO: switching to < 1 triggers triton-related issues. diff --git a/uv.lock b/uv.lock index b7e0144b0..4cdcbdcc5 100644 --- a/uv.lock +++ b/uv.lock @@ -2295,20 +2295,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, ] -[[package]] -name = "seaborn" -version = "0.13.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "matplotlib", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "pandas", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/86/59/a451d7420a77ab0b98f7affa3a1d78a313d2f7281a57afb1a34bae8ab412/seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7", size = 1457696, upload-time = "2024-01-25T13:21:52.551Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/83/11/00d3c3dfc25ad54e731d91449895a79e4bf2384dc3ac01809010ba88f6d5/seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987", size = 294914, upload-time = "2024-01-25T13:21:49.598Z" }, -] - [[package]] name = "semantic-version" version = "2.10.0" @@ -2742,7 +2728,6 @@ dependencies = [ { name = "polars", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "psutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "pynvml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, - { name = "seaborn", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2792,7 +2777,6 @@ requires-dist = [ { name = "polars", specifier = "~=1.25.2" }, { name = "psutil" }, { name = "pynvml" }, - { name = "seaborn", specifier = ">=0.13.2" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, From 2731d295448280bc7e8abf99413ca5009104115b Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 17 Nov 2025 13:46:24 +0100 Subject: [PATCH 19/40] Update ERA5 stream to O96 grid --- config/streams/era5_1deg/era5.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 23e369cf4..b3b98ff57 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -9,8 +9,8 @@ ERA5 : type : anemoi - #filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] - filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. From 028bb98a2cdb45a838e3202ad6737878198ce386 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 18 Nov 2025 13:28:47 +0100 Subject: [PATCH 20/40] Resolving bug after merging with develop and updating default_config --- config/default_config.yml | 2 +- src/weathergen/model/engines.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index d0b4d7ef0..f5e80b608 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -37,7 +37,7 @@ ae_global_trailing_layer_norm: False decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False -pred_self_attention: False +pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index f79f1d564..acb45269f 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -302,7 +302,6 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: self.ae_global_blocks.append( torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) - return self.ae_global_blocks def forward(self, tokens, use_reentrant): for block in self.ae_global_blocks: From ba840668cff70e2e7ae1c3cd7b2e24c30648cc8f Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 19 Nov 2025 13:56:15 +0100 Subject: [PATCH 21/40] Enable loading old model checkpoints after recent merges --- src/weathergen/train/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index a4ced9b2b..3de1e65f0 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -796,6 +796,7 @@ def load_model(self, run_id: str, mini_epoch=-1): is_model_sharded = self.cf.with_ddp and self.cf.with_fsdp if is_model_sharded: + params = self.model.rename_old_state_dict(params=params) # For backward compatibility meta_sharded_sd = self.model.state_dict() maybe_sharded_sd = {} for param_name, full_tensor in params.items(): From 4f00cc6252d91bb65aa8b0f0804ccb2ec6f0af2c Mon Sep 17 00:00:00 2001 From: Matthias Date: Wed, 19 Nov 2025 14:09:49 +0100 Subject: [PATCH 22/40] Update WeatherGenReader with mini-epoch notation --- packages/evaluate/src/weathergen/evaluate/io_reader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/io_reader.py b/packages/evaluate/src/weathergen/evaluate/io_reader.py index 66fb2602d..fe12c58e5 100644 --- a/packages/evaluate/src/weathergen/evaluate/io_reader.py +++ b/packages/evaluate/src/weathergen/evaluate/io_reader.py @@ -469,7 +469,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non super().__init__(eval_cfg, run_id, private_paths) - self.mini_epoch = eval_cfg.mini_epoch + # TODO: remove backwards compatibility to "epoch" in Feb. 2026 + self.mini_epoch = getattr(eval_cfg, "mini_epoch", eval_cfg["epoch"]) self.rank = eval_cfg.rank # Load model configuration and set (run-id specific) directories @@ -889,7 +890,7 @@ def load_scores(self, stream: str, region: str, metric: str) -> xr.DataArray | N """ score_path = ( Path(self.metrics_dir) - / f"{self.run_id}_{stream}_{region}_{metric}_epoch{self.epoch:05d}.json" + / f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json" ) _logger.debug(f"Looking for: {score_path}") From e44e1393f7c380f963ff5a53b51c77db7bf1cee5 Mon Sep 17 00:00:00 2001 From: Matthias Date: Thu, 20 Nov 2025 13:55:21 +0100 Subject: [PATCH 23/40] Minor modifications to latent histogram plotting --- src/weathergen/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 2a2b9cd73..a623a2a83 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -612,11 +612,11 @@ def plot_token_distribution(self, tokens, fstep): if not hasattr(self, "xlim"): self.xlim = ax.get_xlim() self.ylim = ax.get_ylim() - ax.set_xlim(self.xlim) + ax.set_xlim(0.2*self.xlim) ax.set_ylim(self.ylim) ax.set_title(f"Forecast step {fstep}") os.makedirs(plot_path, exist_ok=True) - fig.savefig(plot_path / f"fstep_{str(fstep).zfill(2)}.png") + fig.savefig(plot_path / f"fstep_{str(fstep).zfill(3)}.png") plt.close() ######################################### From c979ab416a9e91d522963d170bfca505fb9411d2 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 21 Nov 2025 10:39:01 +0100 Subject: [PATCH 24/40] Resolve bug in histogram plotting --- src/weathergen/model/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a623a2a83..ba021860b 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -610,9 +610,9 @@ def plot_token_distribution(self, tokens, fstep): fig, ax = plt.subplots() ax.hist(tokens.flatten().to("cpu").numpy(), bins=30) if not hasattr(self, "xlim"): - self.xlim = ax.get_xlim() - self.ylim = ax.get_ylim() - ax.set_xlim(0.2*self.xlim) + self.xlim = np.array(ax.get_xlim()) + self.ylim = np.array(ax.get_ylim()) + ax.set_xlim(0.5 * self.xlim) ax.set_ylim(self.ylim) ax.set_title(f"Forecast step {fstep}") os.makedirs(plot_path, exist_ok=True) From d24c4b6800b45bd1f859e61d8b29eab5a540c176 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 21 Nov 2025 12:44:33 +0100 Subject: [PATCH 25/40] Replace getattr by cf.get --- src/weathergen/model/engines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index acb45269f..03d036b59 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -298,7 +298,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: norm_eps=self.cf.mlp_norm_eps, ) ) - if getattr(self.cf, "ae_global_trailing_layer_norm", False): + if self.cf.get("ae_global_trailing_layer_norm", False): self.ae_global_blocks.append( torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) @@ -371,7 +371,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: ) ) # Optionally, add LayerNorm after i-th layer - if i in getattr(self.cf, "fe_layer_norm_after_blocks", []): + if i in self.cf.get("fe_layer_norm_after_blocks", []): self.fe_blocks.append( torch.nn.LayerNorm(self.cf.ae_global_dim_embed, elementwise_affine=False) ) From 89670bf508b0a915188f37ea69493940ab10741f Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 24 Nov 2025 11:54:10 +0100 Subject: [PATCH 26/40] Change target read-out engine from 1 to 2 layers --- config/streams/era5_1deg/era5.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index b3b98ff57..912075c4b 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -30,7 +30,7 @@ ERA5 : dim_embed : 256 target_readout : type : 'obs_value' # token or obs_value - num_layers : 1 + num_layers : 2 num_heads : 4 # sampling_rate : 0.2 pred_head : From 58474b285faaff40b48aed73f3d292befbb65073 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 28 Nov 2025 09:24:57 +0100 Subject: [PATCH 27/40] Set aux-info for fe-blocks to none --- src/weathergen/model/engines.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 03d036b59..3f8ecc3e8 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -312,7 +312,7 @@ def forward(self, tokens, use_reentrant): class ForecastingEngine(torch.nn.Module): name: "ForecastingEngine" - def __init__(self, cf: Config, num_healpix_cells: int) -> None: + def __init__(self, cf: Config, num_healpix_cells: int, dim_aux: int = None) -> None: """ Initialize the ForecastingEngine with the configuration. @@ -337,7 +337,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -353,7 +353,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_qk_lnorm=self.cf.fe_with_qk_lnorm, with_flash=self.cf.with_flash_attention, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), ) @@ -366,7 +366,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: with_residual=True, dropout_rate=self.cf.fe_dropout_rate, norm_type=self.cf.norm_type, - dim_aux=1, + dim_aux=dim_aux, norm_eps=self.cf.mlp_norm_eps, ) ) @@ -386,8 +386,8 @@ def init_weights_final(m): block.apply(init_weights_final) def forward(self, tokens, fstep): - aux_info = torch.tensor([fstep], dtype=torch.float32, device="cuda") - for block in self.fe_blocks: + aux_info = None + for b_idx, block in enumerate(self.fe_blocks): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = block(tokens) else: From 184dcd962e59d105f15a1503946186db5919fe52 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis <79579567+SavvasMel@users.noreply.github.com> Date: Fri, 12 Dec 2025 11:04:21 +0100 Subject: [PATCH 28/40] fix a plotting bug (#1453) --- packages/evaluate/src/weathergen/evaluate/plotter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotter.py b/packages/evaluate/src/weathergen/evaluate/plotter.py index cb15e6f24..007b957e7 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotter.py +++ b/packages/evaluate/src/weathergen/evaluate/plotter.py @@ -14,8 +14,8 @@ from matplotlib.lines import Line2D from PIL import Image from scipy.stats import wilcoxon - from weathergen.common.config import _load_private_conf + from weathergen.evaluate.plot_utils import ( DefaultMarkerSize, ) @@ -482,7 +482,7 @@ def scatter_plot( # TODO: make this nicer parts = ["map", self.run_id, tag] - if self.sample: + if self.sample is not None: parts.append(str(self.sample)) if "valid_time" in data.coords: From d3b63d25a918c385c1f046f70f22a7e742515347 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 12 Dec 2025 11:20:27 +0100 Subject: [PATCH 29/40] Update train/val dates, HL=5, fsteps=2, lat-weighting --- config/default_config.yml | 4 ++-- config/streams/era5_1deg/era5.yml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index f5e80b608..552649583 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -140,8 +140,8 @@ log_grad_norms: False start_date: 197901010000 end_date: 202012310000 -start_date_val: 202101010000 -end_date_val: 202201010000 +start_date_val: 2023100100 +end_date_val: 2023123100 len_hrs: 6 step_hrs: 6 input_window_steps: 1 diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 912075c4b..33826b58d 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -14,6 +14,7 @@ ERA5 : source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] loss_weight : 1. + location_weight : cosine_latitude masking_rate : 0.6 masking_rate_none : 0.05 token_size : 8 From a584e41887364d1939fccb232a9ea404898d289d Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 19 Dec 2025 08:47:30 +0100 Subject: [PATCH 30/40] Defined base config for parameter search --- config/default_config.yml | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 552649583..750d8209b 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -9,7 +9,7 @@ embed_dropout_rate: 0.1 target_cell_local_prediction: True -ae_local_dim_embed: 1024 +ae_local_dim_embed: 2048 ae_local_num_blocks: 0 ae_local_num_heads: 16 ae_local_dropout_rate: 0.1 @@ -45,18 +45,18 @@ pred_mlp_adaln: True # one is training an auto-encoder forecast_offset : 1 forecast_delta_hrs: 0 -forecast_steps: 4 +forecast_steps: 2 forecast_policy: "fixed" forecast_freeze_model: False forecast_att_dense_rate: 1.0 -fe_num_blocks: 8 +fe_num_blocks: 16 fe_num_heads: 16 fe_dropout_rate: 0.1 fe_with_qk_lnorm: True -fe_layer_norm_after_blocks: [] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer -impute_latent_noise_std: 0.0 # 1e-4 +fe_layer_norm_after_blocks: [7] # Index starts at 0. Thus, [3] adds a LayerNorm after the fourth layer +impute_latent_noise_std: 1e-4 -healpix_level: 4 +healpix_level: 5 with_mixed_precision: True with_flash_attention: True @@ -116,7 +116,7 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"], "same_strategy_per_batch": false } -num_epochs: 64 +num_epochs: 128 samples_per_epoch: 4096 samples_per_validation: 512 shuffle: True @@ -164,3 +164,24 @@ train_log_freq: terminal: 10 metrics: 20 checkpoint: 250 + +# Tags for experiment tracking +# These tags will be logged in MLFlow along with completed runs for train, eval, val +# The tags are free-form, with the following rules: +# - tags should be primitive types (strings, numbers, booleans). NO lists or dictionaries +# - tags should not duplicate existing config entries. +# - try to reuse existing tags where possible. MLFlow does not like having too many unique tags +# - do not use long strings in values (less than 20 characters is a good rule of thumb, we may enforce this in the future) +wgtags: + # The name of the organization of the person running the experiment. + # This may be autofilled in the future. Expected values are lowercase strings of + # the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List + # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" + org: None + # The name of the experiment. This is a distinctive codename for the experiment campaign being run. + # This is expected to be the primary tag for comparing experiments in MLFlow. + # Expected values are lowercase strings with no spaces, just underscores: + # Examples: "rollout_ablation_grid" + exp: "rollout_params" + # *** Experiment-specific tags *** + grid_search: 1 \ No newline at end of file From d7e75eb111e1a5622968c6e1e1aaadabae633575 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 20 Dec 2025 13:40:41 +0100 Subject: [PATCH 31/40] Increase encoder/decoder size and add mlflow tags --- config/default_config.yml | 7 ++++--- config/streams/era5_1deg/era5.yml | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 750d8209b..5d1246f7e 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -139,9 +139,9 @@ nn_module: "te" log_grad_norms: False start_date: 197901010000 -end_date: 202012310000 -start_date_val: 2023100100 -end_date_val: 2023123100 +end_date: 202212310000 +start_date_val: 202310010000 +end_date_val: 202312310000 len_hrs: 6 step_hrs: 6 input_window_steps: 1 @@ -178,6 +178,7 @@ wgtags: # the organizations codenames in https://confluence.ecmwf.int/display/MAEL/Staff+Contact+List # e.g. "ecmwf", "cmcc", "metnor", "jsc", "escience" org: None + issue: 1495 # The name of the experiment. This is a distinctive codename for the experiment campaign being run. # This is expected to be the primary tag for comparing experiments in MLFlow. # Expected values are lowercase strings with no spaces, just underscores: diff --git a/config/streams/era5_1deg/era5.yml b/config/streams/era5_1deg/era5.yml index 33826b58d..33b47d9bd 100644 --- a/config/streams/era5_1deg/era5.yml +++ b/config/streams/era5_1deg/era5.yml @@ -24,11 +24,11 @@ ERA5 : net : transformer num_tokens : 1 num_heads : 8 - dim_embed : 256 + dim_embed : 512 num_blocks : 2 embed_target_coords : net : linear - dim_embed : 256 + dim_embed : 512 target_readout : type : 'obs_value' # token or obs_value num_layers : 2 From 9bc45c577b75ed1e13b2ff1112669c852fe07eb8 Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 22 Dec 2025 14:48:35 +0100 Subject: [PATCH 32/40] Added plot_train content --- config/runs_plot_train.yml | 198 ++++++++++++++++++++++++++++++++++++- 1 file changed, 196 insertions(+), 2 deletions(-) diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index 49924b524..24c71920f 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -1,6 +1,200 @@ train : plot : - lnjzhore : + # in9eslqf : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v1" + # eval: vgbndhco + # t5vqafju : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v2" + # eval: vgbndhco + # qz9n6815 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.05, v3" + # eval: vgbndhco + + # scapqu18 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v1" + # eval: vgbndhco + # yqy2ezoa : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v2" + # eval: vgbndhco + # bp9lcgwn : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.1, v3" + # eval: vgbndhco + + # vzeakjlb : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v1" + # eval: vgbndhco + # a8n4zrfs : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v2" + # eval: vgbndhco + # cxicf671 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.2, v3" + # eval: vgbndhco + + # z3infogw : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v1" + # eval: vgbndhco + # a9hp2qju : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v2" + # eval: vgbndhco + # achvju39 : + # slurm_id: 0 + # description: "lr 5e-4, w-dec 0.4, v3" + # eval: vgbndhco + + cszpe803 : slurm_id: 0 - description: "Christian's naoj54ch with new code" + description: "lr 5e-4, w-dec 0.6, v1" + eval: vgbndhco + y0kauh4s : + slurm_id: 0 + description: "lr 5e-4, w-dec 0.6, v2" + eval: vgbndhco + eneq4ahr : + slurm_id: 0 + description: "lr 5e-4, w-dec 0.6, v3" + eval: vgbndhco + + + + # mia69x1h : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v1" + # eval: vgbndhco + # lgzkdwls : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v2" + # eval: vgbndhco + # jr39znm6 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.05, v3" + # eval: vgbndhco + + # gzxgp7cw : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v1" + # eval: vgbndhco + # el6zytfd : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v2" + # eval: vgbndhco + # c64w3cgy : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.1, v3" + # eval: vgbndhco + + # m4x3a0jt : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v1" + # eval: vgbndhco + # manyrowd : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v2" + # eval: vgbndhco + # ijwbpy3k : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.2, v3" + # eval: vgbndhco + + # i9qkv084 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v1" + # eval: vgbndhco + # l78tqy2z : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v2" + # eval: vgbndhco + # xn4wa7b2 : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.4, v3" + # eval: vgbndhco + + s9sldzyb : + slurm_id: 0 + description: "lr 1e-4, w-dec 0.6, v1" + eval: vgbndhco + e29izt1j : + slurm_id: 0 + description: "lr 1e-4, w-dec 0.6, v2" + eval: vgbndhco + bmoc645w : + slurm_id: 0 + description: "lr 1e-4, w-dec 0.6, v3" + eval: vgbndhco + + + + # q4l8jb2e : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v1" + # eval: vgbndhco + # eytr9nki : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v2" + # eval: vgbndhco + # bbcm27x1 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.05, v3" + # eval: vgbndhco + + # jjbfpuya : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v1" + # eval: vgbndhco + # wxehoqic : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v2" + # eval: vgbndhco + # sbylixor : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.1, v3" + # eval: vgbndhco + + # scguorkl : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v1" + # eval: vgbndhco + # uh0iz8sa : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v2" + # eval: vgbndhco + # jizcxg9f : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.2, v3" + # eval: vgbndhco + + # d2wgjec9 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v1" + # eval: vgbndhco + # g10zvcn4 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v2" + # eval: vgbndhco + # a2vlj964 : + # slurm_id: 0 + # description: "lr 5e-5, w-dec 0.4, v3" + # eval: vgbndhco + + z8vx03bg : + slurm_id: 0 + description: "lr 5e-5, w-dec 0.6, v1" + eval: vgbndhco + e3k2v450 : + slurm_id: 0 + description: "lr 5e-5, w-dec 0.6, v2" + eval: vgbndhco + qmil5gwk : + slurm_id: 0 + description: "lr 5e-5, w-dec 0.6, v3" eval: vgbndhco \ No newline at end of file From 40b070e8ebed318c29e6a22170c1d42b18d7ded1 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 23 Dec 2025 13:09:05 +0100 Subject: [PATCH 33/40] Adam betas per cli and aifs channel weighting option --- config/default_config.yml | 2 + config/streams/era5_1deg_w-aifs/era5.yml | 105 +++++++++++++++++++++++ launch_multi.sh | 31 +++++++ src/weathergen/train/trainer.py | 9 +- 4 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 config/streams/era5_1deg_w-aifs/era5.yml create mode 100644 launch_multi.sh diff --git a/config/default_config.yml b/config/default_config.yml index 5d1246f7e..d287f2abc 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -131,6 +131,8 @@ lr_steps_cooldown: 512 lr_policy_warmup: "cosine" lr_policy_decay: "constant" lr_policy_cooldown: "linear" +adam_beta1: null # Becomes 0.8 with 2 nodes +adam_beta2: null # Becomes 0.9 with 2 nodes grad_clip: 1.0 weight_decay: 0.1 diff --git a/config/streams/era5_1deg_w-aifs/era5.yml b/config/streams/era5_1deg_w-aifs/era5.yml new file mode 100644 index 000000000..69108aa11 --- /dev/null +++ b/config/streams/era5_1deg_w-aifs/era5.yml @@ -0,0 +1,105 @@ +# (C) Copyright 2024 WeatherGenerator contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +ERA5 : + type : anemoi + filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr'] + # filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr'] + source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp'] + target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp'] + loss_weight : 1. + location_weight : cosine_latitude + channel_weights : + q_50: 0.2 + q_100: 0.23 + q_150: 0.26 + q_200: 0.29 + q_250: 0.33 + q_300: 0.36 + q_400: 0.42 + q_500: 0.48 + q_600: 0.55 + q_700: 0.61 + q_850: 0.71 + q_925: 0.75 + q_1000: 0.8 + t_50: 0.2 + t_100: 0.23 + t_150: 0.26 + t_200: 0.29 + t_250: 0.33 + t_300: 0.36 + t_400: 0.42 + t_500: 0.48 + t_600: 0.55 + t_700: 0.61 + t_850: 0.71 + t_925: 0.75 + t_1000: 0.8 + u_50: 0.2 + u_100: 0.23 + u_150: 0.26 + u_200: 0.29 + u_250: 0.33 + u_300: 0.36 + u_400: 0.42 + u_500: 0.48 + u_600: 0.55 + u_700: 0.61 + u_850: 0.71 + u_925: 0.75 + u_1000: 0.8 + v_50: 0.2 + v_100: 0.23 + v_150: 0.26 + v_200: 0.29 + v_250: 0.33 + v_300: 0.36 + v_400: 0.42 + v_500: 0.48 + v_600: 0.55 + v_700: 0.61 + v_850: 0.71 + v_925: 0.75 + v_1000: 0.8 + z_50: 0.2 + z_100: 0.23 + z_150: 0.26 + z_200: 0.29 + z_250: 0.33 + z_300: 0.36 + z_400: 0.42 + z_500: 0.48 + z_600: 0.55 + z_700: 0.61 + z_850: 0.71 + z_925: 0.75 + z_1000: 0.8 + masking_rate : 0.6 + masking_rate_none : 0.05 + token_size : 8 + tokenize_spacetime : True + max_num_targets: -1 + embed : + net : transformer + num_tokens : 1 + num_heads : 8 + dim_embed : 512 + num_blocks : 2 + embed_target_coords : + net : linear + dim_embed : 512 + target_readout : + type : 'obs_value' # token or obs_value + num_layers : 2 + num_heads : 4 + # sampling_rate : 0.2 + pred_head : + ens_size : 1 + num_layers : 1 diff --git a/launch_multi.sh b/launch_multi.sh new file mode 100644 index 000000000..ead59ac53 --- /dev/null +++ b/launch_multi.sh @@ -0,0 +1,31 @@ + +# # Resume training +# #for run_id in in9eslqf t5vqafju qz9n6815 scapqu18 yqy2ezoa bp9lcgwn vzeakjlb a8n4zrfs cxicf671 z3infogw a9hp2qju achvju39 cszpe803 y0kauh4s eneq4ahr ; do +# #for run_id in mia69x1h lgzkdwls jr39znm6 gzxgp7cw el6zytfd c64w3cgy m4x3a0jt manyrowd ijwbpy3k i9qkv084 l78tqy2z xn4wa7b2 s9sldzyb e29izt1j bmoc645w ; do +# for run_id in q4l8jb2e eytr9nki bbcm27x1 jjbfpuya wxehoqic sbylixor scguorkl uh0iz8sa jizcxg9f d2wgjec9 g10zvcn4 a2vlj964 z8vx03bg e3k2v450 qmil5gwk ; do +# echo "$run_id" +# #cp ../WeatherGenerator-private/hpc/santis/weathergen_slurm_train.sh /capstor/scratch/cscs/mkarlbau/slurm/slurm_weathergen_"$run_id"_dir/WeatherGenerator-private/hpc/santis/. +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + +# GRID SEARCH 1 +# for lr in "5e-4" "1e-4" "5e-5" ; do +# for w_dec in 0.05 0.1 0.2 0.4 0.6 ; do +# echo "$lr $w_dec" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=$lr weight_decay=$w_dec +# done +# done + + +# GRID SEARCH 2 +# beta1 0.6 0.7 0.8 0.9 0.95 +# beta2 0.8 0.9 0.95, 0.99 +# streams_directory="./config/streams/era5_1deg_w-aifs" "./config/streams/era5_1deg" +for beta1 in 0.6 0.7 0.8 0.9 0.95 ; do + for beta2 in 0.8 0.9 0.95 0.99 ; do + for sd in "./config/streams/era5_1deg_w-aifs" "./config/streams/era5_1deg" ; do + echo "$beta1 $beta2 $sd" + ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.0001 weight_decay=0.1 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd + done +done \ No newline at end of file diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 3de1e65f0..6473d156f 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -339,15 +339,20 @@ def run(self, cf, devices, run_id_contd=None, mini_epoch_contd=None): if is_root() and not cf.with_fsdp and not cf.with_ddp: self.model.print_num_parameters() + # Retrieve Adam betas from config or compute them dynamically if not specified + beta1, beta2 = cf.get("adam_beta1", None), cf.get("adam_beta2", None) + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ # aiming for beta1=0.9 and beta2=0.95 following the MAE paper https://arxiv.org/pdf/2111.06377 kappa = ( cf.batch_size_per_gpu * cf.world_size ) # I doubt this holds for us from some anecdotal runs + # aiming for beta1 = 0.9 at one node, ie kappa=B=4 beta1 = max( 0.5, 1.0 - kappa * (1.0 - 0.975) - ) # aiming for beta1 = 0.9 at one node, ie kappa=B=4 - beta2 = 1.0 - kappa * (1.0 - 0.9875) # aiming for beta2 = 0.95 at one node, ie B=4 + ) if beta1 is None else beta1 + # aiming for beta2 = 0.95 at one node, ie B=4 + beta2 = 1.0 - kappa * (1.0 - 0.9875) if beta2 is None else beta2 eps = 2e-08 / np.sqrt(kappa) self.optimizer = torch.optim.AdamW( From 8c6dafbbaf6c3b594e4822daeb56490931773334 Mon Sep 17 00:00:00 2001 From: Matthias Date: Tue, 6 Jan 2026 13:09:36 +0100 Subject: [PATCH 34/40] Updated launch script --- launch_multi.sh | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/launch_multi.sh b/launch_multi.sh index ead59ac53..33e9ea924 100644 --- a/launch_multi.sh +++ b/launch_multi.sh @@ -1,4 +1,12 @@ +# GRID SEARCH 1 +# for lr in "5e-4" "1e-4" "5e-5" ; do +# for w_dec in 0.05 0.1 0.2 0.4 0.6 ; do +# echo "$lr $w_dec" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=$lr weight_decay=$w_dec +# done +# done + # # Resume training # #for run_id in in9eslqf t5vqafju qz9n6815 scapqu18 yqy2ezoa bp9lcgwn vzeakjlb a8n4zrfs cxicf671 z3infogw a9hp2qju achvju39 cszpe803 y0kauh4s eneq4ahr ; do # #for run_id in mia69x1h lgzkdwls jr39znm6 gzxgp7cw el6zytfd c64w3cgy m4x3a0jt manyrowd ijwbpy3k i9qkv084 l78tqy2z xn4wa7b2 s9sldzyb e29izt1j bmoc645w ; do @@ -8,24 +16,34 @@ # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv # done +# # Delete validation.zarr and plots directories +# # ww0r248v, pg5oe6rq, jc10pgys, usizx285, yqy2ezoa, xbgazwtc, bxf4bdlz, tmnrdyk9, ue5ky698, hd4mvfa1, s301cqla, m63ocvtq, fq2jposb, fh1drqz8, r2ut4clp, u16fvide, t790hb8a, d0bou6tk, ge9zmby0, c745lzyr, sevmrclb, rqer37pc, ot3hqr0x, kj2qxw9k, x18rkx3s, afe4cwb0, srkhuy4g, mph51qok, bh2z0jkt, b0lwy3rk, qehytran, eionpvqj, oo4hq36z, a1x2cdf0, dn13x6ql, c7c480k2, uot4snvp, p3kvrg9j, mcugwbsp, qiz2bfkv, solj81d4, ku4r3omn, kro1j69u, gfm9e1z6, njhycz89 +for run_id in ww0r248v ; do + echo "results/$run_id/validation_chkpt00000_rank0000.zarr" + rm -r "results/$run_id/validation_chkpt00000_rank0000.zarr" + rm -r "results/$run_id/plots" +done -# GRID SEARCH 1 -# for lr in "5e-4" "1e-4" "5e-5" ; do -# for w_dec in 0.05 0.1 0.2 0.4 0.6 ; do -# echo "$lr $w_dec" -# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=$lr weight_decay=$w_dec -# done -# done +# # Train continue grid search 2 +# # for run_id in s3cr2ef4 mrcw4e8h djuk9pzn aepv7wu1 v3h9k2bj ynlvs89z tz49cgl2 l98sz3jw m1k9eplh tgmftqi4 pu6e03ng lnph48zk t69trx2e j9eiz4dq z1sd0pnr tide3hv6 m45w7y1t np85vcnq ldv8n37w cfblam6s fb4ms5ec w584rjeg ctg1usw6 jz8tjv6m w4xqvt69 qrjo9xs0 u34liwsg exvmlypc uubfj3gz h392evts tp8b24cj oxfpqjw1 xpv8qnyd xpw5o9fx ti4vmpsg qkq9yjl7 e1enpxz3 nwj4z09v rhk794ou h1pwbi4a bcmvof51 b2weib8p hzd5uet7 k5pilh01 v4ah1kzx ppxufsjg pcl0snok wurb6xtk npmszy6v y9kfdpom x9ay3kmw irn3mjyc z4awp1g5 gmlzaqhj r0spjzik knouy46d q3tjv9di dwp6e7d9 n1qikpey j2dg98i0 zed62zhu a4vqxo9u suz4ilra m3vscjad mclsg8op rgtp5jy9 jnwrgzpf w9rkhg1v ob6r9moj t6ibtc1j mgyzswul z271maoe q57y4ve3 ohce4138 ao13xq8w t6bjgt9a x762dlfb xdo9plre hlrt72oe hbnm4u1x ; do +# # for run_id in z08sckyz k52e7hfo p09l6hpz gl3ev20k iaw2gbrn m6jniu34 obwocxim ep5q8gzn tkywi1hd lu3lj0z4 y75fha42 kmdx1g3n kmztan34 pg9k1fli qgv2layj heomgws2 yo8jyrbv jfd9lr0u x3rqmjes gk1ny5oh c0yf9pdt n2vm0sxq fvq6tyeo ql6spk9b p5ibvwda cp0k1an8 c01zraug hv97duk1 i2lrahz9 ge6v74ir e6ijt39q sx6m2ejs va9d6yv3 nw43c2n0 z1yi6s2c jsfhgdq1 d3o1a9lc u0gf6k3l gf3r9noc nb6342qh zyxq24l1 t5uh3gv9 hng8tw29 xd60xpu5 y90egrso jnsxo6mt lfq8om96 hhqnr8l9 sqne63ot unq1uwez cm3tolj1 k8irbcyk fqxcans2 ujqxld2f ksdn58ca wirp7n6q ttlvz8f2 xx68golp sm9i8lgx qeopqgju d2mj6l1d u1h675cf dge0wox1 gz6ad0sc g0hb53if n6nymxjc l1np8abi pybuh3o4 gn8jcdug vj6gn49y lnwequ80 t486blik il6f0p87 yixm5qls uafuiesw fh4vuolp ey17ezol im8g2pzi ff1ldwox wb47lhdj ; do +# # for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done # GRID SEARCH 2 # beta1 0.6 0.7 0.8 0.9 0.95 # beta2 0.8 0.9 0.95, 0.99 -# streams_directory="./config/streams/era5_1deg_w-aifs" "./config/streams/era5_1deg" -for beta1 in 0.6 0.7 0.8 0.9 0.95 ; do - for beta2 in 0.8 0.9 0.95 0.99 ; do - for sd in "./config/streams/era5_1deg_w-aifs" "./config/streams/era5_1deg" ; do - echo "$beta1 $beta2 $sd" - ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.0001 weight_decay=0.1 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd - done -done \ No newline at end of file +# streams_directory="./config/streams/era5_1deg, "./config/streams/era5_1deg_w-aifs"" +# for beta1 in 0.6 0.7 0.8 0.9 0.95 ; do +# for beta2 in 0.8 0.9 0.95 0.99 ; do +# for sd in "./config/streams/era5_1deg" "./config/streams/era5_1deg_w-aifs" ; do +# echo "$beta1 $beta2 $sd" +# # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.0001 weight_decay=0.1 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --chain-jobs 3 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=$beta1 adam_beta2=$beta2 streams_directory=$sd +# done +# done +# done From b5c98cb1ef906dba7cd06b7c741b66f6427a9741 Mon Sep 17 00:00:00 2001 From: Matthias Date: Fri, 9 Jan 2026 11:59:26 +0100 Subject: [PATCH 35/40] Updated multiple job launch script with pre-training dropout and noise ablation --- config/default_config.yml | 2 +- launch_multi.sh | 42 ++++++++++++++++++++++++++------------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index d287f2abc..f817cf898 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -187,4 +187,4 @@ wgtags: # Examples: "rollout_ablation_grid" exp: "rollout_params" # *** Experiment-specific tags *** - grid_search: 1 \ No newline at end of file + grid_search: "dropout" \ No newline at end of file diff --git a/launch_multi.sh b/launch_multi.sh index 33e9ea924..0add2bd09 100644 --- a/launch_multi.sh +++ b/launch_multi.sh @@ -18,22 +18,13 @@ # # Delete validation.zarr and plots directories # # ww0r248v, pg5oe6rq, jc10pgys, usizx285, yqy2ezoa, xbgazwtc, bxf4bdlz, tmnrdyk9, ue5ky698, hd4mvfa1, s301cqla, m63ocvtq, fq2jposb, fh1drqz8, r2ut4clp, u16fvide, t790hb8a, d0bou6tk, ge9zmby0, c745lzyr, sevmrclb, rqer37pc, ot3hqr0x, kj2qxw9k, x18rkx3s, afe4cwb0, srkhuy4g, mph51qok, bh2z0jkt, b0lwy3rk, qehytran, eionpvqj, oo4hq36z, a1x2cdf0, dn13x6ql, c7c480k2, uot4snvp, p3kvrg9j, mcugwbsp, qiz2bfkv, solj81d4, ku4r3omn, kro1j69u, gfm9e1z6, njhycz89 -for run_id in ww0r248v ; do - echo "results/$run_id/validation_chkpt00000_rank0000.zarr" - rm -r "results/$run_id/validation_chkpt00000_rank0000.zarr" - rm -r "results/$run_id/plots" -done - - -# # Train continue grid search 2 -# # for run_id in s3cr2ef4 mrcw4e8h djuk9pzn aepv7wu1 v3h9k2bj ynlvs89z tz49cgl2 l98sz3jw m1k9eplh tgmftqi4 pu6e03ng lnph48zk t69trx2e j9eiz4dq z1sd0pnr tide3hv6 m45w7y1t np85vcnq ldv8n37w cfblam6s fb4ms5ec w584rjeg ctg1usw6 jz8tjv6m w4xqvt69 qrjo9xs0 u34liwsg exvmlypc uubfj3gz h392evts tp8b24cj oxfpqjw1 xpv8qnyd xpw5o9fx ti4vmpsg qkq9yjl7 e1enpxz3 nwj4z09v rhk794ou h1pwbi4a bcmvof51 b2weib8p hzd5uet7 k5pilh01 v4ah1kzx ppxufsjg pcl0snok wurb6xtk npmszy6v y9kfdpom x9ay3kmw irn3mjyc z4awp1g5 gmlzaqhj r0spjzik knouy46d q3tjv9di dwp6e7d9 n1qikpey j2dg98i0 zed62zhu a4vqxo9u suz4ilra m3vscjad mclsg8op rgtp5jy9 jnwrgzpf w9rkhg1v ob6r9moj t6ibtc1j mgyzswul z271maoe q57y4ve3 ohce4138 ao13xq8w t6bjgt9a x762dlfb xdo9plre hlrt72oe hbnm4u1x ; do -# # for run_id in z08sckyz k52e7hfo p09l6hpz gl3ev20k iaw2gbrn m6jniu34 obwocxim ep5q8gzn tkywi1hd lu3lj0z4 y75fha42 kmdx1g3n kmztan34 pg9k1fli qgv2layj heomgws2 yo8jyrbv jfd9lr0u x3rqmjes gk1ny5oh c0yf9pdt n2vm0sxq fvq6tyeo ql6spk9b p5ibvwda cp0k1an8 c01zraug hv97duk1 i2lrahz9 ge6v74ir e6ijt39q sx6m2ejs va9d6yv3 nw43c2n0 z1yi6s2c jsfhgdq1 d3o1a9lc u0gf6k3l gf3r9noc nb6342qh zyxq24l1 t5uh3gv9 hng8tw29 xd60xpu5 y90egrso jnsxo6mt lfq8om96 hhqnr8l9 sqne63ot unq1uwez cm3tolj1 k8irbcyk fqxcans2 ujqxld2f ksdn58ca wirp7n6q ttlvz8f2 xx68golp sm9i8lgx qeopqgju d2mj6l1d u1h675cf dge0wox1 gz6ad0sc g0hb53if n6nymxjc l1np8abi pybuh3o4 gn8jcdug vj6gn49y lnwequ80 t486blik il6f0p87 yixm5qls uafuiesw fh4vuolp ey17ezol im8g2pzi ff1ldwox wb47lhdj ; do -# # for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do -# for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do -# echo "$run_id" -# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# for run_id in ww0r248v ; do +# echo "results/$run_id/validation_chkpt00000_rank0000.zarr" +# rm -r "results/$run_id/validation_chkpt00000_rank0000.zarr" +# rm -r "results/$run_id/plots" # done + # GRID SEARCH 2 # beta1 0.6 0.7 0.8 0.9 0.95 # beta2 0.8 0.9 0.95, 0.99 @@ -47,3 +38,26 @@ done # done # done # done + +# # Train continue grid search 2 +# # for run_id in s3cr2ef4 mrcw4e8h djuk9pzn aepv7wu1 v3h9k2bj ynlvs89z tz49cgl2 l98sz3jw m1k9eplh tgmftqi4 pu6e03ng lnph48zk t69trx2e j9eiz4dq z1sd0pnr tide3hv6 m45w7y1t np85vcnq ldv8n37w cfblam6s fb4ms5ec w584rjeg ctg1usw6 jz8tjv6m w4xqvt69 qrjo9xs0 u34liwsg exvmlypc uubfj3gz h392evts tp8b24cj oxfpqjw1 xpv8qnyd xpw5o9fx ti4vmpsg qkq9yjl7 e1enpxz3 nwj4z09v rhk794ou h1pwbi4a bcmvof51 b2weib8p hzd5uet7 k5pilh01 v4ah1kzx ppxufsjg pcl0snok wurb6xtk npmszy6v y9kfdpom x9ay3kmw irn3mjyc z4awp1g5 gmlzaqhj r0spjzik knouy46d q3tjv9di dwp6e7d9 n1qikpey j2dg98i0 zed62zhu a4vqxo9u suz4ilra m3vscjad mclsg8op rgtp5jy9 jnwrgzpf w9rkhg1v ob6r9moj t6ibtc1j mgyzswul z271maoe q57y4ve3 ohce4138 ao13xq8w t6bjgt9a x762dlfb xdo9plre hlrt72oe hbnm4u1x ; do +# # for run_id in z08sckyz k52e7hfo p09l6hpz gl3ev20k iaw2gbrn m6jniu34 obwocxim ep5q8gzn tkywi1hd lu3lj0z4 y75fha42 kmdx1g3n kmztan34 pg9k1fli qgv2layj heomgws2 yo8jyrbv jfd9lr0u x3rqmjes gk1ny5oh c0yf9pdt n2vm0sxq fvq6tyeo ql6spk9b p5ibvwda cp0k1an8 c01zraug hv97duk1 i2lrahz9 ge6v74ir e6ijt39q sx6m2ejs va9d6yv3 nw43c2n0 z1yi6s2c jsfhgdq1 d3o1a9lc u0gf6k3l gf3r9noc nb6342qh zyxq24l1 t5uh3gv9 hng8tw29 xd60xpu5 y90egrso jnsxo6mt lfq8om96 hhqnr8l9 sqne63ot unq1uwez cm3tolj1 k8irbcyk fqxcans2 ujqxld2f ksdn58ca wirp7n6q ttlvz8f2 xx68golp sm9i8lgx qeopqgju d2mj6l1d u1h675cf dge0wox1 gz6ad0sc g0hb53if n6nymxjc l1np8abi pybuh3o4 gn8jcdug vj6gn49y lnwequ80 t486blik il6f0p87 yixm5qls uafuiesw fh4vuolp ey17ezol im8g2pzi ff1ldwox wb47lhdj ; do +# # for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# for run_id in eosc19jg kl0rcua6 ukwujlf4 kfg9vle7 gp5cnl91 d5fwxbil lzg3iprw qxdtz5m1 p9d1ya0q x03dmoey f1pezdro i0jh85s1 gw6leyu4 ev2zrdsu b9kzho1t smgpkdbj tmk1adys wt89w7kp kuzwmra5 ezimdn39 npeywgca qx9bgzqw z8qvl037 jv6am9p8 lokz5def fvef6dpn gw3fz7rc f7jn9ep8 nblsm3uo ze1mh2cb t5fsvqxh vwz3ue4g ez2iolsm t946rtp8 bh65jf8k zb7s32vu momerjg2 o45qt3oc z8vwrchb lfc1rl63 gxibzqeo jx4pw8jb icnlv2kw lx5c0qh4 e2ndef59 aeqf5ozl e0riq4o6 psq79jhc cgncym6p idpvqbaw qn52qavo hygrvmh4 n6azvsmo difz8odj s5o06pyi v450oj3h g4n6iu78 yqtgkaz3 ezl2c8od vcophb7i kl7gop09 uhj4npb6 yjql2371 e1ungd2z awn0a856 hxtscki8 rln8xs45 i6jkuig9 s8qetp41 vn903brm sjxehz2y ch2cnwf1 qo7aihzs v2djkyu0 irmnhejc jct7zx6j y3ezyh1x prx0wqsk iecmarv3 kr4gtp38 ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + +# # DROPOUT [0.05, 0.1, 0.15, 0.2, 0.3, 0.4] +# for dropout in 0.05 0.1 0.15 0.2 0.3 0.4 ; do +# echo "$dropout" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" embed_dropout_rate=$dropout ae_local_dropout_rate=$dropout ae_adapter_dropout_rate=$dropout ae_global_dropout_rate=$dropout fe_dropout_rate=$dropout +# done + + +# # NOISE LEVEL [1e-3, 5e-4, (1e-4), 5e-5, 1e-5] +# for nl in "1e-3" "5e-4" "1e-4" "5e-5" "1e-5" ; do +# echo "$nl" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" impute_latent_noise_std=$nl +# done \ No newline at end of file From 76cb62c55725896bf685f43b558836bb2ee3a2be Mon Sep 17 00:00:00 2001 From: Matthias Date: Mon, 12 Jan 2026 10:57:33 +0100 Subject: [PATCH 36/40] Updated experiment launch script --- launch_multi.sh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/launch_multi.sh b/launch_multi.sh index 0add2bd09..b2a820248 100644 --- a/launch_multi.sh +++ b/launch_multi.sh @@ -55,9 +55,27 @@ # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" embed_dropout_rate=$dropout ae_local_dropout_rate=$dropout ae_adapter_dropout_rate=$dropout ae_global_dropout_rate=$dropout fe_dropout_rate=$dropout # done +# # Train continue dropout +# for run_id in d2f1p4vh m8e7psdl n2nmxc7b flaucoz5 dnl5r61x aojt3c1z p1phw3g9 kacy7jbz uk0uvcfn d1fhev63 pxg7jnzt z40dbxjy fpymqrv3 pe93az4w saxqsfzb yjlzi5g7 zewh2o5n dy36qb7e ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + # # NOISE LEVEL [1e-3, 5e-4, (1e-4), 5e-5, 1e-5] # for nl in "1e-3" "5e-4" "1e-4" "5e-5" "1e-5" ; do # echo "$nl" # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --link-venv --options lr_max=0.00005 weight_decay=0.05 adam_beta1=0.85 adam_beta2=0.9 streams_directory="./config/streams/era5_1deg_w-aifs" impute_latent_noise_std=$nl +# done + +# # Train continue noise level +# for run_id in fmpesclt h6lu3sh8 tgmwaifc vavdy4zf qf24wjsq nbc3il5x vvwizau9 wyhcr51m n207tod4 cey23p7w rh49o7yj qt72d4iy vl5n39cj ocyx09uw fydmc3vg ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +# done + + +# for run_id in fmpesclt vvwizau9 vl5n39cj ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv # done \ No newline at end of file From d14291c2231b69007e5d30707d2d5d8a79fb918d Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Mon, 19 Jan 2026 14:41:35 +0100 Subject: [PATCH 37/40] added config and scripts for launching multiple runs --- config/default_config.yml | 2 +- config/eval_config.yml | 189 +++++++++++++++++++++++++++++--- config/eval_config_matthias.yml | 115 +++++++++++++++++++ config/eval_config_seq.yml | 35 ++++++ config/runs_plot_train.yml | 26 +++-- launch_multi.sh | 17 ++- launch_multi_infer.sh | 5 + launch_multi_lr.sh | 34 ++++++ pyproject.toml | 1 + uv.lock | 2 + weather_slurm_inferece.sh | 111 +++++++++++++++++++ 11 files changed, 513 insertions(+), 24 deletions(-) create mode 100644 config/eval_config_matthias.yml create mode 100644 config/eval_config_seq.yml create mode 100644 launch_multi_infer.sh create mode 100644 launch_multi_lr.sh create mode 100644 weather_slurm_inferece.sh diff --git a/config/default_config.yml b/config/default_config.yml index f817cf898..a88d359ef 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -187,4 +187,4 @@ wgtags: # Examples: "rollout_ablation_grid" exp: "rollout_params" # *** Experiment-specific tags *** - grid_search: "dropout" \ No newline at end of file + grid_search: "dropout" diff --git a/config/eval_config.yml b/config/eval_config.yml index 937bc59be..0403f148b 100644 --- a/config/eval_config.yml +++ b/config/eval_config.yml @@ -1,28 +1,187 @@ -verbose: true -image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. -dpi_val : 300 -summary_plots : true -print_summary: false +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + ERA5: + marker_size: 4 evaluation: - metrics : ["rmse"] + metrics : ["froct", "rmse"] regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: true run_ids : - ptluswdo: - label: "ptluswdo: 64ep 2fs (naoj54ch) + 32ep 8fs 2e-5" + # lr=5e-4 + xs5l8zmj: + label: "cosine scheduler lr_max=5e-4 v1" epoch: 0 rank: 0 streams: ERA5: - channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] - #channels: ["2t", "q_850", ] + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] evaluation: sample: "all" forecast_step: "all" - plotting: - sample: [0] - forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] - plot_maps: true - plot_histograms: false \ No newline at end of file + + x9zvml1k: + label: "cosine scheduler lr_max=5e-4 v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + an8rap5h: + label: "cosine scheduler lr_max=5e-4 v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + # lr=1e-4 + u2qk39pi: + label: "cosine scheduler lr_max=1e-4 v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + zswipf53: + label: "cosine scheduler lr_max=1e-4 v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + dsdvzg59: + label: "cosine scheduler lr_max=1e-4 v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + # lr=5e-5 + #r812ji96: + # label: "cosine scheduler lr_max=5e-5 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #gj6eq2dx: + # label: "cosine scheduler lr_max=5e-5 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #ff80snum: + # label: "cosine scheduler lr_max=5e-5 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + ## lr=1e-5 + v0yha29i: + label: "cosine scheduler lr_max=1e-5 v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + cbmk73y0: + label: "cosine scheduler lr_max=1e-5 v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + ngdrjcbt: + label: "cosine scheduler lr_max=1e-5 v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + # lr=5e-6 + voulcvsi: + label: "cosine scheduler lr_max=5e-6 v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + urlp39xq: + label: "cosine scheduler lr_max=5e-6 v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + ch1n05gd: + label: "cosine scheduler lr_max=5e-6 v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + diff --git a/config/eval_config_matthias.yml b/config/eval_config_matthias.yml new file mode 100644 index 000000000..cbd200d66 --- /dev/null +++ b/config/eval_config_matthias.yml @@ -0,0 +1,115 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + ERA5: + marker_size: 5 + +evaluation: + metrics : ["froct", "rmse"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: true + +run_ids : + + otnh7lcg: + label: "Control: fine-tune 1979-2022, 8x32 epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + v8bc9sxg: + label: "fine-tune 2018-2022, 8x32 epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + adja05fc: + label: "fine-tune 2018-2022, 8x32 epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + nub26n5i: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + kl5g92ne: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + bwzo5qfn: + label: "fine-tune 2018-2022, seq [4, 6, 8] epochs, v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + jrp2kgem: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v1" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + ibmenr7o: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v2" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" + + renlmg1i: + label: "fine-tune 2018-2022, seq [3, 4, ..., 8] epochs, v3" + epoch: 0 + rank: 0 + streams: + ERA5: + channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + evaluation: + sample: "all" + forecast_step: "all" diff --git a/config/eval_config_seq.yml b/config/eval_config_seq.yml new file mode 100644 index 000000000..232dd0877 --- /dev/null +++ b/config/eval_config_seq.yml @@ -0,0 +1,35 @@ +global_plotting_options: + image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" .. + dpi_val : 300 + +evaluation: + metrics : ["froct", "rmse","acc"] #, "mae"] + regions: ["global"] + summary_plots : true + summary_dir: "./plots/" + print_summary: false #print out score values on screen. it can be verbose + log_scale: false + add_grid: false + score_cards: true + num_processes: auto + + +run_ids : + m9p1w7kx: + label: "8 steps all" + epoch: 0 + rank: 0 + streams: + ERA5: + climatology_path: /iopsstor/scratch/cscs/lessig/data/assets/climatology/aifs-ea-an-oper-0001-mars-o96-1980-2020-6h-v6_climatology.zarr + channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ] + #channels: ["2t", "q_850", ] + evaluation: + sample: "all" + forecast_step: "all" + plotting: + sample: [0] + forecast_step: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80 ] + plot_maps: false + plot_histograms: false + plot_animations: false diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index 24c71920f..cb3d33d36 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -52,19 +52,31 @@ train : # description: "lr 5e-4, w-dec 0.4, v3" # eval: vgbndhco - cszpe803 : + h8l1yem5 : slurm_id: 0 - description: "lr 5e-4, w-dec 0.6, v1" + description: "lr 5e-4, v2" eval: vgbndhco - y0kauh4s : + jeubm9ld : slurm_id: 0 - description: "lr 5e-4, w-dec 0.6, v2" + description: "lr 5e-4, v3" eval: vgbndhco - eneq4ahr : + bezt6v8g : slurm_id: 0 - description: "lr 5e-4, w-dec 0.6, v3" + description: "lr 5e-4, v1" eval: vgbndhco + ya0gty48 : + slurm_id: 0 + description: "lr 1e-4, v1" + eval: vgbndhco + djiy1v3e : + slurm_id: 0 + description: "lr 1e-4, v2" + eval: vgbndhco + g96y1dq5 : + slurm_id: 0 + description: "lr 1e-4, v3" + eval: vgbndhco # mia69x1h : @@ -197,4 +209,4 @@ train : qmil5gwk : slurm_id: 0 description: "lr 5e-5, w-dec 0.6, v3" - eval: vgbndhco \ No newline at end of file + eval: vgbndhco diff --git a/launch_multi.sh b/launch_multi.sh index b2a820248..fbd571b06 100644 --- a/launch_multi.sh +++ b/launch_multi.sh @@ -78,4 +78,19 @@ # for run_id in fmpesclt vvwizau9 vl5n39cj ; do # echo "$run_id" # ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv -# done \ No newline at end of file +# done +# + +for run_id in fl9xrpao ; do + echo "$run_id" + ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv +done + + +# Cosine learning_rate test +#for lr_max in "5e-4" "1e-4" "5e-5" "1e-5" "5e-6" ; do +# echo "$lr_max" +# for from_run_id in dnl5r61x ; do +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $from_run_id --link-venv --options istep=0 num_epochs=32 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" +#done +#done diff --git a/launch_multi_infer.sh b/launch_multi_infer.sh new file mode 100644 index 000000000..a7574748a --- /dev/null +++ b/launch_multi_infer.sh @@ -0,0 +1,5 @@ +#for run_id in unov2gdz pv5hu3mc exsm2wty czfrhdae zha9i6x3 ypr1b3a4 vctfgruv htdwjqpx gakr74pw qpogewjf y3trwpx7 asnz2gyl dd1cq6nv dn15vfks fl9xrpao ; do +for run_id in lvlfd8er hr1l2whz a68hqu13 ; do + echo $ "$run_id" + sbatch weather_slurm_inferece.sh "$run_id" + done diff --git a/launch_multi_lr.sh b/launch_multi_lr.sh new file mode 100644 index 000000000..85e31cdcc --- /dev/null +++ b/launch_multi_lr.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Extract all unique (lr, from_run_id) tuples from mapping +tuples=( + "5e-4 dnl5r61x" # zd4t0zmp + "5e-4 vvwizau9" # fzmgdsev + "5e-4 wyhcr51m" # rqhn8y14 + "5e-5 dnl5r61x" # nibpxofg + "5e-5 wyhcr51m" # zylxr8pm + "1e-5 dnl5r61x" # s1urb38z + "1e-5 vvwizau9" # lmix3abo + "1e-5 wyhcr51m" # gu20n5l8 + "5e-6 dnl5r61x" # sxlqdhue + "5e-6 vvwizau9" # s8pfvmle +) + +echo "Launching ${#tuples[@]} experiments..." + +for tuple in "${tuples[@]}"; do + read lr_max from_run_id <<< "$tuple" + echo "=== $from_run_id @ $lr_max ===" + + ../WeatherGenerator-private/hpc/launch-slurm.py \ + --nodes 2 \ + --time 24:00:00 \ + --from-run-id "$from_run_id" \ + --link-venv \ + --options istep=0 num_epochs=32 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" + + echo "----------------------------------------" +done + +echo "All $((${#tuples[@]})) jobs submitted!" + diff --git a/pyproject.toml b/pyproject.toml index 0f0f7a296..7051bfa36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "weathergen-common", "weathergen-evaluate", "weathergen-readers-extra", + "pyyaml>=6.0.2", ] diff --git a/uv.lock b/uv.lock index 4cdcbdcc5..d10a5e2dd 100644 --- a/uv.lock +++ b/uv.lock @@ -2728,6 +2728,7 @@ dependencies = [ { name = "polars", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "psutil", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "pynvml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, + { name = "pyyaml", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "tqdm", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-common", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, { name = "weathergen-evaluate", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-10-weathergen-cpu' and extra == 'extra-10-weathergen-gpu')" }, @@ -2777,6 +2778,7 @@ requires-dist = [ { name = "polars", specifier = "~=1.25.2" }, { name = "psutil" }, { name = "pynvml" }, + { name = "pyyaml", specifier = ">=6.0.2" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-linux_aarch64.whl" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'gpu'", url = "https://download.pytorch.org/whl/cu126/torch-2.6.0%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl" }, { name = "torch", marker = "sys_platform == 'linux' and extra == 'cpu'", specifier = "==2.6.0", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "weathergen", extra = "cpu" } }, diff --git a/weather_slurm_inferece.sh b/weather_slurm_inferece.sh new file mode 100644 index 000000000..3106475c0 --- /dev/null +++ b/weather_slurm_inferece.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +#SBATCH --job-name=train +#SBATCH --output=./logs/output_%j.txt +#SBATCH --error=./logs/error_%j.txt +#SBATCH --exclusive --mem=450G +#SBATCH --partition=normal +#SBATCH --gres=gpu:1 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --time=01:00:00 +#SBATCH -A ch17 +#SBATCH --output=logs/weathergen-%x.%j.out +#SBATCH --error=logs/weathergen-%x.%j.err + + +UENV_IMAGE="prgenv-gnu/25.6:v2" + +RUN_ID="$1" +export RUN_ID + +echo "Top-level from_run_id: $RUN_ID" + +echo "=== Checking for uenv image: $UENV_IMAGE ===" + +IMAGE_EXISTS=false + +if [ "$IMAGE_EXISTS" = false ]; then + if uenv image inspect "$UENV_IMAGE" &>/dev/null; then + IMAGE_EXISTS=true + fi +fi + +if [ "$IMAGE_EXISTS" = false ]; then + echo "========================================" + echo "ERROR: uenv image '$UENV_IMAGE' not found!" + echo "========================================" + echo "" + echo "The image needs to be pulled before use." + echo "" + echo "Steps to fix:" + echo "" + echo " 1. On the santis login node, run:" + echo " uenv image pull $UENV_IMAGE" + echo "" + echo " 2. Wait for download to complete (this may take a few minutes)" + echo "" + echo " 3. Verify the image is available:" + echo " uenv image ls" + echo "" + echo " 4. Re-submit your SLURM job" + echo "" + echo "========================================" + exit 1 +fi + +echo "✓ Image '$UENV_IMAGE' found" +echo "" + +RUN_ID="$1" + +uenv run "$UENV_IMAGE" --view=modules -- bash << 'EOF' + +module load aws-ofi-nccl/1.16.0 + +export NCCL_NET="AWS Libfabric" +export MPICH_GPU_SUPPORT_ENABLED=0 +export NCCL_NET_GDR_LEVEL=PHB +export NCCL_CROSS_NIC=1 +export NCCL_PROTO=^LL128 + +export FI_CXI_DEFAULT_CQ_SIZE=131072 +export FI_CXI_DEFAULT_TX_SIZE=16384 +export FI_CXI_DISABLE_HOST_REGISTER=1 +export FI_CXI_RX_MATCH_MODE=software +export FI_MR_CACHE_MONITOR=userfaultfd + +export MASTER_ADDR="$(scontrol show hostnames "$SLURM_NODELIST" | head -n 1)" +export MASTER_PORT=29514 + +# disable core dumps +ulimit -c 0 +ulimit -t unlimited + +export CC=/usr/bin/gcc +export NCCL_DEBUG=INFO + +echo "Starting job." +echo "Number of Nodes: $SLURM_JOB_NUM_NODES" +echo "Number of Tasks: $SLURM_NTASKS" +echo "from_run_id: $RUN_ID" +echo "WEATHERGEN_HOME: $WEATHERGEN_HOME" +echo "WEATHERGEN_CONFIG_EXTRA: $WEATHERGEN_CONFIG_EXTRA" +echo "SLURM_JOB_ID: $SLURM_JOB_ID" +echo "SLURM_JOB_NAME: $SLURM_JOB_NAME" +echo "SLURM_SUBMIT_DIR: $SLURM_SUBMIT_DIR" +echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST" +date + + +#cd $WEATHERGEN_HOME +source .venv/bin/activate + +srun uv run --offline inference --from_run_id "$RUN_ID" --samples=16 --options forecast_steps=80 + + +echo "Finished job." +sstat -j $SLURM_JOB_ID.batch --format=JobID,MaxVMSize +date +EOF + From 42b4318f2795e36c3184ef860669c1159cfa5e5e Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Tue, 27 Jan 2026 09:26:21 +0100 Subject: [PATCH 38/40] yml and multi scripts --- config/eval_config.yml | 186 ++++++++++++++++++++++--------------- config/runs_plot_train.yml | 114 ++++++++++++++--------- launch_multi.sh | 29 +++--- launch_multi_infer.sh | 33 ++++++- weather_slurm_inferece.sh | 20 ++-- 5 files changed, 241 insertions(+), 141 deletions(-) diff --git a/config/eval_config.yml b/config/eval_config.yml index 0403f148b..7b02522c6 100644 --- a/config/eval_config.yml +++ b/config/eval_config.yml @@ -16,8 +16,42 @@ evaluation: run_ids : # lr=5e-4 - xs5l8zmj: - label: "cosine scheduler lr_max=5e-4 v1" + #xs5l8zmj: + # label: "cosine scheduler lr_max=5e-4 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #x9zvml1k: + # label: "cosine scheduler lr_max=5e-4 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + #an8rap5h: + # label: "cosine scheduler lr_max=5e-4 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" + + # lr=1e-4 + u2qk39pi: + label: "cosine scheduler lr_max=1e-4 v1 epoch=32" epoch: 0 rank: 0 streams: @@ -27,8 +61,8 @@ run_ids : sample: "all" forecast_step: "all" - x9zvml1k: - label: "cosine scheduler lr_max=5e-4 v2" + zswipf53: + label: "cosine scheduler lr_max=1e-4 v2 epoch=32" epoch: 0 rank: 0 streams: @@ -38,8 +72,8 @@ run_ids : sample: "all" forecast_step: "all" - an8rap5h: - label: "cosine scheduler lr_max=5e-4 v3" + dsdvzg59: + label: "cosine scheduler lr_max=1e-4 v3 epoch=32" epoch: 0 rank: 0 streams: @@ -49,9 +83,8 @@ run_ids : sample: "all" forecast_step: "all" - # lr=1e-4 - u2qk39pi: - label: "cosine scheduler lr_max=1e-4 v1" + ikvap8bm: + label: "cosine scheduler lr_max=1e-4 v1 epoch=48" epoch: 0 rank: 0 streams: @@ -61,8 +94,8 @@ run_ids : sample: "all" forecast_step: "all" - zswipf53: - label: "cosine scheduler lr_max=1e-4 v2" + o5b9wnu3: + label: "cosine scheduler lr_max=1e-4 v2 epoch=48" epoch: 0 rank: 0 streams: @@ -72,8 +105,8 @@ run_ids : sample: "all" forecast_step: "all" - dsdvzg59: - label: "cosine scheduler lr_max=1e-4 v3" + y94xmnhj: + label: "cosine scheduler lr_max=1e-4 v3 epoch=48" epoch: 0 rank: 0 streams: @@ -82,8 +115,7 @@ run_ids : evaluation: sample: "all" forecast_step: "all" - - # lr=5e-5 + ## lr=5e-5 #r812ji96: # label: "cosine scheduler lr_max=5e-5 v1" # epoch: 0 @@ -117,71 +149,71 @@ run_ids : # sample: "all" # forecast_step: "all" - ## lr=1e-5 - v0yha29i: - label: "cosine scheduler lr_max=1e-5 v1" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + # ## lr=1e-5 + #v0yha29i: + # label: "cosine scheduler lr_max=1e-5 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - cbmk73y0: - label: "cosine scheduler lr_max=1e-5 v2" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #cbmk73y0: + # label: "cosine scheduler lr_max=1e-5 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - ngdrjcbt: - label: "cosine scheduler lr_max=1e-5 v3" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #ngdrjcbt: + # label: "cosine scheduler lr_max=1e-5 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - # lr=5e-6 - voulcvsi: - label: "cosine scheduler lr_max=5e-6 v1" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + ## lr=5e-6 + #voulcvsi: + # label: "cosine scheduler lr_max=5e-6 v1" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - urlp39xq: - label: "cosine scheduler lr_max=5e-6 v2" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #urlp39xq: + # label: "cosine scheduler lr_max=5e-6 v2" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - ch1n05gd: - label: "cosine scheduler lr_max=5e-6 v3" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #ch1n05gd: + # label: "cosine scheduler lr_max=5e-6 v3" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" diff --git a/config/runs_plot_train.yml b/config/runs_plot_train.yml index cb3d33d36..2d18fa589 100644 --- a/config/runs_plot_train.yml +++ b/config/runs_plot_train.yml @@ -52,31 +52,31 @@ train : # description: "lr 5e-4, w-dec 0.4, v3" # eval: vgbndhco - h8l1yem5 : - slurm_id: 0 - description: "lr 5e-4, v2" - eval: vgbndhco - jeubm9ld : - slurm_id: 0 - description: "lr 5e-4, v3" - eval: vgbndhco - bezt6v8g : - slurm_id: 0 - description: "lr 5e-4, v1" - eval: vgbndhco - - ya0gty48 : - slurm_id: 0 - description: "lr 1e-4, v1" - eval: vgbndhco - djiy1v3e : - slurm_id: 0 - description: "lr 1e-4, v2" - eval: vgbndhco - g96y1dq5 : - slurm_id: 0 - description: "lr 1e-4, v3" - eval: vgbndhco + #h8l1yem5 : + # slurm_id: 0 + # description: "lr 5e-4, v2" + # eval: vgbndhco + #jeubm9ld : + # slurm_id: 0 + # description: "lr 5e-4, v3" + # eval: vgbndhco + #bezt6v8g : + # slurm_id: 0 + # description: "lr 5e-4, v1" + # eval: vgbndhco + # + #ya0gty48 : + # slurm_id: 0 + # description: "lr 1e-4, v1" + # eval: vgbndhco + #djiy1v3e : + # slurm_id: 0 + # description: "lr 1e-4, v2" + # eval: vgbndhco + #g96y1dq5 : + # slurm_id: 0 + # description: "lr 1e-4, v3" + # eval: vgbndhco # mia69x1h : @@ -131,18 +131,18 @@ train : # description: "lr 1e-4, w-dec 0.4, v3" # eval: vgbndhco - s9sldzyb : - slurm_id: 0 - description: "lr 1e-4, w-dec 0.6, v1" - eval: vgbndhco - e29izt1j : - slurm_id: 0 - description: "lr 1e-4, w-dec 0.6, v2" - eval: vgbndhco - bmoc645w : - slurm_id: 0 - description: "lr 1e-4, w-dec 0.6, v3" - eval: vgbndhco + #s9sldzyb : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v1" + # eval: vgbndhco + #e29izt1j : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v2" + # eval: vgbndhco + #bmoc645w : + # slurm_id: 0 + # description: "lr 1e-4, w-dec 0.6, v3" + # eval: vgbndhco @@ -198,15 +198,43 @@ train : # description: "lr 5e-5, w-dec 0.4, v3" # eval: vgbndhco - z8vx03bg : + #lvy8406i : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=512, v1" + # eval: vgbndhco + #ipn3jryk : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=512, v2" + # eval: vgbndhco + #wrucxsk6 : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=2048, v1" + # eval: vgbndhco + #dfvo0ir1 : + # slurm_id: 0 + # description: "lr 1e-4, cool_down_steps=2048, v2" + # eval: vgbndhco + czfrhdae : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + zha9i6x3 : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + ypr1b3a4 : + slurm_id: 0 + description: "lr 1e-4, cosine_lr=32 epochs, v1" + eval: vgbndhco + otn1u3oe : slurm_id: 0 - description: "lr 5e-5, w-dec 0.6, v1" + description: "lr 1e-4, cosine_lr=48 epochs, v1" eval: vgbndhco - e3k2v450 : + r2z01faj : slurm_id: 0 - description: "lr 5e-5, w-dec 0.6, v2" + description: "lr 1e-4, cosine_lr=48 epochs, v2" eval: vgbndhco - qmil5gwk : + xxjfcwq1 : slurm_id: 0 - description: "lr 5e-5, w-dec 0.6, v3" + description: "lr 1e-4, cosine_lr=48 epochs,v2" eval: vgbndhco diff --git a/launch_multi.sh b/launch_multi.sh index fbd571b06..092e1d4f6 100644 --- a/launch_multi.sh +++ b/launch_multi.sh @@ -75,22 +75,27 @@ # done -# for run_id in fmpesclt vvwizau9 vl5n39cj ; do -# echo "$run_id" -# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv -# done -# - -for run_id in fl9xrpao ; do - echo "$run_id" - ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --run-id $run_id --link-venv -done + #for run_id in fmpesclt vvwizau9 vl5n39cj ; do + for run_id in r2z01faj xxjfcwq1 ; do + echo "$run_id" + ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $run_id --run-id $run_id --link-venv + done + + +#for run_id in bwtkj0he pji5hbze pu5ct7ox; do +#for run_id in lvy8406i ipn3jryk it0uzsl3 qm45twzj d3gc8fdn ; do +#for run_id in it0uzsl3 qm45twzj d3gc8fdn ; do +#for run_id in ipn3jryk d3gc8fdn ; do +# echo "$run_id" +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --from-run-id $run_id --time 24:00:00 --link-venv --options num_epochs=36 lr_steps_cooldown=2048 +#done # Cosine learning_rate test #for lr_max in "5e-4" "1e-4" "5e-5" "1e-5" "5e-6" ; do +#for lr_max in "1e-4" ; do # echo "$lr_max" -# for from_run_id in dnl5r61x ; do -# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $from_run_id --link-venv --options istep=0 num_epochs=32 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" +# for from_run_id in dnl5r61x vvwizau9 wyhcr51m; do +# ../WeatherGenerator-private/hpc/launch-slurm.py --nodes 2 --time 24:00:00 --from-run-id $from_run_id --link-venv --options istep=0 num_epochs=48 lr_max=$lr_max lr_policy_decay="cosine" forecast_steps=8 freeze_modules=".*global.*|.*local.*|.*adapter.*|.*ERA5.*" #done #done diff --git a/launch_multi_infer.sh b/launch_multi_infer.sh index a7574748a..f3ff9bc85 100644 --- a/launch_multi_infer.sh +++ b/launch_multi_infer.sh @@ -1,5 +1,34 @@ +#!/bin/bash + +# (from_run_id, to_run_id) tuples from mapping +run_pairs=( + #"czfrhdae u2qk39pi" + #"zha9i6x3 zswipf53" + #"ypr1b3a4 dsdvzg59" + #"vctfgruv r812ji96" + #"htdwqjpx gj6eq2dx" + #"gakr74pw ff80snum" + #"qpogewjf v0yha29i" + #"y3trwpx7 cbmk73y0" + #"asnz2gyl ngdrjcbt" + #"dd1cq6nv voulcvsi" + #"dn15vfks urlp39xq" + #"fl9xrpao ch1n05gd" +) + +#for tuple in "${run_pairs[@]}"; do +# read from_run_id run_id <<< "$tuple" +# echo "From: $from_run_id → Run_id: $run_id" +# sbatch weather_slurm_inferece.sh "$from_run_id" "$run_id" +#done + + + + + #for run_id in unov2gdz pv5hu3mc exsm2wty czfrhdae zha9i6x3 ypr1b3a4 vctfgruv htdwjqpx gakr74pw qpogewjf y3trwpx7 asnz2gyl dd1cq6nv dn15vfks fl9xrpao ; do -for run_id in lvlfd8er hr1l2whz a68hqu13 ; do +#for run_id in xqbky3ht whsolnr7 e0yzx968 ; do +for run_id in otn1u3oe r2z01faj xxjfcwq1 ; do echo $ "$run_id" sbatch weather_slurm_inferece.sh "$run_id" - done +done diff --git a/weather_slurm_inferece.sh b/weather_slurm_inferece.sh index 3106475c0..28b1fc2ab 100644 --- a/weather_slurm_inferece.sh +++ b/weather_slurm_inferece.sh @@ -16,10 +16,14 @@ UENV_IMAGE="prgenv-gnu/25.6:v2" -RUN_ID="$1" -export RUN_ID +FROM_RUN_ID="$1" +#RUN_ID="$2" -echo "Top-level from_run_id: $RUN_ID" +export FROM_RUN_ID +#export RUN_ID + +echo "Top-level from_run_id: $FROM_RUN_ID" +echo "Top-level run_id: $RUN_ID" echo "=== Checking for uenv image: $UENV_IMAGE ===" @@ -57,7 +61,8 @@ fi echo "✓ Image '$UENV_IMAGE' found" echo "" -RUN_ID="$1" +FROM_RUN_ID="$1" +#RUN_ID="$2" uenv run "$UENV_IMAGE" --view=modules -- bash << 'EOF' @@ -88,7 +93,8 @@ export NCCL_DEBUG=INFO echo "Starting job." echo "Number of Nodes: $SLURM_JOB_NUM_NODES" echo "Number of Tasks: $SLURM_NTASKS" -echo "from_run_id: $RUN_ID" +echo "from_run_id: $FROM_RUN_ID" +#echo "run_id: $RUN_ID" echo "WEATHERGEN_HOME: $WEATHERGEN_HOME" echo "WEATHERGEN_CONFIG_EXTRA: $WEATHERGEN_CONFIG_EXTRA" echo "SLURM_JOB_ID: $SLURM_JOB_ID" @@ -101,8 +107,8 @@ date #cd $WEATHERGEN_HOME source .venv/bin/activate -srun uv run --offline inference --from_run_id "$RUN_ID" --samples=16 --options forecast_steps=80 - +srun uv run --offline inference --from_run_id "$FROM_RUN_ID" --samples=16 --options forecast_steps=80 +#srun uv run inference --from_run_id "$FROM_RUN_ID" --run_id "$RUN_ID" --samples 16 --start_date=2023-10-01 --end_date=2023-12-01 --options forecast_steps=80 echo "Finished job." sstat -j $SLURM_JOB_ID.batch --format=JobID,MaxVMSize From 892f5096112477f783817c94950f62dd286149f9 Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Tue, 27 Jan 2026 18:53:04 +0100 Subject: [PATCH 39/40] changed eval_config and inference script --- config/eval_config.yml | 66 +++++++++++++++++++-------------------- weather_slurm_inferece.sh | 2 +- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/config/eval_config.yml b/config/eval_config.yml index 7b02522c6..a8135c324 100644 --- a/config/eval_config.yml +++ b/config/eval_config.yml @@ -50,40 +50,40 @@ run_ids : # forecast_step: "all" # lr=1e-4 - u2qk39pi: - label: "cosine scheduler lr_max=1e-4 v1 epoch=32" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #u2qk39pi: + # label: "cosine scheduler lr_max=1e-4 v1 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - zswipf53: - label: "cosine scheduler lr_max=1e-4 v2 epoch=32" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #zswipf53: + # label: "cosine scheduler lr_max=1e-4 v2 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - dsdvzg59: - label: "cosine scheduler lr_max=1e-4 v3 epoch=32" - epoch: 0 - rank: 0 - streams: - ERA5: - channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] - evaluation: - sample: "all" - forecast_step: "all" + #dsdvzg59: + # label: "cosine scheduler lr_max=1e-4 v3 epoch=32" + # epoch: 0 + # rank: 0 + # streams: + # ERA5: + # channels: ["2t", "10u", "10v", "q_850", "t_850", "u_850", "v_850", "z_500"] + # evaluation: + # sample: "all" + # forecast_step: "all" - ikvap8bm: + qc5dw7ki: label: "cosine scheduler lr_max=1e-4 v1 epoch=48" epoch: 0 rank: 0 @@ -94,7 +94,7 @@ run_ids : sample: "all" forecast_step: "all" - o5b9wnu3: + oqe79vpk: label: "cosine scheduler lr_max=1e-4 v2 epoch=48" epoch: 0 rank: 0 @@ -105,7 +105,7 @@ run_ids : sample: "all" forecast_step: "all" - y94xmnhj: + hhblaokc: label: "cosine scheduler lr_max=1e-4 v3 epoch=48" epoch: 0 rank: 0 diff --git a/weather_slurm_inferece.sh b/weather_slurm_inferece.sh index 28b1fc2ab..7818b2eb6 100644 --- a/weather_slurm_inferece.sh +++ b/weather_slurm_inferece.sh @@ -107,7 +107,7 @@ date #cd $WEATHERGEN_HOME source .venv/bin/activate -srun uv run --offline inference --from_run_id "$FROM_RUN_ID" --samples=16 --options forecast_steps=80 +srun uv run --offline inference --from_run_id "$FROM_RUN_ID" --samples=16 --start_date=2023-10-01 --end_date=2023-12-01 --options forecast_steps=80 #srun uv run inference --from_run_id "$FROM_RUN_ID" --run_id "$RUN_ID" --samples 16 --start_date=2023-10-01 --end_date=2023-12-01 --options forecast_steps=80 echo "Finished job." From a44d5a4f86025437bb27066ba9f1e3aa2805e67a Mon Sep 17 00:00:00 2001 From: ankitpatnala Date: Thu, 29 Jan 2026 14:14:26 +0100 Subject: [PATCH 40/40] added spike function to the fstep weighting --- config/default_config.yml | 7 ++++++- src/weathergen/train/loss.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index a88d359ef..f86c1a591 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -80,7 +80,12 @@ loss_fcts_val: - - "mse" - 1.0 - +timestep_weight: [spike_function, + {"type":"probability", + "values":{ 4 : 0.6, + 6 : 0.2, + 8 : 0.1, + 10 : 0.1} ] batch_size_per_gpu: 1 batch_size_validation_per_gpu: 1 diff --git a/src/weathergen/train/loss.py b/src/weathergen/train/loss.py index 406cd051c..5f620a810 100644 --- a/src/weathergen/train/loss.py +++ b/src/weathergen/train/loss.py @@ -195,3 +195,24 @@ def gamma_decay(forecast_steps, gamma): fsteps = np.arange(forecast_steps) weights = gamma**fsteps return weights * (len(fsteps) / np.sum(weights)) + + +def spike_function(forecast_steps, spike_type): + fstep = np.arange(forecast_steps) + weights = np.zeros_like(fstep, dtype=float) + if spike_type["type"] == "last": + weights[-1] = 1.0 + elif spike_type["type"] == "probability": + steps_probs = spike_type["values"] + fs_steps = list(steps_probs.keys()) + fs_steps = [int(x) for x in fs_steps] + assert max(fs_steps) <= forecast_steps, ( + f"Max step {max(fs_steps)} > forecast_steps {forecast_steps}" + ) + fs_probs = list(steps_probs.values()) + assert np.isclose(np.array(fs_probs).sum(), 1.0) + fs_selected = np.random.choice(fs_steps, p=fs_probs) + weights[fs_selected] = 1.0 + else: + raise ValueError(f"Spike type {spike_type} is not defined") + return weights