diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..37768f9 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,42 @@ +name: Linting + +on: + push: + branches: [ main, dev, refactor/** ] + pull_request: + branches: [ main, dev ] + +jobs: + lint: + runs-on: ubuntu-latest + continue-on-error: true # Don't block on linting errors + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install linting tools + run: | + python -m pip install --upgrade pip + pip install ruff black isort + + - name: Run Ruff + continue-on-error: true + run: ruff check src/ --output-format=github || true + + - name: Check Black formatting + continue-on-error: true + run: black --check src/ || true + + - name: Check import sorting + continue-on-error: true + run: isort --check-only src/ || true + + - name: Lint summary + if: always() + run: echo "Linting completed (issues are reported but don't block the workflow)" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..2a86724 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,59 @@ +name: Tests + +on: + push: + branches: [ main, dev, refactor/** ] + pull_request: + branches: [ main, dev ] + +jobs: + test: + runs-on: ubuntu-latest + continue-on-error: true # Don't fail the workflow if tests fail + + strategy: + matrix: + python-version: ['3.10', '3.11'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: recursive # Include AION, AstroCLIP, astroPT submodules + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y libhdf5-dev + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + + # Install package in editable mode + if [ -f pyproject.toml ]; then + pip install -e . + fi + + # Install test requirements if they exist + if [ -f requirements-test.txt ]; then + pip install -r requirements-test.txt + elif [ -f requirements.txt ]; then + pip install -r requirements.txt + fi + + - name: Run tests + continue-on-error: true # Don't fail even if tests fail + run: | + pytest tests/ -v --tb=short --continue-on-collection-errors || true + + - name: Test summary + if: always() + run: echo "Tests completed (failures are reported but don't block the workflow)" diff --git a/LINTING.md b/LINTING.md new file mode 100644 index 0000000..7803d2d --- /dev/null +++ b/LINTING.md @@ -0,0 +1,37 @@ +# Linting and Formatting + +This project uses Black, isort, and Ruff for code quality. + +## Run Formatters Locally + +```bash +# Install tools +pip install black isort ruff + +# Format code with Black +black src/ + +# Sort imports with isort +isort src/ + +# Check linting with Ruff +ruff check src/ + +# Auto-fix Ruff issues +ruff check src/ --fix +``` + +## Pre-commit Hook + +Install pre-commit to automatically format before commits: + +```bash +pip install pre-commit +pre-commit install +``` + +## Configuration Files + +- **Black**: Uses default settings (88 chars, py310+) +- **isort**: Compatible with Black +- **Ruff**: Configured in `pyproject.toml` diff --git a/src/fmb/__init__.py b/src/fmb/__init__.py index a19f48f..494cb4c 100644 --- a/src/fmb/__init__.py +++ b/src/fmb/__init__.py @@ -4,4 +4,3 @@ Module: fmb.__init__ Description: FMB module: fmb.__init__ """ - diff --git a/src/fmb/analysis/displacement.py b/src/fmb/analysis/displacement.py index f7fc562..7d077dc 100644 --- a/src/fmb/analysis/displacement.py +++ b/src/fmb/analysis/displacement.py @@ -6,25 +6,15 @@ Description: Embedding space displacement analysis """ -""" -Unified Displacement Analysis Module. -Combines: -1. Multi-Model Analysis (Model A vs B) -2. Cross-Modality Analysis (Modality X vs Y) -3. Extensive Analysis (9x9 Pairwise) -""" - import argparse -import sys -import yaml from pathlib import Path -from typing import Sequence, Optional, Dict, List, Tuple +from typing import Dict, Optional import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -import matplotlib +import yaml from fmb.paths import load_paths from fmb.viz.style import set_style @@ -33,172 +23,277 @@ set_style() KEYS = { - "AION": {"Images": "embedding_hsc", "Spectra": "embedding_spectrum", "Joint": "embedding_hsc_desi"}, - "AstroPT": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"}, - "AstroCLIP": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"} + "AION": { + "Images": "embedding_hsc", + "Spectra": "embedding_spectrum", + "Joint": "embedding_hsc_desi", + }, + "AstroPT": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, + "AstroCLIP": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, } COLORS = {"AION": "#1f77b4", "AstroPT": "#ff7f0e", "AstroCLIP": "#2ca02c"} + def load_data(path: Path) -> pd.DataFrame: """Load anomaly scores from CSV.""" df = pd.read_csv(path) # Ensure object_id is string - df['object_id'] = df['object_id'].astype(str) + df["object_id"] = df["object_id"].astype(str) return df + def find_score_file(model_name: str, base_dir: Path) -> Optional[Path]: """Auto-detect score file for a model.""" # Pattern: anomaly_scores_{model}.csv or *{model}*.csv candidates = list(base_dir.glob(f"anomaly_scores_{model_name.lower()}.csv")) if not candidates: candidates = list(base_dir.glob(f"*{model_name.lower()}*.csv")) - + return candidates[0] if candidates else None + def plot_panel_simple(ax, ranks_src, ranks_tgt, n_total, color, label=None): """Plot single histogram panel (for Multi-Model).""" top_1_threshold = n_total * 0.01 subset_mask = ranks_src <= top_1_threshold - + if subset_mask.sum() == 0: - ax.text(0.5, 0.5, "No data", transform=ax.transAxes, ha='center') + ax.text(0.5, 0.5, "No data", transform=ax.transAxes, ha="center") return ranks_pct = (ranks_tgt[subset_mask] / n_total) * 100 - - ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, edgecolor='black', alpha=0.7, linewidth=0.5, zorder=3, label=label) - - ax.axvline(x=1, color='red', linestyle='--', linewidth=1, zorder=4) - ax.axvline(x=10, color='orange', linestyle='--', linewidth=1, zorder=4) - ax.axvspan(0, 1, color='red', alpha=0.1, zorder=1) - ax.axvspan(1, 10, color='orange', alpha=0.1, zorder=1) - - retained_1 = (ranks_tgt[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 - retained_10 = (ranks_tgt[subset_mask] <= (n_total * 0.1)).sum() / subset_mask.sum() * 100 - + + ax.hist( + ranks_pct, + bins=40, + range=(0, 100), + color=color, + edgecolor="black", + alpha=0.7, + linewidth=0.5, + zorder=3, + label=label, + ) + + ax.axvline(x=1, color="red", linestyle="--", linewidth=1, zorder=4) + ax.axvline(x=10, color="orange", linestyle="--", linewidth=1, zorder=4) + ax.axvspan(0, 1, color="red", alpha=0.1, zorder=1) + ax.axvspan(1, 10, color="orange", alpha=0.1, zorder=1) + + retained_1 = ( + (ranks_tgt[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 + ) + retained_10 = ( + (ranks_tgt[subset_mask] <= (n_total * 0.1)).sum() / subset_mask.sum() * 100 + ) + stats_text = ( f"\\textbf{{Retained}}:\n" f"Top 1\\%: {retained_1:.1f}\\%\n" f"Top 10\\%: {retained_10:.1f}\\%" ) - ax.text(0.95, 0.92, stats_text, transform=ax.transAxes, ha='right', va='top', - bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'), zorder=5, fontsize=7) - ax.grid(True, linestyle=':', alpha=0.4, zorder=0) + ax.text( + 0.95, + 0.92, + stats_text, + transform=ax.transAxes, + ha="right", + va="top", + bbox=dict( + facecolor="white", alpha=0.8, edgecolor="none", boxstyle="round,pad=0.2" + ), + zorder=5, + fontsize=7, + ) + ax.grid(True, linestyle=":", alpha=0.4, zorder=0) + def run_multi_model(dfs: Dict[str, pd.DataFrame], output_dir: Path): """3x3 Grid: Rows=Modality, Cols=Model Pair comparisons.""" print("Running Multi-Model Analysis...") embedding_types = ["Images", "Spectra", "Joint"] comparisons = [("AION", "AstroPT"), ("AstroPT", "AstroCLIP"), ("AstroCLIP", "AION")] - - fig, axes = plt.subplots(3, 3, figsize=(10, 9), sharex=True, sharey='row') - + + fig, axes = plt.subplots(3, 3, figsize=(10, 9), sharex=True, sharey="row") + for row_idx, emb_type in enumerate(embedding_types): # Extract subsets subsets = {} for m, df in dfs.items(): key = KEYS[m].get(emb_type) if key: - subsets[m] = df[df['embedding_key'] == key].copy() - + subsets[m] = df[df["embedding_key"] == key].copy() + # Merge if not all(m in subsets for m in ["AION", "AstroPT", "AstroCLIP"]): print(f"Skipping {emb_type} (missing data)") continue - - merged = subsets["AION"][['object_id', 'rank']].rename(columns={'rank': 'rank_aion'}) - merged = merged.merge(subsets["AstroPT"][['object_id', 'rank']].rename(columns={'rank': 'rank_astropt'}), on='object_id') - merged = merged.merge(subsets["AstroCLIP"][['object_id', 'rank']].rename(columns={'rank': 'rank_astroclip'}), on='object_id') - + + merged = subsets["AION"][["object_id", "rank"]].rename( + columns={"rank": "rank_aion"} + ) + merged = merged.merge( + subsets["AstroPT"][["object_id", "rank"]].rename( + columns={"rank": "rank_astropt"} + ), + on="object_id", + ) + merged = merged.merge( + subsets["AstroCLIP"][["object_id", "rank"]].rename( + columns={"rank": "rank_astroclip"} + ), + on="object_id", + ) + n_total = len(merged) - + for col_idx, (src, tgt) in enumerate(comparisons): ax = axes[row_idx, col_idx] plot_panel_simple( - ax, - merged[f"rank_{src.lower()}"].values, - merged[f"rank_{tgt.lower()}"].values, - n_total, - COLORS[src] + ax, + merged[f"rank_{src.lower()}"].values, + merged[f"rank_{tgt.lower()}"].values, + n_total, + COLORS[src], ) ax.set_title(rf"\textbf{{{src}}} $\rightarrow$ \textbf{{{tgt}}}", pad=5) - - if col_idx == 0: ax.set_ylabel(rf"\textbf{{{emb_type}}}" + "\nCount") - if row_idx == 2: ax.set_xlabel("Percentile Rank in Target") + + if col_idx == 0: + ax.set_ylabel(rf"\textbf{{{emb_type}}}" + "\nCount") + if row_idx == 2: + ax.set_xlabel("Percentile Rank in Target") plt.tight_layout() fig.savefig(output_dir / "displacement_multi_model.png", dpi=300) print(f"Saved {output_dir / 'displacement_multi_model.png'}") plt.close(fig) + def run_cross_modality(dfs: Dict[str, pd.DataFrame], output_dir: Path): """3x3 Grid: Rows=Source Mod, Cols=Target Mod.""" print("Running Cross-Modality Analysis...") modalities = ["Images", "Spectra", "Joint"] - + # Merge all into one big DF big_merged = None for m_name, df in dfs.items(): m_sub = None for mod in modalities: key = KEYS[m_name].get(mod) - if not key: continue - - d = df[df['embedding_key'] == key][['object_id', 'rank']].rename(columns={'rank': f'rank_{m_name}_{mod}'}) - m_sub = d if m_sub is None else m_sub.merge(d, on='object_id') - - big_merged = m_sub if big_merged is None else big_merged.merge(m_sub, on='object_id') - - if big_merged is None: return + if not key: + continue + + d = df[df["embedding_key"] == key][["object_id", "rank"]].rename( + columns={"rank": f"rank_{m_name}_{mod}"} + ) + m_sub = d if m_sub is None else m_sub.merge(d, on="object_id") + + big_merged = ( + m_sub if big_merged is None else big_merged.merge(m_sub, on="object_id") + ) + + if big_merged is None: + return n_total = len(big_merged) - - fig, axes = plt.subplots(3, 3, figsize=(11, 10), sharex=True, sharey='row') - + + fig, axes = plt.subplots(3, 3, figsize=(11, 10), sharex=True, sharey="row") + top_1_threshold = n_total * 0.01 - + for r, src_mod in enumerate(modalities): for c, tgt_mod in enumerate(modalities): ax = axes[r, c] stats_lines = [] - + for m_name in ["AION", "AstroPT", "AstroCLIP"]: - src_col = f'rank_{m_name}_{src_mod}' - tgt_col = f'rank_{m_name}_{tgt_mod}' - if src_col not in big_merged.columns or tgt_col not in big_merged.columns: continue - + src_col = f"rank_{m_name}_{src_mod}" + tgt_col = f"rank_{m_name}_{tgt_mod}" + if ( + src_col not in big_merged.columns + or tgt_col not in big_merged.columns + ): + continue + ranks_src = big_merged[src_col].values ranks_tgt = big_merged[tgt_col].values - + subset_mask = ranks_src <= top_1_threshold - if subset_mask.sum() == 0: continue - + if subset_mask.sum() == 0: + continue + ranks_pct = (ranks_tgt[subset_mask] / n_total) * 100 color = COLORS[m_name] - - ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, histtype='step', alpha=0.9, linewidth=1.2, zorder=3, label=m_name) - ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, alpha=0.1, zorder=2) - - retained = (ranks_tgt[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 + + ax.hist( + ranks_pct, + bins=40, + range=(0, 100), + color=color, + histtype="step", + alpha=0.9, + linewidth=1.2, + zorder=3, + label=m_name, + ) + ax.hist( + ranks_pct, bins=40, range=(0, 100), color=color, alpha=0.1, zorder=2 + ) + + retained = ( + (ranks_tgt[subset_mask] <= top_1_threshold).sum() + / subset_mask.sum() + * 100 + ) stats_lines.append(rf"{m_name}: {retained:.1f}\%") - - ax.axvline(x=1, color='red', linestyle='--', linewidth=0.8, alpha=0.5, zorder=4) - + + ax.axvline( + x=1, color="red", linestyle="--", linewidth=0.8, alpha=0.5, zorder=4 + ) + stats_txt = "\\textbf{Retained Top 1\\%}:\n" + "\n".join(stats_lines) - ax.text(0.95, 0.92, stats_txt, transform=ax.transAxes, ha='right', va='top', - bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'), zorder=5, fontsize=6) - - ax.set_title(rf"\textbf{{{src_mod}}} $\rightarrow$ \textbf{{{tgt_mod}}}", pad=5) - if c == 0: ax.set_ylabel(rf"\textbf{{{src_mod}}}" + "\nCount") - if r == 2: ax.set_xlabel(rf"Percentile Rank in \textbf{{{tgt_mod}}}") - if r == 0 and c == 0: ax.legend(loc='lower right', fontsize=6) - ax.grid(True, linestyle=':', alpha=0.4) + ax.text( + 0.95, + 0.92, + stats_txt, + transform=ax.transAxes, + ha="right", + va="top", + bbox=dict( + facecolor="white", + alpha=0.8, + edgecolor="none", + boxstyle="round,pad=0.2", + ), + zorder=5, + fontsize=6, + ) + + ax.set_title( + rf"\textbf{{{src_mod}}} $\rightarrow$ \textbf{{{tgt_mod}}}", pad=5 + ) + if c == 0: + ax.set_ylabel(rf"\textbf{{{src_mod}}}" + "\nCount") + if r == 2: + ax.set_xlabel(rf"Percentile Rank in \textbf{{{tgt_mod}}}") + if r == 0 and c == 0: + ax.legend(loc="lower right", fontsize=6) + ax.grid(True, linestyle=":", alpha=0.4) plt.tight_layout() fig.savefig(output_dir / "displacement_cross_modality.png", dpi=300) print(f"Saved {output_dir / 'displacement_cross_modality.png'}") plt.close(fig) + def run_extensive(dfs: Dict[str, pd.DataFrame], output_dir: Path): """9x9 Grid.""" print("Running Extensive Analysis...") @@ -207,78 +302,113 @@ def run_extensive(dfs: Dict[str, pd.DataFrame], output_dir: Path): for m in ["AION", "AstroPT", "AstroCLIP"]: for mod in modalities: systems.append((m, mod)) - + # Merge big_merged = None for m, mod in systems: key = KEYS[m].get(mod) - if not key: continue - d = dfs[m][dfs[m]['embedding_key'] == key][['object_id', 'rank']].rename(columns={'rank': f'rank_{m}_{mod}'}) - big_merged = d if big_merged is None else big_merged.merge(d, on='object_id') - - if big_merged is None: return + if not key: + continue + d = dfs[m][dfs[m]["embedding_key"] == key][["object_id", "rank"]].rename( + columns={"rank": f"rank_{m}_{mod}"} + ) + big_merged = d if big_merged is None else big_merged.merge(d, on="object_id") + + if big_merged is None: + return n_total = len(big_merged) top_1_thr = n_total * 0.01 - + # Heatmap ret_matrix = np.zeros((9, 9)) for i, (src_m, src_mod) in enumerate(systems): - src_col = f'rank_{src_m}_{src_mod}' + src_col = f"rank_{src_m}_{src_mod}" mask = big_merged[src_col] <= top_1_thr for j, (tgt_m, tgt_mod) in enumerate(systems): - tgt_col = f'rank_{tgt_m}_{tgt_mod}' - retained = (big_merged.loc[mask, tgt_col] <= top_1_thr).sum() / mask.sum() * 100 + tgt_col = f"rank_{tgt_m}_{tgt_mod}" + retained = ( + (big_merged.loc[mask, tgt_col] <= top_1_thr).sum() / mask.sum() * 100 + ) ret_matrix[i, j] = retained - + plt.figure(figsize=(10, 8)) labels = [f"{m}-{mo}" for m, mo in systems] - sns.heatmap(ret_matrix, annot=True, fmt=".1f", xticklabels=labels, yticklabels=labels, cmap="YlGnBu") + sns.heatmap( + ret_matrix, + annot=True, + fmt=".1f", + xticklabels=labels, + yticklabels=labels, + cmap="YlGnBu", + ) plt.title("Extensive Retention Matrix") - plt.xticks(rotation=45, ha='right') + plt.xticks(rotation=45, ha="right") plt.tight_layout() plt.savefig(output_dir / "displacement_extensive_heatmap.png", dpi=300) plt.close() - + # Grid # (Skipping full 9x9 grid code for brevity unless requested, focusing on heatmap which is dense info) # The user asked for "everything", so I should include it. - - fig, axes = plt.subplots(9, 9, figsize=(18, 18), sharex=True, sharey='row') + + fig, axes = plt.subplots(9, 9, figsize=(18, 18), sharex=True, sharey="row") for i, (src_m, src_mod) in enumerate(systems): - src_col = f'rank_{src_m}_{src_mod}' + src_col = f"rank_{src_m}_{src_mod}" mask = big_merged[src_col] <= top_1_thr color = COLORS[src_m] - + for j, (tgt_m, tgt_mod) in enumerate(systems): - tgt_col = f'rank_{tgt_m}_{tgt_mod}' + tgt_col = f"rank_{tgt_m}_{tgt_mod}" ranks_pct = (big_merged.loc[mask, tgt_col] / n_total) * 100 - + ax = axes[i, j] - ax.hist(ranks_pct, bins=30, range=(0, 100), color=color, alpha=0.7, edgecolor='none') - ax.axvline(x=1, color='red', linestyle='--', linewidth=0.5) - ax.text(0.9, 0.9, f"{ret_matrix[i, j]:.0f}%", transform=ax.transAxes, ha='right', fontsize=6) - - if i == 0: ax.set_title(f"{tgt_m}\n{tgt_mod}", fontsize=7) - if j == 0: ax.set_ylabel(f"{src_m}\n{src_mod}", fontsize=7) - if i == 8: ax.set_xlabel("%", fontsize=7) + ax.hist( + ranks_pct, + bins=30, + range=(0, 100), + color=color, + alpha=0.7, + edgecolor="none", + ) + ax.axvline(x=1, color="red", linestyle="--", linewidth=0.5) + ax.text( + 0.9, + 0.9, + f"{ret_matrix[i, j]:.0f}%", + transform=ax.transAxes, + ha="right", + fontsize=6, + ) + + if i == 0: + ax.set_title(f"{tgt_m}\n{tgt_mod}", fontsize=7) + if j == 0: + ax.set_ylabel(f"{src_m}\n{src_mod}", fontsize=7) + if i == 8: + ax.set_xlabel("%", fontsize=7) ax.grid(alpha=0.2) - + plt.tight_layout() fig.savefig(output_dir / "displacement_extensive_grid.png", dpi=200) plt.close(fig) print(f"Saved extensive plots to {output_dir}") + def run_analysis(config_path: Optional[Path] = None, output_dir: Optional[Path] = None): paths = load_paths() if not config_path: config_path = paths.repo_root / "src/fmb/configs/analysis/displacement.yaml" - + with open(config_path, "r") as f: cfg = yaml.safe_load(f) - - out_dir_path = Path(output_dir) if output_dir else paths.repo_root / cfg.get("output_dir", "runs/analysis/displacement") + + out_dir_path = ( + Path(output_dir) + if output_dir + else paths.repo_root / cfg.get("output_dir", "runs/analysis/displacement") + ) out_dir_path.mkdir(parents=True, exist_ok=True) - + # Load Scores dfs = {} for m in cfg.get("models", ["AION", "AstroPT", "AstroCLIP"]): @@ -289,31 +419,32 @@ def run_analysis(config_path: Optional[Path] = None, output_dir: Optional[Path] else: # Auto-detect in outliers path p = find_score_file(m, paths.outliers) - + if p and p.exists(): print(f"Loading {m} scores from {p}...") dfs[m] = load_data(p) else: print(f"Warning: Scores for {m} not found.") - + if len(dfs) < 2: print("Need at least 2 models for displacement analysis.") return ptype = cfg.get("plot_type", "all") - + if ptype in ["all", "multi_model"]: run_multi_model(dfs, out_dir_path) - + if ptype in ["all", "cross_modality"]: run_cross_modality(dfs, out_dir_path) - + if ptype in ["all", "extensive"]: run_extensive(dfs, out_dir_path) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", help="Config file") args = parser.parse_args() - + run_analysis(Path(args.config) if args.config else None) diff --git a/src/fmb/analysis/outliers.py b/src/fmb/analysis/outliers.py index 29b0560..180229b 100644 --- a/src/fmb/analysis/outliers.py +++ b/src/fmb/analysis/outliers.py @@ -7,10 +7,10 @@ import argparse from pathlib import Path -from typing import Optional, List, Dict -import pandas as pd -import numpy as np +from typing import List + import matplotlib.pyplot as plt +import pandas as pd import seaborn as sns from fmb.paths import load_paths @@ -18,24 +18,28 @@ # Try to import venn try: from matplotlib_venn import venn2, venn3 + HAS_VENN = True except ImportError: HAS_VENN = False + class AnomalyAnalyzer: def __init__(self, input_csv: Path, output_dir: Path): self.input_csv = input_csv self.output_dir = output_dir self.output_dir.mkdir(parents=True, exist_ok=True) self.df = self._load_data() - + def _load_data(self) -> pd.DataFrame: if not self.input_csv.exists(): raise FileNotFoundError(f"Input file not found: {self.input_csv}") print(f"Loading analysis data from {self.input_csv}...") return pd.read_csv(self.input_csv, low_memory=False) - def plot_correlations(self, metrics: List[str] = ["rank_cosine", "rank_mm_geo", "p_joint"]) -> None: + def plot_correlations( + self, metrics: List[str] = ["rank_cosine", "rank_mm_geo", "p_joint"] + ) -> None: """Plot Spearman correlation of ranks between models.""" if "model" not in self.df.columns: print("[warn] 'model' column missing, skipping correlation analysis.") @@ -45,11 +49,11 @@ def plot_correlations(self, metrics: List[str] = ["rank_cosine", "rank_mm_geo", for m in metrics: if m not in self.df.columns: continue - + pivot = self.df.pivot(index="object_id", columns="model", values=m) if pivot.shape[1] < 2: continue - + plt.figure(figsize=(8, 6)) corr = pivot.corr(method="spearman") sns.heatmap(corr, annot=True, cmap="coolwarm", vmin=0, vmax=1) @@ -64,26 +68,26 @@ def plot_scatter_ranks(self) -> None: """Scatter plot of Image vs Spectrum percentiles.""" print("Generating rank scatter plots...") plt.figure(figsize=(9, 8)) - + hue = "model" if "model" in self.df.columns else None - + sns.scatterplot( - data=self.df, - x="p_img", - y="p_spec", - hue=hue, - alpha=0.3, + data=self.df, + x="p_img", + y="p_spec", + hue=hue, + alpha=0.3, s=15, - palette="viridis" if hue else None + palette="viridis" if hue else None, ) - + plt.xlabel("Image Anomaly Percentile (1.0 = Most Anomalous)") plt.ylabel("Spectrum Anomaly Percentile (1.0 = Most Anomalous)") plt.title("Image vs Spectrum Anomaly Percentiles") plt.plot([0, 1], [0, 1], "k--", alpha=0.5, label="y=x") - plt.legend(loc='upper right') + plt.legend(loc="upper right") plt.tight_layout() - + out_file = self.output_dir / "scatter_p_img_p_spec.png" plt.savefig(out_file) plt.close() @@ -97,40 +101,52 @@ def plot_uplift_distribution(self) -> None: print("Generating uplift distribution plot...") plt.figure(figsize=(8, 6)) hue = "model" if "model" in self.df.columns else None - + sns.histplot( - data=self.df, - x="uplift_mm", - hue=hue, - element="step", - kde=True, + data=self.df, + x="uplift_mm", + hue=hue, + element="step", + kde=True, bins=50, - common_norm=False + common_norm=False, ) - + plt.axvline(0, color="k", linestyle="--", alpha=0.5) - + # Add explanatory annotations # 1. Near Zero: Robust - plt.text(0.02, plt.gca().get_ylim()[1]*0.9, - "Robust Multimodal\nAnomalies", - fontsize=9, color="green", ha="left") - + plt.text( + 0.02, + plt.gca().get_ylim()[1] * 0.9, + "Robust Multimodal\nAnomalies", + fontsize=9, + color="green", + ha="left", + ) + # 2. Negative: Suppressed - plt.text(-0.2, plt.gca().get_ylim()[1]*0.9, - "Suppressed\nUnimodal Artifacts", - fontsize=9, color="red", ha="right") + plt.text( + -0.2, + plt.gca().get_ylim()[1] * 0.9, + "Suppressed\nUnimodal Artifacts", + fontsize=9, + color="red", + ha="right", + ) plt.title("Impact of Multimodal Fusion (Uplift Distribution)") - plt.xlabel("Uplift Score\n(Negative values indicate objects filtered out because modalities disagree)") + plt.xlabel( + "Uplift Score\n(Negative values indicate objects filtered out because modalities disagree)" + ) plt.ylabel("Count") - + # Move legend if needed, but default is usually fine. # Ensure the legend title is clear. # We don't call plt.legend() because it overwrites the hue legend. - + plt.tight_layout() - + out_file = self.output_dir / "uplift_hist.png" plt.savefig(out_file) plt.close() @@ -140,10 +156,12 @@ def plot_overlaps(self, top_k: int = 200) -> None: """Venn diagram of Top-K candidates across models.""" if "model" not in self.df.columns: return - + models = sorted(self.df["model"].unique()) if len(models) not in [2, 3]: - print(f"[info] Skipping Venn diagram (suitable for 2 or 3 models, found {len(models)})") + print( + f"[info] Skipping Venn diagram (suitable for 2 or 3 models, found {len(models)})" + ) return if not HAS_VENN: @@ -151,31 +169,41 @@ def plot_overlaps(self, top_k: int = 200) -> None: return print(f"Generating Venn diagram for Top-{top_k} anomalies...") - + # Identify top-k objects per model based on rank_mm_geo (or score_mm_geo) # Using sorted head if ranks are not pre-computed 1..N # Assume score_mm_geo descending is best. - + sets = {} for m in models: sub = self.df[self.df["model"] == m] if "score_mm_geo" in sub.columns: - top_objs = set(sub.sort_values("score_mm_geo", ascending=False).head(top_k)["object_id"]) + top_objs = set( + sub.sort_values("score_mm_geo", ascending=False).head(top_k)[ + "object_id" + ] + ) elif "rank_mm_geo" in sub.columns: - top_objs = set(sub.sort_values("rank_mm_geo", ascending=True).head(top_k)["object_id"]) + top_objs = set( + sub.sort_values("rank_mm_geo", ascending=True).head(top_k)[ + "object_id" + ] + ) else: print(f"[warn] No suitable ranking col for model {m}") continue sets[m] = top_objs - + plt.figure(figsize=(8, 8)) if len(models) == 2: venn2([sets[models[0]], sets[models[1]]], set_labels=models) elif len(models) == 3: - venn3([sets[models[0]], sets[models[1]], sets[models[2]]], set_labels=models) - + venn3( + [sets[models[0]], sets[models[1]], sets[models[2]]], set_labels=models + ) + plt.title(f"Overlap of Top-{top_k} Multimodal Anomalies") - + out_file = self.output_dir / f"venn_top{top_k}_models.png" plt.savefig(out_file) plt.close() @@ -190,19 +218,25 @@ def run_all(self, top_k: int = 200) -> None: def main(argv: List[str] = None): parser = argparse.ArgumentParser(description="Analyze Anomaly Detection Results") - parser.add_argument("--input_csv", type=str, help="Path to all_scores.csv (default: auto-detect from runs/outliers/multimodal/all_scores.csv)") - parser.add_argument("--top-k", type=int, default=200, help="Top-K threshold for overlaps") - + parser.add_argument( + "--input_csv", + type=str, + help="Path to all_scores.csv (default: auto-detect from runs/outliers/multimodal/all_scores.csv)", + ) + parser.add_argument( + "--top-k", type=int, default=200, help="Top-K threshold for overlaps" + ) + args = parser.parse_args(argv) - + paths = load_paths() - + # Auto-detect input input if not provided if args.input_csv: in_path = Path(args.input_csv) else: in_path = paths.outliers / "multimodal" / "all_scores.csv" - + if not in_path.exists(): print(f"[error] Input file not found: {in_path}") print("Please run 'python -m fmb.cli detect multimodal' first.") @@ -210,10 +244,11 @@ def main(argv: List[str] = None): # Output directory: runs/analysis/outliers out_dir = paths.runs_root / "analysis" / "outliers" - + analyzer = AnomalyAnalyzer(in_path, out_dir) analyzer.run_all(top_k=args.top_k) print(f"\n[success] Analysis plots saved to {out_dir}") + if __name__ == "__main__": main() diff --git a/src/fmb/analysis/regression/predict_physical_params.py b/src/fmb/analysis/regression/predict_physical_params.py index 0104fbe..a82e95b 100644 --- a/src/fmb/analysis/regression/predict_physical_params.py +++ b/src/fmb/analysis/regression/predict_physical_params.py @@ -11,217 +11,231 @@ Models: Ridge (Linear Baseline), LightGBM (Non-linear). Metrics: R², RMSE, PR, EPD, PES, CWP. """ -import sys -import yaml import argparse from pathlib import Path -from typing import Sequence, Tuple, Dict, List, Optional, Any +from typing import Dict, List, Optional, Tuple import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec import numpy as np import pandas as pd import seaborn as sns +import shap import torch +import yaml from astropy.io import fits +from lightgbm import LGBMRegressor +from sklearn.linear_model import Ridge from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from sklearn.model_selection import train_test_split -from sklearn.linear_model import Ridge -from lightgbm import LGBMRegressor -import shap -from tqdm import tqdm -from fmb.paths import load_paths, FMBPaths from fmb.data.utils import load_embeddings_file +from fmb.paths import load_paths from fmb.viz.style import set_style # Apply style set_style() + def load_config(config_path: Path) -> Dict: if not config_path.exists(): raise FileNotFoundError(f"Config not found: {config_path}") with open(config_path, "r") as f: return yaml.safe_load(f) -def load_catalog(path: Path, id_col_override: Optional[str] = None) -> Tuple[Dict, List[str], str]: + +def load_catalog( + path: Path, id_col_override: Optional[str] = None +) -> Tuple[Dict, List[str], str]: """Load FITS catalog.""" print(f"Loading catalog from {path}...") with fits.open(path) as hdul: data = hdul[1].data columns = hdul[1].columns.names - + id_column = id_col_override if not id_column: # Auto-detect ID - for cand in ['TARGETID', 'targetid', 'TargetID', 'object_id', 'objid']: + for cand in ["TARGETID", "targetid", "TargetID", "object_id", "objid"]: if cand in columns: id_column = cand break - + if not id_column: raise ValueError(f"Could not find ID column. Available: {columns}") - + print(f"Using '{id_column}' as ID column.") - + catalog_dict = {} for row in data: oid = str(row[id_column]).strip() # Store all columns provided catalog_dict[oid] = {col: row[col] for col in columns} - + # Detect numeric columns numeric_cols = [] for col in columns: - if col == id_column: continue + if col == id_column: + continue try: # Basic check on first element v = data[0][col] if isinstance(v, (int, float, np.number)): numeric_cols.append(col) - except: pass - + except: + pass + return catalog_dict, numeric_cols, id_column + def merge_data( records: List[dict], catalog: Dict, target_param: str, embedding_key: str, - subset_ids: Optional[set] = None + subset_ids: Optional[set] = None, ) -> Tuple[np.ndarray, np.ndarray, List[str]]: """Merge embeddings with catalog targets.""" # Build dictionary for fast lookup rec_dict = {str(r.get("object_id") or r.get("targetid", "")): r for r in records} - + if subset_ids is not None: target_ids = sorted(list(subset_ids)) else: target_ids = sorted(list(rec_dict.keys())) - + X_list = [] y_list = [] valid_ids = [] - + for obj_id in target_ids: if obj_id not in rec_dict or obj_id not in catalog: continue - + rec = rec_dict[obj_id] - + # Handle Embeddings - # Special logic for 'embedding_joint' if not present in file + # Special logic for 'embedding_joint' if not present in file # (Assuming extract_embedding_matrices logic or just load what's there) # Here we do simple lookup, computing joint on fly if needed emb_vec = None val = rec.get(embedding_key) - + if val is None and embedding_key == "embedding_joint": - # Try compute - img = rec.get("embedding_images") or rec.get("embedding_hsc") - spec = rec.get("embedding_spectra") or rec.get("embedding_spectrum") - if img is not None and spec is not None: - if not isinstance(img, np.ndarray): img = np.array(img).flatten() - if not isinstance(spec, np.ndarray): spec = np.array(spec).flatten() - emb_vec = np.concatenate([img, spec]) + # Try compute + img = rec.get("embedding_images") or rec.get("embedding_hsc") + spec = rec.get("embedding_spectra") or rec.get("embedding_spectrum") + if img is not None and spec is not None: + if not isinstance(img, np.ndarray): + img = np.array(img).flatten() + if not isinstance(spec, np.ndarray): + spec = np.array(spec).flatten() + emb_vec = np.concatenate([img, spec]) elif val is not None: - emb_vec = val - if not isinstance(emb_vec, np.ndarray): - if isinstance(emb_vec, torch.Tensor): - emb_vec = emb_vec.detach().cpu().numpy().flatten() - else: - emb_vec = np.array(emb_vec).flatten() - - if emb_vec is None: continue - + emb_vec = val + if not isinstance(emb_vec, np.ndarray): + if isinstance(emb_vec, torch.Tensor): + emb_vec = emb_vec.detach().cpu().numpy().flatten() + else: + emb_vec = np.array(emb_vec).flatten() + + if emb_vec is None: + continue + # Handle Target target_val = catalog[obj_id].get(target_param) try: target_val = float(target_val) - if np.isnan(target_val) or np.isinf(target_val): continue - except: continue - + if np.isnan(target_val) or np.isinf(target_val): + continue + except: + continue + X_list.append(emb_vec) y_list.append(target_val) valid_ids.append(obj_id) - + if not X_list: return np.array([]), np.array([]), [] - + return np.stack(X_list), np.array(y_list), valid_ids + def calculate_pr(shap_values: np.ndarray) -> Tuple[float, float, float]: """Calculate Participation Ratio and PR90.""" # Per-sample normalization (B2 fix) row_sums = np.abs(shap_values).sum(axis=1, keepdims=True) + 1e-12 shap_norm = shap_values / row_sums - + phi = np.abs(shap_norm).mean(axis=0) - + sum_phi = np.sum(phi) sum_phi_sq = np.sum(phi**2) - if sum_phi_sq == 0: return 0.0, 0.0, 0.0 - + if sum_phi_sq == 0: + return 0.0, 0.0, 0.0 + D = len(phi) pr = (sum_phi**2) / (D * sum_phi_sq) - + # PR90 sorted_phi = np.sort(phi)[::-1] cumsum = np.cumsum(sorted_phi) thresh = 0.90 * sum_phi n90 = np.searchsorted(cumsum, thresh) + 1 pr90 = n90 / D - + return pr, pr90, phi + def bootstrap_pr(shap_values: np.ndarray, n_boot: int = 50) -> Tuple[float, float]: """Bootstrap uncertainty for PR.""" prs = [] n_samples = shap_values.shape[0] - if n_samples < 50: return 0.0, 0.0 - + if n_samples < 50: + return 0.0, 0.0 + # Pre-normalize once to be safe or re-normalize per boot? # Logic: normalize SHAP per sample first, then bootstrap samples. row_sums = np.abs(shap_values).sum(axis=1, keepdims=True) + 1e-12 shap_norm = shap_values / row_sums - + for _ in range(n_boot): idx = np.random.choice(n_samples, n_samples, replace=True) sample_shap = shap_norm[idx] # Re-calc phi for this boot phi = np.abs(sample_shap).mean(axis=0) - + sum_phi = np.sum(phi) sum_phi_sq = np.sum(phi**2) if sum_phi_sq == 0: prs.append(0.0) else: - D = len(phi) - pr = (sum_phi**2) / (D * sum_phi_sq) - prs.append(pr) - + D = len(phi) + pr = (sum_phi**2) / (D * sum_phi_sq) + prs.append(pr) + return np.mean(prs), np.std(prs) + def train_and_evaluate(X_tr, y_tr, X_te, y_te, seed: int, run_shap: bool) -> Dict: results = {} - + # Ridge ridge = Ridge(random_state=seed) ridge.fit(X_tr, y_tr) pred_ridge = ridge.predict(X_te) results["r2_ridge"] = r2_score(y_te, pred_ridge) - + # LightGBM lgbm = LGBMRegressor(n_jobs=-1, random_state=seed, verbose=-1) lgbm.fit(X_tr, y_tr) pred_lgbm = lgbm.predict(X_te) - + results["r2"] = r2_score(y_te, pred_lgbm) results["rmse"] = np.sqrt(mean_squared_error(y_te, pred_lgbm)) results["mae"] = mean_absolute_error(y_te, pred_lgbm) results["y_test"] = y_te.tolist() results["y_pred"] = pred_lgbm.tolist() - + # SHAP if run_shap: try: @@ -229,14 +243,14 @@ def train_and_evaluate(X_tr, y_tr, X_te, y_te, seed: int, run_shap: bool) -> Dic # Subsample for SHAP X_shap = X_te[:2000] if len(X_te) > 2000 else X_te shap_values = explainer.shap_values(X_shap) - + pr, pr90, phi = calculate_pr(shap_values) results["pr"] = pr results["pr90"] = pr90 results["phi"] = phi - + # Bootstrap - pr_mean, pr_std = bootstrap_pr(shap_values, n_boot=50) + pr_mean, pr_std = bootstrap_pr(shap_values, n_boot=50) results["pr_std"] = pr_std # Metrics @@ -246,43 +260,47 @@ def train_and_evaluate(X_tr, y_tr, X_te, y_te, seed: int, run_shap: bool) -> Dic results["pes"] = results["r2"] / (pr + eps) results["cwp"] = results["r2"] * pr results["linear_gap"] = results["r2"] - results["r2_ridge"] - + except Exception as e: print(f"SHAP Error: {e}") - + return results + def get_random_embeddings(n, dim, seed): rng = np.random.default_rng(seed) return rng.standard_normal((n, dim)) + def plot_scatter(df: pd.DataFrame, out_dir: Path): """Generate scatter plots.""" params = df["target_param"].unique() for p in params: - sub = df[df["target_param"] == p] + df[df["target_param"] == p] # Grid plot loop similar to original... # Simplified for now to save space, assuming separate Analysis class or function - pass + pass # (Keeping original plotting logic is complex for in-place edit, skipping detailed re-implementation in this step for brevity locally testing first) # Actually, I should keep it runnable. I'll include a simple plotter. + # --- Optimized Plotting --- class ResultPlotter: def __init__(self, df: pd.DataFrame, out_dir: Path): self.df = df self.out_dir = out_dir - + def plot_all(self): self.plot_scatter() self.plot_pareto() - + def plot_scatter(self): params = self.df["target_param"].unique() for p in params: dd = self.df[self.df["target_param"] == p] - if dd.empty: continue - + if dd.empty: + continue + # Simple grid g = sns.FacetGrid(dd, col="model", row="modality", height=3, aspect=1) g.map_dataframe(self._scatter_on, "y_test", "y_pred") @@ -291,72 +309,79 @@ def plot_scatter(self): g.fig.suptitle(f"Prediction: {p}") g.savefig(self.out_dir / f"scatter_{p}.png") plt.close() - + def _scatter_on(self, y_test, y_pred, color=None, label=None, data=None): - # Need to unwrap lists if specific format - # But Seaborn handles dataframes. y_test is list in cell. Explode? - # Data structure is one row per experiment. - # For plotting we need points. - # Explode: - row = data.iloc[0] # Should be unique per facet - yt, yp = np.array(row["y_test"]), np.array(row["y_pred"]) - r2 = row["r2"] - plt.scatter(yt, yp, alpha=0.1, s=1, color='k') - - mn, mx = min(yt.min(), yp.min()), max(yt.max(), yp.max()) - plt.plot([mn, mx], [mn, mx], 'r--') - plt.text(0.05, 0.9, f"R2={r2:.2f}", transform=plt.gca().transAxes) + # Need to unwrap lists if specific format + # But Seaborn handles dataframes. y_test is list in cell. Explode? + # Data structure is one row per experiment. + # For plotting we need points. + # Explode: + row = data.iloc[0] # Should be unique per facet + yt, yp = np.array(row["y_test"]), np.array(row["y_pred"]) + r2 = row["r2"] + plt.scatter(yt, yp, alpha=0.1, s=1, color="k") + + mn, mx = min(yt.min(), yp.min()), max(yt.max(), yp.max()) + plt.plot([mn, mx], [mn, mx], "r--") + plt.text(0.05, 0.9, f"R2={r2:.2f}", transform=plt.gca().transAxes) def plot_pareto(self): - if "pr" not in self.df.columns: return - plt.figure(figsize=(8,6)) - sns.scatterplot(data=self.df, x="pr", y="r2", hue="model", style="modality", s=100) + if "pr" not in self.df.columns: + return + plt.figure(figsize=(8, 6)) + sns.scatterplot( + data=self.df, x="pr", y="r2", hue="model", style="modality", s=100 + ) plt.title("Performance (R2) vs Efficiency (PR)") plt.savefig(self.out_dir / "pareto.png") plt.close() + def run_analysis( config_path: Optional[Path] = None, output_dir: Optional[Path] = None, - slurm: bool = False + slurm: bool = False, ): """Main execution function.""" paths = load_paths() if not config_path: config_path = paths.repo_root / "src/fmb/configs/analysis/regression.yaml" - + cfg = load_config(config_path) - + if output_dir: out_path = Path(output_dir) else: out_path = paths.analysis / "regression" out_path.mkdir(parents=True, exist_ok=True) - + # 1. Load Catalog # Catalog path: either in config or default location in data # Assuming catalog is in data folder. # Config might specify 'catalog_path' relative to data or absolute. - cat_path = paths.dataset / cfg.get("catalog_filename", "euclid_desi_catalog.fits") # Default? + cat_path = paths.dataset / cfg.get( + "catalog_filename", "euclid_desi_catalog.fits" + ) # Default? # Or search for generic catalog if not cat_path.exists(): # Try finding *catalog*.fits in data candidates = list(paths.dataset.glob("*catalog*.fits")) - if candidates: cat_path = candidates[0] + if candidates: + cat_path = candidates[0] else: - print(f"Catalog not found in {paths.dataset}") - return - + print(f"Catalog not found in {paths.dataset}") + return + catalog, num_cols, id_col = load_catalog(cat_path, cfg.get("catalog_id_column")) - + # 2. Determine Targets targets = cfg.get("targets", []) col_map = cfg.get("column_mapping", {}) - + # 3. Load Models models_to_run = cfg.get("models", ["AION", "AstroPT", "AstroCLIP"]) emb_files = [] - + # Auto-detect embeddings for m in models_to_run: # Search paths.embeddings @@ -365,97 +390,111 @@ def run_analysis( # Try subfolder cand2 = list((paths.embeddings / m.lower()).glob("*.pt")) candidates.extend(cand2) - + if candidates: emb_files.append((m, candidates[0])) else: print(f"[warn] No embeddings found for {m}") - + if not emb_files: print("No embeddings found.") return - + # Load all embeddings into memory loaded_data = [] for m, p in emb_files: recs = load_embeddings_file(p) loaded_data.append((m, recs)) - + results = [] seed = cfg.get("seed", 42) - + for t_name in targets: - col = col_map.get(t_name, t_name) # Map to catalog column + col = col_map.get(t_name, t_name) # Map to catalog column if col not in num_cols: - # Try uppercase - if col.upper() in num_cols: col = col.upper() - else: - print(f"Skipping target {t_name} (Col {col} not found)") - continue - + # Try uppercase + if col.upper() in num_cols: + col = col.upper() + else: + print(f"Skipping target {t_name} (Col {col} not found)") + continue + print(f"\nAnalyzing Target: {t_name} (Col: {col})") - + # Split logic: find common valid IDs for FAIR comparison? # Or split per dataset? # Strategy: Use only IDs valid in catalog for this param. - valid_ids = [k for k,v in catalog.items() if isinstance(v.get(col), (int, float, np.number)) and not np.isnan(float(v.get(col)))] - - train_ids, test_ids = train_test_split(valid_ids, test_size=cfg.get("test_size", 0.2), random_state=seed) + valid_ids = [ + k + for k, v in catalog.items() + if isinstance(v.get(col), (int, float, np.number)) + and not np.isnan(float(v.get(col))) + ] + + train_ids, test_ids = train_test_split( + valid_ids, test_size=cfg.get("test_size", 0.2), random_state=seed + ) train_set = set(train_ids) test_set = set(test_ids) - + for m_name, recs in loaded_data: # Determine keys available - if not recs: continue + if not recs: + continue sample = recs[0] keys = [k for k in sample.keys() if k.startswith("embedding_")] - + for k in keys: print(f" Model: {m_name}, Key: {k}") X_tr, y_tr, _ = merge_data(recs, catalog, col, k, subset_ids=train_set) X_te, y_te, _ = merge_data(recs, catalog, col, k, subset_ids=test_set) - + if len(X_tr) < 50: print(f" Insufficient data (Tr={len(X_tr)}).") continue - - metrics = train_and_evaluate(X_tr, y_tr, X_te, y_te, seed, cfg.get("run_shap", True)) - metrics.update({ - "target_param": t_name, - "model": m_name, - "modality": k.replace("embedding_", "") - }) + + metrics = train_and_evaluate( + X_tr, y_tr, X_te, y_te, seed, cfg.get("run_shap", True) + ) + metrics.update( + { + "target_param": t_name, + "model": m_name, + "modality": k.replace("embedding_", ""), + } + ) results.append(metrics) print(f" R2={metrics['r2']:.3f}") - + # Random Baseline (Restored!) print(" Processing Random Baseline...") dim_random = 512 X_rand_tr = get_random_embeddings(len(train_ids), dim_random, seed) y_rand_tr = np.array([catalog[oid][col] for oid in train_ids]) - + X_rand_te = get_random_embeddings(len(test_ids), dim_random, seed + 1) y_rand_te = np.array([catalog[oid][col] for oid in test_ids]) - - met_rand = train_and_evaluate(X_rand_tr, y_rand_tr, X_rand_te, y_rand_te, seed, cfg.get("run_shap", True)) - met_rand.update({ - "target_param": t_name, - "model": "Random", - "modality": "embedding_random" - }) + + met_rand = train_and_evaluate( + X_rand_tr, y_rand_tr, X_rand_te, y_rand_te, seed, cfg.get("run_shap", True) + ) + met_rand.update( + {"target_param": t_name, "model": "Random", "modality": "embedding_random"} + ) results.append(met_rand) print(f" R2={met_rand.get('r2'):.3f} (Random)") - + # Save & Plot if results: df = pd.DataFrame(results) df_clean = df.drop(columns=["y_test", "y_pred"], errors="ignore") df_clean.to_csv(out_path / "results_summary.csv", index=False) print(f"Results saved to {out_path}") - + plotter = ResultPlotter(df, out_path) plotter.plot_all() + if __name__ == "__main__": # Barebones CLI for direct test parser = argparse.ArgumentParser() diff --git a/src/fmb/analysis/similarity.py b/src/fmb/analysis/similarity.py index 9c86d43..fb96a2c 100644 --- a/src/fmb/analysis/similarity.py +++ b/src/fmb/analysis/similarity.py @@ -5,32 +5,25 @@ Description: Visual similarity analysis in embedding space """ -import argparse from pathlib import Path -from typing import List, Dict, Tuple, Optional, Sequence -import torch -import torch.nn.functional as F -import pandas as pd -import numpy as np +from typing import List, Tuple + import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch -from fmb.paths import load_paths +from fmb.data.load_display_data import EuclidDESIDataset from fmb.data.utils import ( - read_object_ids, - collect_samples, - collect_samples_with_index, - load_index, + collect_samples, + extract_embedding_matrices, load_embeddings_file, - extract_embedding_matrices ) from fmb.viz.similarity import plot_vertical_panels -from fmb.data.load_display_data import EuclidDESIDataset + def find_nearest_neighbors( - query_ids: List[str], - all_ids: List[str], - emb_matrix: torch.Tensor, - k: int = 5 + query_ids: List[str], all_ids: List[str], emb_matrix: torch.Tensor, k: int = 5 ) -> List[str]: """ Returns flat list of [Q1, N1..Nk, Q2, N1..Nk]. @@ -38,7 +31,7 @@ def find_nearest_neighbors( id_map = {oid: i for i, oid in enumerate(all_ids)} valid_q_idxs = [] valid_q_ids = [] - + for q in query_ids: if q in id_map: valid_q_idxs.append(id_map[q]) @@ -49,20 +42,20 @@ def find_nearest_neighbors( if not valid_q_idxs: return [] - q_vecs = emb_matrix[valid_q_idxs] # (M, D) - + q_vecs = emb_matrix[valid_q_idxs] # (M, D) + # Sim (M, N) sim = torch.mm(q_vecs, emb_matrix.t()) - + # Top k+1+padding (to skip self) search_k = min(len(all_ids), k + 10) top_v, top_i = torch.topk(sim, k=search_k, dim=1) - + results = [] for i, qid in enumerate(valid_q_ids): # Always put Query first batch = [qid] - + candidates = top_i[i].tolist() found = 0 for idx in candidates: @@ -74,17 +67,18 @@ def find_nearest_neighbors( found += 1 if found == k: break - + results.extend(batch) - + return results + def visualize_similarity( query_ids: List[str], tasks: List[Tuple[str, Path]], n_similar: int, output_path: Path, - cache_dir: str + cache_dir: str, ): """Run similarity search and visualize results for multiple models.""" if not tasks: @@ -93,7 +87,7 @@ def visualize_similarity( all_annotated = [] row_lbls = [] - + # Pre-check all paths valid_tasks = [] for m, p in tasks: @@ -101,39 +95,42 @@ def visualize_similarity( print(f"[warn] Embedding file not found for {m}: {p}") else: valid_tasks.append((m, p)) - + if not valid_tasks: print("No valid embedding files found.") return - + # We load dataset once to save time if possible? # Actually collect_samples re-loads, but we can instantiation dataset once. ds = EuclidDESIDataset(split="all", cache_dir=cache_dir) - + for model_name, emb_path in valid_tasks: print(f"\nProcessing Model: {model_name}...") records = load_embeddings_file(emb_path) matrices, all_ids = extract_embedding_matrices(records) - + for mod_key, mat in matrices.items(): mod_pretty = mod_key.replace("embedding_", "").capitalize() # Special case renaming - if mod_key == "embedding_hsc": mod_pretty = "Image" - if mod_key == "embedding_spectra": mod_pretty = "Spectrum" - if mod_key == "embedding_joint": mod_pretty = "Joint" - + if mod_key == "embedding_hsc": + mod_pretty = "Image" + if mod_key == "embedding_spectra": + mod_pretty = "Spectrum" + if mod_key == "embedding_joint": + mod_pretty = "Joint" + print(f" Searching in modality: {mod_pretty} ({mod_key})") - + ordered_ids = find_nearest_neighbors(query_ids, all_ids, mat, k=n_similar) - if not ordered_ids: + if not ordered_ids: print(" No neighbors found.") continue - + samples = collect_samples(ds, ordered_ids, verbose=False) - + # Map back to ordered list (handle missing) s_map = {str(s.get("object_id") or s.get("targetid")): s for s in samples} - + for i, oid in enumerate(ordered_ids): # We need to construct rows. # ordered_ids contains [Q1, N1..Nk, Q2, N1..Nk...] @@ -141,7 +138,7 @@ def visualize_similarity( # Rename for grid cols = n_similar + 1 - + # Should have exactly len(ordered_ids) samples current_batch = [] for oid in ordered_ids: @@ -149,11 +146,11 @@ def visualize_similarity( current_batch.append(s_map[oid]) else: current_batch.append({"object_id": oid}) - + annotated_batch = [] for i, s in enumerate(current_batch): new_s = s.copy() - orig = str(new_s.get("object_id","")) + orig = str(new_s.get("object_id", "")) col_idx = i % cols if col_idx == 0: prefix = "[QUERY]" @@ -161,9 +158,9 @@ def visualize_similarity( prefix = f"[#{col_idx}]" new_s["object_id"] = f"{prefix} {orig}" annotated_batch.append(new_s) - + all_annotated.extend(annotated_batch) - + # Add label for each query row num_queries = len(ordered_ids) // cols lbl = f"{model_name}\n{mod_pretty}" @@ -173,27 +170,28 @@ def visualize_similarity( if not all_annotated: print("No results to visualize.") return - + print(f"\nGenerating combined visualization at {output_path}...") plot_vertical_panels( - all_annotated, - cols=n_similar+1, - save_path=output_path, + all_annotated, + cols=n_similar + 1, + save_path=output_path, show=False, - row_labels=row_lbls + row_labels=row_lbls, ) + def analyze_neighbor_ranks( query_ids: List[str], tasks: List[Tuple[str, Path, Path]], n_similar: int, - output_dir: Path + output_dir: Path, ): """Analyze rank distribution of neighbors for multiple models and combine into one plot.""" if not tasks: print("No tasks provided.") return - + # Pre-check valid_tasks = [] for m, ep, sp in tasks: @@ -204,7 +202,7 @@ def analyze_neighbor_ranks( print(f"[warn] Scores not found for {m}: {sp}") continue valid_tasks.append((m, ep, sp)) - + if not valid_tasks: print("No valid tasks found.") return @@ -215,50 +213,67 @@ def analyze_neighbor_ranks( print(f"\nProcessing Model: {model_name}...") records = load_embeddings_file(emb_path) matrices, all_ids = extract_embedding_matrices(records) - + print(f" Loading scores from {scores_path}...") scores_df = pd.read_csv(scores_path) scores_df["object_id"] = scores_df["object_id"].astype(str) - + for mod_key, mat in matrices.items(): mod_pretty = mod_key.replace("embedding_", "").capitalize() print(f" Analyzing neighbors for {mod_pretty} ({mod_key})...") - + # Match score key logic target_key = mod_key if target_key not in scores_df["embedding_key"].values: - if "hsc" in target_key and "embedding_hsc" in scores_df["embedding_key"].values: target_key = "embedding_hsc" - if "spectra" in target_key and "embedding_spectrum" in scores_df["embedding_key"].values: target_key = "embedding_spectrum" - if "joint" in target_key and "embedding_hsc_desi" in scores_df["embedding_key"].values: target_key = "embedding_hsc_desi" - - sub_scores = scores_df[scores_df["embedding_key"] == target_key].set_index("object_id") + if ( + "hsc" in target_key + and "embedding_hsc" in scores_df["embedding_key"].values + ): + target_key = "embedding_hsc" + if ( + "spectra" in target_key + and "embedding_spectrum" in scores_df["embedding_key"].values + ): + target_key = "embedding_spectrum" + if ( + "joint" in target_key + and "embedding_hsc_desi" in scores_df["embedding_key"].values + ): + target_key = "embedding_hsc_desi" + + sub_scores = scores_df[scores_df["embedding_key"] == target_key].set_index( + "object_id" + ) if sub_scores.empty: continue - + ordered_ids = find_nearest_neighbors(query_ids, all_ids, mat, k=n_similar) - if not ordered_ids: continue - + if not ordered_ids: + continue + neighbor_ranks = [] for i in range(0, len(ordered_ids), n_similar + 1): - block = ordered_ids[i:i+n_similar+1] - neighbors = block[1:] + block = ordered_ids[i : i + n_similar + 1] + neighbors = block[1:] for nid in neighbors: if nid in sub_scores.index: r = sub_scores.loc[nid, "rank"] neighbor_ranks.append(r) - + if not neighbor_ranks: continue - + ranks = np.array(neighbor_ranks) total_obj = len(sub_scores) percentiles = (ranks / total_obj) * 100.0 - - plot_data.append({ - "title": f"{model_name}\n{mod_pretty}", - "data": percentiles, - "count": len(ranks) - }) + + plot_data.append( + { + "title": f"{model_name}\n{mod_pretty}", + "data": percentiles, + "count": len(ranks), + } + ) if not plot_data: print("No data collected for plotting.") @@ -268,44 +283,59 @@ def analyze_neighbor_ranks( n_plots = len(plot_data) cols = 3 rows = (n_plots + cols - 1) // cols - - fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows + 1), constrained_layout=True) + + fig, axes = plt.subplots( + rows, cols, figsize=(4 * cols, 3 * rows + 1), constrained_layout=True + ) axes = axes.flatten() - + # Global title? fig.suptitle(f"Neighbor Anomaly Ranks (Query N={len(query_ids)})", fontsize=16) for i, item in enumerate(plot_data): ax = axes[i] data = item["data"] - - ax.hist(data, bins=50, range=(0, 100), color='#1f77b4', edgecolor='black', alpha=0.7) + + ax.hist( + data, bins=50, range=(0, 100), color="#1f77b4", edgecolor="black", alpha=0.7 + ) ax.set_title(item["title"], fontsize=10) ax.set_xlabel("Anomaly Percentile") if i % cols == 0: ax.set_ylabel("Count") - + # Stats top1 = np.mean(data <= 1.0) * 100 stats = f"N={item['count']}\nTop1%={top1:.1f}%" - ax.text(0.95, 0.95, stats, transform=ax.transAxes, ha='right', va='top', - bbox=dict(facecolor='white', alpha=0.9, pad=2), fontsize=8) + ax.text( + 0.95, + 0.95, + stats, + transform=ax.transAxes, + ha="right", + va="top", + bbox=dict(facecolor="white", alpha=0.9, pad=2), + fontsize=8, + ) # Hide empty subplots for j in range(i + 1, len(axes)): - axes[j].axis('off') - + axes[j].axis("off") + out_file = output_dir / "neighbor_ranks_combined.png" plt.savefig(out_file, dpi=150) plt.close() print(f"\n[success] Combined plot saved to: {out_file}") + # --- CLI Entry Points --- + def main_search(argv: List[str] = None): # Argparse wrapper needed for cleaner integration maybe? # Or just use args passed from Typer - pass # Implemented in CLI directly via calls to visualize_similarity + pass # Implemented in CLI directly via calls to visualize_similarity + def main(argv: List[str] = None): # This main handles both? Or we expose separate functions. diff --git a/src/fmb/cli.py b/src/fmb/cli.py index 1977f9d..e731926 100644 --- a/src/fmb/cli.py +++ b/src/fmb/cli.py @@ -6,12 +6,13 @@ """ from __future__ import annotations -import typer + import subprocess -import os import sys from pathlib import Path +import typer + # --- Path setup for external dependencies --- repo_root = Path(__file__).resolve().parents[2] external_paths = [ @@ -25,24 +26,25 @@ sys.path.insert(0, str(p)) # -------------------------------------------- -from typing import Optional, List -import yaml +from typing import List, Optional from fmb.paths import load_paths - app = typer.Typer( help="FMB: Foundation Models Benchmark CLI - Refactored Pipeline", no_args_is_help=True, - context_settings={"allow_extra_args": True, "ignore_unknown_options": True} + context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, ) + def run_slurm(sbatch_file: str, name: str, extra_args: List[str]): """Submit a SLURM job using sbatch.""" cmd = ["sbatch", f"slurm/{sbatch_file}"] if extra_args: - typer.echo(f" Note: Extra arguments {extra_args} are NOT forwarded to sbatch automatically.") - + typer.echo( + f" Note: Extra arguments {extra_args} are NOT forwarded to sbatch automatically." + ) + typer.echo(f" Submitting Slurm job for {name}...") try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) @@ -52,18 +54,28 @@ def run_slurm(sbatch_file: str, name: str, extra_args: List[str]): except FileNotFoundError: typer.echo("❌ 'sbatch' command not found. Are you on a Slurm cluster?") + def forward_args(ctx: typer.Context): """Clean up sys.argv to only contain the extra arguments for underlying argparse.""" # Typer consumes the command and arguments it knows. # We want to give the rest to the script's main(). sys.argv = [sys.argv[0]] + ctx.args -@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def retrain( ctx: typer.Context, - model: str = typer.Argument(..., help="Model to retrain (aion, astropt, astroclip)"), - config: Optional[str] = typer.Option(None, "--config", help="Path to YAML config file"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + model: str = typer.Argument( + ..., help="Model to retrain (aion, astropt, astroclip)" + ), + config: Optional[str] = typer.Option( + None, "--config", help="Path to YAML config file" + ), + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """Stage 01: Retrain foundation models or codecs.""" if slurm: @@ -71,14 +83,14 @@ def retrain( return typer.echo(f"Running retrain for {model} locally...") - + # Build sys.argv for the underlying script sys.argv = [sys.argv[0]] if config: sys.argv.extend(["--config", config]) # Add any extra args from ctx.args sys.argv.extend(ctx.args) - + # Use new simplified entry points if model == "aion" or model == "aion_codec": # AION now only has the Euclid<->HSC adapter U-Net training @@ -90,15 +102,24 @@ def retrain( else: typer.echo(f"❌ Unknown model: {model}") raise typer.Exit(1) - + run_task() -@app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def embed( ctx: typer.Context, - model: str = typer.Argument(..., help="Model to use for embeddings (aion, astropt, astroclip)"), - config: Optional[str] = typer.Option(None, "--config", help="Path to model-specific configuration YAML"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + model: str = typer.Argument( + ..., help="Model to use for embeddings (aion, astropt, astroclip)" + ), + config: Optional[str] = typer.Option( + None, "--config", help="Path to model-specific configuration YAML" + ), + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """Stage 02: Generate embeddings from foundation models.""" # Combine --config with extra args @@ -111,10 +132,11 @@ def embed( return typer.echo(f"Generating embeddings for {model} locally...") - + # Update sys.argv to pass extra args to the script # We keep the script name (argv[0]) and append our processed args import sys + sys.argv = [sys.argv[0]] + extra_args if model == "aion": @@ -127,32 +149,35 @@ def embed( else: typer.echo(f"❌ Unknown model: {model}") raise typer.Exit(1) - - run_task() - + run_task() # --- Data Commands --- data_app = typer.Typer(help="Stage 00: Data Setup & Indexing") app.add_typer(data_app, name="data") + @data_app.command() def index( ctx: typer.Context, - cache_dir: Optional[str] = typer.Option(None, "--cache-dir", help="Dataset cache directory"), + cache_dir: Optional[str] = typer.Option( + None, "--cache-dir", help="Dataset cache directory" + ), splits: str = typer.Option("all", "--splits", help="Comma-separated splits"), output: Optional[str] = typer.Option(None, "--output", help="Path to output CSV"), - overwrite: bool = typer.Option(False, "--overwrite", help="Overwrite existing file"), + overwrite: bool = typer.Option( + False, "--overwrite", help="Overwrite existing file" + ), ): """Create a CSV index of the dataset (object_id -> split/index).""" from fmb.data.index_dataset import run_indexing - + run_indexing( cache_dir=cache_dir, splits=[s.strip() for s in splits.split(",") if s.strip()], output=Path(output) if output else None, - overwrite=overwrite + overwrite=overwrite, ) @@ -160,12 +185,17 @@ def index( detect_app = typer.Typer(help="Stage 03: Detect anomalies using embeddings.") app.add_typer(detect_app, name="detect") -@detect_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@detect_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def outliers( ctx: typer.Context, - method: str = typer.Option("nfs", "--method", help="Method to use: 'nfs' (Normalizing Flows)"), + method: str = typer.Option( + "nfs", "--method", help="Method to use: 'nfs' (Normalizing Flows)" + ), config: Optional[str] = typer.Option(None, "--config", help="Path to config YAML"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job") + slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job"), ): """ Run Normalizing Flow-based outlier detection. @@ -174,31 +204,35 @@ def outliers( if slurm: # We might need a specific sbatch file for this new command # For now, let's assume we use the generic one or a new one - typer.echo("Slurm submission for 'detect outliers' not yet fully configured with new script. Running locally.") + typer.echo( + "Slurm submission for 'detect outliers' not yet fully configured with new script. Running locally." + ) # run_slurm(f"03_detection/nfs.sbatch", "detect outliers", ctx.args) # return typer.echo(f"🔍 Running outlier detection ({method}) locally...") - + # Forward args to the new run script # We construct the argv manually to map options from fmb.detection import run - + # Extract known options from context if passed, or just forward everything # But Typer consumes config/method. We need to pass them to argparse if they are needed. # The run.py uses argparse. - # Reconstruct argv for argparse run_args = [] if config: run_args.extend(["--config", config]) - + # Pass through other arguments (like --aion-embeddings) run_args.extend(ctx.args) - + run.main(run_args) -@detect_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@detect_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def cosine( ctx: typer.Context, aion_embeddings: Optional[str] = typer.Option(None, "--aion-embeddings"), @@ -210,86 +244,134 @@ def cosine( Output: runs/outliers/cosine_scores_{model}.csv """ from fmb.detection import cosine + # Construct args manually (Typer -> argparse) args = [] - if aion_embeddings: args.extend(["--aion-embeddings", aion_embeddings]) - if astropt_embeddings: args.extend(["--astropt-embeddings", astropt_embeddings]) - if astroclip_embeddings: args.extend(["--astroclip-embeddings", astroclip_embeddings]) - + if aion_embeddings: + args.extend(["--aion-embeddings", aion_embeddings]) + if astropt_embeddings: + args.extend(["--astropt-embeddings", astropt_embeddings]) + if astroclip_embeddings: + args.extend(["--astroclip-embeddings", astroclip_embeddings]) + cosine.main(args) -@detect_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@detect_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def multimodal( ctx: typer.Context, - top_k: int = typer.Option(200, "--top-k", help="Number of top anomalies to export per model (ranked by fusion score)."), - fusion: str = typer.Option("geo", "--fusion", help="Fusion method: 'geo' (Geometric Mean), 'min' (Minimum), 'avg' (Average)."), - t_img: float = typer.Option(0.99, "--t-img", help="Filter: Only keep objects in top P percentile of Image Density Anomaly."), - t_spec: float = typer.Option(0.99, "--t-spec", help="Filter: Only keep objects in top P percentile of Spectrum Density Anomaly."), - t_mis: float = typer.Option(0.99, "--t-mis", help="Filter: Only keep objects in top P percentile of Cosine Mismatch."), + top_k: int = typer.Option( + 200, + "--top-k", + help="Number of top anomalies to export per model (ranked by fusion score).", + ), + fusion: str = typer.Option( + "geo", + "--fusion", + help="Fusion method: 'geo' (Geometric Mean), 'min' (Minimum), 'avg' (Average).", + ), + t_img: float = typer.Option( + 0.99, + "--t-img", + help="Filter: Only keep objects in top P percentile of Image Density Anomaly.", + ), + t_spec: float = typer.Option( + 0.99, + "--t-spec", + help="Filter: Only keep objects in top P percentile of Spectrum Density Anomaly.", + ), + t_mis: float = typer.Option( + 0.99, + "--t-mis", + help="Filter: Only keep objects in top P percentile of Cosine Mismatch.", + ), ): """ Combine & Filter Anomalies (Multimodal Fusion). - + Generates a final list of anomalies by combining: 1. Cosine Mismatch (Image vs Spectrum) 2. Image Density Outliers (Normalizing Flows) 3. Spectrum Density Outliers (Normalizing Flows) - + Outputs are saved to: runs/outliers/multimodal/ """ from fmb.detection import multimodal - + args = [ - "--top-k", str(top_k), - "--fusion", fusion, - "--t-img", str(t_img), - "--t-spec", str(t_spec), - "--t-mis", str(t_mis), + "--top-k", + str(top_k), + "--fusion", + fusion, + "--t-img", + str(t_img), + "--t-spec", + str(t_spec), + "--t-mis", + str(t_mis), ] multimodal.main(args) - # --- Analyze Commands --- analyze_app = typer.Typer(help="Stage 04: Analyze embeddings and detection results.") app.add_typer(analyze_app, name="analyze") -@analyze_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@analyze_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def predict_params( ctx: typer.Context, - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """Predict physical parameters (redshift, etc.) from embeddings.""" if slurm: - run_slurm("04_analysis/predict_params.sbatch", "analysis predict_params", ctx.args) + run_slurm( + "04_analysis/predict_params.sbatch", "analysis predict_params", ctx.args + ) return - typer.echo(f"Running physical parameter prediction locally...") + typer.echo("Running physical parameter prediction locally...") forward_args(ctx) from fmb.analysis.predict_physical_params import main as run_task + run_task() -@analyze_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@analyze_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def tsne( ctx: typer.Context, - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """Generate t-SNE comparison plots.""" if slurm: run_slurm("04_analysis/tsne.sbatch", "analysis tsne", ctx.args) return - typer.echo(f"Running t-SNE analysis locally...") + typer.echo("Running t-SNE analysis locally...") forward_args(ctx) from fmb.viz.tsne_comparison import main as run_task + run_task() -@analyze_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@analyze_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def outliers( ctx: typer.Context, input_csv: str = typer.Option(None, "--input-csv", help="Path to all_scores.csv"), top_k: int = typer.Option(200, "--top-k", help="Top-K threshold"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job") + slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job"), ): """ Analyze Multimodal Anomaly Results (Correlations, Uplift, Overlap). @@ -302,22 +384,37 @@ def outliers( pass typer.echo("Analyzing anomaly results...") - + from fmb.analysis import outliers - + args = [] - if input_csv: args.extend(["--input_csv", input_csv]) - if top_k: args.extend(["--top-k", str(top_k)]) - + if input_csv: + args.extend(["--input_csv", input_csv]) + if top_k: + args.extend(["--top-k", str(top_k)]) + outliers.main(args) -@analyze_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@analyze_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def similarity( ctx: typer.Context, - emb_path: Optional[str] = typer.Option(None, "--embeddings", help="Path to embeddings .pt file (optional, auto-detected if omitted)"), - model: Optional[str] = typer.Option(None, "--model", help="Name of the model (or 'all'). Default: all"), - queries: Optional[List[str]] = typer.Option(None, "--query", help="Object ID(s) to query"), - query_csv: Optional[str] = typer.Option(None, "--query-csv", help="CSV with object_id column to query"), + emb_path: Optional[str] = typer.Option( + None, + "--embeddings", + help="Path to embeddings .pt file (optional, auto-detected if omitted)", + ), + model: Optional[str] = typer.Option( + None, "--model", help="Name of the model (or 'all'). Default: all" + ), + queries: Optional[List[str]] = typer.Option( + None, "--query", help="Object ID(s) to query" + ), + query_csv: Optional[str] = typer.Option( + None, "--query-csv", help="CSV with object_id column to query" + ), n_similar: int = typer.Option(5, "--n-similar", help="Number of neighbors"), save: Optional[str] = typer.Option(None, "--save", help="Output image path"), cache_dir: Optional[str] = typer.Option(None, "--cache-dir"), @@ -326,26 +423,28 @@ def similarity( Visual Similarity Search. Finds and displays nearest neighbors (Images + Spectra). """ + from pathlib import Path + from fmb.analysis import similarity from fmb.data.utils import read_object_ids from fmb.paths import load_paths - from pathlib import Path - + paths = load_paths() - + # 1. Resolve Queries q_ids = [] - if queries: q_ids.extend(queries) + if queries: + q_ids.extend(queries) if query_csv: q_ids.extend(read_object_ids([Path(query_csv)])) - + if not q_ids: typer.echo("❌ No query IDs provided.") raise typer.Exit(1) - + # 2. Resolve Tasks (Model, Path) tasks = [] - + if emb_path: # Explicit path provided tsk_name = model if model else "CustomModel" @@ -356,37 +455,48 @@ def similarity( if not emb_root.exists(): typer.echo(f"❌ Embeddings root not found: {emb_root}") raise typer.Exit(1) - + candidates = [] # Simple heuristic: scan directory for known .pt files pt_files = list(emb_root.glob("*embeddings*.pt")) - + # Also check subdirectories if organized by model? # paths.embeddings might contain 'aion/embeddings.pt', etc. # But commonly they are in root of run. - + for p in pt_files: fname = p.name.lower() m_name = "Unknown" - if "astropt" in fname: m_name = "AstroPT" - elif "astroclip" in fname: m_name = "AstroCLIP" - elif "aion" in fname: m_name = "AION" - else: m_name = fname.replace("embeddings", "").replace(".pt", "").strip("_").capitalize() + if "astropt" in fname: + m_name = "AstroPT" + elif "astroclip" in fname: + m_name = "AstroCLIP" + elif "aion" in fname: + m_name = "AION" + else: + m_name = ( + fname.replace("embeddings", "") + .replace(".pt", "") + .strip("_") + .capitalize() + ) candidates.append((m_name, p)) - + # Filter if model and model.lower() != "all": tasks = [t for t in candidates if t[0].lower() == model.lower()] if not tasks: - typer.echo(f"❌ No embeddings found matching model '{model}'. Found: {[c[0] for c in candidates]}") - raise typer.Exit(1) + typer.echo( + f"❌ No embeddings found matching model '{model}'. Found: {[c[0] for c in candidates]}" + ) + raise typer.Exit(1) else: tasks = candidates if not tasks: - typer.echo(f"❌ No embedding files found in {paths.embeddings}") - raise typer.Exit(1) - + typer.echo(f"❌ No embedding files found in {paths.embeddings}") + raise typer.Exit(1) + # 3. Resolve Paths if save: out_path = Path(save) @@ -399,30 +509,47 @@ def similarity( out_path = analysis_dir / f"similarity_{name}.png" else: out_path = analysis_dir / "similarity_combined.png" - + if not cache_dir: cache_dir = str(paths.dataset) typer.echo(f"🔍 Finding similar objects for {len(q_ids)} queries...") typer.echo(f" Tasks: {[t[0] for t in tasks]}") typer.echo(f" Output: {out_path}") - + similarity.visualize_similarity( query_ids=q_ids, tasks=tasks, n_similar=n_similar, output_path=out_path, - cache_dir=cache_dir + cache_dir=cache_dir, ) -@analyze_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@analyze_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def neighbor_ranks( ctx: typer.Context, - emb_path: Optional[str] = typer.Option(None, "--embeddings", help="Path to embeddings .pt file (optional, auto-detected if omitted)"), - scores_path: Optional[str] = typer.Option(None, "--scores", help="Path to anomaly_scores.csv (optional, auto-detected if omitted)"), - model: Optional[str] = typer.Option(None, "--model", help="Filter by model name (default: all)"), - queries: Optional[List[str]] = typer.Option(None, "--query", help="Object ID(s) to query"), - query_csv: Optional[str] = typer.Option(None, "--query-csv", help="CSV with object_id column to query"), + emb_path: Optional[str] = typer.Option( + None, + "--embeddings", + help="Path to embeddings .pt file (optional, auto-detected if omitted)", + ), + scores_path: Optional[str] = typer.Option( + None, + "--scores", + help="Path to anomaly_scores.csv (optional, auto-detected if omitted)", + ), + model: Optional[str] = typer.Option( + None, "--model", help="Filter by model name (default: all)" + ), + queries: Optional[List[str]] = typer.Option( + None, "--query", help="Object ID(s) to query" + ), + query_csv: Optional[str] = typer.Option( + None, "--query-csv", help="CSV with object_id column to query" + ), n_similar: int = typer.Option(10, "--n-similar", help="Number of neighbors"), out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Output directory"), ): @@ -430,26 +557,28 @@ def neighbor_ranks( Analyze Rank Distribution of Neighbors. Checks if neighbors of anomalies are also anomalies. """ + from pathlib import Path + from fmb.analysis import similarity from fmb.data.utils import read_object_ids from fmb.paths import load_paths - from pathlib import Path - + paths = load_paths() - + # 1. Resolve Queries q_ids = [] - if queries: q_ids.extend(queries) + if queries: + q_ids.extend(queries) if query_csv: q_ids.extend(read_object_ids([Path(query_csv)])) - + if not q_ids: typer.echo("❌ No query IDs provided.") raise typer.Exit(1) # 2. Resolve Tasks (Model, EmbPath, ScorePath) tasks = [] - + if emb_path and scores_path: # Explicit m_name = model if model else "CustomModel" @@ -457,50 +586,63 @@ def neighbor_ranks( else: # Auto-detect emb_root = paths.embeddings - score_root = paths.outliers # Usually where anomaly_scores_*.csv live - + score_root = paths.outliers # Usually where anomaly_scores_*.csv live + candidates = [] pt_files = list(emb_root.glob("*embeddings*.pt")) - + for p in pt_files: fname = p.name.lower() m_name = "Unknown" - if "astropt" in fname: m_name = "AstroPT" - elif "astroclip" in fname: m_name = "AstroCLIP" - elif "aion" in fname: m_name = "AION" - else: m_name = fname.replace("embeddings", "").replace(".pt", "").strip("_").capitalize() - + if "astropt" in fname: + m_name = "AstroPT" + elif "astroclip" in fname: + m_name = "AstroCLIP" + elif "aion" in fname: + m_name = "AION" + else: + m_name = ( + fname.replace("embeddings", "") + .replace(".pt", "") + .strip("_") + .capitalize() + ) + # Try to find matching score slug = m_name.lower().replace(" ", "") possible_names = [ f"anomaly_scores_{slug}.csv", f"scores_{slug}.csv", - f"{slug}_scores.csv" + f"{slug}_scores.csv", ] - + found_score = None for sname in possible_names: sp = score_root / sname if sp.exists(): found_score = sp break - + if found_score: candidates.append((m_name, p, found_score)) else: pass - + # Filter if model and model.lower() != "all": tasks = [t for t in candidates if t[0].lower() == model.lower()] if not tasks: - typer.echo(f"❌ No valid tasks (embedding+scores) found matching model '{model}'.") - raise typer.Exit(1) + typer.echo( + f"❌ No valid tasks (embedding+scores) found matching model '{model}'." + ) + raise typer.Exit(1) else: tasks = candidates if not tasks: - typer.echo("❌ No valid tasks found. Ensure embeddings(.pt) and scores(.csv) exist and match.") + typer.echo( + "❌ No valid tasks found. Ensure embeddings(.pt) and scores(.csv) exist and match." + ) typer.echo(f" Embeddings: {paths.embeddings}") typer.echo(f" Scores: {paths.outliers}") raise typer.Exit(1) @@ -510,75 +652,81 @@ def neighbor_ranks( out_path = Path(out_dir) else: out_path = paths.analysis / "neighbors" - + out_path.mkdir(parents=True, exist_ok=True) - + typer.echo(f"📊 Analyzing neighbor ranks for {len(q_ids)} queries...") typer.echo(f" Tasks: {[(t[0], str(t[2].name)) for t in tasks]}") typer.echo(f" Output: {out_path}") - + similarity.analyze_neighbor_ranks( - query_ids=q_ids, - tasks=tasks, - n_similar=n_similar, - output_dir=out_path + query_ids=q_ids, tasks=tasks, n_similar=n_similar, output_dir=out_path ) + @analyze_app.command() def regression( ctx: typer.Context, - config: Optional[str] = typer.Option(None, "--config", help="Path to regression config.yaml"), + config: Optional[str] = typer.Option( + None, "--config", help="Path to regression config.yaml" + ), out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Output directory"), ): """ Run physical parameter regression analysis. Predicts Redshift, Mass, SFR from embeddings. """ - from fmb.analysis.regression import predict_physical_params from pathlib import Path - + + from fmb.analysis.regression import predict_physical_params + cfg_path = Path(config) if config else None out_path = Path(out_dir) if out_dir else None - - predict_physical_params.run_analysis( - config_path=cfg_path, - output_dir=out_path - ) + + predict_physical_params.run_analysis(config_path=cfg_path, output_dir=out_path) + @analyze_app.command() def displacement( ctx: typer.Context, - config: Optional[str] = typer.Option(None, "--config", help="Path to displacement config.yaml"), + config: Optional[str] = typer.Option( + None, "--config", help="Path to displacement config.yaml" + ), out_dir: Optional[str] = typer.Option(None, "--out-dir", help="Output directory"), ): """ Run displacement analysis (retention across models/modalities). Generates Multi-Model, Cross-Modality, and Extensive plots. """ - from fmb.analysis import displacement from pathlib import Path - + + from fmb.analysis import displacement + cfg_path = Path(config) if config else None out_path = Path(out_dir) if out_dir else None - - displacement.run_analysis( - config_path=cfg_path, - output_dir=out_path - ) + + displacement.run_analysis(config_path=cfg_path, output_dir=out_path) + @app.command() def display( ctx: typer.Context, - split: str = typer.Option("train", "--split", help="Dataset split to load (train, test, all)"), + split: str = typer.Option( + "train", "--split", help="Dataset split to load (train, test, all)" + ), index: int = typer.Option(0, "--index", help="Index of sample to display"), save: Optional[str] = typer.Option(None, "--save", help="Path to save the figure"), - show_bands: bool = typer.Option(False, "--show-bands", help="Display spectrum/SED and individual bands"), - no_gui: bool = typer.Option(False, "--no-gui", help="Don't open GUI window (save only)") + show_bands: bool = typer.Option( + False, "--show-bands", help="Display spectrum/SED and individual bands" + ), + no_gui: bool = typer.Option( + False, "--no-gui", help="Don't open GUI window (save only)" + ), ): """Load and display dataset samples.""" typer.echo(f"📊 Loading dataset split '{split}'...") forward_args(ctx) - + # Build arguments for the display script sys.argv = [sys.argv[0], "--split", split, "--index", str(index)] if save: @@ -587,16 +735,20 @@ def display( sys.argv.append("--show-bands") if no_gui: sys.argv.append("--no-gui") - + from fmb.data.load_display_data import main as display_main + display_main() + @app.command() def paths( data: bool = typer.Option(False, "--data", help="Print DATA_ROOT only"), embeddings: bool = typer.Option(False, "--embeddings", help="Print EMB_ROOT only"), - checkpoints: bool = typer.Option(False, "--checkpoints", help="Print CKPT_ROOT only"), - runs: bool = typer.Option(False, "--runs", help="Print RUNS_ROOT only") + checkpoints: bool = typer.Option( + False, "--checkpoints", help="Print CKPT_ROOT only" + ), + runs: bool = typer.Option(False, "--runs", help="Print RUNS_ROOT only"), ): """Display current path configuration.""" P = load_paths() @@ -620,10 +772,15 @@ def paths( viz_app = typer.Typer(help="Stage 05: Visualizations for publication and inspection.") app.add_typer(viz_app, name="viz") -@viz_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@viz_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def paper_umap( ctx: typer.Context, - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """Generate the publication-ready combined UMAP figure (AstroPT, AION, AstroCLIP).""" if slurm: @@ -633,16 +790,30 @@ def paper_umap( typer.echo("Generating publication combined UMAP plot locally...") forward_args(ctx) from fmb.viz.combined_umap import main as run_task + run_task() -@viz_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@viz_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def advanced_analysis( ctx: typer.Context, - aion_scores: Optional[str] = typer.Option(None, "--aion-scores", help="Path to AION scores CSV"), - astropt_scores: Optional[str] = typer.Option(None, "--astropt-scores", help="Path to AstroPT scores CSV"), - astroclip_scores: Optional[str] = typer.Option(None, "--astroclip-scores", help="Path to AstroCLIP scores CSV"), - save_prefix: Optional[str] = typer.Option(None, "--save-prefix", help="Prefix for output files"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job instead of running locally") + aion_scores: Optional[str] = typer.Option( + None, "--aion-scores", help="Path to AION scores CSV" + ), + astropt_scores: Optional[str] = typer.Option( + None, "--astropt-scores", help="Path to AstroPT scores CSV" + ), + astroclip_scores: Optional[str] = typer.Option( + None, "--astroclip-scores", help="Path to AstroCLIP scores CSV" + ), + save_prefix: Optional[str] = typer.Option( + None, "--save-prefix", help="Prefix for output files" + ), + slurm: bool = typer.Option( + False, "--slurm", help="Submit as a Slurm job instead of running locally" + ), ): """ Generate Advanced Analysis figures (Spearman, Jaccard, Disagreements). @@ -653,27 +824,44 @@ def advanced_analysis( typer.echo("Running Advanced Analysis...") from fmb.viz.outliers.advanced_analysis import run_analysis - + run_analysis( aion_scores=aion_scores, astropt_scores=astropt_scores, astroclip_scores=astroclip_scores, - save_prefix=save_prefix + save_prefix=save_prefix, ) -@viz_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@viz_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def outlier_grid( ctx: typer.Context, - csv: List[str] = typer.Option(..., "--csv", help="CSV file(s) with object_id column"), + csv: List[str] = typer.Option( + ..., "--csv", help="CSV file(s) with object_id column" + ), split: str = typer.Option("all", "--split", help="Dataset split(s)"), - cache_dir: Optional[str] = typer.Option(None, "--cache-dir", help="Data cache directory"), - max_count: int = typer.Option(12, "--max", help="Maximum number of images to display"), + cache_dir: Optional[str] = typer.Option( + None, "--cache-dir", help="Data cache directory" + ), + max_count: int = typer.Option( + 12, "--max", help="Maximum number of images to display" + ), cols: int = typer.Option(3, "--cols", help="Number of columns"), - save: Optional[str] = typer.Option(None, "--save", help="Path to save the figure (default: analysis/outliers_grid.png)"), - show: bool = typer.Option(False, "--show/--no-show", help="Enable/Disable interactive display"), - index: Optional[str] = typer.Option(None, "--index", help="Optional CSV mapping object_id -> split/index"), + save: Optional[str] = typer.Option( + None, + "--save", + help="Path to save the figure (default: analysis/outliers_grid.png)", + ), + show: bool = typer.Option( + False, "--show/--no-show", help="Enable/Disable interactive display" + ), + index: Optional[str] = typer.Option( + None, "--index", help="Optional CSV mapping object_id -> split/index" + ), verbose: bool = typer.Option(False, "--verbose", help="Enable verbose logging"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job") + slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job"), ): """ Generate Publication Outlier Grid (Images + Spectra). @@ -684,7 +872,7 @@ def outlier_grid( typer.echo("Generating Outlier Grid...") from fmb.viz.outliers.outlier_grid import run_grid_plot - + run_grid_plot( csv_paths=csv, split=split, @@ -694,19 +882,26 @@ def outlier_grid( save_path=save, show=show, index_path=index, - verbose=verbose + verbose=verbose, ) -@viz_app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) + +@viz_app.command( + context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def single_object( ctx: typer.Context, object_id: str = typer.Option(..., "--object-id", help="ID of the object to plot"), index: Optional[str] = typer.Option(None, "--index", help="Path to index CSV"), - cache_dir: Optional[str] = typer.Option(None, "--cache-dir", help="Data cache directory"), - save: Optional[str] = typer.Option(None, "--save", help="Output filename (default: analysis/object_{id}.pdf)"), + cache_dir: Optional[str] = typer.Option( + None, "--cache-dir", help="Data cache directory" + ), + save: Optional[str] = typer.Option( + None, "--save", help="Output filename (default: analysis/object_{id}.pdf)" + ), smooth: float = typer.Option(2.0, "--smooth", help="Smoothing for spectrum"), dpi: int = typer.Option(300, "--dpi", help="DPI for saving"), - slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job") + slurm: bool = typer.Option(False, "--slurm", help="Submit as a Slurm job"), ): """ Generate Single Object Visualization (Spectrum + Bands). @@ -717,14 +912,14 @@ def single_object( typer.echo(f"Plotting Object {object_id}...") from fmb.viz.outliers.single_object import run_single_object_plot - + run_single_object_plot( object_id=object_id, index_path=index, cache_dir=cache_dir, save_path=save, smooth=smooth, - dpi=dpi + dpi=dpi, ) diff --git a/src/fmb/data/astroclip_loader.py b/src/fmb/data/astroclip_loader.py index 7fe95f9..4509045 100644 --- a/src/fmb/data/astroclip_loader.py +++ b/src/fmb/data/astroclip_loader.py @@ -6,12 +6,13 @@ """ from pathlib import Path -from typing import Optional, List -import pandas as pd +from typing import Optional + import numpy as np +import pandas as pd import torch import torch.nn.functional as F -from datasets import load_from_disk, Dataset +from datasets import Dataset, load_from_disk # Mapping of split names to local directory names (in cache_dir) LOCAL_SPLIT_PATHS = { @@ -27,98 +28,105 @@ def load_local_arrow_dataset( seed: int = 42, ) -> pd.DataFrame: """Load dataset from local Arrow cache and convert to pandas DataFrame. - + Args: cache_dir: Directory containing the cached dataset (e.g., /n03data/ronceray/datasets) split: Which split to load ('train', 'test') max_samples: Optional limit on number of samples to load seed: Random seed for sampling - + Returns: pd.DataFrame with columns: spectrum, redshift, image, and other metadata """ cache_path = Path(cache_dir) - + # Check if split exists locally local_split_name = LOCAL_SPLIT_PATHS.get(split) - + # Handle "all" split by loading train and test and concatenating if split == "all": print("Dataset split 'all' requested. Loading 'train' and 'test'...") - df_train = load_local_arrow_dataset(cache_dir, "train", max_samples=None, seed=seed) - df_test = load_local_arrow_dataset(cache_dir, "test", max_samples=None, seed=seed) + df_train = load_local_arrow_dataset( + cache_dir, "train", max_samples=None, seed=seed + ) + df_test = load_local_arrow_dataset( + cache_dir, "test", max_samples=None, seed=seed + ) df_all = pd.concat([df_train, df_test], ignore_index=True) - + if max_samples is not None: - # Re-sample from the combined dataframe if needed - if len(df_all) > max_samples: + # Re-sample from the combined dataframe if needed + if len(df_all) > max_samples: print(f"Sampling {max_samples} from {len(df_all)} combined samples") - df_all = df_all.sample(n=max_samples, random_state=seed).reset_index(drop=True) + df_all = df_all.sample(n=max_samples, random_state=seed).reset_index( + drop=True + ) return df_all if not local_split_name: - raise ValueError(f"Unknown split '{split}'. Available: {list(LOCAL_SPLIT_PATHS.keys()) + ['all']}") - + raise ValueError( + f"Unknown split '{split}'. Available: {list(LOCAL_SPLIT_PATHS.keys()) + ['all']}" + ) + local_path = cache_path / local_split_name - + if not local_path.exists(): raise FileNotFoundError( f"Local Arrow dataset not found at {local_path}\n" f"Expected directory structure: {cache_dir}/{local_split_name}/" ) - + print(f"Loading dataset from {local_path}") - + # Load using HuggingFace's load_from_disk dataset: Dataset = load_from_disk(str(local_path)) - + print(f"✓ Loaded {len(dataset)} samples from '{split}' split") - + # Sample if requested if max_samples is not None and max_samples < len(dataset): print(f"📊 Sampling {max_samples} from {len(dataset)} samples") indices = list(range(len(dataset))) import random + random.seed(seed) random.shuffle(indices) indices = indices[:max_samples] dataset = dataset.select(indices) - + # Convert to pandas DataFrame df = dataset.to_pandas() - + print(f"✓ Converted to DataFrame with {len(df)} rows") print(f" Columns: {list(df.columns)}") - + # Verify required columns required_cols = ["spectrum", "redshift"] missing_cols = [col for col in required_cols if col not in df.columns] if missing_cols: raise ValueError(f"Missing required columns: {missing_cols}") - + # Check for image columns (might be rgb_image or separate bands) - image_cols = [col for col in df.columns if 'image' in col.lower()] + image_cols = [col for col in df.columns if "image" in col.lower()] print(f" Image columns found: {image_cols}") - + return df def prepare_image_from_euclid_bands( - sample: dict, - image_size: int = 224, - target_key: str = "image" + sample: dict, image_size: int = 224, target_key: str = "image" ) -> torch.Tensor: """Prepare image tensor from Euclid bands or RGB image. - + This function handles different image formats in the dataset: - If 'RGB_image' or 'rgb_image' exists, use it directly - Otherwise, stack individual bands (VIS_image, NISP_Y_image, etc.) - + Args: sample: Dictionary from the dataset row image_size: Target size for the image target_key: Key to look for in the sample (e.g., 'rgb_image') - + Returns: torch.Tensor of shape (C, H, W) """ @@ -128,7 +136,7 @@ def prepare_image_from_euclid_bands( if key in sample: rgb_key = key break - + if rgb_key: image = sample[rgb_key] if isinstance(image, torch.Tensor): @@ -139,7 +147,7 @@ def prepare_image_from_euclid_bands( # Try to stack individual bands (uppercase first, then lowercase) band_keys_upper = ["VIS_image", "NISP_Y_image", "NISP_J_image", "NISP_H_image"] band_keys_lower = ["vis_image", "nisp_y_image", "nisp_j_image", "nisp_h_image"] - + bands = [] # Try uppercase first for key in band_keys_upper: @@ -148,7 +156,7 @@ def prepare_image_from_euclid_bands( if band.ndim == 3 and band.shape[0] == 1: band = band.squeeze(0) bands.append(band) - + # If no uppercase bands found, try lowercase if not bands: for key in band_keys_lower: @@ -157,23 +165,23 @@ def prepare_image_from_euclid_bands( if band.ndim == 3 and band.shape[0] == 1: band = band.squeeze(0) bands.append(band) - + if not bands: raise ValueError( f"No image data found in sample. " f"Checked: RGB_image, rgb_image, {band_keys_upper}, {band_keys_lower}" ) - + # Stack to create multi-channel image img = torch.stack(bands, dim=0) - + # Normalize and clean img = torch.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0) - + # Ensure proper channel ordering (C, H, W) if img.ndim == 3 and img.shape[0] not in (1, 3, 4): img = img.permute(2, 0, 1).contiguous() - + # Resize if needed if image_size and (img.shape[-1] != image_size or img.shape[-2] != image_size): img = F.interpolate( @@ -182,11 +190,11 @@ def prepare_image_from_euclid_bands( mode="bilinear", align_corners=False, ).squeeze(0) - + # For RGB (3 channels), clamp to [0, 1] if img.shape[0] == 3: img = img.clamp(0.0, 1.0) - + return img @@ -195,131 +203,143 @@ def convert_dataset_to_astroclip_format( image_size: int = 144, ) -> pd.DataFrame: """Convert DataFrame from EuclidDESI format to AstroCLIP expected format. - + AstroCLIP expects: - 'image': torch.Tensor (C, H, W) - 'spectrum': dict with 'flux' and optionally 'wavelength' - 'redshift': float - + Args: df: Input DataFrame from load_local_arrow_dataset image_size: Target image size - + Returns: DataFrame with processed columns """ print(f"🔄 Converting {len(df)} samples to AstroCLIP format...") - + try: from PIL import Image except ImportError: Image = None - + processed_rows = [] - + for idx, row in df.iterrows(): try: # Get RGB image (uppercase key) - rgb_image = row.get('RGB_image') + rgb_image = row.get("RGB_image") if rgb_image is None: continue - + # Convert PIL Image to numpy array if Image is not None and isinstance(rgb_image, Image.Image): rgb_image = np.array(rgb_image) - + # SPECIAL CASE: If it's a dict (Arrow struct), try to convert if isinstance(rgb_image, dict): # Sometimes images are stored as dicts in Arrow (e.g. {'bytes': ..., 'path': ...}) - if 'bytes' in rgb_image: - import io - if Image is None: - from PIL import Image - rgb_image = np.array(Image.open(io.BytesIO(rgb_image['bytes']))) - elif 'array' in rgb_image: # hypothetical - rgb_image = np.array(rgb_image['array']) - + if "bytes" in rgb_image: + import io + + if Image is None: + from PIL import Image + rgb_image = np.array(Image.open(io.BytesIO(rgb_image["bytes"]))) + elif "array" in rgb_image: # hypothetical + rgb_image = np.array(rgb_image["array"]) + # Convert to tensor format (C, H, W) if isinstance(rgb_image, np.ndarray): if rgb_image.ndim == 3: - image_tensor = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 + image_tensor = ( + torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 + ) else: - image_tensor = torch.from_numpy(rgb_image).unsqueeze(0).float() / 255.0 + image_tensor = ( + torch.from_numpy(rgb_image).unsqueeze(0).float() / 255.0 + ) else: # Fallback: if already tensor-like image_tensor = torch.as_tensor(rgb_image).float() if image_tensor.ndim == 3 and image_tensor.shape[0] not in (1, 3): # Try to convert to (C, H, W) image_tensor = image_tensor.permute(2, 0, 1).contiguous() - + # Resize if needed - if image_size and (image_tensor.shape[-1] != image_size or image_tensor.shape[-2] != image_size): + if image_size and ( + image_tensor.shape[-1] != image_size + or image_tensor.shape[-2] != image_size + ): image_tensor = F.interpolate( image_tensor.unsqueeze(0), size=(image_size, image_size), mode="bilinear", align_corners=False, ).squeeze(0) - + # Process spectrum (should already be a dict) - spectrum = row.get('spectrum') + spectrum = row.get("spectrum") if spectrum is None: continue # Skip samples without spectrum - + # Ensure spectrum is properly formatted if not isinstance(spectrum, dict): # If it's an array, convert to dict if isinstance(spectrum, (list, np.ndarray)): spectrum = {"flux": spectrum} - + # Get redshift - redshift = row.get('redshift', 0.0) - - processed_rows.append({ - "image": image_tensor, - "spectrum": spectrum, - "redshift": redshift, - "object_id": row.get('object_id'), - "targetid": row.get('targetid'), - }) - + redshift = row.get("redshift", 0.0) + + processed_rows.append( + { + "image": image_tensor, + "spectrum": spectrum, + "redshift": redshift, + "object_id": row.get("object_id"), + "targetid": row.get("targetid"), + } + ) + except Exception as e: print(f"Warning: Failed to process row {idx}: {e}") import traceback + traceback.print_exc() continue - + result_df = pd.DataFrame(processed_rows) print(f"✓ Successfully converted {len(result_df)} samples") - + return result_df if __name__ == "__main__": # Test loading import argparse + parser = argparse.ArgumentParser() parser.add_argument("--cache-dir", default="/n03data/ronceray/datasets") parser.add_argument("--split", default="train") parser.add_argument("--max-samples", type=int, default=10) args = parser.parse_args() - + df = load_local_arrow_dataset( cache_dir=args.cache_dir, split=args.split, max_samples=args.max_samples, ) - - print(f"\n📊 Dataset info:") + + print("\n📊 Dataset info:") print(f" Shape: {df.shape}") print(f" Columns: {list(df.columns)}") - print(f"\n First row preview:") + print("\n First row preview:") print(f" Redshift: {df.iloc[0]['redshift']}") print(f" Spectrum type: {type(df.iloc[0].get('spectrum'))}") - + # Test conversion converted = convert_dataset_to_astroclip_format(df, image_size=144) - print(f"\n✓ Conversion test successful!") + print("\n✓ Conversion test successful!") print(f" Converted shape: {converted.shape}") if len(converted) > 0: print(f" Image shape: {converted.iloc[0]['image'].shape}") diff --git a/src/fmb/data/astroclip_parquet.py b/src/fmb/data/astroclip_parquet.py index e5922c6..160f618 100644 --- a/src/fmb/data/astroclip_parquet.py +++ b/src/fmb/data/astroclip_parquet.py @@ -13,7 +13,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy as np import pandas as pd @@ -74,7 +74,9 @@ def resolve_parquet_path(raw_path: str) -> str: files = list_repo_files(repo_id=repo_id, repo_type="dataset") guesses = [f for f in files if inner_path.split("/")[-1] in f] hint = f" Exemples trouvés: {guesses[:5]}" if guesses else "" - raise FileNotFoundError(f"{inner_path} introuvable sur Hugging Face.{hint}") from exc + raise FileNotFoundError( + f"{inner_path} introuvable sur Hugging Face.{hint}" + ) from exc path_obj = Path(raw_path).expanduser() if not path_obj.exists(): @@ -164,7 +166,9 @@ def _load_impl(self) -> pd.DataFrame: if "image" not in df.columns and "RGB_image" in df.columns: df["image"] = df["RGB_image"].apply(self._preprocess_image) elif "image" not in df.columns: - raise ValueError("Le parquet doit contenir une colonne 'image' ou 'RGB_image'.") + raise ValueError( + "Le parquet doit contenir une colonne 'image' ou 'RGB_image'." + ) if "redshift" not in df.columns: raise ValueError("La colonne 'redshift' est absente du parquet.") @@ -175,7 +179,11 @@ def _load_impl(self) -> pd.DataFrame: if self.focus_high_z: df = df.nlargest(self.sample_size, "redshift").reset_index(drop=True) else: - df = df.sample(self.sample_size, random_state=42).sort_index().reset_index(drop=True) + df = ( + df.sample(self.sample_size, random_state=42) + .sort_index() + .reset_index(drop=True) + ) df["pair_id"] = np.arange(len(df)) return df diff --git a/src/fmb/data/datasets.py b/src/fmb/data/datasets.py index 7371fad..543a1f1 100644 --- a/src/fmb/data/datasets.py +++ b/src/fmb/data/datasets.py @@ -5,18 +5,18 @@ Description: PyTorch dataset wrappers for foundation models """ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union +import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset -import numpy as np -from dataclasses import dataclass, field -from typing import Optional, List, Dict, Any, Union -import sys # Attempt to import AION types for AionDataset try: from aion.modalities import EuclidImage + AION_AVAILABLE = True except ImportError: AION_AVAILABLE = False @@ -33,41 +33,42 @@ "nisp_h_image": 918.35, } + @dataclass class FMBDataConfig: """Shared configuration for FMB datasets.""" + split: str = "train" cache_dir: Optional[str] = None image_size: int = 96 max_entries: Optional[int] = None # For AION/AstroCLIP - crop_size: int = 96 + crop_size: int = 96 # For AstroCLIP/AstroPT spectrum_length: int = 7781 # For AstroCLIP - spectrum_norm: str = "zscore" # none, zscore, minmax + spectrum_norm: str = "zscore" # none, zscore, minmax include_wavelength: bool = False - slice_length: Optional[int] = None # Alias for spectrum_length in AstroCLIP config - + slice_length: Optional[int] = None # Alias for spectrum_length in AstroCLIP config + def __post_init__(self): # Unify slice_length and spectrum_length if self.slice_length is not None: self.spectrum_length = self.slice_length + class FMBBaseDataset(Dataset): """Base class for FMB datasets wrapping EuclidDESIDataset.""" - + def __init__(self, config: FMBDataConfig, verbose: bool = False): self.config = config self.base = EuclidDESIDataset( - split=config.split, - cache_dir=config.cache_dir, - verbose=verbose + split=config.split, cache_dir=config.cache_dir, verbose=verbose ) - + self._indices = list(range(len(self.base))) if config.max_entries and config.max_entries > 0: - self._indices = self._indices[:config.max_entries] + self._indices = self._indices[: config.max_entries] def __len__(self) -> int: return len(self._indices) @@ -76,27 +77,31 @@ def _get_base_sample(self, idx: int) -> Dict[str, Any]: base_idx = self._indices[idx] return self.base[base_idx] - def _process_image(self, image_data: Any, resize_to: Optional[int] = None) -> torch.Tensor: + def _process_image( + self, image_data: Any, resize_to: Optional[int] = None + ) -> torch.Tensor: """Standard image tensor conversion and optional resizing.""" if image_data is None: - # Return zero tensor or handle as needed by subclass? - # Usually subclasses check for None. Here we assume valid input or handle logic outside. - return torch.zeros((3, resize_to or 64, resize_to or 64), dtype=torch.float32) + # Return zero tensor or handle as needed by subclass? + # Usually subclasses check for None. Here we assume valid input or handle logic outside. + return torch.zeros( + (3, resize_to or 64, resize_to or 64), dtype=torch.float32 + ) img = torch.as_tensor(image_data, dtype=torch.float32) img = torch.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0) - + # Ensure (C, H, W) - if img.ndim == 3 and img.shape[0] not in (1, 3, 4): + if img.ndim == 3 and img.shape[0] not in (1, 3, 4): # If shape is like (H, W, C) -> permute img = img.permute(2, 0, 1).contiguous() elif img.ndim == 2: img = img.unsqueeze(0) - + # Clamp usually to [0, 1] for safety in many models, but specific models might differ. # AstroCLIP clamps to [0, 1]. AstroPT clamps to [0, 1]. AION converts units. # We'll leave clamping to specific implementations or define a default here. - # img = img.clamp(0.0, 1.0) + # img = img.clamp(0.0, 1.0) if resize_to and (img.shape[-1] != resize_to or img.shape[-2] != resize_to): img = F.interpolate( @@ -105,13 +110,17 @@ def _process_image(self, image_data: Any, resize_to: Optional[int] = None) -> to mode="bilinear", align_corners=False, ).squeeze(0) - + return img - def _pad_or_trim_spectrum(self, flux: torch.Tensor, target_len: int) -> torch.Tensor: + def _pad_or_trim_spectrum( + self, flux: torch.Tensor, target_len: int + ) -> torch.Tensor: if flux.numel() < target_len: pad_len = target_len - flux.numel() - flux = torch.cat([flux, torch.zeros(pad_len, dtype=flux.dtype, device=flux.device)]) + flux = torch.cat( + [flux, torch.zeros(pad_len, dtype=flux.dtype, device=flux.device)] + ) elif flux.numel() > target_len: flux = flux[:target_len] return flux @@ -119,9 +128,10 @@ def _pad_or_trim_spectrum(self, flux: torch.Tensor, target_len: int) -> torch.Te class AstroClipDataset(FMBBaseDataset): """Dataset for AstroCLIP training (Image + Spectrum).""" - + def _normalise_spectrum(self, tensor: torch.Tensor, mode: str) -> torch.Tensor: - if mode == "none": return tensor + if mode == "none": + return tensor if mode == "zscore": std = tensor.std(unbiased=False).clamp(min=1e-6) return (tensor - tensor.mean()) / std @@ -132,36 +142,42 @@ def _normalise_spectrum(self, tensor: torch.Tensor, mode: str) -> torch.Tensor: def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: sample = self._get_base_sample(idx) - + # Image - img = self._process_image(sample.get("rgb_image"), resize_to=self.config.image_size) - img = img.clamp(0.0, 1.0) # AstroCLIP specific - + img = self._process_image( + sample.get("rgb_image"), resize_to=self.config.image_size + ) + img = img.clamp(0.0, 1.0) # AstroCLIP specific + # Spectrum spec_dict = sample.get("spectrum") or {} flux = np.asarray(spec_dict.get("flux", [])) - + # Handle wavelength if needed (defaults to linear dummy if empty/not requested) wavelength = np.asarray(spec_dict.get("wavelength", [])) if len(wavelength) == 0 and len(flux) > 0: - wavelength = np.linspace(0, 1, len(flux), dtype=np.float32) + wavelength = np.linspace(0, 1, len(flux), dtype=np.float32) if len(flux) == 0: - # Handle empty spectrum case gracefully? - flux = np.zeros(self.config.spectrum_length, dtype=np.float32) - wavelength = np.zeros(self.config.spectrum_length, dtype=np.float32) + # Handle empty spectrum case gracefully? + flux = np.zeros(self.config.spectrum_length, dtype=np.float32) + wavelength = np.zeros(self.config.spectrum_length, dtype=np.float32) flux_tensor = torch.as_tensor(flux, dtype=torch.float32) - flux_tensor = self._pad_or_trim_spectrum(flux_tensor, self.config.spectrum_length) + flux_tensor = self._pad_or_trim_spectrum( + flux_tensor, self.config.spectrum_length + ) flux_tensor = self._normalise_spectrum(flux_tensor, self.config.spectrum_norm) - + if self.config.include_wavelength: wave_tensor = torch.as_tensor(wavelength, dtype=torch.float32) - wave_tensor = self._pad_or_trim_spectrum(wave_tensor, self.config.spectrum_length) + wave_tensor = self._pad_or_trim_spectrum( + wave_tensor, self.config.spectrum_length + ) spectrum = torch.stack([flux_tensor, wave_tensor], dim=-1) else: spectrum = flux_tensor.unsqueeze(-1) - + return { "image": img, "spectrum": spectrum, @@ -173,8 +189,10 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: class AstroPTDataset(FMBBaseDataset): """Dataset for AstroPT training (Multimodal).""" - - def _prepare_spectrum_astropt(self, spectrum_dict: Optional[dict]) -> Optional[torch.Tensor]: + + def _prepare_spectrum_astropt( + self, spectrum_dict: Optional[dict] + ) -> Optional[torch.Tensor]: if spectrum_dict is None or spectrum_dict.get("flux") is None: return None flux = torch.as_tensor(spectrum_dict["flux"], dtype=torch.float32) @@ -188,10 +206,12 @@ def _prepare_spectrum_astropt(self, spectrum_dict: Optional[dict]) -> Optional[t def __getitem__(self, idx: int) -> Dict[str, Any]: sample = self._get_base_sample(idx) - - image = self._process_image(sample.get("rgb_image"), resize_to=self.config.image_size) + + image = self._process_image( + sample.get("rgb_image"), resize_to=self.config.image_size + ) image = image.clamp(0.0, 1.0) - + spectrum = self._prepare_spectrum_astropt(sample.get("spectrum")) targetid = sample.get("targetid") @@ -199,7 +219,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: targetid_val = int(targetid) if targetid is not None else -1 except Exception: targetid_val = -1 - + redshift = sample.get("redshift") redshift_val = float(redshift) if redshift is not None else 0.0 @@ -214,7 +234,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: class AionDataset(FMBBaseDataset): """Dataset for AION training (Euclid 4-band images).""" - + def __getitem__(self, idx: int) -> Union[EuclidImage, Any]: sample = self._get_base_sample(idx) @@ -236,8 +256,10 @@ def __getitem__(self, idx: int) -> Union[EuclidImage, Any]: t = t.squeeze(0) if t.ndim != 2: # Some samples might be 3D, take first channel or squeeze - if t.ndim == 3: t = t[0] - else: raise ValueError(f"Expected 2D band, got {tuple(t.shape)}") + if t.ndim == 3: + t = t[0] + else: + raise ValueError(f"Expected 2D band, got {tuple(t.shape)}") bands.append(t) flux = torch.stack(bands, dim=0) # (4,H,W) @@ -245,19 +267,25 @@ def __getitem__(self, idx: int) -> Union[EuclidImage, Any]: # AION typically expects resizing too, but uses crop often? # retrain_aion.py had: # if self.config.resize and ... interpolation - - if self.config.image_size and (flux.shape[-1] != self.config.image_size or flux.shape[-2] != self.config.image_size): + + if self.config.image_size and ( + flux.shape[-1] != self.config.image_size + or flux.shape[-2] != self.config.image_size + ): flux = F.interpolate( flux.unsqueeze(0), size=(self.config.image_size, self.config.image_size), mode="bilinear", align_corners=False, ).squeeze(0) - + if AION_AVAILABLE: return EuclidImage(flux=flux, bands=EUCLID_BANDS) else: - return {"flux": flux, "bands": EUCLID_BANDS} # Fallback for testing without AION + return { + "flux": flux, + "bands": EUCLID_BANDS, + } # Fallback for testing without AION class AionMultimodalDataset(AionDataset): @@ -266,13 +294,13 @@ class AionMultimodalDataset(AionDataset): def __getitem__(self, idx: int) -> Dict[str, Any]: # Get base sample sample = self._get_base_sample(idx) - + # 1. Process Image using AionDataset logic (but we need to call it manually or reuse logic) # Inheriting from AionDataset allows reuse if we factor out image processing? # AionDataset.__getitem__ returns just the image. # Let's call super().__getitem__ to get the image object euclid_image = super().__getitem__(idx) - + # 2. Process Spectrum spec = sample.get("spectrum") # Return dict with both @@ -280,5 +308,5 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: "object_id": sample.get("object_id") or sample.get("targetid"), "redshift": sample.get("redshift"), "image": euclid_image, - "spectrum": spec # Raw dict, processed by collator/script usually + "spectrum": spec, # Raw dict, processed by collator/script usually } diff --git a/src/fmb/data/index_dataset.py b/src/fmb/data/index_dataset.py index 761e1e6..b7ca7e2 100644 --- a/src/fmb/data/index_dataset.py +++ b/src/fmb/data/index_dataset.py @@ -8,7 +8,7 @@ import argparse import csv from pathlib import Path -from typing import Sequence, Optional, List +from typing import List, Optional, Sequence from datasets import get_dataset_split_names, load_dataset, load_from_disk from tqdm import tqdm @@ -29,13 +29,13 @@ def run_indexing( splits: Sequence[str] = ("all",), output: Optional[Path] = None, overwrite: bool = False, - hf_dataset_id: Optional[str] = None + hf_dataset_id: Optional[str] = None, ) -> None: """ Main entry point for indexing the dataset. """ paths = load_paths() - + # Defaults if cache_dir is None: cache_dir = str(paths.dataset) @@ -50,7 +50,7 @@ def run_indexing( # Determine splits final_splits: List[str] = [] - + # Check if "all" is requested if "all" in [s.lower() for s in splits]: # Priority to local directories in cache_dir @@ -58,7 +58,7 @@ def run_indexing( for s_name, local_name in LOCAL_SPLITS.items(): if (Path(cache_dir) / local_name).is_dir(): local_found.append(s_name) - + if local_found: final_splits = local_found print(f"Found local splits: {final_splits}") @@ -84,7 +84,7 @@ def _load_split(split_name: str): if path.is_dir(): print(f" Loading split '{split_name}' from local directory: {path}") return load_from_disk(str(path)) - + print(f" Loading split '{split_name}' from HF dataset {hf_dataset_id}") return load_dataset(hf_dataset_id, split=split_name, cache_dir=cache_dir) @@ -100,14 +100,14 @@ def _load_split(split_name: str): # Use 'targetid' or 'object_id' # Euclid dataset often uses 'object_id', older versions might use 'targetid' # We check first sample to be efficient? No, iterate all. - + count = 0 for idx, sample in enumerate(tqdm(ds, desc=f"{split}", unit="sample")): oid = sample.get("object_id") or sample.get("targetid") if oid is not None: writer.writerow([oid, split, idx]) count += 1 - + print(f" Recorded {count} entries for split '{split}'.") print(f"Index written to {output}") @@ -118,18 +118,21 @@ def main(argv: Sequence[str] | None = None) -> None: parser.add_argument("--cache-dir", default=None, help="Dataset cache directory") parser.add_argument("--splits", default="all", help="Comma-separated splits") parser.add_argument("--output", default=None, help="Path to output CSV") - parser.add_argument("--overwrite", action="store_true", help="Overwrite existing file") - + parser.add_argument( + "--overwrite", action="store_true", help="Overwrite existing file" + ) + args = parser.parse_args(argv) splits_list = [s.strip() for s in args.splits.split(",") if s.strip()] - + run_indexing( cache_dir=args.cache_dir, splits=splits_list, output=Path(args.output) if args.output else None, - overwrite=args.overwrite + overwrite=args.overwrite, ) + if __name__ == "__main__": main() diff --git a/src/fmb/data/load_display_data.py b/src/fmb/data/load_display_data.py index 0852fc7..a34bb2e 100644 --- a/src/fmb/data/load_display_data.py +++ b/src/fmb/data/load_display_data.py @@ -17,21 +17,23 @@ import argparse import os + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -from random import sample import sys import warnings # Add src to pythonpath FIRST, before any local imports from pathlib import Path + src_path = Path(__file__).resolve().parents[2] if str(src_path) not in sys.path: sys.path.insert(0, str(src_path)) +from typing import Optional, Sequence + # Matplotlib defaults to interactive; switch to "Agg" if --no-gui is passed import matplotlib -from typing import Optional, Sequence def _maybe_switch_to_agg(no_gui: bool): if no_gui: @@ -43,6 +45,7 @@ def _maybe_switch_to_agg(no_gui: bool): except Exception: matplotlib.use("Agg") + import matplotlib.pyplot as plt import numpy as np import torch @@ -53,19 +56,21 @@ def _maybe_switch_to_agg(no_gui: bool): Image = None try: - from datasets import load_dataset, concatenate_datasets, load_from_disk + from datasets import concatenate_datasets, load_dataset, load_from_disk except ImportError as e: raise SystemExit( "The 'datasets' package is required. Install it with: pip install datasets" ) from e -from torch.utils.data import DataLoader + from fmb.paths import load_paths -HF_DATASET_ID = "msiudek/astroPT_euclid_Q1_desi_dr1_dataset" # Fallback if not in paths, but paths has default +HF_DATASET_ID = "msiudek/astroPT_euclid_Q1_desi_dr1_dataset" # Fallback if not in paths, but paths has default + class EuclidDESIDataset(torch.utils.data.Dataset): """PyTorch Dataset wrapper for the Euclid+DESI HuggingFace dataset.""" + def __init__( self, split="train", @@ -74,14 +79,15 @@ def __init__( verbose: bool = False, ): import os + paths = load_paths() if cache_dir is None: cache_dir = str(paths.dataset) - + os.makedirs(cache_dir, exist_ok=True) self.verbose = verbose self.transform = transform - + self.hf_dataset_id = getattr(paths, "dataset_hf_id", HF_DATASET_ID) requested_splits: list[str] @@ -89,18 +95,22 @@ def __init__( local_split_paths = {} if paths.dataset_train and paths.dataset_train.exists(): - local_split_paths["train"] = paths.dataset_train + local_split_paths["train"] = paths.dataset_train if paths.dataset_test and paths.dataset_test.exists(): - local_split_paths["test"] = paths.dataset_test + local_split_paths["test"] = paths.dataset_test def _load_split(split_name: str): """Load a split from local disk if available, otherwise from HF.""" if split_name in local_split_paths: if self.verbose: - print(f"Loading split '{split_name}' from {local_split_paths[split_name]}") + print( + f"Loading split '{split_name}' from {local_split_paths[split_name]}" + ) return load_from_disk(str(local_split_paths[split_name])) if self.verbose: - print(f"Loading split '{split_name}' from HF dataset {self.hf_dataset_id}") + print( + f"Loading split '{split_name}' from HF dataset {self.hf_dataset_id}" + ) return load_dataset( self.hf_dataset_id, split=split_name, @@ -112,7 +122,9 @@ def _load_split(split_name: str): if normalized.lower() in {"all", "*"}: requested_splits = list(local_split_paths) or ["train", "test"] else: - requested_splits = [part.strip() for part in normalized.split(",") if part.strip()] + requested_splits = [ + part.strip() for part in normalized.split(",") if part.strip() + ] if not requested_splits: raise ValueError("No valid split names provided") for split_name in requested_splits: @@ -142,15 +154,14 @@ def _load_split(split_name: str): self.splits = requested_splits if self.verbose: per_split_sizes = { - name: len(ds) - for name, ds in zip(self.splits, datasets_to_concat) + name: len(ds) for name, ds in zip(self.splits, datasets_to_concat) } print( f"Loaded EuclidDESIDataset with splits={self.splits} total_samples={len(self.dataset)}" ) print(f"Per-split sizes: {per_split_sizes}") preview = [ - (self.dataset[i].get("object_id") or self.dataset[i].get("targetid")) + (self.dataset[i].get("object_id") or self.dataset[i].get("targetid")) for i in range(min(3, len(self.dataset))) ] print(f"Object ID preview: {preview}") @@ -163,14 +174,16 @@ def __getitem__(self, idx): sample = self.dataset[idx] # Convert PIL image to tensor - rgb_image = sample['RGB_image'] + rgb_image = sample["RGB_image"] if Image is not None and isinstance(rgb_image, Image.Image): rgb_image = np.array(rgb_image) # Convert to tensor format (C, H, W) if isinstance(rgb_image, np.ndarray): if rgb_image.ndim == 3: - rgb_image_t = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 + rgb_image_t = ( + torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 + ) else: rgb_image_t = torch.from_numpy(rgb_image).unsqueeze(0).float() / 255.0 else: @@ -184,56 +197,66 @@ def __getitem__(self, idx): # Process spectrum data spectrum_data = None - if sample.get('spectrum') is not None: - flux = sample['spectrum'].get('flux') - wavelength = sample['spectrum'].get('wavelength') - error = sample['spectrum'].get('error') + if sample.get("spectrum") is not None: + flux = sample["spectrum"].get("flux") + wavelength = sample["spectrum"].get("wavelength") + error = sample["spectrum"].get("error") flux = np.array(flux) if flux is not None else None wavelength = np.array(wavelength) if wavelength is not None else None error = np.array(error) if error is not None else None - ivar = 1.0 / (error ** 2) if error is not None else None + ivar = 1.0 / (error**2) if error is not None else None # mask not provided → make a "valid empty" boolean mask mask = np.zeros_like(flux, dtype=bool) if flux is not None else None if flux is not None: spectrum_data = { - 'flux': torch.from_numpy(flux).float(), - 'wavelength': torch.from_numpy(wavelength).float() if wavelength is not None else None, - 'error': torch.from_numpy(error).float() if error is not None else None, - 'ivar': torch.from_numpy(ivar).float() if ivar is not None else None, - 'mask': torch.from_numpy(mask).bool() if mask is not None else None, + "flux": torch.from_numpy(flux).float(), + "wavelength": ( + torch.from_numpy(wavelength).float() + if wavelength is not None + else None + ), + "error": ( + torch.from_numpy(error).float() if error is not None else None + ), + "ivar": ( + torch.from_numpy(ivar).float() if ivar is not None else None + ), + "mask": torch.from_numpy(mask).bool() if mask is not None else None, } # Process SED data sed_fluxes = None - if sample.get('sed_data') is not None: - flux_keys = [k for k in sample['sed_data'].keys() if k.startswith('flux_')] + if sample.get("sed_data") is not None: + flux_keys = [k for k in sample["sed_data"].keys() if k.startswith("flux_")] if flux_keys: - sed_fluxes = torch.tensor([sample['sed_data'][k] for k in flux_keys]).float() + sed_fluxes = torch.tensor( + [sample["sed_data"][k] for k in flux_keys] + ).float() # Individual band images (optional) def _to_tensor_img(x): return torch.from_numpy(np.array(x)).float() if x is not None else None - vis_image = _to_tensor_img(sample.get('VIS_image')) - nisp_y_image = _to_tensor_img(sample.get('NISP_Y_image')) - nisp_j_image = _to_tensor_img(sample.get('NISP_J_image')) - nisp_h_image = _to_tensor_img(sample.get('NISP_H_image')) + vis_image = _to_tensor_img(sample.get("VIS_image")) + nisp_y_image = _to_tensor_img(sample.get("NISP_Y_image")) + nisp_j_image = _to_tensor_img(sample.get("NISP_J_image")) + nisp_h_image = _to_tensor_img(sample.get("NISP_H_image")) return { - 'object_id': sample.get('object_id') or sample.get('targetid'), - 'targetid': sample.get('targetid'), - 'redshift': sample.get('redshift'), - 'rgb_image': rgb_image_t, - 'vis_image': vis_image, - 'nisp_y_image': nisp_y_image, - 'nisp_j_image': nisp_j_image, - 'nisp_h_image': nisp_h_image, - 'spectrum': spectrum_data, - 'sed_fluxes': sed_fluxes, + "object_id": sample.get("object_id") or sample.get("targetid"), + "targetid": sample.get("targetid"), + "redshift": sample.get("redshift"), + "rgb_image": rgb_image_t, + "vis_image": vis_image, + "nisp_y_image": nisp_y_image, + "nisp_j_image": nisp_j_image, + "nisp_h_image": nisp_h_image, + "spectrum": spectrum_data, + "sed_fluxes": sed_fluxes, } @@ -267,7 +290,7 @@ def display_one_sample( fig, ax_rgb = plt.subplots(figsize=(5, 5)) # ----- RGB ----- - rgb = sample['rgb_image'] + rgb = sample["rgb_image"] if rgb.ndim == 3 and rgb.shape[0] in (1, 3): rgb_np = rgb.permute(1, 2, 0).numpy() if rgb_np.shape[2] == 1: # grayscale @@ -282,10 +305,14 @@ def display_one_sample( if show_bands: # ----- Spectrum (if available) ----- - spec = sample.get('spectrum') - if spec is not None and spec.get('flux') is not None: - flux = spec['flux'].numpy() - wavelength = spec['wavelength'].numpy() if spec.get('wavelength') is not None else np.arange(len(flux)) + spec = sample.get("spectrum") + if spec is not None and spec.get("flux") is not None: + flux = spec["flux"].numpy() + wavelength = ( + spec["wavelength"].numpy() + if spec.get("wavelength") is not None + else np.arange(len(flux)) + ) ax_spec.plot(wavelength, flux, linewidth=0.8) ax_spec.set_title("DESI Spectrum") ax_spec.set_xlabel("Wavelength (Å)") @@ -295,7 +322,7 @@ def display_one_sample( ax_spec.set_axis_off() # ----- SED (if available) ----- - sed = sample.get('sed_fluxes') + sed = sample.get("sed_fluxes") if sed is not None: ax_sed.bar(range(len(sed)), sed.numpy()) ax_sed.set_title(f"SED ({len(sed)} bands)") @@ -307,10 +334,10 @@ def display_one_sample( # ----- Individual bands ----- for ax, band_tensor, label in [ - (ax_vis, sample.get('vis_image'), "VIS"), - (ax_y, sample.get('nisp_y_image'), "NIR-Y"), - (ax_j, sample.get('nisp_j_image'), "NIR-J"), - (ax_h, sample.get('nisp_h_image'), "NIR-H"), + (ax_vis, sample.get("vis_image"), "VIS"), + (ax_y, sample.get("nisp_y_image"), "NIR-Y"), + (ax_j, sample.get("nisp_j_image"), "NIR-J"), + (ax_h, sample.get("nisp_h_image"), "NIR-H"), ]: if band_tensor is not None: im = ax.imshow(band_tensor.numpy(), cmap="viridis") @@ -337,21 +364,42 @@ def parse_args(argv=None): p = argparse.ArgumentParser( description="Load and display a sample from the Euclid+DESI dataset." ) - p.add_argument("--index", type=int, default=0, help="Index of the sample to display (default: 0)") + p.add_argument( + "--index", + type=int, + default=0, + help="Index of the sample to display (default: 0)", + ) p.add_argument("--split", type=str, default="train", help="HF split to use") - + # Default to configured path try: default_cache = str(load_paths().dataset) except Exception: default_cache = "./data" - p.add_argument("--cache-dir", type=str, - default=default_cache, - help="HuggingFace cache directory") - p.add_argument("--save", type=str, default=None, help="Save path for the figure (png/jpg, optional)") - p.add_argument("--no-gui", action="store_true", help="Do not open a window (save only if --save is provided)") - p.add_argument("--show-bands", action="store_true", help="Display spectrum/SED + individual bands if available") + p.add_argument( + "--cache-dir", + type=str, + default=default_cache, + help="HuggingFace cache directory", + ) + p.add_argument( + "--save", + type=str, + default=None, + help="Save path for the figure (png/jpg, optional)", + ) + p.add_argument( + "--no-gui", + action="store_true", + help="Do not open a window (save only if --save is provided)", + ) + p.add_argument( + "--show-bands", + action="store_true", + help="Display spectrum/SED + individual bands if available", + ) return p.parse_args(argv) diff --git a/src/fmb/data/utils.py b/src/fmb/data/utils.py index 19eec36..d30657c 100644 --- a/src/fmb/data/utils.py +++ b/src/fmb/data/utils.py @@ -7,14 +7,18 @@ import csv from pathlib import Path -from typing import List, Dict, Optional, Sequence, Tuple -import torch +from typing import Dict, List, Optional, Sequence, Tuple + import numpy as np +import torch from tqdm import tqdm from fmb.data.load_display_data import EuclidDESIDataset -def read_object_ids(csv_paths: Sequence[Path], limit: Optional[int] = None, verbose: bool = False) -> List[str]: + +def read_object_ids( + csv_paths: Sequence[Path], limit: Optional[int] = None, verbose: bool = False +) -> List[str]: """Read object IDs from a list of CSV files.""" ids = [] for p in csv_paths: @@ -25,24 +29,25 @@ def read_object_ids(csv_paths: Sequence[Path], limit: Optional[int] = None, verb with open(p, "r") as f: reader = csv.DictReader(f) if "object_id" not in reader.fieldnames: - # Check if it's single column no header or different name? - # For now strict. - print(f"[warn] 'object_id' column missing in {p}") - continue + # Check if it's single column no header or different name? + # For now strict. + print(f"[warn] 'object_id' column missing in {p}") + continue for row in reader: ids.append(str(row["object_id"])) if limit and len(ids) >= limit: break except Exception as e: print(f"[error] Failed to read {p}: {e}") - + if limit and len(ids) >= limit: break - + if verbose: print(f"Loaded {len(ids)} object IDs.") return ids + def load_index(path: Path) -> Dict[str, tuple]: """Load index CSV mapping object_id -> (split, index).""" mapping = {} @@ -56,117 +61,138 @@ def load_index(path: Path) -> Dict[str, tuple]: mapping[oid] = (split, idx) return mapping -def collect_samples(dataset: EuclidDESIDataset, object_ids: List[str], verbose: bool = False) -> List[Dict]: + +def collect_samples( + dataset: EuclidDESIDataset, object_ids: List[str], verbose: bool = False +) -> List[Dict]: """ Collect samples from dataset matching object_ids. WARNING: Linear scan if no index is used. Very slow for large datasets. """ target_set = set(object_ids) samples = [] - + # We scan the dataset # Optim: if dataset provides a way to get ID quickly without loading full sample? # EuclidDESIDataset loads full sample. # But usually we call this on a small list of query IDs. - + if verbose: - print(f"Scanning dataset ({len(dataset)} samples) for {len(target_set)} targets...") - + print( + f"Scanning dataset ({len(dataset)} samples) for {len(target_set)} targets..." + ) + found_map = {} - + # Heuristic: iterate full dataset once # Or rely on dataset having an internal index? It doesn't. - + for i in tqdm(range(len(dataset)), desc="Scanning Dataset", disable=not verbose): # We need to access just the ID if possible to be fast - # But dataset[i] does heavy loading (images etc). + # But dataset[i] does heavy loading (images etc). # This function is inherently slow without an auxiliary index. # But we must support it as fallback. - + # Accessing raw HF dataset might be faster for just ID check? # self.dataset[i]['object_id'] raw_sample = dataset.dataset[i] oid = str(raw_sample.get("object_id") or raw_sample.get("targetid")) - + if oid in target_set: # Now load full processed sample found_map[oid] = dataset[i] if len(found_map) == len(target_set): break - + # preserve order for oid in object_ids: if oid in found_map: samples.append(found_map[oid]) - + return samples -def collect_samples_with_index(cache_dir: str, object_ids: List[str], index_map: Dict[str, tuple], verbose: bool = False) -> List[Dict]: + +def collect_samples_with_index( + cache_dir: str, + object_ids: List[str], + index_map: Dict[str, tuple], + verbose: bool = False, +) -> List[Dict]: """Collect samples efficiently using a pre-computed index.""" samples = [] - + # Group by split to minimize dataset reloads by_split = {} for oid in object_ids: if oid in index_map: split, idx = index_map[oid] - if split not in by_split: by_split[split] = [] + if split not in by_split: + by_split[split] = [] by_split[split].append((oid, idx)) - + gathered = {} - + for split, items in by_split.items(): - if verbose: print(f"Loading split '{split}' for {len(items)} samples...") + if verbose: + print(f"Loading split '{split}' for {len(items)} samples...") ds = EuclidDESIDataset(split=split, cache_dir=cache_dir, verbose=False) for oid, idx in items: gathered[oid] = ds[idx] - + # preserve order for oid in object_ids: if oid in gathered: samples.append(gathered[oid]) - + return samples + def prepare_rgb_image(sample: Dict) -> np.ndarray: """Extract RGB image from sample as (H, W, 3/1) numpy array.""" # From tensor (C, H, W) to (H, W, C) numpy img_t = sample.get("rgb_image") if img_t is None: return np.zeros((64, 64, 3), dtype=np.uint8) - + img_np = img_t.permute(1, 2, 0).cpu().numpy() # Clip 0..1 img_np = np.clip(img_np, 0, 1) return img_np + # --- Embedding Loading Utilities --- + def load_embeddings_file(path: Path) -> List[Dict]: """Load raw embeddings list of dicts.""" print(f"Loading embeddings from {path}...") try: data = torch.load(path, map_location="cpu", weights_only=False) - if isinstance(data, list): return data - if isinstance(data, dict): return [data] + if isinstance(data, list): + return data + if isinstance(data, dict): + return [data] raise ValueError(f"Unknown format in {path}") except Exception as e: print(f"Error loading {path}: {e}") return [] -def extract_embedding_matrices(records: List[Dict]) -> Tuple[Dict[str, torch.Tensor], List[str]]: + +def extract_embedding_matrices( + records: List[Dict], +) -> Tuple[Dict[str, torch.Tensor], List[str]]: """ Extract tensors for all available modalities. Returns map {modality_key: Tensor(N, D)} and list of object_ids. """ import torch.nn.functional as F - + if not records: return {}, [] - + sample = records[0] keys = [] - + # Heuristics for modality keys if "embedding_images" in sample and "embedding_spectra" in sample: # AstroPT/Clip style @@ -176,49 +202,51 @@ def extract_embedding_matrices(records: List[Dict]) -> Tuple[Dict[str, torch.Ten keys = ["embedding_hsc", "embedding_spectrum"] if "embedding_hsc_desi" in sample: keys.append("embedding_hsc_desi") - + # Fallback/Filtering (only keys present in sample) final_keys = [] # If explicit keys derived, verify them if keys: - final_keys = [k for k in keys if k in sample] + final_keys = [k for k in keys if k in sample] else: - final_keys = [k for k in sample.keys() if k.startswith("embedding_")] - + final_keys = [k for k in sample.keys() if k.startswith("embedding_")] + # Ensure joint calculation if missing but components exist? # For AstroCLIP/PT, usually 'embedding_joint' is saved. - # If not, existing scripts handled it specifically. + # If not, existing scripts handled it specifically. # For generic loader, we keep it simple: return what is there. - + print(f" Detected modalities: {final_keys}") - + vectors_map = {k: [] for k in final_keys} oids = [] - + for r in records: oid = str(r.get("object_id") or r.get("targetid", "")) - if not oid: continue - + if not oid: + continue + current = {} valid = True for k in final_keys: v = r.get(k) - if v is None: - valid = False; break + if v is None: + valid = False + break if not isinstance(v, torch.Tensor): v = torch.tensor(v) current[k] = v.flatten().float() - + if valid: oids.append(oid) for k in final_keys: vectors_map[k].append(current[k]) - + # Stack and Normalize matrices = {} for k, vlist in vectors_map.items(): mat = torch.stack(vlist) mat = F.normalize(mat, p=2, dim=1) matrices[k] = mat - + return matrices, oids diff --git a/src/fmb/detection/__init__.py b/src/fmb/detection/__init__.py index e7e8cb3..04b1950 100644 --- a/src/fmb/detection/__init__.py +++ b/src/fmb/detection/__init__.py @@ -4,4 +4,3 @@ Module: fmb.detection.__init__ Description: FMB module: fmb.detection.__init__ """ - diff --git a/src/fmb/detection/cosine.py b/src/fmb/detection/cosine.py index 2418b26..9cb1db3 100644 --- a/src/fmb/detection/cosine.py +++ b/src/fmb/detection/cosine.py @@ -6,26 +6,30 @@ """ import argparse -import sys from pathlib import Path -from typing import List, Dict, Tuple, Optional +from typing import Dict, List, Tuple import numpy as np import torch -from fmb.paths import load_paths + from fmb.detection import utils +from fmb.paths import load_paths # Reusing the known key pairs logic but adapted KNOWN_KEY_PAIRS = { "astropt": ("embedding_images", "embedding_spectra"), - "aion": ("embedding_hsc", "embedding_spectrum"), # heuristic, might be embedding_hsc_desi - "astroclip": ("embedding_images", "embedding_spectra") + "aion": ( + "embedding_hsc", + "embedding_spectrum", + ), # heuristic, might be embedding_hsc_desi + "astroclip": ("embedding_images", "embedding_spectra"), } + def detect_keys(record: Dict, model_name: str) -> Tuple[str, str]: """Detect image/spectrum keys for a record.""" keys = set(record.keys()) - + # Check explicit known pairs first if model_name in KNOWN_KEY_PAIRS: img, spec = KNOWN_KEY_PAIRS[model_name] @@ -38,28 +42,27 @@ def detect_keys(record: Dict, model_name: str) -> Tuple[str, str]: return "embedding_hsc_desi", "embedding_spectrum" # Fallback heuristics - img_candidates = [k for k in keys if 'image' in k.lower() or 'hsc' in k.lower()] - spec_candidates = [k for k in keys if 'spectr' in k.lower()] - + img_candidates = [k for k in keys if "image" in k.lower() or "hsc" in k.lower()] + spec_candidates = [k for k in keys if "spectr" in k.lower()] + if len(img_candidates) >= 1 and len(spec_candidates) >= 1: # Pick shortest or first? return img_candidates[0], spec_candidates[0] - + raise ValueError(f"Could not key pair in {list(keys)}") + def compute_cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float: norm1 = np.linalg.norm(vec1) norm2 = np.linalg.norm(vec2) - if norm1 == 0 or norm2 == 0: return 0.0 + if norm1 == 0 or norm2 == 0: + return 0.0 return float(np.dot(vec1 / norm1, vec2 / norm2)) -def process_file( - input_path: Path, - output_path: Path, - model_name: str -) -> None: + +def process_file(input_path: Path, output_path: Path, model_name: str) -> None: print(f"\n---> Processing Cosine for {model_name} in {input_path}") - + try: records = utils.load_records(input_path) except Exception as e: @@ -81,32 +84,34 @@ def process_file( # We need aligned lists. extract_embeddings returns arrays and IDs. # But we need to ensure we match image and spec for the SAME object. # The utils.extract_embeddings extracts one key. - + # Let's iterate manually to ensure pairing rows = [] skipped = 0 - + for i, rec in enumerate(records): obj_id = rec.get("object_id", str(i)) - + # Helper to get numpy def get_np(k): v = rec.get(k) - if v is None: return None - if isinstance(v, torch.Tensor): return v.detach().cpu().numpy().flatten() + if v is None: + return None + if isinstance(v, torch.Tensor): + return v.detach().cpu().numpy().flatten() return np.asarray(v).flatten() v_img = get_np(img_key) v_spec = get_np(spec_key) - + if v_img is None or v_spec is None: skipped += 1 continue - + if np.any(np.isnan(v_img)) or np.any(np.isnan(v_spec)): skipped += 1 continue - + sim = compute_cosine_similarity(v_img, v_spec) rows.append({"object_id": str(obj_id), "cosine_similarity": sim}) @@ -125,21 +130,27 @@ def get_np(k): # Save import csv + output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=["object_id", "cosine_similarity", "rank"]) + writer = csv.DictWriter( + f, fieldnames=["object_id", "cosine_similarity", "rank"] + ) writer.writeheader() writer.writerows(rows) - + print(f"[success] Saved {len(rows)} cosine scores to {output_path}") + def main(argv: List[str] = None): - parser = argparse.ArgumentParser(description="Compute Cosine Similarity (Image vs Spectrum)") + parser = argparse.ArgumentParser( + description="Compute Cosine Similarity (Image vs Spectrum)" + ) # Optional overrides parser.add_argument("--aion-embeddings", type=str) parser.add_argument("--astropt-embeddings", type=str) parser.add_argument("--astroclip-embeddings", type=str) - + args = parser.parse_args(argv) paths = load_paths() out_dir = paths.outliers @@ -148,18 +159,33 @@ def main(argv: List[str] = None): # Resolve Inputs # Similar logic to run.py inputs = [] - + # AION - p = Path(args.aion_embeddings) if args.aion_embeddings else (paths.embeddings / "aions_embeddings.pt") - if p.exists(): inputs.append((p, "aion")) - + p = ( + Path(args.aion_embeddings) + if args.aion_embeddings + else (paths.embeddings / "aions_embeddings.pt") + ) + if p.exists(): + inputs.append((p, "aion")) + # AstroPT - p = Path(args.astropt_embeddings) if args.astropt_embeddings else (paths.embeddings / "astropt_embeddings.pt") - if p.exists(): inputs.append((p, "astropt")) - + p = ( + Path(args.astropt_embeddings) + if args.astropt_embeddings + else (paths.embeddings / "astropt_embeddings.pt") + ) + if p.exists(): + inputs.append((p, "astropt")) + # AstroCLIP - p = Path(args.astroclip_embeddings) if args.astroclip_embeddings else (paths.embeddings / "embeddings_astroclip.pt") - if p.exists(): inputs.append((p, "astroclip")) + p = ( + Path(args.astroclip_embeddings) + if args.astroclip_embeddings + else (paths.embeddings / "embeddings_astroclip.pt") + ) + if p.exists(): + inputs.append((p, "astroclip")) if not inputs: print("[error] No embedding files found.") @@ -169,5 +195,6 @@ def main(argv: List[str] = None): out_file = out_dir / f"cosine_scores_{name}.csv" process_file(p, out_file, name) + if __name__ == "__main__": main() diff --git a/src/fmb/detection/models.py b/src/fmb/detection/models.py index d97e1b2..1f7c62c 100644 --- a/src/fmb/detection/models.py +++ b/src/fmb/detection/models.py @@ -5,9 +5,9 @@ Description: Normalizing Flow architectures for anomaly detection """ -import torch -import torch.nn as nn import normflows as nf +import torch.nn as nn + def build_coupling_flow( dim: int, @@ -22,14 +22,14 @@ def build_coupling_flow( raise RuntimeError( f"Embedding dimension {dim} is not supported for RealNVP-style coupling (needs an even dimension >= 2)." ) - + base = nf.distributions.base.DiagGaussian(dim) flows: list[nn.Module] = [] - + # Split dimensions for coupling cond_dim = dim // 2 transformed_dim = dim - cond_dim - + for _ in range(num_transforms): # MLP for the affine transformation parameters net = nf.nets.MLP( @@ -42,7 +42,7 @@ def build_coupling_flow( flows.append(nf.flows.Permute(dim, mode="swap")) # ActNorm for training stability flows.append(nf.flows.ActNorm((dim,))) - + return nf.NormalizingFlow(base, flows) @@ -57,15 +57,19 @@ def build_autoregressive_flow( """ base = nf.distributions.base.DiagGaussian(dim) flows: list[nn.Module] = [] - + for _ in range(num_transforms): # Masked Affine Autoregressive Flow - flows.append(nf.flows.MaskedAffineAutoregressive(features=dim, hidden_features=hidden_features)) + flows.append( + nf.flows.MaskedAffineAutoregressive( + features=dim, hidden_features=hidden_features + ) + ) # Permutation to mix dimensions flows.append(nf.flows.Permute(dim, mode="swap")) # ActNorm flows.append(nf.flows.ActNorm(dim)) - + return nf.NormalizingFlow(base, flows) @@ -82,4 +86,6 @@ def build_flow( elif flow_type in ("autoregressive", "maf", "ar"): return build_autoregressive_flow(dim, hidden_features, num_transforms) else: - raise ValueError(f"Unknown flow_type: {flow_type}. Options: 'coupling', 'autoregressive'") + raise ValueError( + f"Unknown flow_type: {flow_type}. Options: 'coupling', 'autoregressive'" + ) diff --git a/src/fmb/detection/multimodal.py b/src/fmb/detection/multimodal.py index ec0dc0a..b87bf3e 100644 --- a/src/fmb/detection/multimodal.py +++ b/src/fmb/detection/multimodal.py @@ -6,14 +6,11 @@ """ import argparse -import sys from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Optional -import pandas as pd import numpy as np -import torch -import torch.nn.functional as F +import pandas as pd from fmb.paths import load_paths @@ -21,12 +18,16 @@ # Loaders # ------------------------------------------------------------------------- + def normalize_modality(key: str) -> Optional[str]: k = key.lower() - if "hsc" in k or "image" in k: return "img" - if "spec" in k: return "spec" + if "hsc" in k or "image" in k: + return "img" + if "spec" in k: + return "spec" return None + def load_cosine_csv(path: Path, model_name: str) -> pd.DataFrame: if not path.exists(): print(f"[warn] Cosine CSV not found: {path} (skipping model {model_name})") @@ -39,6 +40,7 @@ def load_cosine_csv(path: Path, model_name: str) -> pd.DataFrame: df = df.rename(columns={"rank": "rank_cosine"}) return df + def load_nf_csv(path: Path, model_name: str) -> pd.DataFrame: if not path.exists(): print(f"[warn] NF CSV not found: {path} (skipping model {model_name})") @@ -46,67 +48,86 @@ def load_nf_csv(path: Path, model_name: str) -> pd.DataFrame: df = pd.read_csv(path) # Expected: object_id, embedding_key, anomaly_sigma, rank df["object_id"] = df["object_id"].astype(str) - + # Pivot logic df["modality"] = df["embedding_key"].apply(normalize_modality) df = df.dropna(subset=["modality"]) - + # Deduplicate keeping min rank (most anomalous) - df = df.sort_values("rank", ascending=True).drop_duplicates(["object_id", "modality"]) - - p_sigma = df.pivot(index="object_id", columns="modality", values="anomaly_sigma").add_prefix("nf_").add_suffix("_sigma") - p_rank = df.pivot(index="object_id", columns="modality", values="rank").add_prefix("rank_") - + df = df.sort_values("rank", ascending=True).drop_duplicates( + ["object_id", "modality"] + ) + + p_sigma = ( + df.pivot(index="object_id", columns="modality", values="anomaly_sigma") + .add_prefix("nf_") + .add_suffix("_sigma") + ) + p_rank = df.pivot(index="object_id", columns="modality", values="rank").add_prefix( + "rank_" + ) + wide = p_sigma.join(p_rank).reset_index() wide["model"] = model_name return wide + # ------------------------------------------------------------------------- # Scoring # ------------------------------------------------------------------------- + def rank_to_p(rank_col: pd.Series, N: int) -> pd.Series: # rank 1 is top anomaly. p=1.0. rank N is p=0.0 return (N - rank_col + 1.0) / N + def compute_scores(df: pd.DataFrame) -> pd.DataFrame: - N = len(df) - # We want global percentiles usually, or per-model? + len(df) + # We want global percentiles usually, or per-model? # The script used global. Let's do global for simplicity of comparison. # But usually ranks are relative to the sub-population. # Let's do groupings if we have multiple models. - + OUT = [] for model, sub in df.groupby("model"): sub = sub.copy() Nm = len(sub) # Handle missing cols if some model misses a modality - if "rank_cosine" in sub: sub["p_mis"] = rank_to_p(sub["rank_cosine"], Nm) - else: sub["p_mis"] = 0.0 - - if "rank_img" in sub: sub["p_img"] = rank_to_p(sub["rank_img"], Nm) - else: sub["p_img"] = 0.0 # Not anomalous - - if "rank_spec" in sub: sub["p_spec"] = rank_to_p(sub["rank_spec"], Nm) - else: sub["p_spec"] = 0.0 + if "rank_cosine" in sub: + sub["p_mis"] = rank_to_p(sub["rank_cosine"], Nm) + else: + sub["p_mis"] = 0.0 + + if "rank_img" in sub: + sub["p_img"] = rank_to_p(sub["rank_img"], Nm) + else: + sub["p_img"] = 0.0 # Not anomalous + + if "rank_spec" in sub: + sub["p_spec"] = rank_to_p(sub["rank_spec"], Nm) + else: + sub["p_spec"] = 0.0 # Fusion sub["score_mm_geo"] = sub["p_mis"] * np.sqrt(sub["p_img"] * sub["p_spec"]) sub["score_mm_min"] = sub["p_mis"] * np.minimum(sub["p_img"], sub["p_spec"]) - + # Uplift: How much does multimodal fusion add over the best single modality? # (Assuming p_mis is part of the fusion, so we compare fusion result vs single modality ranks) # Note: p_img/p_spec are percentiles (0..1). sub["uplift_mm"] = sub["score_mm_geo"] - np.maximum(sub["p_img"], sub["p_spec"]) - + OUT.append(sub) - + return pd.concat(OUT, ignore_index=True) if OUT else pd.DataFrame() + # ------------------------------------------------------------------------- # Main # ------------------------------------------------------------------------- + def main(argv: List[str] = None): parser = argparse.ArgumentParser() # If not provided, we infer from paths @@ -116,42 +137,46 @@ def main(argv: List[str] = None): parser.add_argument("--t-img", type=float, default=0.0) parser.add_argument("--t-spec", type=float, default=0.0) parser.add_argument("--t-mis", type=float, default=0.0) - + args = parser.parse_args(argv) - + paths = load_paths() root = paths.outliers - + # 1. Load Everything models = ["aion", "astropt", "astroclip"] - + all_cos = [] all_nf = [] - + for m in models: # Cosine p_cos = root / f"cosine_scores_{m}.csv" df_c = load_cosine_csv(p_cos, m) - if not df_c.empty: all_cos.append(df_c) - + if not df_c.empty: + all_cos.append(df_c) + # NF p_nf = root / f"anomaly_scores_{m}.csv" df_n = load_nf_csv(p_nf, m) - if not df_n.empty: all_nf.append(df_n) - + if not df_n.empty: + all_nf.append(df_n) + if not all_cos or not all_nf: - print("[error] Missing input CSVs (run 'detect outliers' and 'detect cosine' first).") + print( + "[error] Missing input CSVs (run 'detect outliers' and 'detect cosine' first)." + ) return df_cos = pd.concat(all_cos, ignore_index=True) df_nf = pd.concat(all_nf, ignore_index=True) - + # Merge print(f"Merging {len(df_cos)} cosine rows with {len(df_nf)} NF rows...") # Inner join mainly df = pd.merge(df_cos, df_nf, on=["object_id", "model"], how="inner") print(f"Merged total: {len(df)}") - + if df.empty: print("[error] Empty merge result.") return @@ -161,17 +186,21 @@ def main(argv: List[str] = None): # 3. Filter & Sort # Apply thresholds - mask = (df["p_img"] >= args.t_img) & (df["p_spec"] >= args.t_spec) & (df["p_mis"] >= args.t_mis) + mask = ( + (df["p_img"] >= args.t_img) + & (df["p_spec"] >= args.t_spec) + & (df["p_mis"] >= args.t_mis) + ) df_filt = df[mask].copy() - + key = f"score_mm_{args.fusion}" # Sort descending (score high -> anomalous) df_filt = df_filt.sort_values(key, ascending=False) - + OUT_DIR = root / "multimodal" OUT_DIR.mkdir(parents=True, exist_ok=True) print(f"\nOutput Directory: {OUT_DIR}") - + # Save All out_all = OUT_DIR / "all_scores.csv" df.to_csv(out_all, index=False) @@ -179,14 +208,14 @@ def main(argv: List[str] = None): # --- Export Unfiltered Ranked Lists (for inspection/plotting) --- # The user requested "properly weighted" lists per model for plotting - + RANKED_DIR = OUT_DIR / "ranked" RANKED_DIR.mkdir(parents=True, exist_ok=True) - + # We sort the full dataframe by fusion score # And export top-K per model # This ignores t_img/t_spec/t_mis thresholds - + print(f"\nExporting ranked (unfiltered) lists to: {RANKED_DIR}") for model, sub in df.groupby("model"): # Sort desc by fusion score @@ -198,19 +227,24 @@ def main(argv: List[str] = None): # --- Filtered Lists --- if df_filt.empty: print("\n[warn] No anomalies passed the STRICT filtering thresholds!") - print(f" Thresholds: P_img>={args.t_img}, P_spec>={args.t_spec}, P_mis>={args.t_mis}") - print(" (You can use the 'ranked' lists above for plotting regardless of these filters)") + print( + f" Thresholds: P_img>={args.t_img}, P_spec>={args.t_spec}, P_mis>={args.t_mis}" + ) + print( + " (You can use the 'ranked' lists above for plotting regardless of these filters)" + ) return # Save Filtered Top-K per model FILTERED_DIR = OUT_DIR / "filtered" FILTERED_DIR.mkdir(parents=True, exist_ok=True) - + for model, sub in df_filt.groupby("model"): top = sub.head(args.top_k) fname = FILTERED_DIR / f"top{args.top_k}_{model}_{args.fusion}_filtered.csv" top.to_csv(fname, index=False) print(f"[success] Saved filtered top-{len(top)} '{model}' to {fname}") + if __name__ == "__main__": main() diff --git a/src/fmb/detection/run.py b/src/fmb/detection/run.py index d2d8270..7cef2e7 100644 --- a/src/fmb/detection/run.py +++ b/src/fmb/detection/run.py @@ -7,14 +7,15 @@ import argparse import sys -import yaml from pathlib import Path from typing import List import torch +import yaml +from fmb.detection import models, train, utils from fmb.paths import load_paths -from fmb.detection import utils, models, train + def load_config(config_path: Path) -> dict: if not config_path.exists(): @@ -22,14 +23,15 @@ def load_config(config_path: Path) -> dict: with open(config_path, "r") as f: return yaml.safe_load(f) + def process_single_embedding( path: Path, possible_keys: List[str], config: dict, model_name: str, - output_dir: Path + output_dir: Path, ) -> None: - + # 1. Load Data try: records = utils.load_records(path) @@ -43,33 +45,41 @@ def process_single_embedding( return valid_keys = [k for k in possible_keys if k in records[0]] - + # AstroPT Synthesis Fallback - if model_name == "astropt" and "embedding_joint" in possible_keys and "embedding_joint" not in valid_keys: + if ( + model_name == "astropt" + and "embedding_joint" in possible_keys + and "embedding_joint" not in valid_keys + ): if "embedding_images" in records[0] and "embedding_spectra" in records[0]: print(" [info] AstroPT: 'embedding_joint' not found. Synthesizing...") valid_keys.append("embedding_joint") - # We synthesize on the fly during extraction below if needed, + # We synthesize on the fly during extraction below if needed, # OR we update records here. Let's update records here for simplicity. for rec in records: img = rec.get("embedding_images") spec = rec.get("embedding_spectra") if img is not None and spec is not None: - # Ensure numpy - if isinstance(img, torch.Tensor): img = img.cpu().numpy() - if isinstance(spec, torch.Tensor): spec = spec.cpu().numpy() - rec["embedding_joint"] = import_numpy().concatenate([img.flatten(), spec.flatten()]) - + # Ensure numpy + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + rec["embedding_joint"] = import_numpy().concatenate( + [img.flatten(), spec.flatten()] + ) + if not valid_keys: print(f"[warn] No valid keys found in {path}. Expected: {possible_keys}") return # 3. Process Each Key all_rows = [] - + for key in valid_keys: print(f"\n---> Processing {model_name} key: '{key}'") - + # Extract embeddings, ids = utils.extract_embeddings(records, key) if len(embeddings) == 0: @@ -79,7 +89,7 @@ def process_single_embedding( # Clean embeddings_tensor = torch.from_numpy(embeddings).float() embeddings_tensor, ids = utils.filter_nonfinite_rows(embeddings_tensor, ids) - + if len(embeddings_tensor) < 2: print(f"[warn] Not enough data for {key}") continue @@ -90,7 +100,9 @@ def process_single_embedding( if pca_comps < embeddings_tensor.shape[1]: embeddings_tensor, _ = utils.apply_pca(embeddings_tensor, pca_comps) else: - print(f"[info] Skipping PCA (requested {pca_comps} >= dim {embeddings_tensor.shape[1]})") + print( + f"[info] Skipping PCA (requested {pca_comps} >= dim {embeddings_tensor.shape[1]})" + ) # Standardize if config.get("standardize", True): @@ -108,9 +120,9 @@ def process_single_embedding( hidden_features=config.get("hidden_features", 256), num_transforms=config.get("num_transforms", 8), ) - + print(f"[{key}] Training {config.get('flow_type')} flow (dim={dim})...") - + # Train train.train_flow_model( flow=flow, @@ -123,14 +135,14 @@ def process_single_embedding( grad_clip=config.get("grad_clip", 5.0), weight_decay=config.get("weight_decay", 1e-5), ) - + # Score log_probs = train.compute_log_probs( - flow, - embeddings_tensor, - config.get("device", "cuda" if torch.cuda.is_available() else "cpu") + flow, + embeddings_tensor, + config.get("device", "cuda" if torch.cuda.is_available() else "cpu"), ) - + # Collate rows = utils.collate_rows(ids, key, log_probs) all_rows.extend(rows) @@ -146,37 +158,51 @@ def process_single_embedding( # Cleanup del records import gc - gc.collect() + + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def import_numpy(): import numpy as np + return np + def main(argv: List[str] = None): paths = load_paths() - - parser = argparse.ArgumentParser(description="Run Normalizing Flow Outlier Detection") - parser.add_argument("--config", type=str, default=str(paths.repo_root / "src/fmb/configs/detection/anomalies.yaml"), help="Path to config YAML") - + + parser = argparse.ArgumentParser( + description="Run Normalizing Flow Outlier Detection" + ) + parser.add_argument( + "--config", + type=str, + default=str(paths.repo_root / "src/fmb/configs/detection/anomalies.yaml"), + help="Path to config YAML", + ) + # Embedding inputs (defaults from paths.py) # If not provided in CLI, we check paths.py properties # But paths.py properties usually return Paths, which might not exist if not downloaded. # We will use reasonable defaults or check existence. - + # We allow explicit CLI overrides, otherwise we check default locations parser.add_argument("--aion-embeddings", type=str, help="Path to AION embeddings") - parser.add_argument("--astropt-embeddings", type=str, help="Path to AstroPT embeddings") - parser.add_argument("--astroclip-embeddings", type=str, help="Path to AstroCLIP embeddings") - + parser.add_argument( + "--astropt-embeddings", type=str, help="Path to AstroPT embeddings" + ) + parser.add_argument( + "--astroclip-embeddings", type=str, help="Path to AstroCLIP embeddings" + ) + args = parser.parse_args(argv) - + # Load config config = load_config(Path(args.config)) utils.set_random_seed(config.get("random_seed", 42)) - + # Determine Output Directory # Use paths.outliers (from paths_local.yaml) output_dir = paths.outliers @@ -188,32 +214,46 @@ def main(argv: List[str] = None): # Or just require them? # The previous script used absolute paths. # We can try to guess defaults based on known filenames if not provided. - + tasks = [] - + # AION - aion_path = Path(args.aion_embeddings) if args.aion_embeddings else (paths.embeddings / "aions_embeddings.pt") + aion_path = ( + Path(args.aion_embeddings) + if args.aion_embeddings + else (paths.embeddings / "aions_embeddings.pt") + ) if aion_path.exists(): tasks.append((aion_path, utils.KEYS_AION, "aion")) elif args.aion_embeddings: print(f"[error] AION file not found: {aion_path}") # AstroPT - astropt_path = Path(args.astropt_embeddings) if args.astropt_embeddings else (paths.embeddings / "astropt_embeddings.pt") + astropt_path = ( + Path(args.astropt_embeddings) + if args.astropt_embeddings + else (paths.embeddings / "astropt_embeddings.pt") + ) if astropt_path.exists(): tasks.append((astropt_path, utils.KEYS_ASTROPT, "astropt")) elif args.astropt_embeddings: print(f"[error] AstroPT file not found: {astropt_path}") # AstroCLIP - astroclip_path = Path(args.astroclip_embeddings) if args.astroclip_embeddings else (paths.embeddings / "embeddings_astroclip.pt") + astroclip_path = ( + Path(args.astroclip_embeddings) + if args.astroclip_embeddings + else (paths.embeddings / "embeddings_astroclip.pt") + ) if astroclip_path.exists(): tasks.append((astroclip_path, utils.KEYS_ASTROCLIP, "astroclip")) elif args.astroclip_embeddings: print(f"[error] AstroCLIP file not found: {astroclip_path}") if not tasks: - print("[error] No embedding files found. Please provide paths via --aion-embeddings etc. or ensure they exist in configured 'embeddings_path'.") + print( + "[error] No embedding files found. Please provide paths via --aion-embeddings etc. or ensure they exist in configured 'embeddings_path'." + ) return for fpath, keys, name in tasks: @@ -225,8 +265,10 @@ def main(argv: List[str] = None): sys.exit(1) except Exception as e: import traceback + traceback.print_exc() print(f"[error] Failed processing {name}: {e}") + if __name__ == "__main__": main() diff --git a/src/fmb/detection/train.py b/src/fmb/detection/train.py index ca73e73..dde06aa 100644 --- a/src/fmb/detection/train.py +++ b/src/fmb/detection/train.py @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm + def train_flow_model( flow: nn.Module, data: torch.Tensor, @@ -31,25 +32,25 @@ def train_flow_model( shuffle=True, drop_last=False, ) - + flow.to(device_obj) optimizer = torch.optim.Adam(flow.parameters(), lr=lr, weight_decay=weight_decay) flow.train() - + # Epoch loop with progress bar epoch_pbar = tqdm(range(1, epochs + 1), desc="Training Epochs", unit="epoch") - + for epoch in epoch_pbar: total_loss = 0.0 total_items = 0 skipped_batches = 0 - + # Batch loop - hide progress for individual batches unless very slow - batch_pbar = tqdm(loader, desc=f"Epoch {epoch}", leave=False, unit="batch", disable=True) - + tqdm(loader, desc=f"Epoch {epoch}", leave=False, unit="batch", disable=True) + for (batch,) in loader: batch = batch.to(device_obj) - + # Loss is negative log likelihood (or forward KLD for some flows) try: # Try standard log_prob first (for RealNVP) @@ -60,39 +61,43 @@ def train_flow_model( # Fallback for some wrappers (like MAF in older versions maybe?) # standardized wrapper usually ensures log_prob exists loss = flow.forward_kld(batch) - + if not torch.isfinite(loss): optimizer.zero_grad(set_to_none=True) skipped_batches += 1 continue - + optimizer.zero_grad(set_to_none=True) loss.backward() - + if grad_clip is not None and grad_clip > 0: torch.nn.utils.clip_grad_norm_(flow.parameters(), grad_clip) - + optimizer.step() - + total_loss += loss.item() * batch.size(0) total_items += batch.size(0) - + avg_loss = total_loss / max(total_items, 1) if total_items > 0 else float("nan") - + # Update epoch pbar epoch_pbar.set_postfix({"loss": f"{avg_loss:.4f}"}) - + if not np.isfinite(avg_loss): - tqdm.write(f"[error] Epoch {epoch}: Loss is NaN/Inf. Stopping training for this key.") - break + tqdm.write( + f"[error] Epoch {epoch}: Loss is NaN/Inf. Stopping training for this key." + ) + break if skipped_batches == len(loader): - tqdm.write(f"[warn] Epoch {epoch}: All batches skipped (NaN/Inf).") - break + tqdm.write(f"[warn] Epoch {epoch}: All batches skipped (NaN/Inf).") + break if log_every > 0 and (epoch == 1 or epoch == epochs or epoch % log_every == 0): - tqdm.write(f"[{flow.__class__.__name__}] epoch {epoch:03d}/{epochs:03d} | loss={avg_loss:.4f}") - + tqdm.write( + f"[{flow.__class__.__name__}] epoch {epoch:03d}/{epochs:03d} | loss={avg_loss:.4f}" + ) + flow.eval() @@ -101,12 +106,12 @@ def compute_log_probs(flow: nn.Module, data: torch.Tensor, device: str) -> np.nd device_obj = torch.device(device) loader = DataLoader(TensorDataset(data), batch_size=2048, shuffle=False) log_probs_list = [] - + flow.eval() with torch.no_grad(): for (batch,) in loader: batch = batch.to(device_obj) lp = flow.log_prob(batch) log_probs_list.append(lp.cpu().numpy()) - + return np.concatenate(log_probs_list) diff --git a/src/fmb/detection/utils.py b/src/fmb/detection/utils.py index 092ac69..f05e2bb 100644 --- a/src/fmb/detection/utils.py +++ b/src/fmb/detection/utils.py @@ -8,11 +8,10 @@ import csv import random from pathlib import Path -from typing import Sequence, Dict, Tuple, List, Optional, Iterable, Union +from typing import Iterable, Sequence import numpy as np import torch -from torch.utils.data import TensorDataset, DataLoader # Default keys per model type KEYS_AION = ["embedding_hsc_desi", "embedding_hsc", "embedding_spectrum"] @@ -39,26 +38,28 @@ def load_records(path: Path) -> list[dict]: raise ValueError(f"Unsupported embeddings format: {type(data)}") -def extract_embeddings(records: Sequence[dict], key: str) -> tuple[np.ndarray, list[str]]: +def extract_embeddings( + records: Sequence[dict], key: str +) -> tuple[np.ndarray, list[str]]: """Extract vectors and object_ids for a specific key.""" vectors: list[np.ndarray] = [] object_ids: list[str] = [] - + found_any = False for rec in records: tensor = rec.get(key) if tensor is None: continue found_any = True - + if isinstance(tensor, torch.Tensor): array = tensor.detach().cpu().numpy().copy() else: array = np.asarray(tensor).copy() - + if array.ndim > 1: array = array.flatten() - + vectors.append(array) object_id = rec.get("object_id", "") object_ids.append(str(object_id)) @@ -78,13 +79,13 @@ def filter_nonfinite_rows( mask = torch.isfinite(tensor).all(dim=1) if mask.all(): return tensor, list(object_ids) - + filtered_tensor = tensor[mask] filtered_ids = [obj for obj, keep in zip(object_ids, mask.tolist()) if keep] dropped = len(object_ids) - len(filtered_ids) if dropped > 0: print(f"[warn] dropped {dropped} rows containing NaN/inf values") - + return filtered_tensor, filtered_ids @@ -99,7 +100,9 @@ def clip_embeddings_by_sigma(tensor: torch.Tensor, sigma: float) -> torch.Tensor return torch.clamp(tensor, min=lower, max=upper) -def standardize_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def standardize_tensor( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Standardize (z-score) the tensor.""" mean = tensor.mean(dim=0, keepdim=True) std = tensor.std(dim=0, keepdim=True).clamp_min(1e-6) @@ -110,22 +113,26 @@ def standardize_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor def apply_pca(tensor: torch.Tensor, n_components: int) -> tuple[torch.Tensor, object]: """Reduce dimensions using PCA.""" from sklearn.decomposition import PCA - + n_samples, n_features = tensor.shape max_components = min(n_samples, n_features) - + if n_components > max_components: - print(f" [PCA] Reducing n_components from {n_components} to {max_components} (min(samples, features))") + print( + f" [PCA] Reducing n_components from {n_components} to {max_components} (min(samples, features))" + ) n_components = max_components - - print(f" [PCA] Fitting PCA with n_components={n_components} on shape {tensor.shape}...") + + print( + f" [PCA] Fitting PCA with n_components={n_components} on shape {tensor.shape}..." + ) data_np = tensor.cpu().numpy() pca = PCA(n_components=n_components) transformed_np = pca.fit_transform(data_np) - + explained = pca.explained_variance_ratio_.sum() print(f" [PCA] Explained variance: {explained:.4f}") - + return torch.from_numpy(transformed_np).float(), pca @@ -149,7 +156,7 @@ def collate_rows( order = np.argsort(-sigma_scores) ranks = np.empty_like(order) ranks[order] = np.arange(1, len(order) + 1) - + rows: list[dict] = [] for idx, object_id in enumerate(object_ids): rows.append( @@ -169,7 +176,14 @@ def save_scores_csv(path: Path, rows: Iterable[dict]) -> None: """Save results to CSV.""" if not rows: return - fieldnames = ["object_id", "embedding_key", "log_prob", "neg_log_prob", "anomaly_sigma", "rank"] + fieldnames = [ + "object_id", + "embedding_key", + "log_prob", + "neg_log_prob", + "anomaly_sigma", + "rank", + ] path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) diff --git a/src/fmb/embeddings/generate_embeddings_astropt.py b/src/fmb/embeddings/generate_embeddings_astropt.py index 24cef30..c27c411 100644 --- a/src/fmb/embeddings/generate_embeddings_astropt.py +++ b/src/fmb/embeddings/generate_embeddings_astropt.py @@ -10,23 +10,22 @@ Run inference with the trained astroPT multimodal model and export embeddings. """ import argparse -from pathlib import Path -from typing import Dict, List, Optional, Any -import sys -import yaml import os +import sys import warnings - -# Use fmb.paths -from fmb.paths import load_paths - -# Imports from local package -from fmb.data.datasets import AstroPTDataset, FMBDataConfig +from pathlib import Path +from typing import Any, Dict, List, Optional import torch +import yaml from torch.utils.data import DataLoader from tqdm import tqdm +# Imports from local package +from fmb.data.datasets import AstroPTDataset, FMBDataConfig +# Use fmb.paths +from fmb.paths import load_paths + # Add external/astroPT/src to path astropt_path = Path(__file__).resolve().parents[3] / "external" / "astroPT" / "src" if str(astropt_path) not in sys.path: @@ -34,12 +33,10 @@ # Imports from external AstroPT try: - from astropt.model import GPT, GPTConfig, ModalityRegistry, ModalityConfig + from astropt.model import GPT, GPTConfig, ModalityConfig, ModalityRegistry # Use internal path for dataloader which was migrated from fmb.models.astropt.euclid_desi_dataset.multimodal_dataloader import ( - multimodal_collate_fn, - prepare_multimodal_batch, - ) + multimodal_collate_fn, prepare_multimodal_batch) except ImportError as e: print(f"Error importing AstroPT components: {e}") sys.exit(1) diff --git a/src/fmb/models/aion/codec_manager.py b/src/fmb/models/aion/codec_manager.py index 5e27418..2730cae 100644 --- a/src/fmb/models/aion/codec_manager.py +++ b/src/fmb/models/aion/codec_manager.py @@ -7,18 +7,16 @@ from __future__ import annotations -import json import inspect -from functools import lru_cache +import json from pathlib import Path from typing import Any import torch -from huggingface_hub import hf_hub_download - +from aion.codecs.config import HF_REPO_ID, MODALITY_CODEC_MAPPING, CodecType from aion.codecs.manager import CodecManager -from aion.codecs.config import MODALITY_CODEC_MAPPING, CodecType, HF_REPO_ID from aion.modalities import Modality +from huggingface_hub import hf_hub_download class LocalCodecManager(CodecManager): diff --git a/src/fmb/models/aion/config.py b/src/fmb/models/aion/config.py index 65504e3..a1c78a9 100644 --- a/src/fmb/models/aion/config.py +++ b/src/fmb/models/aion/config.py @@ -16,7 +16,7 @@ class AIONTrainingConfig(BaseTrainingConfig): """ Configuration for AION adapter training. - + Parameters ---------- hidden : int @@ -38,29 +38,29 @@ class AIONTrainingConfig(BaseTrainingConfig): max_entries : int Maximum dataset entries (0 for all). """ - + # Output out_dir: Path = load_paths().retrained_weights / "aion" - + # Model (U-Net) hidden: int = 16 use_unet_checkpointing: bool = False - + # Codec codec_grad: str = "ste" disable_codec_checkpointing: bool = False - + # Preprocessing resize: int = 96 crop_size: int = 96 max_abs: float = 100.0 cpu_crop: bool = False - + # Data cache_dir: str = str(load_paths().dataset) split: str = "all" max_entries: int = 0 - + # Training defaults (override base) epochs: int = 15 learning_rate: float = 1e-4 diff --git a/src/fmb/models/aion/load_weights.py b/src/fmb/models/aion/load_weights.py index f4b3a2e..807d6cb 100644 --- a/src/fmb/models/aion/load_weights.py +++ b/src/fmb/models/aion/load_weights.py @@ -9,9 +9,12 @@ import torch from aion.model import AION -from .codec_manager import LocalCodecManager + from fmb.paths import load_paths +from .codec_manager import LocalCodecManager + + # Use lazy loading for default to allow config to be set before access if possible, or just load now. # Since this is toplevel constant, it will load on import. def _get_default_model_dir() -> Path: @@ -20,6 +23,7 @@ def _get_default_model_dir() -> Path: except Exception: return Path("/pbs/throng/training/astroinfo2025/model") + def load_model_and_codec( model_dir: Path | None = None, device: torch.device | None = None, @@ -34,7 +38,7 @@ def load_model_and_codec( """ if model_dir is None: model_dir = _get_default_model_dir() - + print(f"Loading model from {model_dir}...") model_dir = Path(model_dir) codec_repo = Path(codec_dir) if codec_dir is not None else model_dir @@ -47,4 +51,4 @@ def load_model_and_codec( if __name__ == "__main__": model, codec_manager = load_model_and_codec() - print(f"Model and codec manager loaded") + print("Model and codec manager loaded") diff --git a/src/fmb/models/aion/model.py b/src/fmb/models/aion/model.py index c26a2ae..028eebe 100644 --- a/src/fmb/models/aion/model.py +++ b/src/fmb/models/aion/model.py @@ -6,20 +6,18 @@ """ import json -import math -from pathlib import Path from typing import List, Optional, Tuple +import safetensors.torch as st import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint_utils -from torch.utils.data import Dataset -import safetensors.torch as st from huggingface_hub import hf_hub_download +from torch.utils.data import Dataset from fmb.data.load_display_data import EuclidDESIDataset -from fmb.models.aion.modalities import EuclidImage, HSCImage, Image +from fmb.models.aion.modalities import EuclidImage # Constants EUCLID_BANDS = ["EUCLID-VIS", "EUCLID-Y", "EUCLID-J", "EUCLID-H"] @@ -36,8 +34,10 @@ # U-Net blocks # ----------------------------- + class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" + def __init__(self, in_ch: int, out_ch: int, hidden_dim: Optional[int] = None): super().__init__() mid_ch = hidden_dim if hidden_dim else out_ch @@ -56,6 +56,7 @@ def forward(self, x): class Down(nn.Module): """Downscaling with maxpool then double conv""" + def __init__(self, in_ch: int, out_ch: int): super().__init__() self.net = nn.Sequential( @@ -69,6 +70,7 @@ def forward(self, x): class Up(nn.Module): """Upscaling then double conv""" + def __init__(self, in_ch: int, out_ch: int, bilinear: bool = True): super().__init__() if bilinear: @@ -87,8 +89,7 @@ def forward(self, x1, x2): diffY = x2.size(2) - x1.size(2) diffX = x2.size(3) - x1.size(3) - x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, - diffY // 2, diffY - diffY // 2]) + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x = torch.cat([x2, x1], dim=1) return self.conv(x) @@ -104,7 +105,14 @@ def forward(self, x): class SimpleUNet(nn.Module): """Small U-Net for domain adaptation.""" - def __init__(self, n_channels: int, n_classes: int, hidden: int, use_checkpointing: bool = False): + + def __init__( + self, + n_channels: int, + n_classes: int, + hidden: int, + use_checkpointing: bool = False, + ): super().__init__() self.use_checkpointing = use_checkpointing @@ -134,27 +142,45 @@ def forward(self, x): class EuclidToHSC(SimpleUNet): """Adapter to translate Euclid (4 bands) into HSC-like (5 bands).""" + def __init__(self, hidden: int, use_checkpointing: bool): - super().__init__(n_channels=4, n_classes=5, hidden=hidden, use_checkpointing=use_checkpointing) + super().__init__( + n_channels=4, + n_classes=5, + hidden=hidden, + use_checkpointing=use_checkpointing, + ) class HSCToEuclid(SimpleUNet): """Adapter to translate HSC (5 bands) back to Euclid (4 bands).""" + def __init__(self, hidden: int, use_checkpointing: bool): - super().__init__(n_channels=5, n_classes=4, hidden=hidden, use_checkpointing=use_checkpointing) + super().__init__( + n_channels=5, + n_classes=4, + hidden=hidden, + use_checkpointing=use_checkpointing, + ) # ----------------------------- # Dataset # ----------------------------- + class EuclidImageDataset(Dataset): """Dataset wrapper for Euclid images from DESI.""" - def __init__(self, split: str, cache_dir: str, max_entries: Optional[int], resize: int): + + def __init__( + self, split: str, cache_dir: str, max_entries: Optional[int], resize: int + ): self.base = EuclidDESIDataset(split=split, cache_dir=cache_dir, verbose=False) self.resize = resize - self._indices = list(range(len(self.base))) if not max_entries or max_entries <= 0 else list( - range(min(len(self.base), max_entries)) + self._indices = ( + list(range(len(self.base))) + if not max_entries or max_entries <= 0 + else list(range(min(len(self.base), max_entries))) ) def __len__(self) -> int: @@ -186,7 +212,9 @@ def __getitem__(self, idx: int) -> EuclidImage: flux = torch.stack(bands, dim=0) # (4,H,W) - if self.resize and (flux.shape[-1] != self.resize or flux.shape[-2] != self.resize): + if self.resize and ( + flux.shape[-1] != self.resize or flux.shape[-2] != self.resize + ): flux = F.interpolate( flux.unsqueeze(0), size=(self.resize, self.resize), @@ -207,14 +235,19 @@ def collate_euclid(batch: List[EuclidImage]) -> EuclidImage: # Helpers # ----------------------------- + def load_frozen_codec(device: torch.device) -> Tuple[nn.Module, dict]: """Load frozen AION ImageCodec from HF Hub.""" # Import here to avoid circular dependency from aion.codecs import ImageCodec from aion.codecs.config import HF_REPO_ID - cfg_path = hf_hub_download(HF_REPO_ID, "codecs/image/config.json", local_files_only=True) - weights_path = hf_hub_download(HF_REPO_ID, "codecs/image/model.safetensors", local_files_only=True) + cfg_path = hf_hub_download( + HF_REPO_ID, "codecs/image/config.json", local_files_only=True + ) + weights_path = hf_hub_download( + HF_REPO_ID, "codecs/image/model.safetensors", local_files_only=True + ) with open(cfg_path) as f: codec_cfg = json.load(f) @@ -223,6 +256,7 @@ def load_frozen_codec(device: torch.device) -> Tuple[nn.Module, dict]: # Patch band registry to avoid collisions with Euclid bands during codec init from aion.codecs.preprocessing.band_to_index import BAND_TO_INDEX + original_bands = dict(BAND_TO_INDEX) try: keys_to_remove = [k for k in list(BAND_TO_INDEX.keys()) if "EUCLID" in k] @@ -246,7 +280,9 @@ def load_frozen_codec(device: torch.device) -> Tuple[nn.Module, dict]: state = st.load_file(weights_path, device="cpu") missing, unexpected = codec.load_state_dict(state, strict=False) if missing or unexpected: - print(f"[info] Codec load: missing={len(missing)}, unexpected={len(unexpected)}") + print( + f"[info] Codec load: missing={len(missing)}, unexpected={len(unexpected)}" + ) for p in codec.parameters(): p.requires_grad = False @@ -261,6 +297,10 @@ def load_aion_components( ) -> Tuple[EuclidToHSC, HSCToEuclid, nn.Module]: """Load adapters and codec.""" codec, _ = load_frozen_codec(device) - euclid_to_hsc = EuclidToHSC(hidden=hidden, use_checkpointing=use_checkpointing).to(device) - hsc_to_euclid = HSCToEuclid(hidden=hidden, use_checkpointing=use_checkpointing).to(device) + euclid_to_hsc = EuclidToHSC(hidden=hidden, use_checkpointing=use_checkpointing).to( + device + ) + hsc_to_euclid = HSCToEuclid(hidden=hidden, use_checkpointing=use_checkpointing).to( + device + ) return euclid_to_hsc, hsc_to_euclid, codec diff --git a/src/fmb/models/aion/retrain_euclid_hsc_adapter_unet.py b/src/fmb/models/aion/retrain_euclid_hsc_adapter_unet.py index 471876a..bd8f05f 100644 --- a/src/fmb/models/aion/retrain_euclid_hsc_adapter_unet.py +++ b/src/fmb/models/aion/retrain_euclid_hsc_adapter_unet.py @@ -18,45 +18,43 @@ """ import argparse -import gc -import json import math from pathlib import Path -from typing import List, Optional, Tuple +# CRITICAL: Monkey-patch aion.modalities to use our local classes +import aion.modalities import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision.transforms import RandomCrop from tqdm.auto import tqdm -import safetensors.torch as st -from fmb.paths import load_paths from fmb.models.aion.modalities import EuclidImage, HSCImage, Image from fmb.models.aion.model import ( + EUCLID_BANDS, + HSC_BANDS, + EuclidImageDataset, EuclidToHSC, HSCToEuclid, - EuclidImageDataset, collate_euclid, load_frozen_codec, - EUCLID_BANDS, - HSC_BANDS, ) +from fmb.paths import load_paths -# CRITICAL: Monkey-patch aion.modalities to use our local classes -import aion.modalities aion.modalities.Image = Image aion.modalities.HSCImage = HSCImage aion.modalities.EuclidImage = EuclidImage # Fix typo/legacy method in aion library from aion.codecs.preprocessing.image import CenterCrop, RescaleToLegacySurvey + + def _fixed_reverse_zeropoint(self, scale): return 22.5 - 2.5 * math.log10(scale) + if not hasattr(RescaleToLegacySurvey, "_reverse_zeropoint"): RescaleToLegacySurvey._reverse_zeropoint = _fixed_reverse_zeropoint RescaleToLegacySurvey.reverse_zeropoint = _fixed_reverse_zeropoint @@ -65,7 +63,7 @@ def _fixed_reverse_zeropoint(self, scale): def parse_args() -> argparse.Namespace: """Parse arguments from CLI and YAML config.""" paths = load_paths() - + # First pass: Get config file parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--config", type=str, default=None) @@ -75,47 +73,90 @@ def parse_args() -> argparse.Namespace: defaults = {} if early_args.config: import yaml - with open(early_args.config, 'r') as f: + + with open(early_args.config, "r") as f: defaults = yaml.safe_load(f) or {} # Second pass: Full arguments parser = argparse.ArgumentParser(description="Retrain AION adapters.") - parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") - + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + # Data - parser.add_argument("--cache-dir", type=str, default=defaults.get("cache_dir", str(paths.dataset))) + parser.add_argument( + "--cache-dir", type=str, default=defaults.get("cache_dir", str(paths.dataset)) + ) parser.add_argument("--split", type=str, default=defaults.get("split", "train")) - parser.add_argument("--max-entries", type=int, default=defaults.get("max_entries", 0)) - + parser.add_argument( + "--max-entries", type=int, default=defaults.get("max_entries", 0) + ) + # Training parser.add_argument("--batch-size", type=int, default=defaults.get("batch_size", 8)) parser.add_argument("--epochs", type=int, default=defaults.get("epochs", 5)) - parser.add_argument("--lr", type=float, default=float(defaults.get("learning_rate", 1e-4))) - parser.add_argument("--grad-clip", type=float, default=defaults.get("grad_clip", 1.0)) - parser.add_argument("--accum-steps", type=int, default=defaults.get("accum_steps", 1)) - parser.add_argument("--amp-dtype", type=str, default=defaults.get("amp_dtype", "float16"), choices=["float16", "bfloat16"]) - + parser.add_argument( + "--lr", type=float, default=float(defaults.get("learning_rate", 1e-4)) + ) + parser.add_argument( + "--grad-clip", type=float, default=defaults.get("grad_clip", 1.0) + ) + parser.add_argument( + "--accum-steps", type=int, default=defaults.get("accum_steps", 1) + ) + parser.add_argument( + "--amp-dtype", + type=str, + default=defaults.get("amp_dtype", "float16"), + choices=["float16", "bfloat16"], + ) + # Preprocessing parser.add_argument("--resize", type=int, default=defaults.get("resize", 96)) parser.add_argument("--crop-size", type=int, default=defaults.get("crop_size", 96)) parser.add_argument("--max-abs", type=float, default=defaults.get("max_abs", 100.0)) - parser.add_argument("--cpu-crop", action="store_true", default=defaults.get("cpu_crop", False)) - + parser.add_argument( + "--cpu-crop", action="store_true", default=defaults.get("cpu_crop", False) + ) + # Model parser.add_argument("--hidden", type=int, default=defaults.get("hidden", 16)) - parser.add_argument("--use-unet-checkpointing", action="store_true", default=defaults.get("use_unet_checkpointing", False)) - parser.add_argument("--codec-grad", type=str, default=defaults.get("codec_grad", "ste"), choices=["ste", "full"]) - parser.add_argument("--disable-codec-checkpointing", action="store_true", default=defaults.get("disable_codec_checkpointing", False)) - + parser.add_argument( + "--use-unet-checkpointing", + action="store_true", + default=defaults.get("use_unet_checkpointing", False), + ) + parser.add_argument( + "--codec-grad", + type=str, + default=defaults.get("codec_grad", "ste"), + choices=["ste", "full"], + ) + parser.add_argument( + "--disable-codec-checkpointing", + action="store_true", + default=defaults.get("disable_codec_checkpointing", False), + ) + # Output default_out = paths.retrained_weights / "aion" - parser.add_argument("--output", type=str, default=defaults.get("out_dir", str(default_out))) - parser.add_argument("--resume-adapter", type=str, default=defaults.get("resume_adapter", None)) - parser.add_argument("--auto-resume", action="store_true", default=defaults.get("auto_resume", True)) - + parser.add_argument( + "--output", type=str, default=defaults.get("out_dir", str(default_out)) + ) + parser.add_argument( + "--resume-adapter", type=str, default=defaults.get("resume_adapter", None) + ) + parser.add_argument( + "--auto-resume", action="store_true", default=defaults.get("auto_resume", True) + ) + # Logging - parser.add_argument("--num-workers", type=int, default=defaults.get("num_workers", 0)) - parser.add_argument("--log-gpu-mem-every", type=int, default=defaults.get("log_gpu_mem_every", 50)) + parser.add_argument( + "--num-workers", type=int, default=defaults.get("num_workers", 0) + ) + parser.add_argument( + "--log-gpu-mem-every", type=int, default=defaults.get("log_gpu_mem_every", 50) + ) return parser.parse_args() @@ -138,16 +179,25 @@ def main() -> None: if args.auto_resume and not args.resume_adapter: ckpts = list(out_dir.glob("adapters_epoch_*.pt")) if ckpts: + def get_epoch(p: Path) -> int: - try: return int(p.stem.split("_")[-1]) - except: return -1 + try: + return int(p.stem.split("_")[-1]) + except: + return -1 + latest = max(ckpts, key=get_epoch) print(f"[info] Auto-resume from latest: {latest}") args.resume_adapter = str(latest) # Data max_entries = None if args.max_entries <= 0 else args.max_entries - dataset = EuclidImageDataset(split=args.split, cache_dir=args.cache_dir, max_entries=max_entries, resize=args.resize) + dataset = EuclidImageDataset( + split=args.split, + cache_dir=args.cache_dir, + max_entries=max_entries, + resize=args.resize, + ) loader = DataLoader( dataset, batch_size=args.batch_size, @@ -161,14 +211,22 @@ def get_epoch(p: Path) -> int: # Models codec, codec_cfg = load_frozen_codec(device) - euclid_to_hsc = EuclidToHSC(hidden=args.hidden, use_checkpointing=args.use_unet_checkpointing).to(device) - hsc_to_euclid = HSCToEuclid(hidden=args.hidden, use_checkpointing=args.use_unet_checkpointing).to(device) - - optimizer = torch.optim.Adam(list(euclid_to_hsc.parameters()) + list(hsc_to_euclid.parameters()), lr=args.lr) + euclid_to_hsc = EuclidToHSC( + hidden=args.hidden, use_checkpointing=args.use_unet_checkpointing + ).to(device) + hsc_to_euclid = HSCToEuclid( + hidden=args.hidden, use_checkpointing=args.use_unet_checkpointing + ).to(device) + + optimizer = torch.optim.Adam( + list(euclid_to_hsc.parameters()) + list(hsc_to_euclid.parameters()), lr=args.lr + ) criterion = nn.MSELoss(reduction="mean") amp_dtype = torch.float16 if args.amp_dtype == "float16" else torch.bfloat16 - scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda" and amp_dtype == torch.float16)) + scaler = torch.amp.GradScaler( + "cuda", enabled=(device.type == "cuda" and amp_dtype == torch.float16) + ) if args.resume_adapter: ckpt = torch.load(args.resume_adapter, map_location="cpu") @@ -195,6 +253,7 @@ def codec_bridge(hsc_flux: torch.Tensor) -> torch.Tensor: else: # Full gradients through codec from torch.utils.checkpoint import checkpoint + if not args.disable_codec_checkpointing and hsc_flux.requires_grad: return checkpoint(run_codec_roundtrip, hsc_flux, use_reentrant=False) return run_codec_roundtrip(hsc_flux) @@ -206,12 +265,14 @@ def codec_bridge(hsc_flux: torch.Tensor) -> torch.Tensor: def process_batch(x_full: torch.Tensor) -> float: B = x_full.shape[0] try: - with torch.amp.autocast("cuda", dtype=amp_dtype, enabled=(device.type == "cuda")): + with torch.amp.autocast( + "cuda", dtype=amp_dtype, enabled=(device.type == "cuda") + ): hsc_like = euclid_to_hsc(x_full) hsc_dec = codec_bridge(hsc_like) euclid_rec = hsc_to_euclid(hsc_dec) loss = criterion(euclid_rec, x_full) / args.accum_steps - + if not torch.isfinite(loss): raise FloatingPointError("Loss is non-finite") @@ -230,11 +291,11 @@ def process_batch(x_full: torch.Tensor) -> float: # Training Loop training_losses = [] print(f"Starting training (Epoch {start_epoch} to {args.epochs})...") - + for epoch in range(start_epoch, args.epochs + 1): if device.type == "cuda": torch.cuda.reset_peak_memory_stats() - + euclid_to_hsc.train() hsc_to_euclid.train() optimizer.zero_grad(set_to_none=True) @@ -254,7 +315,11 @@ def process_batch(x_full: torch.Tensor) -> float: if (step + 1) % args.accum_steps == 0: if args.grad_clip > 0: scaler.unscale_(optimizer) - nn.utils.clip_grad_norm_(list(euclid_to_hsc.parameters()) + list(hsc_to_euclid.parameters()), args.grad_clip) + nn.utils.clip_grad_norm_( + list(euclid_to_hsc.parameters()) + + list(hsc_to_euclid.parameters()), + args.grad_clip, + ) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) @@ -263,7 +328,9 @@ def process_batch(x_full: torch.Tensor) -> float: if device.type == "cuda" and (step + 1) % args.log_gpu_mem_every == 0: alloc = _mem_gb(torch.cuda.memory_allocated()) peak = _mem_gb(torch.cuda.max_memory_allocated()) - progress.set_postfix(loss=f"{loss_val:.4f}", mem=f"{alloc:.1f}/{peak:.1f}G") + progress.set_postfix( + loss=f"{loss_val:.4f}", mem=f"{alloc:.1f}/{peak:.1f}G" + ) else: progress.set_postfix(loss=f"{loss_val:.4f}") @@ -273,21 +340,26 @@ def process_batch(x_full: torch.Tensor) -> float: # Save Checkpoint ckpt_path = out_dir / f"adapters_epoch_{epoch:03d}.pt" - torch.save({ - "epoch": epoch, + torch.save( + { + "epoch": epoch, + "euclid_to_hsc": euclid_to_hsc.state_dict(), + "hsc_to_euclid": hsc_to_euclid.state_dict(), + "optimizer": optimizer.state_dict(), + "args": vars(args), + }, + ckpt_path, + ) + + # Final Save and Plots + torch.save( + { "euclid_to_hsc": euclid_to_hsc.state_dict(), "hsc_to_euclid": hsc_to_euclid.state_dict(), - "optimizer": optimizer.state_dict(), "args": vars(args), - }, ckpt_path) - - # Final Save and Plots - torch.save({ - "euclid_to_hsc": euclid_to_hsc.state_dict(), - "hsc_to_euclid": hsc_to_euclid.state_dict(), - "args": vars(args), - }, out_dir / "adapters_final.pt") - + }, + out_dir / "adapters_final.pt", + ) # Plot sample results def save_sample_grid(): @@ -302,17 +374,17 @@ def save_sample_grid(): euclid_rec = hsc_to_euclid(hsc_dec) fig, axes = plt.subplots(4, 5, figsize=(15, 10)) - for i in range(4): # Euclid bands + for i in range(4): # Euclid bands axes[0, i].imshow(x[0, i].cpu(), origin="lower") axes[0, i].set_title(f"Input {EUCLID_BANDS[i]}") axes[3, i].imshow(euclid_rec[0, i].cpu(), origin="lower") axes[3, i].set_title(f"Rec {EUCLID_BANDS[i]}") - for i in range(5): # HSC bands + for i in range(5): # HSC bands axes[1, i].imshow(hsc_pred[0, i].cpu(), origin="lower") axes[1, i].set_title(f"Pred {HSC_BANDS[i]}") axes[2, i].imshow(hsc_dec[0, i].cpu(), origin="lower") axes[2, i].set_title(f"Codec {HSC_BANDS[i]}") - + plt.tight_layout() plt.savefig(out_dir / "sample_grid.png") plt.close() @@ -320,5 +392,6 @@ def save_sample_grid(): save_sample_grid() print(f"Finished. Results saved in {out_dir}") + if __name__ == "__main__": main() diff --git a/src/fmb/models/aion/trainer.py b/src/fmb/models/aion/trainer.py index a75e8c8..e47c96c 100644 --- a/src/fmb/models/aion/trainer.py +++ b/src/fmb/models/aion/trainer.py @@ -12,17 +12,17 @@ import torch.utils.checkpoint as checkpoint_utils from torchvision.transforms import RandomCrop -from fmb.models.base.trainer import BaseTrainer from fmb.models.aion.config import AIONTrainingConfig +from fmb.models.base.trainer import BaseTrainer class AIONTrainer(BaseTrainer): """ Trainer for AION Euclid ↔ HSC adapters. - + Trains two U-Net adapters to translate between Euclid and HSC image spaces, using a frozen AION codec as a bridge. - + Parameters ---------- euclid_to_hsc : nn.Module @@ -38,7 +38,7 @@ class AIONTrainer(BaseTrainer): val_loader : Optional[DataLoader] Validation data loader. """ - + def __init__( self, euclid_to_hsc: nn.Module, @@ -52,48 +52,52 @@ def __init__( self.hsc_to_euclid = hsc_to_euclid self.codec = codec self.aion_config = config - + # Combine adapters into a single model for BaseTrainer - model = nn.ModuleDict({ - "euclid_to_hsc": euclid_to_hsc, - "hsc_to_euclid": hsc_to_euclid, - }) - + model = nn.ModuleDict( + { + "euclid_to_hsc": euclid_to_hsc, + "hsc_to_euclid": hsc_to_euclid, + } + ) + super().__init__(model, config, train_loader, val_loader) - + # Setup criterion self.criterion = nn.MSELoss(reduction="mean") - + # Setup cropping self.crop = RandomCrop(size=config.crop_size) - + # Setup codec gradient mode self.use_codec_ckpt = not config.disable_codec_checkpointing - + if config.codec_grad == "ste": self.codec_bridge = self._codec_roundtrip_ste print("🔧 Codec gradient mode: STE (straight-through estimator)") else: self.codec_bridge = self._codec_roundtrip_full print("🔧 Codec gradient mode: Full backprop") - + def _create_optimizer(self) -> torch.optim.Optimizer: """Create optimizer for both adapters.""" - params = list(self.euclid_to_hsc.parameters()) + list(self.hsc_to_euclid.parameters()) + params = list(self.euclid_to_hsc.parameters()) + list( + self.hsc_to_euclid.parameters() + ) return torch.optim.Adam( params, lr=self.config.learning_rate, ) - + def _codec_roundtrip_flux(self, hsc_flux: torch.Tensor) -> torch.Tensor: """ Pass HSC flux through codec roundtrip. - + Parameters ---------- hsc_flux : torch.Tensor HSC flux tensor (B, 5, H, W). - + Returns ------- torch.Tensor @@ -102,21 +106,23 @@ def _codec_roundtrip_flux(self, hsc_flux: torch.Tensor) -> torch.Tensor: try: from aion.modalities import HSCImage except ImportError: - raise ImportError("AION not found. Initialize submodules: git submodule update --init") - + raise ImportError( + "AION not found. Initialize submodules: git submodule update --init" + ) + HSC_BANDS = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] - + hsc_obj = HSCImage(flux=hsc_flux, bands=HSC_BANDS) toks = self.codec.encode(hsc_obj) hsc_rec = self.codec.decode(toks, bands=HSC_BANDS) return hsc_rec.flux - + def _codec_roundtrip_ste(self, hsc_flux: torch.Tensor) -> torch.Tensor: """Codec roundtrip with straight-through estimator (no gradient).""" with torch.no_grad(): y = self._codec_roundtrip_flux(hsc_flux) return hsc_flux + (y - hsc_flux).detach() - + def _codec_roundtrip_full(self, hsc_flux: torch.Tensor) -> torch.Tensor: """Codec roundtrip with full gradient (optionally checkpointed).""" if self.use_codec_ckpt and hsc_flux.requires_grad: @@ -126,16 +132,16 @@ def _codec_roundtrip_full(self, hsc_flux: torch.Tensor) -> torch.Tensor: use_reentrant=False, ) return self._codec_roundtrip_flux(hsc_flux) - + def _preprocess_and_crop(self, euclid_flux_cpu: torch.Tensor) -> torch.Tensor: """ Preprocess and crop Euclid flux. - + Parameters ---------- euclid_flux_cpu : torch.Tensor Euclid flux on CPU. - + Returns ------- torch.Tensor @@ -148,23 +154,23 @@ def _preprocess_and_crop(self, euclid_flux_cpu: torch.Tensor) -> torch.Tensor: x = torch.clamp(x, -self.aion_config.max_abs, self.aion_config.max_abs) x = self.crop(x) return x.to(self.device, non_blocking=True) - + # Preprocess on GPU x = euclid_flux_cpu.to(self.device, non_blocking=False) x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) if self.aion_config.max_abs > 0: x = torch.clamp(x, -self.aion_config.max_abs, self.aion_config.max_abs) return self.crop(x) - + def train_step(self, batch: Any) -> Dict[str, float]: """ Execute one AION training step. - + Parameters ---------- batch : EuclidImage Batch of Euclid images. - + Returns ------- Dict[str, float] @@ -172,29 +178,29 @@ def train_step(self, batch: Any) -> Dict[str, float]: """ # Preprocess x = self._preprocess_and_crop(batch.flux) - + # Forward: Euclid → HSC → Codec → HSC → Euclid hsc_like = self.euclid_to_hsc(x) hsc_dec = self.codec_bridge(hsc_like) euclid_rec = self.hsc_to_euclid(hsc_dec) - + # Compute loss loss = self.criterion(euclid_rec, x) - + if not torch.isfinite(loss): raise FloatingPointError("Non-finite loss encountered") - + return {"loss": loss} - + def val_step(self, batch: Any) -> Dict[str, float]: """ Execute one AION validation step. - + Parameters ---------- batch : EuclidImage Batch of Euclid images. - + Returns ------- Dict[str, float] @@ -202,11 +208,11 @@ def val_step(self, batch: Any) -> Dict[str, float]: """ # Same as train step but without gradient x = self._preprocess_and_crop(batch.flux) - + hsc_like = self.euclid_to_hsc(x) hsc_dec = self.codec_bridge(hsc_like) euclid_rec = self.hsc_to_euclid(hsc_dec) - + loss = self.criterion(euclid_rec, x) - + return {"loss": loss} diff --git a/src/fmb/models/astroclip/config.py b/src/fmb/models/astroclip/config.py index ddf6b7f..9c8e2cd 100644 --- a/src/fmb/models/astroclip/config.py +++ b/src/fmb/models/astroclip/config.py @@ -16,7 +16,7 @@ class AstroCLIPTrainingConfig(BaseTrainingConfig): """ Configuration for AstroCLIP fine-tuning. - + Parameters ---------- checkpoint : str @@ -54,16 +54,16 @@ class AstroCLIPTrainingConfig(BaseTrainingConfig): split : str Dataset split to use ('train', 'test'). """ - + # Output out_dir: Path = load_paths().retrained_weights / "astroclip" - + # Model checkpoint: str = "" # Required learnable_scale: bool = False finetune_spectrum: bool = False unfreeze_backbone_blocks: int = 0 - + # Data cache_dir: str = str(load_paths().dataset) use_arrow: bool = True @@ -75,7 +75,7 @@ class AstroCLIPTrainingConfig(BaseTrainingConfig): spectrum_norm: str = "none" include_wavelength: bool = False focus_high_z: bool = False - + # Training defaults (override base) epochs: int = 5 batch_size: int = 256 @@ -84,10 +84,10 @@ class AstroCLIPTrainingConfig(BaseTrainingConfig): grad_clip: float = 1.0 accumulate_steps: int = 1 amp_dtype: str = "float16" - + # Learning rate schedule warmup_steps: int = 0 # 0 means auto (10% of total) - + # Early stopping patience: int = 3 min_delta: float = 1e-4 diff --git a/src/fmb/models/astroclip/core/__init__.py b/src/fmb/models/astroclip/core/__init__.py index 1adb38f..ca1925e 100644 --- a/src/fmb/models/astroclip/core/__init__.py +++ b/src/fmb/models/astroclip/core/__init__.py @@ -4,4 +4,3 @@ Module: fmb.models.astroclip.core.__init__ Description: FMB module: fmb.models.astroclip.core.__init__ """ - diff --git a/src/fmb/models/astroclip/core/astroclip.py b/src/fmb/models/astroclip/core/astroclip.py index 33a7011..4e544f7 100644 --- a/src/fmb/models/astroclip/core/astroclip.py +++ b/src/fmb/models/astroclip/core/astroclip.py @@ -50,7 +50,7 @@ def __init__( learnable_logit_scale (bool): Whether the logit scale should be learnable. """ super().__init__() - self.save_hyperparameters(ignore=['image_encoder', 'spectrum_encoder']) + self.save_hyperparameters(ignore=["image_encoder", "spectrum_encoder"]) # Define the image and spectrum encoder self.image_encoder = image_encoder diff --git a/src/fmb/models/astroclip/finetune.py b/src/fmb/models/astroclip/finetune.py index f74bc3a..8d2f420 100644 --- a/src/fmb/models/astroclip/finetune.py +++ b/src/fmb/models/astroclip/finetune.py @@ -11,7 +11,7 @@ import json import math import random -import warnings # Added imports +import warnings # Added imports from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple @@ -29,16 +29,16 @@ import torch.nn as nn import torch.nn.functional as F import torch.nn.utils as nn_utils -from torch.cuda.amp import GradScaler, autocast +import yaml # Added for config support +from torch.cuda.amp import autocast from torch.utils.data import DataLoader, Dataset from tqdm import tqdm # ADAPTED IMPORTS for fmb package structure from fmb.data.astroclip_parquet import ParquetDataSource from fmb.models.astroclip.core.astroclip import AstroClipModel, CLIPLoss -from fmb.paths import load_paths # Added paths support +from fmb.paths import load_paths # Added paths support -import yaml # Added for config support def parse_args() -> argparse.Namespace: # First pass: Check for config file @@ -55,29 +55,29 @@ def parse_args() -> argparse.Namespace: # Flatten config if needed or map keys to match argparse dests # For this script, keys in yaml mostly match dests (e.g. checkpoint, output_path) # but dashes in yaml keys are usually handled. Argparse dests usually use underscores. - + # Key mapping aliases (YAML key -> Argparse dest) key_map = { "learning_rate": "lr", "max_epochs": "epochs", } - + for key, value in config_dict.items(): norm_key = key.replace("-", "_") # Apply mapping alias if exists if norm_key in key_map: norm_key = key_map[norm_key] defaults[norm_key] = value - + print(f"Loaded configuration from {known.config}") except Exception as e: print(f"Error loading config file: {e}") parser = argparse.ArgumentParser( description="Fine-tune AstroCLIP image encoder.", - parents=[conf_parser] # Inherit --config arg + parents=[conf_parser], # Inherit --config arg ) - + # Defaults from environment paths = load_paths() @@ -85,7 +85,7 @@ def parse_args() -> argparse.Namespace: "--parquet-path", dest="parquet_paths", nargs="+", - required=False, # Now optional if provided in config/arrow used + required=False, # Now optional if provided in config/arrow used help="Paths to parquet files (hf:// supported). Ignored if --use-arrow is used.", ) parser.add_argument( @@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--cache-dir", type=str, - default=str(paths.dataset), # Use local data path + default=str(paths.dataset), # Use local data path help="Hugging Face cache directory (used with --use-arrow).", ) parser.add_argument( @@ -116,14 +116,22 @@ def parse_args() -> argparse.Namespace: help="Fine-tune the spectrum encoder as well (otherwise it stays frozen).", ) # Removing required=True because they might come from config defaults - parser.add_argument("--checkpoint", required=False, help="AstroCLIP Lightning checkpoint.") - parser.add_argument("--output-path", required=False, help="Output file for the new encoder weights (.pt).") + parser.add_argument( + "--checkpoint", required=False, help="AstroCLIP Lightning checkpoint." + ) + parser.add_argument( + "--output-path", + required=False, + help="Output file for the new encoder weights (.pt).", + ) parser.add_argument( "--output-ckpt", default=None, help="Optional path to save a full AstroCLIP Lightning checkpoint with the fine-tuned encoder.", ) - parser.add_argument("--device", default="cuda", help="Training device (cuda or cpu).") + parser.add_argument( + "--device", default="cuda", help="Training device (cuda or cpu)." + ) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--lr", type=float, default=3e-6) @@ -132,10 +140,17 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--image-size", type=int, default=144) parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--num-workers", type=int, default=4) - parser.add_argument("--val-ratio", type=float, default=0.1, help="Fraction used for validation (0 to disable).") + parser.add_argument( + "--val-ratio", + type=float, + default=0.1, + help="Fraction used for validation (0 to disable).", + ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--amp", action="store_true", help="Enable AMP training.") - parser.add_argument("--disable-augment", action="store_true", help="Disable image augmentations.") + parser.add_argument( + "--disable-augment", action="store_true", help="Disable image augmentations." + ) parser.add_argument( "--spectrum-norm", choices=["zscore", "minmax", "none"], @@ -152,12 +167,42 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Lors du sampling parquet, privilégier les galaxies à haut redshift.", ) - parser.add_argument("--warmup-steps", type=int, default=0, help="Nombre d'itérations de warmup pour le scheduler (0 = 10% des steps).") - parser.add_argument("--patience", type=int, default=3, help="Patience pour l'early stopping (epochs).") - parser.add_argument("--min-delta", type=float, default=1e-4, help="Amélioration minimale requise pour reset la patience.") - parser.add_argument("--grad-clip", type=float, default=1.0, help="Clip des gradients (<=0 pour désactiver).") - parser.add_argument("--accumulate-steps", type=int, default=1, help="Nombre d'itérations pour accumuler les gradients.") - parser.add_argument("--log-interval", type=int, default=20, help="Intervalle d'affichage des logs (nombre de batchs).") + parser.add_argument( + "--warmup-steps", + type=int, + default=0, + help="Nombre d'itérations de warmup pour le scheduler (0 = 10% des steps).", + ) + parser.add_argument( + "--patience", + type=int, + default=3, + help="Patience pour l'early stopping (epochs).", + ) + parser.add_argument( + "--min-delta", + type=float, + default=1e-4, + help="Amélioration minimale requise pour reset la patience.", + ) + parser.add_argument( + "--grad-clip", + type=float, + default=1.0, + help="Clip des gradients (<=0 pour désactiver).", + ) + parser.add_argument( + "--accumulate-steps", + type=int, + default=1, + help="Nombre d'itérations pour accumuler les gradients.", + ) + parser.add_argument( + "--log-interval", + type=int, + default=20, + help="Intervalle d'affichage des logs (nombre de batchs).", + ) parser.add_argument( "--unfreeze-backbone-blocks", type=int, @@ -180,7 +225,7 @@ def parse_args() -> argparse.Namespace: default=None, help="Chemin vers un fichier de poids (.pt ou .ckpt) pour reprendre l'entraînement.", ) - + # Handle max_epochs from config file (alias to epochs) parser.add_argument("--max-epochs", type=int, default=None, dest="max_epochs_alt") @@ -217,7 +262,7 @@ def load_dataframe( split: str = "train", ) -> pd.DataFrame: """Load dataset from Arrow cache or Parquet files. - + Args: parquet_paths: Paths to parquet files (ignored if use_arrow=True) max_samples: Maximum number of samples to load @@ -233,10 +278,10 @@ def load_dataframe( # Load from local Arrow cache (HuggingFace format) # ADAPTED IMPORT from fmb.data.astroclip_loader import ( - load_local_arrow_dataset, convert_dataset_to_astroclip_format, + load_local_arrow_dataset, ) - + print(f"Loading from Arrow cache: {cache_dir} (split={split})") df = load_local_arrow_dataset( cache_dir=cache_dir, @@ -244,15 +289,15 @@ def load_dataframe( max_samples=max_samples, seed=seed, ) - + # Convert to AstroCLIP format df = convert_dataset_to_astroclip_format(df, image_size=image_size) - + else: # Original parquet loading logic if not parquet_paths: raise ValueError("Either --use-arrow or --parquet-path must be specified") - + frames = [] for path in parquet_paths: ds = ParquetDataSource( @@ -277,11 +322,13 @@ def load_dataframe( for col in required_columns: if col not in df.columns: raise ValueError(f"Column '{col}' is missing from final DataFrame.") - + return df.reset_index(drop=True) -def train_val_split(df: pd.DataFrame, val_ratio: float, seed: int) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: +def train_val_split( + df: pd.DataFrame, val_ratio: float, seed: int +) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: if val_ratio <= 0 or len(df) < 2: return df, None @@ -364,7 +411,11 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: spectrum = torch.stack([flux_tensor, wavelength_tensor], dim=-1) image_tensor = row["image"] - image = image_tensor if isinstance(image_tensor, torch.Tensor) else torch.as_tensor(image_tensor) + image = ( + image_tensor + if isinstance(image_tensor, torch.Tensor) + else torch.as_tensor(image_tensor) + ) image = image.float() return { @@ -382,7 +433,9 @@ def build_dataloader( shuffle: bool, num_workers: int, ) -> DataLoader: - dataset = AstroClipFineTuneDataset(df, slice_length, spectrum_norm, include_wavelength) + dataset = AstroClipFineTuneDataset( + df, slice_length, spectrum_norm, include_wavelength + ) return DataLoader( dataset, batch_size=batch_size, @@ -523,22 +576,25 @@ def export_full_checkpoint(model: AstroClipModel, output_ckpt: Path) -> None: def main(args: Optional[argparse.Namespace] = None) -> None: parsed = parse_args() if args is None else args set_seed(parsed.seed) - + # Load paths explicitly in main scope paths = load_paths() - + # Handle epochs logic from config if hasattr(parsed, "max_epochs_alt") and parsed.max_epochs_alt is not None: - # If --epochs was not explicit or is default, prefer max_epochs from config - # Since default for epochs is 5, we check if it is 5 and max_epochs isn't - if parsed.epochs == 5: - parsed.epochs = parsed.max_epochs_alt + # If --epochs was not explicit or is default, prefer max_epochs from config + # Since default for epochs is 5, we check if it is 5 and max_epochs isn't + if parsed.epochs == 5: + parsed.epochs = parsed.max_epochs_alt - print(f"Configuration: Epochs={parsed.epochs}, Batch={parsed.batch_size}, LR={parsed.lr}") + print( + f"Configuration: Epochs={parsed.epochs}, Batch={parsed.batch_size}, LR={parsed.lr}" + ) # Ensure output_path is set if not parsed.output_path: from datetime import datetime + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Use configured retrained_weights path default_dir = paths.retrained_weights / "astroclip" @@ -588,14 +644,17 @@ def main(args: Optional[argparse.Namespace] = None) -> None: # Define unsafe load context manager from contextlib import contextmanager + @contextmanager def unsafe_torch_load_context(): original_load = torch.load + def unsafe_load(*args, **kwargs): # Force weights_only to False even if provided as None # because PyTorch 2.6 defaults None/missing to True - kwargs['weights_only'] = False + kwargs["weights_only"] = False return original_load(*args, **kwargs) + torch.load = unsafe_load try: yield @@ -605,40 +664,46 @@ def unsafe_load(*args, **kwargs): # Load initial model print(f"Loading AstroCLIP checkpoint from: {parsed.checkpoint}") with unsafe_torch_load_context(): - model = AstroClipModel.load_from_checkpoint(parsed.checkpoint, map_location=device) - print(f"Checkpoint loaded successfully") - + model = AstroClipModel.load_from_checkpoint( + parsed.checkpoint, map_location=device + ) + print("Checkpoint loaded successfully") + image_encoder = model.image_encoder.to(device) spectrum_encoder = model.spectrum_encoder.to(device) - + # Handle resume/warm-start from specific weights if parsed.resume_path: print(f"Resuming from weights: {parsed.resume_path}") with unsafe_torch_load_context(): resume_payload = torch.load(parsed.resume_path, map_location=device) - + # Case 1: Custom .pt format if isinstance(resume_payload, dict) and "image_encoder" in resume_payload: image_encoder.load_state_dict(resume_payload["image_encoder"]) if parsed.finetune_spectrum and "spectrum_encoder" in resume_payload: spectrum_encoder.load_state_dict(resume_payload["spectrum_encoder"]) - + # Restore scale if parsed.learnable_scale and "logit_scale" in resume_payload: saved_scale = resume_payload["logit_scale"] if isinstance(model.logit_scale, torch.nn.Parameter): with torch.no_grad(): - model.logit_scale.data.copy_(saved_scale if isinstance(saved_scale, torch.Tensor) else torch.tensor(saved_scale)) + model.logit_scale.data.copy_( + saved_scale + if isinstance(saved_scale, torch.Tensor) + else torch.tensor(saved_scale) + ) print("Weights loaded from .pt format") - + # Case 2: Lightning checkpoint (.ckpt) elif isinstance(resume_payload, dict) and "state_dict" in resume_payload: model.load_state_dict(resume_payload["state_dict"], strict=False) print("Weights loaded from Lightning .ckpt format") - + else: print(f"Warning: Format unrecognized for {parsed.resume_path}") - + # Handle Spectrum Encoder training if parsed.finetune_spectrum: spectrum_encoder.train() @@ -656,17 +721,26 @@ def unsafe_load(*args, **kwargs): # Setup parameters to optimize trainable_params = list([p for p in image_encoder.parameters() if p.requires_grad]) if parsed.finetune_spectrum: - trainable_params.extend([p for p in spectrum_encoder.parameters() if p.requires_grad]) - + trainable_params.extend( + [p for p in spectrum_encoder.parameters() if p.requires_grad] + ) + # Handle logit scale (temperature) if parsed.learnable_scale: if isinstance(model.logit_scale, torch.Tensor): model.logit_scale.requires_grad = True else: import math - initial_val = model.logit_scale if isinstance(model.logit_scale, float) else model.logit_scale.item() - model.logit_scale = torch.nn.Parameter(torch.tensor(initial_val, device=device)) - + + initial_val = ( + model.logit_scale + if isinstance(model.logit_scale, float) + else model.logit_scale.item() + ) + model.logit_scale = torch.nn.Parameter( + torch.tensor(initial_val, device=device) + ) + trainable_params.append(model.logit_scale) print("Note: CLIP temperature/scale is LEARNABLE.") else: @@ -674,23 +748,25 @@ def unsafe_load(*args, **kwargs): model.logit_scale.requires_grad = False print("Note: CLIP temperature/scale is FIXED.") - optimizer = torch.optim.AdamW(trainable_params, lr=parsed.lr, weight_decay=parsed.weight_decay) + optimizer = torch.optim.AdamW( + trainable_params, lr=parsed.lr, weight_decay=parsed.weight_decay + ) total_steps = (len(train_loader) * parsed.epochs) // max(1, parsed.accumulate_steps) warmup_steps = parsed.warmup_steps or max(1, int(0.1 * total_steps)) scheduler_factors = _cosine_scheduler(total_steps, warmup_steps) criterion = CLIPLoss() - + def get_current_scale(): if parsed.learnable_scale: return model.logit_scale if isinstance(model.logit_scale, torch.Tensor): - return model.logit_scale.exp().item() + return model.logit_scale.exp().item() return math.exp(model.logit_scale) # Fixed GradScaler for torch > 2.0 - scaler = torch.amp.GradScaler('cuda', enabled=parsed.amp and device.type == "cuda") + scaler = torch.amp.GradScaler("cuda", enabled=parsed.amp and device.type == "cuda") history: List[HistoryEntry] = [] best_val_loss = float("inf") @@ -708,10 +784,14 @@ def get_current_scale(): optimizer.zero_grad(set_to_none=True) # Added TQDM Progress Bar - pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{parsed.epochs}", - ncols=120, leave=True, - mininterval=0.5, - smoothing=0.1) + pbar = tqdm( + train_loader, + desc=f"Epoch {epoch}/{parsed.epochs}", + ncols=120, + leave=True, + mininterval=0.5, + smoothing=0.1, + ) for batch_idx, batch in enumerate(pbar, start=1): images = batch["image"].to(device, non_blocking=True) @@ -720,10 +800,10 @@ def get_current_scale(): with autocast(enabled=scaler.is_enabled()): image_features = image_encoder(images) spectrum_features = spectrum_encoder(spectrum) - + current_scale = get_current_scale() loss = criterion(image_features, spectrum_features, current_scale) - + cosine = F.cosine_similarity( F.normalize(image_features, dim=-1), F.normalize(spectrum_features, dim=-1), @@ -755,18 +835,27 @@ def get_current_scale(): # Update progress metrics current_loss = loss.item() * parsed.accumulate_steps - pbar.set_postfix({ - 'loss': f'{current_loss:.4f}', - 'cos': f'{cosine.item():.4f}', - 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}' - }) + pbar.set_postfix( + { + "loss": f"{current_loss:.4f}", + "cos": f"{cosine.item():.4f}", + "lr": f'{optimizer.param_groups[0]["lr"]:.2e}', + } + ) train_loss = cumulative_loss / max(1, sample_count) train_cos = cumulative_cosine / max(1, sample_count) val_metrics = {"loss": None, "cosine": None} if val_loader is not None: - val_metrics = evaluate(image_encoder, spectrum_encoder, val_loader, criterion, device, get_current_scale()) + val_metrics = evaluate( + image_encoder, + spectrum_encoder, + val_loader, + criterion, + device, + get_current_scale(), + ) history.append( HistoryEntry( @@ -778,21 +867,32 @@ def get_current_scale(): ) ) - val_loss_for_patience = val_metrics["loss"] if val_metrics["loss"] is not None else train_loss + val_loss_for_patience = ( + val_metrics["loss"] if val_metrics["loss"] is not None else train_loss + ) improved = val_loss_for_patience + parsed.min_delta < best_val_loss if improved: best_val_loss = val_loss_for_patience patience_counter = 0 - + # Save checkpoints checkpoint_payload = { - "image_encoder": {k: v.detach().cpu() for k, v in image_encoder.state_dict().items()}, - "spectrum_encoder": {k: v.detach().cpu() for k, v in spectrum_encoder.state_dict().items()}, + "image_encoder": { + k: v.detach().cpu() for k, v in image_encoder.state_dict().items() + }, + "spectrum_encoder": { + k: v.detach().cpu() + for k, v in spectrum_encoder.state_dict().items() + }, } if parsed.learnable_scale: - scale_val = model.logit_scale.detach().cpu() if isinstance(model.logit_scale, torch.Tensor) else model.logit_scale + scale_val = ( + model.logit_scale.detach().cpu() + if isinstance(model.logit_scale, torch.Tensor) + else model.logit_scale + ) checkpoint_payload["logit_scale"] = scale_val - + torch.save(checkpoint_payload, parsed.output_path) print(f"[Epoch {epoch}] New best model saved to {parsed.output_path}") else: @@ -805,7 +905,7 @@ def get_current_scale(): print(f"Reloading best model from {parsed.output_path}...") with unsafe_torch_load_context(): best_payload = torch.load(parsed.output_path, map_location="cpu") - + if "image_encoder" in best_payload: image_encoder.load_state_dict(best_payload["image_encoder"]) if parsed.finetune_spectrum and "spectrum_encoder" in best_payload: @@ -815,15 +915,21 @@ def get_current_scale(): if parsed.output_ckpt: with unsafe_torch_load_context(): - export_model = AstroClipModel.load_from_checkpoint(parsed.checkpoint, map_location="cpu") - + export_model = AstroClipModel.load_from_checkpoint( + parsed.checkpoint, map_location="cpu" + ) + export_model.image_encoder.load_state_dict(image_encoder.state_dict()) export_model.spectrum_encoder.load_state_dict(spectrum_encoder.state_dict()) - + export_full_checkpoint(export_model, Path(parsed.output_ckpt)) print(f"[OK] Full checkpoint saved to {parsed.output_ckpt}") - history_path = Path(parsed.history_json) if parsed.history_json else Path(parsed.output_path).with_suffix(".history.json") + history_path = ( + Path(parsed.history_json) + if parsed.history_json + else Path(parsed.output_path).with_suffix(".history.json") + ) history_payload = [ { "epoch": entry.epoch, @@ -838,7 +944,11 @@ def get_current_scale(): history_path.write_text(json.dumps(history_payload, indent=2)) print(f"[OK] History saved to {history_path}") - plot_path = Path(parsed.history_plot) if parsed.history_plot else Path(parsed.output_path).with_suffix(".history.png") + plot_path = ( + Path(parsed.history_plot) + if parsed.history_plot + else Path(parsed.output_path).with_suffix(".history.png") + ) _plot_history(history, plot_path) print(f"[OK] Training curves saved to {plot_path}") diff --git a/src/fmb/models/astropt/config.py b/src/fmb/models/astropt/config.py index 2afe15c..1dee0aa 100644 --- a/src/fmb/models/astropt/config.py +++ b/src/fmb/models/astropt/config.py @@ -16,7 +16,7 @@ class AstroPTTrainingConfig(BaseTrainingConfig): """ Configuration for AstroPT multimodal training. - + Parameters ---------- block_size : int @@ -54,10 +54,10 @@ class AstroPTTrainingConfig(BaseTrainingConfig): compile : bool Use torch.compile for model. """ - + # Output out_dir: Path = load_paths().retrained_weights / "astropt" - + # Model architecture block_size: int = 1024 image_patch_size: int = 16 @@ -68,14 +68,14 @@ class AstroPTTrainingConfig(BaseTrainingConfig): n_chan: int = 3 dropout: float = 0.0 bias: bool = False - + # Data cache_dir: str = str(load_paths().dataset) train_split: str = "train" val_split: str = "test" image_size: int = 224 spectrum_length: int = 7781 - + # Training defaults (override base) epochs: int = 30 batch_size: int = 8 @@ -84,16 +84,16 @@ class AstroPTTrainingConfig(BaseTrainingConfig): grad_clip: float = 1.0 gradient_accumulation_steps: int = 4 amp_dtype: str = "bfloat16" - + # Learning rate schedule warmup_iters: int = 2000 lr_decay_iters: int = 30000 min_lr: float = 6e-5 - + # Evaluation eval_interval: int = 100 eval_iters: int = 50 - + # System compile: bool = True - max_iters: int = 30000 + max_iters: int = 30000 diff --git a/src/fmb/models/astropt/euclid_desi_dataset/multimodal_dataloader.py b/src/fmb/models/astropt/euclid_desi_dataset/multimodal_dataloader.py index 7fa65af..7b5e705 100644 --- a/src/fmb/models/astropt/euclid_desi_dataset/multimodal_dataloader.py +++ b/src/fmb/models/astropt/euclid_desi_dataset/multimodal_dataloader.py @@ -5,29 +5,30 @@ Description: Custom collate functions for multimodal data """ +from typing import Any, Dict, List + +import numpy as np import torch import torch.nn.functional as F -import numpy as np -from torch.utils.data import Dataset from datasets import load_dataset from PIL import Image -from typing import Dict, Any, List, Optional +from torch.utils.data import Dataset class EuclidDESIMultimodalDataset(Dataset): """PyTorch Dataset for Euclid images + DESI spectra multimodal training.""" - + def __init__( - self, + self, split: str = "test_batch_1", image_size: int = 224, spectrum_length: int = 7781, # Standard DESI spectrum length image_transform=None, - data_dir: str = "/pbs/home/a/astroinfo09/data/astroPT_euclid_desi_dataset" + data_dir: str = "/pbs/home/a/astroinfo09/data/astroPT_euclid_desi_dataset", ): """ Initialize the multimodal dataset. - + Args: split: Which split to load from local data image_size: Target size for images (will be resized) @@ -39,97 +40,107 @@ def __init__( self.image_size = image_size self.spectrum_length = spectrum_length self.image_transform = image_transform - + def __len__(self): return len(self.dataset) - + def __getitem__(self, idx): """Get a single sample with both image and spectrum data.""" sample = self.dataset[idx] - + # Process RGB image - rgb_image = sample['RGB_image'] + rgb_image = sample["RGB_image"] if isinstance(rgb_image, Image.Image): # Resize to target size - rgb_image = rgb_image.resize((self.image_size, self.image_size), Image.LANCZOS) + rgb_image = rgb_image.resize( + (self.image_size, self.image_size), Image.LANCZOS + ) rgb_image = np.array(rgb_image) - + # Convert to tensor format (C, H, W) and normalize if rgb_image.ndim == 3: rgb_image = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 else: rgb_image = torch.from_numpy(rgb_image).unsqueeze(0).float() / 255.0 - + # Apply transforms if provided if self.image_transform is not None: rgb_image = self.image_transform(rgb_image) - + # Process spectrum data spectrum_flux = None - if sample['spectrum'] is not None and sample['spectrum']['flux'] is not None: - flux = np.array(sample['spectrum']['flux'], dtype=np.float32) - + if sample["spectrum"] is not None and sample["spectrum"]["flux"] is not None: + flux = np.array(sample["spectrum"]["flux"], dtype=np.float32) + # Handle spectrum length - pad or truncate to standard length if len(flux) < self.spectrum_length: # Pad with zeros padded_flux = np.zeros(self.spectrum_length, dtype=np.float32) - padded_flux[:len(flux)] = flux + padded_flux[: len(flux)] = flux spectrum_flux = torch.from_numpy(padded_flux) elif len(flux) > self.spectrum_length: # Truncate - spectrum_flux = torch.from_numpy(flux[:self.spectrum_length]) + spectrum_flux = torch.from_numpy(flux[: self.spectrum_length]) else: spectrum_flux = torch.from_numpy(flux) - + # Normalize spectrum (basic normalization - could be improved) if spectrum_flux.sum() > 0: spectrum_flux = spectrum_flux / (spectrum_flux.std() + 1e-8) - + return { - 'object_id': sample['object_id'], - 'targetid': sample['targetid'], - 'redshift': sample['redshift'], - 'image': rgb_image, # Shape: (C, H, W) - 'spectrum': spectrum_flux, # Shape: (spectrum_length,) or None + "object_id": sample["object_id"], + "targetid": sample["targetid"], + "redshift": sample["redshift"], + "image": rgb_image, # Shape: (C, H, W) + "spectrum": spectrum_flux, # Shape: (spectrum_length,) or None } def multimodal_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """ Collate function for multimodal training batches. - + Handles cases where some samples might not have spectra. """ # Separate samples with and without spectra image_samples = [] spectrum_samples = [] - + for sample in batch: - if sample['image'] is not None: + if sample["image"] is not None: image_samples.append(sample) - if sample['spectrum'] is not None: + if sample["spectrum"] is not None: spectrum_samples.append(sample) - + # Collate images collated = {} if image_samples: - collated['images'] = torch.stack([s['image'] for s in image_samples]) - collated['image_object_ids'] = [s['object_id'] for s in image_samples] - collated['image_targetids'] = torch.tensor([s['targetid'] for s in image_samples]) - collated['image_redshifts'] = torch.tensor([s['redshift'] for s in image_samples]) - + collated["images"] = torch.stack([s["image"] for s in image_samples]) + collated["image_object_ids"] = [s["object_id"] for s in image_samples] + collated["image_targetids"] = torch.tensor( + [s["targetid"] for s in image_samples] + ) + collated["image_redshifts"] = torch.tensor( + [s["redshift"] for s in image_samples] + ) + # Collate spectra if spectrum_samples: - collated['spectra'] = torch.stack([s['spectrum'] for s in spectrum_samples]) - collated['spectrum_object_ids'] = [s['object_id'] for s in spectrum_samples] - collated['spectrum_targetids'] = torch.tensor([s['targetid'] for s in spectrum_samples]) - collated['spectrum_redshifts'] = torch.tensor([s['redshift'] for s in spectrum_samples]) - + collated["spectra"] = torch.stack([s["spectrum"] for s in spectrum_samples]) + collated["spectrum_object_ids"] = [s["object_id"] for s in spectrum_samples] + collated["spectrum_targetids"] = torch.tensor( + [s["targetid"] for s in spectrum_samples] + ) + collated["spectrum_redshifts"] = torch.tensor( + [s["redshift"] for s in spectrum_samples] + ) + # Also include all metadata for reference - collated['all_object_ids'] = [s['object_id'] for s in batch] - collated['all_targetids'] = torch.tensor([s['targetid'] for s in batch]) - collated['all_redshifts'] = torch.tensor([s['redshift'] for s in batch]) - + collated["all_object_ids"] = [s["object_id"] for s in batch] + collated["all_targetids"] = torch.tensor([s["targetid"] for s in batch]) + collated["all_redshifts"] = torch.tensor([s["redshift"] for s in batch]) + return collated @@ -138,101 +149,110 @@ def prepare_multimodal_batch( image_patch_size: int, spectrum_patch_size: int, device: torch.device, - modality_registry + modality_registry, ) -> Dict[str, torch.Tensor]: """ Prepare a multimodal batch for AstroPT model input. - + Args: batch: Collated batch from multimodal_collate_fn image_patch_size: Size of image patches (e.g., 16) spectrum_patch_size: Size of spectrum patches (e.g., 256) device: Target device modality_registry: AstroPT modality registry - + Returns: Dictionary with model inputs """ inputs = {} - + # Process images if present - if 'images' in batch: - images = batch['images'].to(device) # (B, C, H, W) + if "images" in batch: + images = batch["images"].to(device) # (B, C, H, W) B, C, H, W = images.shape - + # Ensure image dimensions are divisible by patch size H_pad = (image_patch_size - (H % image_patch_size)) % image_patch_size W_pad = (image_patch_size - (W % image_patch_size)) % image_patch_size if H_pad or W_pad: images = F.pad(images, (0, W_pad, 0, H_pad)) H, W = images.shape[2], images.shape[3] - + # Create image patches patches_h = H // image_patch_size patches_w = W // image_patch_size num_patches = patches_h * patches_w - + # Reshape to patches: (B, C, H, W) -> (B, num_patches, patch_size*patch_size*C) - image_patches = images.unfold(2, image_patch_size, image_patch_size)\ - .unfold(3, image_patch_size, image_patch_size)\ - .contiguous()\ - .view(B, C, patches_h, patches_w, image_patch_size, image_patch_size)\ - .permute(0, 2, 3, 1, 4, 5)\ - .contiguous()\ - .view(B, num_patches, -1) - + image_patches = ( + images.unfold(2, image_patch_size, image_patch_size) + .unfold(3, image_patch_size, image_patch_size) + .contiguous() + .view(B, C, patches_h, patches_w, image_patch_size, image_patch_size) + .permute(0, 2, 3, 1, 4, 5) + .contiguous() + .view(B, num_patches, -1) + ) + # Create position indices for image patches image_positions = torch.arange(num_patches, device=device, dtype=torch.long) image_positions = image_positions.unsqueeze(0).expand(B, -1) - - inputs['images'] = image_patches - inputs['images_positions'] = image_positions - + + inputs["images"] = image_patches + inputs["images_positions"] = image_positions + # Process spectra if present - if 'spectra' in batch: - spectra = batch['spectra'].to(device) # (B, L) + if "spectra" in batch: + spectra = batch["spectra"].to(device) # (B, L) B, L = spectra.shape - + # Pad spectra to be divisible by patch size pad = (spectrum_patch_size - (L % spectrum_patch_size)) % spectrum_patch_size if pad: spectra = F.pad(spectra, (0, pad)) - + # Reshape into patches: (B, L) -> (B, num_patches, patch_size) spectrum_patches = spectra.view(B, -1, spectrum_patch_size) num_spectrum_patches = spectrum_patches.size(1) - + # Create position indices for spectrum patches - spectrum_positions = torch.arange(num_spectrum_patches, device=device, dtype=torch.long) + spectrum_positions = torch.arange( + num_spectrum_patches, device=device, dtype=torch.long + ) spectrum_positions = spectrum_positions.unsqueeze(0).expand(B, -1) - - inputs['spectra'] = spectrum_patches - inputs['spectra_positions'] = spectrum_positions - + + inputs["spectra"] = spectrum_patches + inputs["spectra_positions"] = spectrum_positions + return inputs if __name__ == "__main__": # Test the dataset and dataloader print("Testing Euclid+DESI multimodal dataset...") - + dataset = EuclidDESIMultimodalDataset(split="train_batch_1") print(f"Dataset loaded with {len(dataset)} samples") - + # Test a single sample sample = dataset[0] print(f"Sample keys: {sample.keys()}") - print(f"Image shape: {sample['image'].shape if sample['image'] is not None else 'None'}") - print(f"Spectrum shape: {sample['spectrum'].shape if sample['spectrum'] is not None else 'None'}") + print( + f"Image shape: {sample['image'].shape if sample['image'] is not None else 'None'}" + ) + print( + f"Spectrum shape: {sample['spectrum'].shape if sample['spectrum'] is not None else 'None'}" + ) print(f"Object ID: {sample['object_id']}") print(f"Redshift: {sample['redshift']}") - + # Test dataloader from torch.utils.data import DataLoader + dataloader = DataLoader(dataset, batch_size=4, collate_fn=multimodal_collate_fn) batch = next(iter(dataloader)) print(f"\nBatch keys: {batch.keys()}") - if 'images' in batch: + if "images" in batch: print(f"Batch images shape: {batch['images'].shape}") - if 'spectra' in batch: - print(f"Batch spectra shape: {batch['spectra'].shape}") \ No newline at end of file + if "spectra" in batch: + print(f"Batch spectra shape: {batch['spectra'].shape}") diff --git a/src/fmb/models/astropt/retrain_spectra_images.py b/src/fmb/models/astropt/retrain_spectra_images.py index 372c0d9..5c06343 100644 --- a/src/fmb/models/astropt/retrain_spectra_images.py +++ b/src/fmb/models/astropt/retrain_spectra_images.py @@ -8,31 +8,30 @@ import argparse import math import os +import sys import time from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Optional, Tuple +from pathlib import Path +from typing import Optional, Tuple +import matplotlib.pyplot as plt +import numpy as np import torch -import torch.nn.functional as F from torch.distributed import destroy_process_group, init_process_group from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -import matplotlib.pyplot as plt -import numpy as np -import sys -from pathlib import Path # Add src to pythonpath src_path = Path(__file__).resolve().parents[3] if str(src_path) not in sys.path: sys.path.insert(0, str(src_path)) -# CHANGED: Import from fmb.data instead of scratch -from fmb.data.load_display_data import EuclidDESIDataset from fmb.data.datasets import AstroPTDataset, FMBDataConfig + +# CHANGED: Import from fmb.data instead of scratch from fmb.paths import load_paths # Add external/astroPT/src to path for astropt package @@ -54,17 +53,22 @@ # CHANGED: Relative import to local euclid_desi_dataset package # Ensure src/fmb/models/astropt/euclid_desi_dataset exists and has multimodal_dataloader try: - from fmb.models.astropt.euclid_desi_dataset.multimodal_dataloader import multimodal_collate_fn, prepare_multimodal_batch + from fmb.models.astropt.euclid_desi_dataset.multimodal_dataloader import ( + multimodal_collate_fn, + prepare_multimodal_batch, + ) except ImportError: - # Fallback if running as module differently? - from .euclid_desi_dataset.multimodal_dataloader import multimodal_collate_fn, prepare_multimodal_batch - + # Fallback if running as module differently? + from .euclid_desi_dataset.multimodal_dataloader import ( + multimodal_collate_fn, + prepare_multimodal_batch, + ) @dataclass class TrainingConfig: """Configuration for multimodal training.""" - + # Output and logging out_dir: str = str(load_paths().retrained_weights / "astropt") eval_interval: int = 100 @@ -72,7 +76,7 @@ class TrainingConfig: log_interval: int = 20 checkpoint_interval: int = 5000 always_save_checkpoint: bool = False - + # Data train_split: str = "train" val_split: str = "test" @@ -82,7 +86,7 @@ class TrainingConfig: image_size: int = 224 spectrum_length: int = 7781 cache_dir: str = str(load_paths().dataset) - + # Model architecture block_size: int = 1024 image_patch_size: int = 16 @@ -93,7 +97,7 @@ class TrainingConfig: n_chan: int = 3 dropout: float = 0.0 bias: bool = False - + # Training learning_rate: float = 6e-4 weight_decay: float = 0.1 @@ -105,11 +109,11 @@ class TrainingConfig: lr_decay_iters: int = 30000 min_lr: float = 6e-5 max_iters: int = 3000 - + # System device: str = "cuda" dtype: str = "bfloat16" - compile: bool = False # Default to False on Windows + compile: bool = False # Default to False on Windows log_via_wandb: bool = False wandb_project: str = "astropt-multimodal" wandb_run_name: str = None @@ -118,7 +122,7 @@ class TrainingConfig: def parse_args() -> Tuple[TrainingConfig, Optional[str]]: """Parse command line arguments.""" paths = load_paths() - + # First pass: get config file parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--config", type=str, default=None) @@ -128,7 +132,8 @@ def parse_args() -> Tuple[TrainingConfig, Optional[str]]: yaml_config = {} if early_args.config: import yaml - with open(early_args.config, 'r') as f: + + with open(early_args.config, "r") as f: yaml_config = yaml.safe_load(f) or {} # Helper to get default from (YAML or TrainingConfig) @@ -136,49 +141,89 @@ def get_default(key, default_val): return yaml_config.get(key, default_val) # Second pass: full arguments - parser = argparse.ArgumentParser(description="Train AstroPT on Euclid images + DESI spectra") - parser.add_argument("--config", type=str, default=None, help="Path to YAML config file") - + parser = argparse.ArgumentParser( + description="Train AstroPT on Euclid images + DESI spectra" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + # Output and Data - parser.add_argument("--out-dir", default=get_default("out_dir", str(paths.retrained_weights / "astropt"))) - parser.add_argument("--cache-dir", default=get_default("cache_dir", str(paths.dataset))) + parser.add_argument( + "--out-dir", + default=get_default("out_dir", str(paths.retrained_weights / "astropt")), + ) + parser.add_argument( + "--cache-dir", default=get_default("cache_dir", str(paths.dataset)) + ) parser.add_argument("--train-split", default=get_default("train_split", "train")) parser.add_argument("--val-split", default=get_default("val_split", "test")) - + # Optimization parser.add_argument("--batch-size", type=int, default=get_default("batch_size", 8)) - parser.add_argument("--grad-accum", dest="gradient_accumulation_steps", type=int, - default=get_default("gradient_accumulation_steps", 4)) - parser.add_argument("--learning-rate", "--lr", type=float, default=get_default("learning_rate", 6e-4)) + parser.add_argument( + "--grad-accum", + dest="gradient_accumulation_steps", + type=int, + default=get_default("gradient_accumulation_steps", 4), + ) + parser.add_argument( + "--learning-rate", + "--lr", + type=float, + default=get_default("learning_rate", 6e-4), + ) parser.add_argument("--max-iters", type=int, default=get_default("max_iters", 3000)) - + # Logging - parser.add_argument("--log-interval", type=int, default=get_default("log_interval", 20)) - parser.add_argument("--eval-interval", type=int, default=get_default("eval_interval", 100)) + parser.add_argument( + "--log-interval", type=int, default=get_default("log_interval", 20) + ) + parser.add_argument( + "--eval-interval", type=int, default=get_default("eval_interval", 100) + ) parser.add_argument("--eval-iters", type=int, default=get_default("eval_iters", 50)) - parser.add_argument("--log-wandb", dest="log_via_wandb", action="store_true", default=get_default("log_via_wandb", False)) - parser.add_argument("--wandb-project", default=get_default("wandb_project", "astropt-multimodal")) + parser.add_argument( + "--log-wandb", + dest="log_via_wandb", + action="store_true", + default=get_default("log_via_wandb", False), + ) + parser.add_argument( + "--wandb-project", default=get_default("wandb_project", "astropt-multimodal") + ) parser.add_argument("--wandb-run-name", default=get_default("wandb_run_name", None)) - + # Resume - parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from") - parser.add_argument("--resume-best", action="store_true", default=False, help="Resume from best checkpoint in out_dir") - + parser.add_argument( + "--resume", type=str, default=None, help="Path to checkpoint to resume from" + ) + parser.add_argument( + "--resume-best", + action="store_true", + default=False, + help="Resume from best checkpoint in out_dir", + ) + # System parser.add_argument("--device", default=get_default("device", "cuda")) - parser.add_argument("--num-workers", type=int, default=get_default("num_workers", 0)) - parser.add_argument("--compile", action="store_true", default=get_default("compile", False)) + parser.add_argument( + "--num-workers", type=int, default=get_default("num_workers", 0) + ) + parser.add_argument( + "--compile", action="store_true", default=get_default("compile", False) + ) args = parser.parse_args() - + # Construct final config config_dict = TrainingConfig().__dict__.copy() - + # Update with YAML config_dict.update(yaml_config) - + # Update with CLI (only if explicitly passed, but argparse already handled defaults correctly now) - # Actually, we can just use vars(args) and filter None? + # Actually, we can just use vars(args) and filter None? # But argparse defaults are NOT None. # The better way is what we did above: set argparse defaults to yaml_values. for k, v in vars(args).items(): @@ -186,23 +231,25 @@ def get_default(key, default_val): config_dict[k] = v config = TrainingConfig(**config_dict) - + # Type conversion for key in ["learning_rate", "min_lr", "weight_decay", "grad_clip", "dropout"]: - if hasattr(config, key): setattr(config, key, float(getattr(config, key))) + if hasattr(config, key): + setattr(config, key, float(getattr(config, key))) for key in ["batch_size", "gradient_accumulation_steps", "image_size", "max_iters"]: - if hasattr(config, key): setattr(config, key, int(getattr(config, key))) + if hasattr(config, key): + setattr(config, key, int(getattr(config, key))) # Resume path resume_path = args.resume if args.resume_best: resume_path = os.path.join(config.out_dir, "ckpt_best.pt") - + # Validate WandB if config.log_via_wandb and not _WANDB_AVAILABLE: print("WandB not available. Disabling.") config.log_via_wandb = False - + return config, resume_path @@ -228,8 +275,16 @@ def setup_ddp(config: TrainingConfig): device = torch.device(config.device) master_process = True seed_offset = 0 - - return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, master_process, seed_offset + + return ( + ddp, + ddp_rank, + ddp_local_rank, + ddp_world_size, + device, + master_process, + seed_offset, + ) def create_modality_registry(config: TrainingConfig) -> ModalityRegistry: @@ -237,9 +292,11 @@ def create_modality_registry(config: TrainingConfig) -> ModalityRegistry: modalities = [ ModalityConfig( name="images", - input_size=config.image_patch_size * config.image_patch_size * config.n_chan, + input_size=config.image_patch_size + * config.image_patch_size + * config.n_chan, patch_size=config.image_patch_size, - loss_weight=779/196, # Mathematically balanced: 779/196 ≈ 3.97 + loss_weight=779 / 196, # Mathematically balanced: 779/196 ≈ 3.97 embed_pos=True, pos_input_size=1, ), @@ -248,34 +305,38 @@ def create_modality_registry(config: TrainingConfig) -> ModalityRegistry: input_size=config.spectrum_patch_size, patch_size=config.spectrum_patch_size, pos_input_size=1, - loss_weight=196/779, # Mathematically balanced: 196/779 ≈ 0.25 + loss_weight=196 / 779, # Mathematically balanced: 196/779 ≈ 0.25 embed_pos=True, ), ] return ModalityRegistry(modalities) -def create_datasets_and_loaders(config: TrainingConfig, ddp: bool, ddp_rank: int, ddp_world_size: int, device): +def create_datasets_and_loaders( + config: TrainingConfig, ddp: bool, ddp_rank: int, ddp_world_size: int, device +): """Create datasets and data loaders.""" - + # Create datasets def mk_config(split_name): - # We handle split logic (removing +) inside FMBDataConfig or Dataset? + # We handle split logic (removing +) inside FMBDataConfig or Dataset? # FMBBaseDataset passes split to EuclidDESIDataset which handles +, so just pass it. - # But wait, AstroPT implementation replaced + with , manually. + # But wait, AstroPT implementation replaced + with , manually. # EuclidDESIDataset supports comma separation. # Let's ensure compatibility. - normalized = split_name.replace("+", ",") if isinstance(split_name, str) else split_name + normalized = ( + split_name.replace("+", ",") if isinstance(split_name, str) else split_name + ) return FMBDataConfig( split=normalized, image_size=config.image_size, spectrum_length=config.spectrum_length, - cache_dir=config.cache_dir + cache_dir=config.cache_dir, ) train_dataset = AstroPTDataset(mk_config(config.train_split)) val_dataset = AstroPTDataset(mk_config(config.val_split)) - + # Create samplers for DDP train_sampler = None val_sampler = None @@ -285,19 +346,19 @@ def mk_config(split_name): num_replicas=ddp_world_size, rank=ddp_rank, shuffle=True, - drop_last=True + drop_last=True, ) val_sampler = DistributedSampler( val_dataset, num_replicas=ddp_world_size, rank=ddp_rank, shuffle=False, - drop_last=False + drop_last=False, ) - + # Use pin_memory only for CUDA - use_pin_memory = device.type == 'cuda' - + use_pin_memory = device.type == "cuda" + # Create data loaders train_loader = DataLoader( train_dataset, @@ -309,7 +370,7 @@ def mk_config(split_name): drop_last=True, sampler=train_sampler, ) - + val_loader = DataLoader( val_dataset, batch_size=config.batch_size, @@ -320,7 +381,7 @@ def mk_config(split_name): drop_last=False, sampler=val_sampler, ) - + return train_dataset, val_dataset, train_loader, val_loader @@ -333,86 +394,93 @@ def get_lr(it: int, config: TrainingConfig) -> float: if it > config.lr_decay_iters: return config.min_lr # Cosine decay - decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters) + decay_ratio = (it - config.warmup_iters) / ( + config.lr_decay_iters - config.warmup_iters + ) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return config.min_lr + coeff * (config.learning_rate - config.min_lr) @torch.no_grad() -def estimate_loss(model, train_loader, val_loader, config, modality_registry, device, ctx): +def estimate_loss( + model, train_loader, val_loader, config, modality_registry, device, ctx +): """Estimate loss on train and validation sets.""" model.eval() losses = {} - + for split, loader in [("train", train_loader), ("val", val_loader)]: split_losses = {modality: [] for modality in ["images", "spectra"]} - + for i, batch in enumerate(loader): if i >= config.eval_iters: break - + # Prepare batch inputs = prepare_multimodal_batch( - batch, config.image_patch_size, config.spectrum_patch_size, - device, modality_registry + batch, + config.image_patch_size, + config.spectrum_patch_size, + device, + modality_registry, ) - + if not inputs: # Skip if no valid inputs print(f"Warning: Empty inputs for {split} batch {i}") continue - - with ctx: # Proper target preparation for autoregressive training # Model outputs seq_len-1 for autoregressive modalities, full seq_len for others targets = {} for modality in inputs.keys(): - if modality.endswith('_positions'): + if modality.endswith("_positions"): continue # Skip position tensors - - if modality == 'images': + + if modality == "images": # For autoregressive modality: target = input[1:] (remove first token) targets[modality] = inputs[modality][:, 1:, :] else: # For non-autoregressive modality: target = input (full sequence) targets[modality] = inputs[modality] - + logits, loss = model(inputs, targets=targets) - + # Debug: check loss if loss is None: print(f"Warning: model returned None loss for {split} batch {i}") print(f" Model inputs: {inputs.keys()}") continue - + # Collect losses per modality - use total loss for all modalities # since the model returns aggregated loss split_losses["images"].append(loss.item()) split_losses["spectra"].append(loss.item()) - + # Average losses avg_losses = {} for modality, losses_list in split_losses.items(): if losses_list: avg_losses[modality] = sum(losses_list) / len(losses_list) else: - avg_losses[modality] = float('inf') - + avg_losses[modality] = float("inf") + losses[split] = avg_losses - + model.train() return losses -def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, ddp, filename="ckpt.pt"): +def save_checkpoint( + model, optimizer, iter_num, best_val_loss, config, ddp, filename="ckpt.pt" +): """Save model checkpoint.""" os.makedirs(config.out_dir, exist_ok=True) - + # Get the raw model (unwrap DDP if needed) raw_model = model.module if ddp else model - + checkpoint = { "model": raw_model.state_dict(), "optimizer": optimizer.state_dict(), @@ -428,70 +496,77 @@ def save_checkpoint(model, optimizer, iter_num, best_val_loss, config, ddp, file "best_val_loss": best_val_loss, "config": config.__dict__, } - + checkpoint_path = os.path.join(config.out_dir, filename) torch.save(checkpoint, checkpoint_path) print(f"Saved checkpoint to {checkpoint_path}") @torch.no_grad() -def visualize_reconstructions(model, val_loader, config, modality_registry, device, ctx, iter_num): +def visualize_reconstructions( + model, val_loader, config, modality_registry, device, ctx, iter_num +): """Generate and save reconstruction visualizations.""" model.eval() - + # Get one batch for visualization batch = next(iter(val_loader)) inputs = prepare_multimodal_batch( - batch, config.image_patch_size, config.spectrum_patch_size, - device, modality_registry + batch, + config.image_patch_size, + config.spectrum_patch_size, + device, + modality_registry, ) - + if not inputs: print("Warning: No valid inputs for visualization") return - + # Forward pass to get reconstructions with ctx: logits, _ = model(inputs) - + # Convert to float32 for matplotlib compatibility if logits: for key in logits: if isinstance(logits[key], torch.Tensor): logits[key] = logits[key].float() - + # Create visualization fig, axes = plt.subplots(3, 5, figsize=(20, 12)) - fig.suptitle(f'Reconstructions at Iteration {iter_num}', fontsize=16) - + fig.suptitle(f"Reconstructions at Iteration {iter_num}", fontsize=16) + # Process each modality - for i in range(min(5, len(batch['all_object_ids']))): # Show up to 5 examples - + for i in range(min(5, len(batch["all_object_ids"]))): # Show up to 5 examples + # Row 0: Original images - if 'images' in batch and len(batch['images']) > i: - orig_img = batch['images'][i].cpu().numpy() + if "images" in batch and len(batch["images"]) > i: + orig_img = batch["images"][i].cpu().numpy() if orig_img.shape[0] == 3: # RGB orig_img = np.transpose(orig_img, (1, 2, 0)) # Normalize to [0,1] for display - orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min()) + orig_img = (orig_img - orig_img.min()) / ( + orig_img.max() - orig_img.min() + ) axes[0, i].imshow(orig_img) - axes[0, i].set_title(f'Original Image {i+1}') - axes[0, i].axis('off') + axes[0, i].set_title(f"Original Image {i+1}") + axes[0, i].axis("off") else: - axes[0, i].axis('off') - + axes[0, i].axis("off") + # Row 1: Reconstructed images - if 'images' in inputs and 'images' in logits and len(logits['images']) > i: - recon_patches = logits['images'][i].cpu().numpy() # Shape: [195, 768] - + if "images" in inputs and "images" in logits and len(logits["images"]) > i: + recon_patches = logits["images"][i].cpu().numpy() # Shape: [195, 768] + # Reshape patches back to image format # Each patch is 16x16x3 = 768 values patch_size = config.image_patch_size n_channels = config.n_chan - + # Calculate grid dimensions (14x14 patches for 224x224 image) patches_per_side = config.image_size // patch_size - + # Only use the first 196 patches to reconstruct 14x14 grid if recon_patches.shape[0] >= patches_per_side * patches_per_side - 1: # We have 195 patches but need 196 for 14x14 grid @@ -500,83 +575,109 @@ def visualize_reconstructions(model, val_loader, config, modality_registry, devi # Add one zero patch at the end zero_patch = np.zeros((1, recon_patches.shape[1])) recon_patches = np.vstack([recon_patches, zero_patch]) - + # Take first 196 patches (14x14) - patches_to_use = recon_patches[:patches_per_side*patches_per_side] - + patches_to_use = recon_patches[: patches_per_side * patches_per_side] + # Reshape to [14, 14, 768] - patch_grid = patches_to_use.reshape(patches_per_side, patches_per_side, -1) - + patch_grid = patches_to_use.reshape( + patches_per_side, patches_per_side, -1 + ) + # Reshape each patch from 768 -> 16x16x3 - recon_img = np.zeros((patches_per_side*patch_size, patches_per_side*patch_size, n_channels)) - + recon_img = np.zeros( + ( + patches_per_side * patch_size, + patches_per_side * patch_size, + n_channels, + ) + ) + for py in range(patches_per_side): for px in range(patches_per_side): - patch_data = patch_grid[py, px].reshape(patch_size, patch_size, n_channels) - y_start, y_end = py*patch_size, (py+1)*patch_size - x_start, x_end = px*patch_size, (px+1)*patch_size + patch_data = patch_grid[py, px].reshape( + patch_size, patch_size, n_channels + ) + y_start, y_end = py * patch_size, (py + 1) * patch_size + x_start, x_end = px * patch_size, (px + 1) * patch_size recon_img[y_start:y_end, x_start:x_end] = patch_data - + # Normalize for display - recon_img = (recon_img - recon_img.min()) / (recon_img.max() - recon_img.min() + 1e-8) + recon_img = (recon_img - recon_img.min()) / ( + recon_img.max() - recon_img.min() + 1e-8 + ) recon_img = np.clip(recon_img, 0, 1) - + axes[1, i].imshow(recon_img) - axes[1, i].set_title(f'Reconstructed Image {i+1}') - axes[1, i].axis('off') + axes[1, i].set_title(f"Reconstructed Image {i+1}") + axes[1, i].axis("off") else: # Not enough patches - show black image axes[1, i].imshow(np.zeros((224, 224, 3))) - axes[1, i].set_title(f'Insufficient patches ({recon_patches.shape[0]})') - axes[1, i].axis('off') + axes[1, i].set_title(f"Insufficient patches ({recon_patches.shape[0]})") + axes[1, i].axis("off") else: - axes[1, i].axis('off') - + axes[1, i].axis("off") + # Row 2: Original and reconstructed spectra on same plot - if 'spectra' in batch and len(batch['spectra']) > i: - orig_spec = batch['spectra'][i].cpu().numpy() - axes[2, i].plot(orig_spec, label='Original', alpha=0.8, linewidth=1, color='blue') - + if "spectra" in batch and len(batch["spectra"]) > i: + orig_spec = batch["spectra"][i].cpu().numpy() + axes[2, i].plot( + orig_spec, label="Original", alpha=0.8, linewidth=1, color="blue" + ) + # Add reconstructed spectrum if available - if 'spectra' in inputs and 'spectra' in logits and len(logits['spectra']) > i: - recon_patches = logits['spectra'][i].cpu().numpy() # Shape: [31, 256] - + if ( + "spectra" in inputs + and "spectra" in logits + and len(logits["spectra"]) > i + ): + recon_patches = logits["spectra"][i].cpu().numpy() # Shape: [31, 256] + # Flatten patches back to spectrum recon_spec = recon_patches.flatten() # Shape: [31*256] = [7936] - + # Truncate to original spectrum length if len(orig_spec) <= len(recon_spec): - recon_spec = recon_spec[:len(orig_spec)] + recon_spec = recon_spec[: len(orig_spec)] else: # Pad if needed - recon_spec = np.pad(recon_spec, (0, len(orig_spec) - len(recon_spec)), 'constant') - - axes[2, i].plot(recon_spec, label='Reconstructed', alpha=0.8, linewidth=1, color='orange') - - axes[2, i].set_ylim([orig_spec.min()-1, orig_spec.max()+1]) - axes[2, i].set_title(f'Spectrum {i+1}') + recon_spec = np.pad( + recon_spec, (0, len(orig_spec) - len(recon_spec)), "constant" + ) + + axes[2, i].plot( + recon_spec, + label="Reconstructed", + alpha=0.8, + linewidth=1, + color="orange", + ) + + axes[2, i].set_ylim([orig_spec.min() - 1, orig_spec.max() + 1]) + axes[2, i].set_title(f"Spectrum {i+1}") axes[2, i].legend(fontsize=8) axes[2, i].grid(True, alpha=0.3) - axes[2, i].set_xlabel('Wavelength Index') - axes[2, i].set_ylabel('Flux') + axes[2, i].set_xlabel("Wavelength Index") + axes[2, i].set_ylabel("Flux") else: - axes[2, i].axis('off') - + axes[2, i].axis("off") + # Hide unused subplots for i in range(5): - if i >= len(batch['all_object_ids']): + if i >= len(batch["all_object_ids"]): for row in range(3): - axes[row, i].axis('off') - + axes[row, i].axis("off") + plt.tight_layout() - + # Save visualization - vis_dir = os.path.join(config.out_dir, 'visualizations') + vis_dir = os.path.join(config.out_dir, "visualizations") os.makedirs(vis_dir, exist_ok=True) - vis_path = os.path.join(vis_dir, f'reconstructions_iter_{iter_num:06d}.png') - plt.savefig(vis_path, dpi=150, bbox_inches='tight') + vis_path = os.path.join(vis_dir, f"reconstructions_iter_{iter_num:06d}.png") + plt.savefig(vis_path, dpi=150, bbox_inches="tight") plt.close() - + print(f"Saved reconstruction visualization: {vis_path}") model.train() @@ -585,20 +686,20 @@ def load_checkpoint(checkpoint_path, model, optimizer, device): """Load model and optimizer state from checkpoint.""" if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - + print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) - + # Load model state model.load_state_dict(checkpoint["model"]) - + # Load optimizer state optimizer.load_state_dict(checkpoint["optimizer"]) - + # Get training state iter_num = checkpoint.get("iter_num", 0) - best_val_loss = checkpoint.get("best_val_loss", float('inf')) - + best_val_loss = checkpoint.get("best_val_loss", float("inf")) + print(f"Resumed from iteration {iter_num}, best val loss: {best_val_loss:.6f}") return iter_num, best_val_loss @@ -607,7 +708,7 @@ def print_training_summary(iter_num, max_iters, best_val_loss, current_loss): """Print a training progress summary.""" progress = (iter_num / max_iters) * 100 print(f"\n{'='*60}") - print(f"TRAINING PROGRESS SUMMARY") + print("TRAINING PROGRESS SUMMARY") print(f"{'='*60}") print(f"Progress: {progress:.1f}% ({iter_num}/{max_iters} iterations)") print(f"Current loss: {current_loss:.6f}") @@ -619,51 +720,64 @@ def print_training_summary(iter_num, max_iters, best_val_loss, current_loss): def main(): """Main training loop.""" config, resume_path = parse_args() - + # Setup DDP - ddp, ddp_rank, ddp_local_rank, ddp_world_size, device, master_process, seed_offset = setup_ddp(config) - + ( + ddp, + ddp_rank, + ddp_local_rank, + ddp_world_size, + device, + master_process, + seed_offset, + ) = setup_ddp(config) + # Seed for reproducibility torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - + # Create output directory if master_process: os.makedirs(config.out_dir, exist_ok=True) print(f"Output directory: {config.out_dir}") print(f"Training on device: {device}") print(f"DDP: {ddp}, World size: {ddp_world_size}") - + # Create modality registry modality_registry = create_modality_registry(config) - + # Create datasets and loaders train_dataset, val_dataset, train_loader, val_loader = create_datasets_and_loaders( config, ddp, ddp_rank, ddp_world_size, device ) - + if master_process: print(f"Train dataset: {len(train_dataset)} samples") print(f"Val dataset: {len(val_dataset)} samples") print(f"Modalities: {modality_registry.names()}") - + # Print optimization settings - print(f"\nOptimization settings:") + print("\nOptimization settings:") for modality_name in modality_registry.names(): modality_config = modality_registry.get_config(modality_name) - print(f" {modality_config.name}: loss_weight={modality_config.loss_weight}, patch_size={modality_config.patch_size}") + print( + f" {modality_config.name}: loss_weight={modality_config.loss_weight}, patch_size={modality_config.patch_size}" + ) print(f" batch_size: {config.batch_size}") print(f" spectrum_patch_size: {config.spectrum_patch_size}") - + # Calculate tokens per iteration tokens_per_iter = ( - config.gradient_accumulation_steps * ddp_world_size * - config.batch_size * config.block_size * len(modality_registry.names()) + config.gradient_accumulation_steps + * ddp_world_size + * config.batch_size + * config.block_size + * len(modality_registry.names()) ) if master_process: print(f"Tokens per iteration: {tokens_per_iter:,}") - + # Create model gpt_config = GPTConfig( block_size=config.block_size, @@ -674,23 +788,23 @@ def main(): dropout=config.dropout, attn_type="causal", ) - + model = GPT(gpt_config, modality_registry) model.to(device) - + if master_process: print(f"Model parameters: {model.get_num_params() / 1e6:.1f}M") - + # Compile model if requested if config.compile: if master_process: print("Compiling model...") model = torch.compile(model) - + # Wrap with DDP if needed if ddp: model = DDP(model, device_ids=[ddp_local_rank]) - + # Create optimizer base_model = model.module if ddp else model optimizer = base_model.configure_optimizers( @@ -699,7 +813,7 @@ def main(): betas=(config.beta1, config.beta2), device_type=device.type, ) - + # Setup mixed precision dtype_map = { "float32": torch.float32, @@ -711,9 +825,10 @@ def main(): scaler = torch.amp.GradScaler(device.type, enabled=(target_dtype == torch.float16)) ctx = ( torch.amp.autocast(device_type=device.type, dtype=target_dtype) - if use_amp else nullcontext() + if use_amp + else nullcontext() ) - + # Initialize wandb if config.log_via_wandb and master_process: wandb.init( @@ -721,97 +836,137 @@ def main(): name=config.wandb_run_name, config=config.__dict__, ) - + # Initialize loss logging loss_log_file = None if master_process: # Create timestamped log file to avoid overwrites import datetime + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - loss_log_file = os.path.join(config.out_dir, f'training_log_{timestamp}.txt') - + loss_log_file = os.path.join(config.out_dir, f"training_log_{timestamp}.txt") + # Always create new log file with header - with open(loss_log_file, 'w') as f: + with open(loss_log_file, "w") as f: f.write("iter,train_loss,val_loss,lr,time_ms\n") print(f"Training log will be saved to: {loss_log_file}") - + # Load checkpoint if resuming start_iter = 0 if resume_path and master_process: try: - start_iter, best_val_loss = load_checkpoint(resume_path, model, optimizer, device) + start_iter, best_val_loss = load_checkpoint( + resume_path, model, optimizer, device + ) except Exception as e: print(f"Failed to load checkpoint: {e}") print("Starting from scratch...") start_iter = 0 - best_val_loss = float('inf') + best_val_loss = float("inf") else: - best_val_loss = float('inf') - + best_val_loss = float("inf") + # Training loop if master_process: if resume_path and start_iter > 0: print(f"Resuming training from iteration {start_iter}...") else: print("Starting training from scratch...") - + model.train() iter_num = start_iter - t0 = time.time() + time.time() t0_start = time.time() - + # Convert loaders to iterators # Convert loaders to iterators train_iter = iter(train_loader) - + from tqdm import tqdm - pbar = tqdm(total=config.max_iters, initial=start_iter, desc="Training", dynamic_ncols=True) - + + pbar = tqdm( + total=config.max_iters, initial=start_iter, desc="Training", dynamic_ncols=True + ) + micro_step = 0 while iter_num < config.max_iters: # Set learning rate lr = get_lr(iter_num, config) if config.decay_lr else config.learning_rate for param_group in optimizer.param_groups: - param_group['lr'] = lr - + param_group["lr"] = lr + # Evaluation and checkpointing - if iter_num % config.eval_interval == 0 and master_process and micro_step % config.gradient_accumulation_steps == 0: + if ( + iter_num % config.eval_interval == 0 + and master_process + and micro_step % config.gradient_accumulation_steps == 0 + ): # Only evaluate at the start of a macro-step to avoid breaking accumulation - metrics = estimate_loss(model, train_loader, val_loader, config, modality_registry, device, ctx) - train_loss = metrics["train"]["spectra"] # Using spectra loss as representative + metrics = estimate_loss( + model, train_loader, val_loader, config, modality_registry, device, ctx + ) + train_loss = metrics["train"][ + "spectra" + ] # Using spectra loss as representative val_loss = metrics["val"]["spectra"] - - pbar.write(f"step {iter_num}: train loss {train_loss:.4f}, val loss {val_loss:.4f}, lr {lr:.2e}") - + + pbar.write( + f"step {iter_num}: train loss {train_loss:.4f}, val loss {val_loss:.4f}, lr {lr:.2e}" + ) + # Print summary every now and then if iter_num > 0 and iter_num % (config.eval_interval * 5) == 0: - print_training_summary(iter_num, config.max_iters, best_val_loss, train_loss) - + print_training_summary( + iter_num, config.max_iters, best_val_loss, train_loss + ) + # Log to wandb if config.log_via_wandb: - wandb.log({ - "iter": iter_num, - "train/loss_spectra": metrics["train"]["spectra"], - "train/loss_images": metrics["train"]["images"], - "val/loss_spectra": metrics["val"]["spectra"], - "val/loss_images": metrics["val"]["images"], - "lr": lr, - }) - + wandb.log( + { + "iter": iter_num, + "train/loss_spectra": metrics["train"]["spectra"], + "train/loss_images": metrics["train"]["images"], + "val/loss_spectra": metrics["val"]["spectra"], + "val/loss_images": metrics["val"]["images"], + "lr": lr, + } + ) + # Save best checkpoint if val_loss < best_val_loss: best_val_loss = val_loss - save_checkpoint(model, optimizer, iter_num, best_val_loss, config, ddp, "ckpt_best.pt") + save_checkpoint( + model, + optimizer, + iter_num, + best_val_loss, + config, + ddp, + "ckpt_best.pt", + ) pbar.write(f"New best validation loss: {best_val_loss:.4f}") - + # Regular checkpoint - if config.always_save_checkpoint or (iter_num > 0 and iter_num % config.checkpoint_interval == 0): - save_checkpoint(model, optimizer, iter_num, best_val_loss, config, ddp, f"ckpt_{iter_num}.pt") - + if config.always_save_checkpoint or ( + iter_num > 0 and iter_num % config.checkpoint_interval == 0 + ): + save_checkpoint( + model, + optimizer, + iter_num, + best_val_loss, + config, + ddp, + f"ckpt_{iter_num}.pt", + ) + # Visualize reconstructions if iter_num > 0 and iter_num % (config.eval_interval * 2) == 0: - visualize_reconstructions(model, val_loader, config, modality_registry, device, ctx, iter_num) - + visualize_reconstructions( + model, val_loader, config, modality_registry, device, ctx, iter_num + ) + # Training step # Handle data loading with potential end of epoch try: @@ -819,35 +974,38 @@ def main(): except StopIteration: train_iter = iter(train_loader) batch = next(train_iter) - + # Prepare inputs inputs = prepare_multimodal_batch( - batch, config.image_patch_size, config.spectrum_patch_size, - device, modality_registry + batch, + config.image_patch_size, + config.spectrum_patch_size, + device, + modality_registry, ) - + if not inputs: continue - + with ctx: - # Proper target preparation (same as in estimate_loss) + # Proper target preparation (same as in estimate_loss) targets = {} for modality in inputs.keys(): - if modality.endswith('_positions'): + if modality.endswith("_positions"): continue - if modality == 'images': + if modality == "images": targets[modality] = inputs[modality][:, 1:, :] else: targets[modality] = inputs[modality] - + logits, loss = model(inputs, targets=targets) loss = loss / config.gradient_accumulation_steps - + # Backward pass scaler.scale(loss).backward() - + micro_step += 1 - + # Step optimizer if micro_step % config.gradient_accumulation_steps == 0: if config.grad_clip != 0.0: @@ -858,33 +1016,33 @@ def main(): optimizer.zero_grad(set_to_none=True) iter_num += 1 pbar.update(1) - + # Log training step if iter_num % config.log_interval == 0 and master_process: loss_val = loss.item() * config.gradient_accumulation_steps pbar.set_postfix({"loss": f"{loss_val:.4f}", "lr": f"{lr:.2e}"}) - + # Write to log file if loss_log_file: - with open(loss_log_file, 'a') as f: + with open(loss_log_file, "a") as f: f.write(f"{iter_num},{loss_val:.6f},,{lr:.2e},\n") - + if master_process: pbar.write("Training finished!") pbar.close() - + # Calculate final stats total_time = time.time() - t0_start hours = int(total_time // 3600) minutes = int((total_time % 3600) // 60) seconds = int(total_time % 60) - + # Write detailed summary to log file if loss_log_file: - with open(loss_log_file, 'a') as f: - f.write("\n" + "="*80 + "\n") + with open(loss_log_file, "a") as f: + f.write("\n" + "=" * 80 + "\n") f.write("TRAINING COMPLETE SUMMARY\n") - f.write("="*80 + "\n") + f.write("=" * 80 + "\n") f.write(f"Total Duration: {hours:02d}:{minutes:02d}:{seconds:02d}\n") f.write(f"Total Iterations: {iter_num}\n") f.write(f"Best Validation Loss: {best_val_loss:.6f}\n") @@ -892,21 +1050,26 @@ def main(): f.write("\nCONFIGURATION:\n") for key, val in config.__dict__.items(): f.write(f"{key}: {val}\n") - f.write("="*80 + "\n") - + f.write("=" * 80 + "\n") + # Final reconstruction visualization print("Generating final reconstructions...") try: - visualize_reconstructions(model, val_loader, config, modality_registry, device, ctx, iter_num) + visualize_reconstructions( + model, val_loader, config, modality_registry, device, ctx, iter_num + ) except Exception as e: print(f"Warning: Failed to generate final visualizations: {e}") - - save_checkpoint(model, optimizer, iter_num, best_val_loss, config, ddp, "ckpt_final.pt") + + save_checkpoint( + model, optimizer, iter_num, best_val_loss, config, ddp, "ckpt_final.pt" + ) if config.log_via_wandb: wandb.finish() - + if ddp: destroy_process_group() + if __name__ == "__main__": main() diff --git a/src/fmb/models/base/config.py b/src/fmb/models/base/config.py index 8116763..c85881a 100644 --- a/src/fmb/models/base/config.py +++ b/src/fmb/models/base/config.py @@ -16,7 +16,7 @@ class BaseTrainingConfig: """ Base configuration for all FMB model trainers. - + Parameters ---------- out_dir : Path @@ -44,10 +44,10 @@ class BaseTrainingConfig: num_workers : int Number of dataloader workers. """ - + # Paths out_dir: Path = load_paths().retrained_weights - + # Training epochs: int = 10 batch_size: int = 8 @@ -55,18 +55,18 @@ class BaseTrainingConfig: weight_decay: float = 0.1 grad_clip: float = 1.0 gradient_accumulation_steps: int = 1 - + # System device: str = "cuda" seed: int = 42 amp_dtype: str = "bfloat16" - + # Logging log_interval: int = 20 checkpoint_interval: int = 1 - + # Data num_workers: int = 0 - + # Resume resume_checkpoint: Optional[str] = None diff --git a/src/fmb/models/base/trainer.py b/src/fmb/models/base/trainer.py index 2dab30a..b8ced69 100644 --- a/src/fmb/models/base/trainer.py +++ b/src/fmb/models/base/trainer.py @@ -15,13 +15,13 @@ from tqdm.auto import tqdm from fmb.models.base.config import BaseTrainingConfig -from fmb.models.base.utils import set_seed, setup_amp, format_memory +from fmb.models.base.utils import format_memory, set_seed, setup_amp class BaseTrainer(ABC): """ Abstract base trainer for all FMB models. - + Provides standardized training loop with: - Automatic mixed precision (AMP) - Gradient accumulation @@ -29,7 +29,7 @@ class BaseTrainer(ABC): - Checkpointing - Validation - Loss history tracking - + Parameters ---------- model : nn.Module @@ -41,7 +41,7 @@ class BaseTrainer(ABC): val_loader : Optional[DataLoader] Validation data loader (optional). """ - + def __init__( self, model: nn.Module, @@ -53,42 +53,42 @@ def __init__( self.config = config self.train_loader = train_loader self.val_loader = val_loader - + # Setup device self.device = torch.device(config.device) self.model.to(self.device) - + # Setup seed set_seed(config.seed) - + # Setup optimizer self.optimizer = self._create_optimizer() - + # Setup AMP self.scaler, self.amp_ctx = setup_amp( device=config.device, dtype=config.amp_dtype, ) - + # Training state self.current_epoch = 0 self.global_step = 0 self.best_val_loss = float("inf") - + # History self.history: Dict[str, List[float]] = { "train_loss": [], "val_loss": [], } - + # Resume if checkpoint provided if config.resume_checkpoint: self.load_checkpoint(config.resume_checkpoint) - + def _create_optimizer(self) -> torch.optim.Optimizer: """ Create optimizer. Can be overridden by subclasses. - + Returns ------- torch.optim.Optimizer @@ -99,56 +99,56 @@ def _create_optimizer(self) -> torch.optim.Optimizer: lr=self.config.learning_rate, weight_decay=self.config.weight_decay, ) - + @abstractmethod def train_step(self, batch: Any) -> Dict[str, float]: """ Execute one training step. - + Must be implemented by subclasses. Should perform forward pass, compute loss, and return metrics dictionary. - + Parameters ---------- batch : Any Batch of training data. - + Returns ------- Dict[str, float] Dictionary of metrics. Must contain 'loss' key. """ pass - + @abstractmethod def val_step(self, batch: Any) -> Dict[str, float]: """ Execute one validation step. - + Must be implemented by subclasses. Should perform forward pass and return metrics dictionary. - + Parameters ---------- batch : Any Batch of validation data. - + Returns ------- Dict[str, float] Dictionary of metrics. Must contain 'loss' key. """ pass - + def train_epoch(self, epoch: int) -> float: """ Train for one complete epoch. - + Parameters ---------- epoch : int Current epoch number (1-indexed). - + Returns ------- float @@ -157,27 +157,27 @@ def train_epoch(self, epoch: int) -> float: self.model.train() total_loss = 0.0 num_batches = 0 - + progress = tqdm( self.train_loader, desc=f"Epoch {epoch}/{self.config.epochs}", leave=True, ) - + self.optimizer.zero_grad(set_to_none=True) - + for step, batch in enumerate(progress): # Forward pass with AMP with self.amp_ctx: metrics = self.train_step(batch) loss = metrics["loss"] - + # Scale loss for gradient accumulation loss = loss / self.config.gradient_accumulation_steps - + # Backward pass self.scaler.scale(loss).backward() - + # Optimizer step (with gradient accumulation) if (step + 1) % self.config.gradient_accumulation_steps == 0: # Gradient clipping @@ -187,37 +187,37 @@ def train_epoch(self, epoch: int) -> float: self.model.parameters(), self.config.grad_clip, ) - + # Optimizer step self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) - + self.global_step += 1 - + # Track loss (unscaled) total_loss += loss.item() * self.config.gradient_accumulation_steps num_batches += 1 - + # Logging if (step + 1) % self.config.log_interval == 0: current_loss = loss.item() * self.config.gradient_accumulation_steps postfix = {"loss": f"{current_loss:.6f}"} - + # Add GPU memory if available if self.device.type == "cuda": mem_allocated = torch.cuda.memory_allocated() postfix["mem"] = format_memory(mem_allocated) - + progress.set_postfix(postfix) - + return total_loss / max(num_batches, 1) - + @torch.no_grad() def validate(self) -> float: """ Validate the model on validation set. - + Returns ------- float @@ -225,21 +225,21 @@ def validate(self) -> float: """ if self.val_loader is None: return float("nan") - + self.model.eval() total_loss = 0.0 num_batches = 0 - + for batch in tqdm(self.val_loader, desc="Validating", leave=False): with self.amp_ctx: metrics = self.val_step(batch) loss = metrics["loss"] - + total_loss += loss.item() num_batches += 1 - + return total_loss / max(num_batches, 1) - + def save_checkpoint( self, epoch: int, @@ -248,7 +248,7 @@ def save_checkpoint( ) -> None: """ Save model checkpoint. - + Parameters ---------- epoch : int @@ -259,7 +259,7 @@ def save_checkpoint( Custom filename (default: checkpoint_epoch_{epoch:03d}.pt). """ self.config.out_dir.mkdir(parents=True, exist_ok=True) - + checkpoint = { "epoch": epoch, "global_step": self.global_step, @@ -270,25 +270,25 @@ def save_checkpoint( "history": self.history, "best_val_loss": self.best_val_loss, } - + # Regular checkpoint if filename is None: filename = f"checkpoint_epoch_{epoch:03d}.pt" - + ckpt_path = self.config.out_dir / filename torch.save(checkpoint, ckpt_path) print(f"💾 Saved checkpoint: {ckpt_path}") - + # Best checkpoint if is_best: best_path = self.config.out_dir / "checkpoint_best.pt" torch.save(checkpoint, best_path) print(f"New best checkpoint: {best_path}") - + def load_checkpoint(self, checkpoint_path: str) -> None: """ Load checkpoint and resume training. - + Parameters ---------- checkpoint_path : str @@ -297,29 +297,29 @@ def load_checkpoint(self, checkpoint_path: str) -> None: ckpt_path = Path(checkpoint_path) if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - + print(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(ckpt_path, map_location=self.device) - + # Load state self.model.load_state_dict(checkpoint["model"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) - + if "scaler" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler"]) - + # Restore training state self.current_epoch = checkpoint.get("epoch", 0) self.global_step = checkpoint.get("global_step", 0) self.best_val_loss = checkpoint.get("best_val_loss", float("inf")) self.history = checkpoint.get("history", {"train_loss": [], "val_loss": []}) - + print(f"Resumed from epoch {self.current_epoch}, step {self.global_step}") - + def train(self) -> None: """ Execute complete training loop. - + Trains for the specified number of epochs, validates after each epoch, and saves checkpoints according to the configuration. """ @@ -332,40 +332,40 @@ def train(self) -> None: print(f"Learning rate: {self.config.learning_rate}") print(f"AMP dtype: {self.config.amp_dtype}") print("=" * 60) - + start_epoch = self.current_epoch + 1 - + for epoch in range(start_epoch, self.config.epochs + 1): self.current_epoch = epoch - + # Train train_loss = self.train_epoch(epoch) - + # Validate val_loss = self.validate() - + # Update history self.history["train_loss"].append(train_loss) self.history["val_loss"].append(val_loss) - + # Print summary print(f"\n Epoch {epoch} Summary:") print(f" Train Loss: {train_loss:.6f}") if not torch.isnan(torch.tensor(val_loss)): print(f" Val Loss: {val_loss:.6f}") - + # Check if best is_best = val_loss < self.best_val_loss if is_best: self.best_val_loss = val_loss - print(f" New best validation loss!") - + print(" New best validation loss!") + # Save checkpoint if epoch % self.config.checkpoint_interval == 0: self.save_checkpoint(epoch, is_best=is_best) - + print() - + print("=" * 60) print("Training Complete!") print(f"Best validation loss: {self.best_val_loss:.6f}") diff --git a/src/fmb/models/base/utils.py b/src/fmb/models/base/utils.py index af38ac5..abca284 100644 --- a/src/fmb/models/base/utils.py +++ b/src/fmb/models/base/utils.py @@ -16,7 +16,7 @@ def set_seed(seed: int) -> None: """ Set random seed for reproducibility across all libraries. - + Parameters ---------- seed : int @@ -35,14 +35,14 @@ def setup_amp( ) -> Tuple[torch.amp.GradScaler, object]: """ Setup Automatic Mixed Precision (AMP) training. - + Parameters ---------- device : str Device type ('cuda', 'cpu', etc.). dtype : str AMP dtype: 'float16', 'bfloat16', or 'float32'. - + Returns ------- scaler : torch.amp.GradScaler @@ -55,35 +55,34 @@ def setup_amp( "bfloat16": torch.bfloat16, "float16": torch.float16, } - + target_dtype = dtype_map.get(dtype, torch.float32) use_amp = target_dtype in {torch.bfloat16, torch.float16} - + # GradScaler only needed for float16 scaler = torch.amp.GradScaler( - device, - enabled=(device == "cuda" and target_dtype == torch.float16) + device, enabled=(device == "cuda" and target_dtype == torch.float16) ) - + # Autocast context ctx = ( torch.amp.autocast(device_type=device, dtype=target_dtype) if use_amp else nullcontext() ) - + return scaler, ctx def format_memory(bytes_val: int) -> str: """ Format memory size in bytes to human-readable string. - + Parameters ---------- bytes_val : int Memory size in bytes. - + Returns ------- str diff --git a/src/fmb/models/external_imports.py b/src/fmb/models/external_imports.py index 72488ae..02a748b 100644 --- a/src/fmb/models/external_imports.py +++ b/src/fmb/models/external_imports.py @@ -13,7 +13,7 @@ def get_repo_root() -> Path: """ Get repository root directory. - + Returns ------- Path @@ -26,57 +26,57 @@ def get_repo_root() -> Path: def setup_external_paths(libraries: List[str] = None) -> None: """ Add external library paths to sys.path. - + Parameters ---------- libraries : List[str], optional List of library names to add. Defaults to all: ['AION', 'astroPT', 'AstroCLIP']. - + Raises ------ FileNotFoundError If external directory doesn't exist. - + Notes ----- This function is idempotent: calling it multiple times is safe. """ if libraries is None: libraries = ["AION", "astroPT", "AstroCLIP"] - + repo_root = get_repo_root() external_dir = repo_root / "external" - + if not external_dir.exists(): raise FileNotFoundError( f"External directory not found: {external_dir}\n" "Initialize submodules with: git submodule update --init --recursive" ) - + for lib in libraries: lib_path = external_dir / lib - + # Special handling for astroPT which has src/ subdirectory if lib == "astroPT": lib_path = lib_path / "src" - + if lib_path.exists() and str(lib_path) not in sys.path: sys.path.insert(0, str(lib_path)) print(f" Added to sys.path: {lib_path}") elif not lib_path.exists(): print(f" Warning: {lib} not found at {lib_path}") - print(f" Run: git submodule update --init --recursive") + print(" Run: git submodule update --init --recursive") def check_external_available(library: str) -> bool: """ Check if an external library is available for import. - + Parameters ---------- library : str Library name ('AION', 'astroPT', or 'AstroCLIP'). - + Returns ------- bool diff --git a/src/fmb/paths.py b/src/fmb/paths.py index d018297..d3c2512 100644 --- a/src/fmb/paths.py +++ b/src/fmb/paths.py @@ -7,26 +7,31 @@ from __future__ import annotations +import datetime as dt import os from dataclasses import dataclass from pathlib import Path from typing import Optional -import datetime as dt + import yaml + def _repo_root() -> Path: return Path(__file__).resolve().parents[2] + def _expand_vars(s: str) -> str: return os.path.expandvars(os.path.expanduser(s)) + def _p(v: str | Path) -> Path: return Path(_expand_vars(str(v))).resolve() + @dataclass(frozen=True) class FMBPaths: repo_root: Path - + # Specific specialized paths dataset: Path dataset_train: Path @@ -35,7 +40,6 @@ class FMBPaths: dataset_index: Path base_weights: Path - # Model specific base weights (can be configured separately) base_weights_aion: Path base_weights_astropt: Path @@ -45,21 +49,29 @@ class FMBPaths: nfs_weights: Path outliers: Path analysis: Path - + # Kept for backward compat or generic usage if needed embeddings: Path cache: Path - + # Fallback/base roots storage_root: Path runs_root: Path - + def ensure(self) -> "FMBPaths": # Create directories that are meant to be output directories # dataset and base_weights are input dirs usually, so we might not want to mkdir them blindly? # But if they are just roots, it's safer to ensure they exist or warn. # Let's ensure output dirs. - for p in [self.retrained_weights, self.nfs_weights, self.outliers, self.analysis, self.embeddings, self.cache, self.runs_root]: + for p in [ + self.retrained_weights, + self.nfs_weights, + self.outliers, + self.analysis, + self.embeddings, + self.cache, + self.runs_root, + ]: p.mkdir(parents=True, exist_ok=True) return self @@ -67,7 +79,7 @@ def embeddings_dir(self, model: str) -> Path: d = self.embeddings / model d.mkdir(parents=True, exist_ok=True) return d - + # Helper for generic "run" outputs if needed def new_run_dir(self, tag: str) -> Path: stamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S") @@ -75,8 +87,10 @@ def new_run_dir(self, tag: str) -> Path: d.mkdir(parents=True, exist_ok=True) return d + _CACHED_PATHS: Optional[FMBPaths] = None + def load_paths(config_path: Optional[Path] = None, *, ensure: bool = True) -> FMBPaths: global _CACHED_PATHS if _CACHED_PATHS is not None and config_path is None: @@ -101,11 +115,11 @@ def load_paths(config_path: Optional[Path] = None, *, ensure: bool = True) -> FM if c.exists(): config_path = c break - + if config_path is None or not config_path.exists(): - # Fallback to template - config_path = repo_root / "src" / "fmb" / "configs" / "paths.template.yaml" - + # Fallback to template + config_path = repo_root / "src" / "fmb" / "configs" / "paths.template.yaml" + # Load config if valid cfg = {} if config_path.exists(): @@ -135,26 +149,34 @@ def resolve_optional(key: str, default_abs: Path) -> Path: dataset_path = resolve("dataset_path", "data") dataset_path_train = resolve_optional("dataset_path_train", dataset_path / "train") dataset_path_test = resolve_optional("dataset_path_test", dataset_path / "test") - dataset_hf_id = cfg.get("dataset_hf_id", "msiudek/astroPT_euclid_Q1_desi_dr1_dataset") + dataset_hf_id = cfg.get( + "dataset_hf_id", "msiudek/astroPT_euclid_Q1_desi_dr1_dataset" + ) dataset_index = resolve_optional("dataset_index_path", dataset_path / "index.csv") base_weights_path = resolve("base_weights_path", "checkpoints/base") - + # Resolve model specific weights, defaulting to base_weights/ - base_weights_aion = resolve_optional("base_weights_path_aion", base_weights_path / "aion") - base_weights_astropt = resolve_optional("base_weights_path_astropt", base_weights_path / "astropt") - base_weights_astroclip = resolve_optional("base_weights_path_astroclip", base_weights_path / "astroclip") + base_weights_aion = resolve_optional( + "base_weights_path_aion", base_weights_path / "aion" + ) + base_weights_astropt = resolve_optional( + "base_weights_path_astropt", base_weights_path / "astropt" + ) + base_weights_astroclip = resolve_optional( + "base_weights_path_astroclip", base_weights_path / "astroclip" + ) retrained_weights_path = resolve("retrained_weights_path", "checkpoints/retrained") nfs_weights_path = resolve("nfs_weights_path", "checkpoints/nfs") outliers_path = resolve("outliers_path", "outputs/outliers") analysis_path = resolve("analysis_path", "outputs/analysis") - + # Generic/Legacy - emb_path = resolve("embeddings_path", "embeddings") # Also support embeddings_path + emb_path = resolve("embeddings_path", "embeddings") # Also support embeddings_path if "emb_root" in cfg: emb_path = _p(cfg["emb_root"]) - + cache_path = resolve("cache_root", "cache") runs_path = resolve("runs_root", "runs") @@ -181,6 +203,6 @@ def resolve_optional(key: str, default_abs: Path) -> Path: if ensure: paths.ensure() - + _CACHED_PATHS = paths return paths diff --git a/src/fmb/setup/check_environment_aion.py b/src/fmb/setup/check_environment_aion.py index a8c3aac..3e302c1 100644 --- a/src/fmb/setup/check_environment_aion.py +++ b/src/fmb/setup/check_environment_aion.py @@ -6,7 +6,7 @@ Description: Validate AION environment and dependencies """ -# Original script comming from https://github.com/mhuertascompany/camels-aion/tree/main and modified +# Original script comming from https://github.com/mhuertascompany/camels-aion/tree/main and modified # to be used as a test for the fmb package """Environment sanity checks AION model loading.""" @@ -31,6 +31,7 @@ from fmb.paths import load_paths + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -45,7 +46,7 @@ def parse_args() -> argparse.Namespace: default="aion-base", help="Hugging Face model identifier to load.", ) - + # Default to configured path try: default_model_dir = load_paths().base_weights_aion @@ -112,31 +113,32 @@ def check_hf_auth() -> None: print(f" Details: {exc}") -def check_aion(model_name: str, model_dir: Path | None, device: str, skip_codecs: bool) -> None: +def check_aion( + model_name: str, model_dir: Path | None, device: str, skip_codecs: bool +) -> None: print("\n== AION Model ==") import json - from huggingface_hub import hf_hub_download + from aion import AION # Lazy import to provide clearer error if missing + from huggingface_hub import hf_hub_download if model_dir is not None: model = AION.from_pretrained(model_dir) repo_id = str(model_dir) config = None - codec_repo: str | Path = model_dir else: repo_id = model_name if "/" in model_name else f"polymathic-ai/{model_name}" config_path = hf_hub_download(repo_id, "config.json") with open(config_path, "r", encoding="utf-8") as fh: config = json.load(fh) model = AION.from_pretrained(repo_id, config=config) - codec_repo = repo_id model = model.to(device) model.eval() print(f"Loaded `{repo_id}` and moved to `{device}`.") print(f"Loaded `{repo_id}` and moved to `{device}`.") print("Model loaded successfully (weights validation passed).") - + if not skip_codecs: print("Note: Codec verification skipped") diff --git a/src/fmb/setup/check_environment_astroclip.py b/src/fmb/setup/check_environment_astroclip.py index 6fd0cdc..a6609db 100644 --- a/src/fmb/setup/check_environment_astroclip.py +++ b/src/fmb/setup/check_environment_astroclip.py @@ -20,8 +20,10 @@ sys.path.insert(0, str(src_path)) import torch + from fmb.paths import load_paths + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -30,7 +32,7 @@ def parse_args() -> argparse.Namespace: default="cuda" if torch.cuda.is_available() else "cpu", help="Target device for model loading test.", ) - + # Default to configured path try: default_model_dir = load_paths().base_weights_astroclip @@ -45,9 +47,10 @@ def parse_args() -> argparse.Namespace: ) return parser.parse_args() + def check_astroclip(model_dir: Path, device: str) -> None: print("\n== AstroCLIP Model ==") - + if not model_dir or not model_dir.exists(): print(f"Error: Model directory {model_dir} does not exist.") sys.exit(1) @@ -55,37 +58,39 @@ def check_astroclip(model_dir: Path, device: str) -> None: # Check for the 3 expected files required_files = ["astrodino.ckpt", "specformer.ckpt", "astroclip.ckpt"] missing = [f for f in required_files if not (model_dir / f).exists()] - + if missing: print(f"Error: Missing required weight files in {model_dir}: {missing}") print("Please run `python src/fmb/setup/download_weights_astroclip.py` first.") sys.exit(1) print(f"Found all required weight files in {model_dir}") - + # Try loading the checkpoints simply with torch.load to verify integrity # We don't need to instantiate the full model here, just check if weights are readable try: for f in required_files: p = model_dir / f print(f"Verifying {f}...", end=" ", flush=True) - # map_location=device to test device memory allocation too if needed, + # map_location=device to test device memory allocation too if needed, # but usually cpu is safer strict integrity check without OOM # weights_only=False is required because these checkpoints might contain # lightning globals (AttributeDict, etc). - ckpt = torch.load(p, map_location="cpu", weights_only=False) + ckpt = torch.load(p, map_location="cpu", weights_only=False) print("OK") del ckpt - + except Exception as e: print(f"\nError check-loading weights: {e}") sys.exit(1) print("AstroCLIP weights verified successfully.") + def main() -> None: args = parse_args() check_astroclip(args.model_dir, args.device) + if __name__ == "__main__": sys.exit(main()) diff --git a/src/fmb/setup/download_data.py b/src/fmb/setup/download_data.py index c342c60..85ca9e2 100644 --- a/src/fmb/setup/download_data.py +++ b/src/fmb/setup/download_data.py @@ -5,18 +5,14 @@ Description: Download Euclid+DESI dataset from HuggingFace """ -import numpy as np -from tqdm import tqdm -import matplotlib.pyplot as plt -import pandas as pd -from datasets import load_dataset, load_from_disk +import os -import torch -from torch.utils.data import DataLoader +import matplotlib.pyplot as plt import numpy as np +import torch +from datasets import load_dataset from PIL import Image - -import os +from torch.utils.data import DataLoader """ Script to download the Euclid+DESI dataset from Hugging Face. @@ -28,193 +24,242 @@ """ os.environ["HF_HOME"] = "/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home" -os.environ["HF_HUB_CACHE"] = "/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/hub" -os.environ["HF_DATASETS_CACHE"] = "/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/datasets" +os.environ["HF_HUB_CACHE"] = ( + "/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/hub" +) +os.environ["HF_DATASETS_CACHE"] = ( + "/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/datasets" +) os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" - +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" class EuclidDESIDataset(torch.utils.data.Dataset): """PyTorch Dataset wrapper for the Euclid+DESI HuggingFace dataset.""" - - def __init__(self, split="train_batch_1", transform=None, - cache_dir="/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/datasets"): + + def __init__( + self, + split="train_batch_1", + transform=None, + cache_dir="/pbs/throng/training/astroinfo2025/model/euclid_desi/hf_home/datasets", + ): import os + os.makedirs(cache_dir, exist_ok=True) print(f"Loading dataset (split={split}) into cache_dir={cache_dir}") self.dataset = load_dataset( - "msiudek/astroPT_euclid_desi_dataset", - split=split, - cache_dir=cache_dir + "msiudek/astroPT_euclid_desi_dataset", split=split, cache_dir=cache_dir ) self.transform = transform - + def __len__(self): return len(self.dataset) - + def __getitem__(self, idx): """Get a single sample from the dataset.""" sample = self.dataset[idx] - + # Convert PIL image to tensor - rgb_image = sample['RGB_image'] + rgb_image = sample["RGB_image"] if isinstance(rgb_image, Image.Image): rgb_image = np.array(rgb_image) - + # Convert to tensor format (C, H, W) if rgb_image.ndim == 3: rgb_image = torch.from_numpy(rgb_image).permute(2, 0, 1).float() / 255.0 else: rgb_image = torch.from_numpy(rgb_image).unsqueeze(0).float() / 255.0 - + # Process spectrum data spectrum_data = None - if sample['spectrum'] is not None: - flux = sample['spectrum']['flux'] - wavelength = sample['spectrum']['wavelength'] + if sample["spectrum"] is not None: + flux = sample["spectrum"]["flux"] + wavelength = sample["spectrum"]["wavelength"] if flux is not None: spectrum_data = { - 'flux': torch.from_numpy(np.array(flux)).float(), - 'wavelength': torch.from_numpy(np.array(wavelength)).float() if wavelength is not None else None + "flux": torch.from_numpy(np.array(flux)).float(), + "wavelength": ( + torch.from_numpy(np.array(wavelength)).float() + if wavelength is not None + else None + ), } - + # Process SED data sed_fluxes = None - if sample['sed_data'] is not None: - flux_keys = [k for k in sample['sed_data'].keys() if k.startswith('flux_')] + if sample["sed_data"] is not None: + flux_keys = [k for k in sample["sed_data"].keys() if k.startswith("flux_")] if flux_keys: - sed_fluxes = torch.tensor([sample['sed_data'][k] for k in flux_keys]).float() - + sed_fluxes = torch.tensor( + [sample["sed_data"][k] for k in flux_keys] + ).float() + # Process individual band images vis_image = None nisp_y_image = None nisp_j_image = None nisp_h_image = None - - if 'VIS_image' in sample and sample['VIS_image'] is not None: - vis_image = torch.from_numpy(np.array(sample['VIS_image'])).float() - if 'NISP_Y_image' in sample and sample['NISP_Y_image'] is not None: - nisp_y_image = torch.from_numpy(np.array(sample['NISP_Y_image'])).float() - if 'NISP_J_image' in sample and sample['NISP_J_image'] is not None: - nisp_j_image = torch.from_numpy(np.array(sample['NISP_J_image'])).float() - if 'NISP_H_image' in sample and sample['NISP_H_image'] is not None: - nisp_h_image = torch.from_numpy(np.array(sample['NISP_H_image'])).float() - + + if "VIS_image" in sample and sample["VIS_image"] is not None: + vis_image = torch.from_numpy(np.array(sample["VIS_image"])).float() + if "NISP_Y_image" in sample and sample["NISP_Y_image"] is not None: + nisp_y_image = torch.from_numpy(np.array(sample["NISP_Y_image"])).float() + if "NISP_J_image" in sample and sample["NISP_J_image"] is not None: + nisp_j_image = torch.from_numpy(np.array(sample["NISP_J_image"])).float() + if "NISP_H_image" in sample and sample["NISP_H_image"] is not None: + nisp_h_image = torch.from_numpy(np.array(sample["NISP_H_image"])).float() + return { - 'object_id': sample['object_id'], - 'targetid': sample['targetid'], - 'redshift': sample['redshift'], - 'rgb_image': rgb_image, - 'vis_image': vis_image, - 'nisp_y_image': nisp_y_image, - 'nisp_j_image': nisp_j_image, - 'nisp_h_image': nisp_h_image, - 'spectrum': spectrum_data, - 'sed_fluxes': sed_fluxes, + "object_id": sample["object_id"], + "targetid": sample["targetid"], + "redshift": sample["redshift"], + "rgb_image": rgb_image, + "vis_image": vis_image, + "nisp_y_image": nisp_y_image, + "nisp_j_image": nisp_j_image, + "nisp_h_image": nisp_h_image, + "spectrum": spectrum_data, + "sed_fluxes": sed_fluxes, } + def test_dataloader(): """Test the PyTorch DataLoader functionality.""" print("Creating PyTorch dataset...") - + try: # Create dataset dataset = EuclidDESIDataset(split="train_batch_1") print(f"Dataset loaded with {len(dataset)} samples") - + # Create dataloader dataloader = DataLoader(dataset, batch_size=4, shuffle=True) - print(f"DataLoader created with batch size 4") - + print("DataLoader created with batch size 4") + # Test loading a batch batch = next(iter(dataloader)) - - print(f"\nBatch info:") + + print("\nBatch info:") print(f"RGB images shape: {batch['rgb_image'].shape}") print(f"Object IDs: {batch['object_id']}") print(f"Redshifts: {batch['redshift']}") - - if batch['sed_fluxes'] is not None and len(batch['sed_fluxes']) > 0: - print(f"SED fluxes shape: {batch['sed_fluxes'][0].shape if batch['sed_fluxes'][0] is not None else 'None'}") - + + if batch["sed_fluxes"] is not None and len(batch["sed_fluxes"]) > 0: + print( + f"SED fluxes shape: {batch['sed_fluxes'][0].shape if batch['sed_fluxes'][0] is not None else 'None'}" + ) + # Check individual band availability - has_vis = batch['vis_image'][0] is not None - has_nisp_y = batch['nisp_y_image'][0] is not None - has_nisp_j = batch['nisp_j_image'][0] is not None - has_nisp_h = batch['nisp_h_image'][0] is not None - print(f"Individual bands available: VIS={has_vis}, Y={has_nisp_y}, J={has_nisp_j}, H={has_nisp_h}") - + has_vis = batch["vis_image"][0] is not None + has_nisp_y = batch["nisp_y_image"][0] is not None + has_nisp_j = batch["nisp_j_image"][0] is not None + has_nisp_h = batch["nisp_h_image"][0] is not None + print( + f"Individual bands available: VIS={has_vis}, Y={has_nisp_y}, J={has_nisp_j}, H={has_nisp_h}" + ) + # Create visualization with RGB, spectrum, and individual bands fig, axes = plt.subplots(2, 3, figsize=(18, 12)) - fig.suptitle(f"Object {batch['object_id'][0]} (z={batch['redshift'][0]:.4f})", fontsize=16) - + fig.suptitle( + f"Object {batch['object_id'][0]} (z={batch['redshift'][0]:.4f})", + fontsize=16, + ) + # Show RGB image - rgb = batch['rgb_image'][0].permute(1, 2, 0).numpy() + rgb = batch["rgb_image"][0].permute(1, 2, 0).numpy() axes[0, 0].imshow(rgb) axes[0, 0].set_title("RGB Composite") - axes[0, 0].axis('off') - + axes[0, 0].axis("off") + # Show spectrum if available - if batch['spectrum'][0] is not None and batch['spectrum'][0]['flux'] is not None: - flux = batch['spectrum'][0]['flux'].numpy() - wavelength = batch['spectrum'][0]['wavelength'].numpy() if batch['spectrum'][0]['wavelength'] is not None else np.arange(len(flux)) - - axes[0, 1].plot(wavelength, flux, 'b-', linewidth=0.8) - axes[0, 1].set_xlabel('Wavelength (Å)') - axes[0, 1].set_ylabel('Flux') - axes[0, 1].set_title('DESI Spectrum') + if ( + batch["spectrum"][0] is not None + and batch["spectrum"][0]["flux"] is not None + ): + flux = batch["spectrum"][0]["flux"].numpy() + wavelength = ( + batch["spectrum"][0]["wavelength"].numpy() + if batch["spectrum"][0]["wavelength"] is not None + else np.arange(len(flux)) + ) + + axes[0, 1].plot(wavelength, flux, "b-", linewidth=0.8) + axes[0, 1].set_xlabel("Wavelength (Å)") + axes[0, 1].set_ylabel("Flux") + axes[0, 1].set_title("DESI Spectrum") axes[0, 1].grid(True, alpha=0.3) else: - axes[0, 1].text(0.5, 0.5, 'No spectrum data', ha='center', va='center', transform=axes[0, 1].transAxes) - axes[0, 1].set_title('Spectrum (Not Available)') - + axes[0, 1].text( + 0.5, + 0.5, + "No spectrum data", + ha="center", + va="center", + transform=axes[0, 1].transAxes, + ) + axes[0, 1].set_title("Spectrum (Not Available)") + # Show SED photometry - if batch['sed_fluxes'][0] is not None: - sed_fluxes = batch['sed_fluxes'][0].numpy() + if batch["sed_fluxes"][0] is not None: + sed_fluxes = batch["sed_fluxes"][0].numpy() axes[0, 2].bar(range(len(sed_fluxes)), sed_fluxes) - axes[0, 2].set_xlabel('Filter Index') - axes[0, 2].set_ylabel('Flux') - axes[0, 2].set_title(f'SED Photometry ({len(sed_fluxes)} bands)') + axes[0, 2].set_xlabel("Filter Index") + axes[0, 2].set_ylabel("Flux") + axes[0, 2].set_title(f"SED Photometry ({len(sed_fluxes)} bands)") else: - axes[0, 2].text(0.5, 0.5, 'No SED data', ha='center', va='center', transform=axes[0, 2].transAxes) - axes[0, 2].set_title('SED (Not Available)') - + axes[0, 2].text( + 0.5, + 0.5, + "No SED data", + ha="center", + va="center", + transform=axes[0, 2].transAxes, + ) + axes[0, 2].set_title("SED (Not Available)") + # Show individual band images band_data = [ - (batch['vis_image'][0], 'VIS Band'), - (batch['nisp_y_image'][0], 'NIR-Y Band'), - (batch['nisp_j_image'][0], 'NIR-J Band') + (batch["vis_image"][0], "VIS Band"), + (batch["nisp_y_image"][0], "NIR-Y Band"), + (batch["nisp_j_image"][0], "NIR-J Band"), ] - + for i, (band_image, title) in enumerate(band_data): if band_image is not None: image_data = band_image.numpy() - im = axes[1, i].imshow(image_data, cmap='viridis') + im = axes[1, i].imshow(image_data, cmap="viridis") axes[1, i].set_title(title) - axes[1, i].axis('off') + axes[1, i].axis("off") plt.colorbar(im, ax=axes[1, i], fraction=0.046, pad=0.04) else: - axes[1, i].text(0.5, 0.5, f'No {title}', ha='center', va='center', transform=axes[1, i].transAxes) - axes[1, i].set_title(f'{title} (Missing)') - axes[1, i].axis('off') - + axes[1, i].text( + 0.5, + 0.5, + f"No {title}", + ha="center", + va="center", + transform=axes[1, i].transAxes, + ) + axes[1, i].set_title(f"{title} (Missing)") + axes[1, i].axis("off") + plt.tight_layout() plt.show() # Save the figure explicitly - output_file = 'dataloader_test_batch.png' - plt.savefig(output_file, dpi=150, bbox_inches='tight') + output_file = "dataloader_test_batch.png" + plt.savefig(output_file, dpi=150, bbox_inches="tight") print(f"Figure saved as '{output_file}'") - + # Close the figure to free memory plt.close(fig) - + print("DataLoader test successful!") - + except Exception as e: print(f"Error in dataloader test: {e}") + if __name__ == "__main__": print("=== PyTorch DataLoader Test ===") - test_dataloader() \ No newline at end of file + test_dataloader() diff --git a/src/fmb/setup/download_weights_aion.py b/src/fmb/setup/download_weights_aion.py index f56fc66..3905dcd 100644 --- a/src/fmb/setup/download_weights_aion.py +++ b/src/fmb/setup/download_weights_aion.py @@ -14,15 +14,13 @@ from __future__ import annotations import argparse -import os -import json import inspect +import json +import sys from pathlib import Path from typing import Optional -import sys -import torch -from huggingface_hub import snapshot_download, hf_hub_download +from huggingface_hub import hf_hub_download, snapshot_download # Add src to pythonpath so we can import 'fmb' package if not installed src_path = Path(__file__).resolve().parents[2] @@ -35,6 +33,7 @@ try: from aion.codecs.config import MODALITY_CODEC_MAPPING from aion.modalities import LegacySurveyImage + _AION_AVAILABLE = True except ImportError: _AION_AVAILABLE = False @@ -43,6 +42,7 @@ # Constants DEFAULT_AION_REPO = "polymathic-ai/aion-base" + def download_aion_model( repo_id: str, revision: Optional[str], @@ -72,25 +72,37 @@ def prime_aion_codecs( # For now, we only handle LegacySurveyImage as in the original script # Simplified logic: just ensure config is there and try to load - - print(f"Priming AION Codecs from '{repo_id}' (using local dir '{dest_dir}' if valid)...") - + + print( + f"Priming AION Codecs from '{repo_id}' (using local dir '{dest_dir}' if valid)..." + ) + # Check if we can use local config repo_ref = str(dest_dir) if (dest_dir / "config.json").exists() else repo_id - + # Download codec config if needed try: if Path(repo_ref).exists(): - config_path = Path(repo_ref) / "codecs" / LegacySurveyImage.name / "config.json" + config_path = ( + Path(repo_ref) / "codecs" / LegacySurveyImage.name / "config.json" + ) if not config_path.exists(): - print("Local codec config missing, falling back to HF Hub download for config...") - config_path = Path(hf_hub_download(repo_id, f"codecs/{LegacySurveyImage.name}/config.json")) + print( + "Local codec config missing, falling back to HF Hub download for config..." + ) + config_path = Path( + hf_hub_download( + repo_id, f"codecs/{LegacySurveyImage.name}/config.json" + ) + ) else: - config_path = Path(hf_hub_download(repo_id, f"codecs/{LegacySurveyImage.name}/config.json")) - + config_path = Path( + hf_hub_download(repo_id, f"codecs/{LegacySurveyImage.name}/config.json") + ) + with open(config_path, "r", encoding="utf-8") as fh: codec_config = json.load(fh) - + # Instantiate codec_cls = MODALITY_CODEC_MAPPING[LegacySurveyImage] init_params = inspect.signature(codec_cls.__init__).parameters @@ -99,10 +111,10 @@ def prime_aion_codecs( for name in init_params if name != "self" and name in codec_config } - + # Use local path if it exists, else repo_id load_repo = str(dest_dir) if (dest_dir / "config.json").exists() else repo_id - + print(f"Instantiating codec from {load_repo}...") codec = codec_cls.from_pretrained( load_repo, @@ -111,32 +123,43 @@ def prime_aion_codecs( ) codec.to(device).eval() print("Codec instantiated successfully (weights should be cached).") - + except Exception as e: print(f"Error priming codecs: {e}") - print("You may need to run this again or check your AION installation/connection.") + print( + "You may need to run this again or check your AION installation/connection." + ) def main() -> None: paths = load_paths() - + parser = argparse.ArgumentParser(description="Download AION model weights.") - parser.add_argument("--repo", default=DEFAULT_AION_REPO, help="HF repo ID for AION.") + parser.add_argument( + "--repo", default=DEFAULT_AION_REPO, help="HF repo ID for AION." + ) parser.add_argument("--revision", default=None, help="Revision for AION.") - parser.add_argument("--force-codecs", action="store_true", help="Force codec priming even if model download skipped.") - + parser.add_argument( + "--force-codecs", + action="store_true", + help="Force codec priming even if model download skipped.", + ) + args = parser.parse_args() dest = paths.base_weights_aion if not dest.exists() or len(list(dest.glob("*"))) == 0: - download_aion_model(args.repo, args.revision, dest) - prime_aion_codecs(args.repo, dest) + download_aion_model(args.repo, args.revision, dest) + prime_aion_codecs(args.repo, dest) else: - print(f"AION directory {dest} already exists and is not empty. Skipping download.") - if args.force_codecs: - prime_aion_codecs(args.repo, dest) - + print( + f"AION directory {dest} already exists and is not empty. Skipping download." + ) + if args.force_codecs: + prime_aion_codecs(args.repo, dest) + print("AION weight setup process finished.") + if __name__ == "__main__": main() diff --git a/src/fmb/setup/download_weights_astroclip.py b/src/fmb/setup/download_weights_astroclip.py index ac3c20e..bb24589 100644 --- a/src/fmb/setup/download_weights_astroclip.py +++ b/src/fmb/setup/download_weights_astroclip.py @@ -20,6 +20,7 @@ sys.path.insert(0, str(src_path)) from huggingface_hub import hf_hub_download + from fmb.paths import load_paths # Weights to download: (repo_id, filename) @@ -29,11 +30,12 @@ ("polymathic-ai/astroclip", "astroclip.ckpt"), ] + def download_astroclip_weights(dest_dir: Path) -> None: """Download the 3 required AstroCLIP checkpoints.""" dest_dir.mkdir(parents=True, exist_ok=True) print(f"Downloading AstroCLIP weights to {dest_dir}...") - + for repo_id, filename in ASTROCLIP_WEIGHTS: print(f" Downloading {filename} from {repo_id}...") hf_hub_download( @@ -44,6 +46,7 @@ def download_astroclip_weights(dest_dir: Path) -> None: ) print("AstroCLIP weights download complete.") + def main() -> None: paths = load_paths() parser = argparse.ArgumentParser(description=__doc__) @@ -54,8 +57,9 @@ def main() -> None: help="Destination directory for the checkpoints.", ) args = parser.parse_args() - + download_astroclip_weights(args.dest) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/combined_umap.py b/src/fmb/viz/combined_umap.py index cb7386f..5927642 100644 --- a/src/fmb/viz/combined_umap.py +++ b/src/fmb/viz/combined_umap.py @@ -7,24 +7,23 @@ import argparse from pathlib import Path -from typing import Sequence, Optional, Tuple +from typing import Sequence, Tuple import matplotlib.pyplot as plt import numpy as np import torch -from astropy.io import fits -from tqdm import tqdm import umap.umap_ as umap +from astropy.io import fits +from fmb.data.load_display_data import EuclidDESIDataset +from fmb.paths import load_paths from fmb.viz.utils import ( - load_index, collect_samples, collect_samples_with_index, - prepare_rgb_image, + load_index, load_viz_style, + prepare_rgb_image, ) -from fmb.data.load_display_data import EuclidDESIDataset -from fmb.paths import load_paths # --- Publication Style Settings --- load_viz_style() @@ -39,6 +38,7 @@ KEY_AION_CACHE = "aion_embedding_hsc_desi" KEY_ASTROCLIP_CACHE = "astroclip_embedding_joint" + def load_embeddings(path: Path) -> list[dict]: """Load embedding records from a .pt file.""" data = torch.load(path, map_location="cpu", weights_only=False) @@ -48,12 +48,13 @@ def load_embeddings(path: Path) -> list[dict]: return [data] raise ValueError(f"Unsupported embeddings format: {type(data)}") + def load_coordinates(load_path: Path) -> dict[str, np.ndarray]: """Load previously computed coordinates (t-SNE or UMAP). Returns empty dict if not found.""" if not load_path.exists(): print(f" Cache file {load_path} not found. Will start with empty cache.") return {} - + try: coords_map = torch.load(load_path, map_location="cpu", weights_only=False) # Convert tensors to numpy if needed @@ -63,81 +64,108 @@ def load_coordinates(load_path: Path) -> dict[str, np.ndarray]: print(f" Loaded coordinates from {load_path}") return coords_map except Exception as e: - print(f"Warning: Failed to load cache {load_path}: {e}. Starting with empty cache.") + print( + f"Warning: Failed to load cache {load_path}: {e}. Starting with empty cache." + ) return {} + def stack_embeddings(records: Sequence[dict], key: str) -> Tuple[np.ndarray, list[str]]: """Stack embeddings. If 'embedding_joint' missing, construct from images/spectra.""" vectors = [] ids = [] - + for rec in records: tensor = rec.get(key) - + # Special handling for Joint embedding (concatenation) # Check if we should try to construct it if "embedding_joint" in key and tensor is None: - # Try un-prefixed components (common case in raw files) - img = rec.get("embedding_images") - spec = rec.get("embedding_spectra") - - if img is not None and spec is not None: - img = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else np.asarray(img) - spec = spec.detach().cpu().numpy() if isinstance(spec, torch.Tensor) else np.asarray(spec) - tensor = np.concatenate([img, spec]) - + # Try un-prefixed components (common case in raw files) + img = rec.get("embedding_images") + spec = rec.get("embedding_spectra") + + if img is not None and spec is not None: + img = ( + img.detach().cpu().numpy() + if isinstance(img, torch.Tensor) + else np.asarray(img) + ) + spec = ( + spec.detach().cpu().numpy() + if isinstance(spec, torch.Tensor) + else np.asarray(spec) + ) + tensor = np.concatenate([img, spec]) + if tensor is None: continue - + if isinstance(tensor, torch.Tensor): vectors.append(tensor.detach().cpu().numpy()) else: vectors.append(np.asarray(tensor)) - + # ID oid = rec.get("object_id", "") if isinstance(oid, torch.Tensor): - oid = oid.item() if oid.numel() == 1 else str(oid.tolist()) + oid = oid.item() if oid.numel() == 1 else str(oid.tolist()) ids.append(str(oid)) - + if not vectors: - raise ValueError(f"No embeddings found for key '{key}' and construction failed.") - + raise ValueError( + f"No embeddings found for key '{key}' and construction failed." + ) + return np.stack(vectors, axis=0), ids + def compute_umap(embeddings: np.ndarray, random_state: int) -> np.ndarray: - print(f" Computing UMAP for {len(embeddings)} samples (dim={embeddings.shape[1]})...") + print( + f" Computing UMAP for {len(embeddings)} samples (dim={embeddings.shape[1]})..." + ) # Using presets similar to visualize script (balanced) # Using n_neighbors=30 to maintain structure, metric=cosine standard reducer = umap.UMAP(random_state=random_state, n_neighbors=15, min_dist=0.1) return reducer.fit_transform(embeddings) + def load_fits_catalog(path: Path) -> tuple[dict, str]: """Load FITS catalog and return dict mapping object_id -> row.""" with fits.open(path) as hdul: data = hdul[1].data columns = hdul[1].columns.names - + catalog_dict = {} id_column = None - + # Find ID column - for priority_col in ['TARGETID', 'targetid', 'TargetID', 'object_id', 'objid', 'id']: + for priority_col in [ + "TARGETID", + "targetid", + "TargetID", + "object_id", + "objid", + "id", + ]: if priority_col in columns: id_column = priority_col break - + if id_column is None: - raise ValueError(f"Could not find object ID column in FITS. Available: {columns}") - + raise ValueError( + f"Could not find object ID column in FITS. Available: {columns}" + ) + print(f"Using '{id_column}' as object ID column") - + for row in data: obj_id = str(row[id_column]) catalog_dict[obj_id] = {col: row[col] for col in columns} - + return catalog_dict, id_column + def normalize_simple(coords: np.ndarray) -> np.ndarray: c_min = coords.min(axis=0) c_max = coords.max(axis=0) @@ -151,15 +179,16 @@ def robust_normalize(coords: np.ndarray) -> np.ndarray: # Use 1st and 99th percentile for robust range c_min = np.percentile(coords, 1, axis=0) c_max = np.percentile(coords, 99, axis=0) - + # Clip extreme outliers for the purpose of grid assignment/visualization boundaries clipped = np.clip(coords, c_min, c_max) - + denom = c_max - c_min denom[denom == 0] = 1e-9 - + return (clipped - c_min) / denom + def assign_to_grid( object_ids: Sequence[str], coords: np.ndarray, @@ -186,7 +215,7 @@ def assign_to_grid( rng = np.random.default_rng(random_state) selected_ids: list[str] = [] cell_positions: list[tuple[int, int]] = [] - + for gy in range(grid_rows): for gx in range(grid_cols): cell = (gx, gy) @@ -199,6 +228,7 @@ def assign_to_grid( return selected_ids, cell_positions + def add_thumbnails( ax: plt.Axes, cell_positions: list[tuple[int, int]], @@ -217,7 +247,7 @@ def add_thumbnails( xmin, xmax = gx, gx + 1 ymin, ymax = gy, gy + 1 - + # Add image ax.imshow( image, @@ -227,14 +257,24 @@ def add_thumbnails( aspect="auto", zorder=10, ) - + # Add frame - rect = plt.Rectangle((xmin, ymin), 1, 1, linewidth=0.5, edgecolor='white', facecolor='none', zorder=11, alpha=0.5) + rect = plt.Rectangle( + (xmin, ymin), + 1, + 1, + linewidth=0.5, + edgecolor="white", + facecolor="none", + zorder=11, + alpha=0.5, + ) ax.add_patch(rect) - - except Exception as exc: + + except Exception: pass + def plot_scatter_panel( ax: plt.Axes, coords: np.ndarray, @@ -244,67 +284,76 @@ def plot_scatter_panel( vmax: float, cmap: str = "plasma", point_size: float = 3.0, - use_hexbin: bool = True, # Default to Hexbin + use_hexbin: bool = True, # Default to Hexbin ) -> None: mask = ~np.isnan(values) - + # Use robust normalization for plotting to ensure main structure fills frame norm_coords = robust_normalize(coords) - + mappable = None - + if use_hexbin: # Hexbin plot mappable = ax.hexbin( - norm_coords[mask, 0], - norm_coords[mask, 1], - C=values[mask], - gridsize=100, # Increased resolution + norm_coords[mask, 0], + norm_coords[mask, 1], + C=values[mask], + gridsize=100, # Increased resolution reduce_C_function=np.mean, mincnt=1, - linewidths=0.0, # Smoother look - edgecolors='face', + linewidths=0.0, # Smoother look + edgecolors="face", cmap=cmap, vmin=vmin, vmax=vmax, - rasterized=True + rasterized=True, ) else: # Scatter plot # Plot NaNs if (~mask).any(): - ax.scatter(norm_coords[~mask, 0], norm_coords[~mask, 1], s=point_size, c='lightgray', alpha=0.3, rasterized=True) - + ax.scatter( + norm_coords[~mask, 0], + norm_coords[~mask, 1], + s=point_size, + c="lightgray", + alpha=0.3, + rasterized=True, + ) + # Plot valid if mask.any(): mappable = ax.scatter( - norm_coords[mask, 0], - norm_coords[mask, 1], - s=point_size, - c=values[mask], - cmap=cmap, - alpha=0.7, + norm_coords[mask, 0], + norm_coords[mask, 1], + s=point_size, + c=values[mask], + cmap=cmap, + alpha=0.7, edgecolors="none", rasterized=True, - vmin=vmin, vmax=vmax + vmin=vmin, + vmax=vmax, ) - + ax.set_title(title, fontsize=24, pad=15) ax.set_xticks([]) ax.set_yticks([]) - ax.set_aspect('auto') + ax.set_aspect("auto") # Set explicit limits since we normalized to 0-1 ax.set_xlim(0, 1) ax.set_ylim(0, 1) - + # Add frame for spine in ax.spines.values(): spine.set_visible(True) spine.set_linewidth(1.5) - spine.set_color('black') - + spine.set_color("black") + return mappable + def plot_thumbnail_panel( ax: plt.Axes, thumb_ids: list[str], @@ -318,27 +367,28 @@ def plot_thumbnail_panel( id_to_sample = {str(s.get("object_id")): s for s in samples} ordered_samples = [] ordered_cells = [] - + for oid, cell in zip(thumb_ids, cell_positions): sample = id_to_sample.get(str(oid)) if sample: ordered_samples.append(sample) ordered_cells.append(cell) - + add_thumbnails(ax, ordered_cells, ordered_samples, grid_rows, grid_cols) - + ax.set_title(title, fontsize=24, pad=15) ax.set_xlim(0, grid_cols) ax.set_ylim(0, grid_rows) - ax.set_aspect('auto') - + ax.set_aspect("auto") + # Add frame but hide ticks ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(True) spine.set_linewidth(1.5) - spine.set_color('black') + spine.set_color("black") + def plot_similarity_histogram( ax: plt.Axes, @@ -349,67 +399,124 @@ def plot_similarity_histogram( ) -> None: # Remove NaNs valid_values = values[~np.isnan(values)] - + if len(valid_values) == 0: return # Plot histogram ax.hist( - np.abs(valid_values), - bins=bins, - range=(0, 1), - density=True, - color=color, - alpha=0.7, - edgecolor='none', - rasterized=True + np.abs(valid_values), + bins=bins, + range=(0, 1), + density=True, + color=color, + alpha=0.7, + edgecolor="none", + rasterized=True, ) - + # Add statistics lines mean_val = np.mean(np.abs(valid_values)) median_val = np.median(np.abs(valid_values)) - - ax.axvline(mean_val, color='black', linestyle='--', linewidth=1.5, label=f'Mean: {mean_val:.2f}') - ax.axvline(median_val, color='red', linestyle=':', linewidth=1.5, label=f'Median: {median_val:.2f}') - + + ax.axvline( + mean_val, + color="black", + linestyle="--", + linewidth=1.5, + label=f"Mean: {mean_val:.2f}", + ) + ax.axvline( + median_val, + color="red", + linestyle=":", + linewidth=1.5, + label=f"Median: {median_val:.2f}", + ) + ax.set_title(title, fontsize=24, pad=15) ax.set_xlim(0, 1) - + # Hide y-axis ticks/labels as density is relative ax.set_yticks([]) ax.set_ylabel("Density", fontsize=32) ax.set_xlabel("|Cosine Similarity|", fontsize=32) - - ax.legend(loc='upper left', fontsize=26, frameon=False) - + + ax.legend(loc="upper left", fontsize=26, frameon=False) + # Add frame for spine in ax.spines.values(): spine.set_visible(True) spine.set_linewidth(1.5) - spine.set_color('black') + spine.set_color("black") + def main(argv: Sequence[str] | None = None) -> None: paths = load_paths() - - parser = argparse.ArgumentParser(description="Generate publication combined UMAP figure (AstroPT, AION, AstroCLIP)") - parser.add_argument("--aion-embeddings", default=str(paths.embeddings / "aions_embeddings.pt"), help="AION .pt file") - parser.add_argument("--astropt-embeddings", default=str(paths.embeddings / "astropt_embeddings.pt"), help="AstroPT .pt file") - parser.add_argument("--astroclip-embeddings", default=str(paths.embeddings / "embeddings_astroclip.pt"), help="AstroCLIP .pt file") - parser.add_argument("--catalog", default=str(paths.dataset / "DESI_DR1_Euclid_Q1_dataset_catalog_EM.fits"), help="FITS catalog") - parser.add_argument("--index", default=str(paths.dataset_index), help="Index CSV mapping object_id -> split/index") - parser.add_argument("--coords-cache", default=str(paths.analysis / "umap" / "coords_cache.pt"), help="Path to pre-computed coords .pt file") + + parser = argparse.ArgumentParser( + description="Generate publication combined UMAP figure (AstroPT, AION, AstroCLIP)" + ) + parser.add_argument( + "--aion-embeddings", + default=str(paths.embeddings / "aions_embeddings.pt"), + help="AION .pt file", + ) + parser.add_argument( + "--astropt-embeddings", + default=str(paths.embeddings / "astropt_embeddings.pt"), + help="AstroPT .pt file", + ) + parser.add_argument( + "--astroclip-embeddings", + default=str(paths.embeddings / "embeddings_astroclip.pt"), + help="AstroCLIP .pt file", + ) + parser.add_argument( + "--catalog", + default=str(paths.dataset / "DESI_DR1_Euclid_Q1_dataset_catalog_EM.fits"), + help="FITS catalog", + ) + parser.add_argument( + "--index", + default=str(paths.dataset_index), + help="Index CSV mapping object_id -> split/index", + ) + parser.add_argument( + "--coords-cache", + default=str(paths.analysis / "umap" / "coords_cache.pt"), + help="Path to pre-computed coords .pt file", + ) parser.add_argument("--physical-param", default="Z", help="Parameter to color by") - parser.add_argument("--save", default=str(paths.analysis / "umap" / "combined_umap.png"), help="Output filename") + parser.add_argument( + "--save", + default=str(paths.analysis / "umap" / "combined_umap.png"), + help="Output filename", + ) parser.add_argument("--grid-rows", type=int, default=25, help="Grid rows") parser.add_argument("--grid-cols", type=int, default=25, help="Grid cols") parser.add_argument("--random-state", type=int, default=42, help="Random state") - parser.add_argument("--cache-dir", default=str(paths.dataset), help="Dataset cache dir") - parser.add_argument("--hexbin", action="store_true", default=True, help="Use hexbin plot instead of scatter (Default: True)") - parser.add_argument("--show-similarity", action="store_true", default=True, help="Add 3rd row with cosine similarity (Images vs Spectra)") # Default True now - parser.add_argument("--no-hexbin", action="store_false", dest="hexbin", help="Disable hexbin plot") - + parser.add_argument( + "--cache-dir", default=str(paths.dataset), help="Dataset cache dir" + ) + parser.add_argument( + "--hexbin", + action="store_true", + default=True, + help="Use hexbin plot instead of scatter (Default: True)", + ) + parser.add_argument( + "--show-similarity", + action="store_true", + default=True, + help="Add 3rd row with cosine similarity (Images vs Spectra)", + ) # Default True now + parser.add_argument( + "--no-hexbin", action="store_false", dest="hexbin", help="Disable hexbin plot" + ) + parser.set_defaults(hexbin=True) - + args = parser.parse_args(argv) # Helper for similarity @@ -421,9 +528,11 @@ def compute_cosine_similarity(records): spec = rec.get("embedding_spectra") # AION specific keys - if img is None: img = rec.get("embedding_hsc") - if spec is None: spec = rec.get("embedding_spectrum") - + if img is None: + img = rec.get("embedding_hsc") + if spec is None: + spec = rec.get("embedding_spectrum") + # If not found, try to look for keys ending in _images/_spectra/_spectrum if img is None: for k in rec.keys(): @@ -435,20 +544,24 @@ def compute_cosine_similarity(records): if k.endswith("_spectra") or k.endswith("_spectrum"): spec = rec[k] break - + if img is None or spec is None: sims.append(np.nan) continue - - if hasattr(img, "detach"): img = img.detach().cpu().numpy() - else: img = np.array(img) - - if hasattr(spec, "detach"): spec = spec.detach().cpu().numpy() - else: spec = np.array(spec) - + + if hasattr(img, "detach"): + img = img.detach().cpu().numpy() + else: + img = np.array(img) + + if hasattr(spec, "detach"): + spec = spec.detach().cpu().numpy() + else: + spec = np.array(spec) + img = img.flatten() spec = spec.flatten() - + ni = np.linalg.norm(img) ns = np.linalg.norm(spec) if ni == 0 or ns == 0: @@ -456,23 +569,23 @@ def compute_cosine_similarity(records): else: sims.append(np.dot(img, spec) / (ni * ns)) return np.array(sims) - + # 1. Load Data print("Loading embeddings...") aion_recs = load_embeddings(Path(args.aion_embeddings)) astropt_recs = load_embeddings(Path(args.astropt_embeddings)) astroclip_recs = load_embeddings(Path(args.astroclip_embeddings)) - + print("Loading catalog...") catalog, _ = load_fits_catalog(Path(args.catalog)) - + print("Loading coordinates cache...") coords_cache_path = Path(args.coords_cache) coords_map = load_coordinates(coords_cache_path) - + # helper to get or compute cache_dirty = False - + def get_or_compute(records, native_key, cache_key): nonlocal cache_dirty if cache_key in coords_map: @@ -488,19 +601,21 @@ def get_or_compute(records, native_key, cache_key): print("Checking AION coords...") coords_aion = get_or_compute(aion_recs, KEY_AION_MSG, KEY_AION_CACHE) - + print("Checking AstroPT coords...") coords_astropt = get_or_compute(astropt_recs, KEY_ASTROPT_MSG, KEY_ASTROPT_CACHE) - + print("Checking AstroCLIP coords...") - coords_astroclip = get_or_compute(astroclip_recs, KEY_ASTROCLIP_MSG, KEY_ASTROCLIP_CACHE) - + coords_astroclip = get_or_compute( + astroclip_recs, KEY_ASTROCLIP_MSG, KEY_ASTROCLIP_CACHE + ) + # Save cache if needed if cache_dirty: print(f"Saving updated cache to {coords_cache_path}...") coords_cache_path.parent.mkdir(parents=True, exist_ok=True) torch.save(coords_map, coords_cache_path) - + # Resolve physical parameter name (handle aliases like 'redshift' -> 'Z') param = args.physical_param if param not in catalog: @@ -510,15 +625,17 @@ def get_or_compute(records, native_key, cache_key): "z": "Z", "Redshift": "Z", "mass": "LOGM", - "sfr": "LOGSFR" + "sfr": "LOGSFR", } if param in aliases and aliases[param] in catalog: - print(f"Parameter '{param}' not found in catalog, using alias '{aliases[param]}'") + print( + f"Parameter '{param}' not found in catalog, using alias '{aliases[param]}'" + ) param = aliases[param] - + # 2. Extract IDs and Values for Coloring (Top Row) print(f"Extracting parameter '{param}'...") - + def get_ids(records, native_key): _, ids = stack_embeddings(records, native_key) return ids @@ -526,7 +643,7 @@ def get_ids(records, native_key): aion_ids_full = get_ids(aion_recs, KEY_AION_MSG) astropt_ids_full = get_ids(astropt_recs, KEY_ASTROPT_MSG) astroclip_ids_full = get_ids(astroclip_recs, KEY_ASTROCLIP_MSG) - + def get_values(ids): vals = [] matches = 0 @@ -535,8 +652,10 @@ def get_values(ids): if len(ids) > 0: print(f"DEBUG: Sample ID (embedding): '{ids[0]}' (type: {type(ids[0])})") cat_sample = next(iter(catalog.keys())) - print(f"DEBUG: Sample ID (catalog): '{cat_sample}' (type: {type(cat_sample)})") - + print( + f"DEBUG: Sample ID (catalog): '{cat_sample}' (type: {type(cat_sample)})" + ) + for oid in ids: # Ensure strict string matching oid_str = str(oid).strip() @@ -546,71 +665,90 @@ def get_values(ids): try: raw = catalog[oid_str].get(param) if raw is not None: - val = float(raw) if not hasattr(raw, 'item') else float(raw.item()) + val = ( + float(raw) + if not hasattr(raw, "item") + else float(raw.item()) + ) except: pass vals.append(val) samples += 1 - print(f"DEBUG: Found {matches}/{samples} matches in catalog for param '{param}'") + print( + f"DEBUG: Found {matches}/{samples} matches in catalog for param '{param}'" + ) return np.array(vals) - + values_aion = get_values(aion_ids_full) values_astro = get_values(astropt_ids_full) values_clip = get_values(astroclip_ids_full) - + # Determine color limits - all_values = np.concatenate([ - values_astro[~np.isnan(values_astro)], - values_aion[~np.isnan(values_aion)], - values_clip[~np.isnan(values_clip)] - ]) - + all_values = np.concatenate( + [ + values_astro[~np.isnan(values_astro)], + values_aion[~np.isnan(values_aion)], + values_clip[~np.isnan(values_clip)], + ] + ) + if len(all_values) == 0: vmin, vmax = 0, 1 else: vmin = np.percentile(all_values, 2) vmax = np.percentile(all_values, 98) - if vmax - vmin < 1e-6: vmin, vmax = all_values.min(), all_values.max() - + if vmax - vmin < 1e-6: + vmin, vmax = all_values.min(), all_values.max() + print(f"Color scale: [{vmin:.3f}, {vmax:.3f}]") - + # 2b. Compute Similarity if requested val_sim_aion = None val_sim_astro = None val_sim_clip = None vmin_sim, vmax_sim = 0, 1 - + if args.show_similarity: print("Computing cosine similarities...") val_sim_aion = compute_cosine_similarity(aion_recs) val_sim_astro = compute_cosine_similarity(astropt_recs) val_sim_clip = compute_cosine_similarity(astroclip_recs) - - all_sims = np.concatenate([ - val_sim_aion[~np.isnan(val_sim_aion)], - val_sim_astro[~np.isnan(val_sim_astro)], - val_sim_clip[~np.isnan(val_sim_clip)] - ]) - + + all_sims = np.concatenate( + [ + val_sim_aion[~np.isnan(val_sim_aion)], + val_sim_astro[~np.isnan(val_sim_astro)], + val_sim_clip[~np.isnan(val_sim_clip)], + ] + ) + if len(all_sims) > 0: vmin_sim = np.percentile(all_sims, 2) vmax_sim = np.percentile(all_sims, 98) print(f"Similarity Color scale: [{vmin_sim:.3f}, {vmax_sim:.3f}]") else: print("Warning: No valid similarities found.") - + # 3. Process Grid Assignments (Bottom Row) print("Assigning to grid...") thumbs_aion, cells_aion = assign_to_grid( aion_ids_full, coords_aion, args.grid_rows, args.grid_cols, args.random_state ) thumbs_astro, cells_astro = assign_to_grid( - astropt_ids_full, coords_astropt, args.grid_rows, args.grid_cols, args.random_state + astropt_ids_full, + coords_astropt, + args.grid_rows, + args.grid_cols, + args.random_state, ) thumbs_clip, cells_clip = assign_to_grid( - astroclip_ids_full, coords_astroclip, args.grid_rows, args.grid_cols, args.random_state + astroclip_ids_full, + coords_astroclip, + args.grid_rows, + args.grid_cols, + args.random_state, ) - + # 4. Fetch Images print("Fetching image samples...") all_thumb_ids = list(set(thumbs_aion) | set(thumbs_astro) | set(thumbs_clip)) @@ -621,122 +759,154 @@ def get_values(ids): index_map=index_map, verbose=True, ) - + retrieved_ids = {str(s.get("object_id")) for s in samples} missing = [oid for oid in all_thumb_ids if oid not in retrieved_ids] if missing: print(f"Fetching {len(missing)} missing samples...") dataset = EuclidDESIDataset(split="all", cache_dir=args.cache_dir, verbose=True) samples.extend(collect_samples(dataset, missing, verbose=True)) - + # 5. Plot print("Plotting...") rows = 3 if args.show_similarity else 2 fig_height = 18 if args.show_similarity else 12 fig, axes = plt.subplots(rows, 3, figsize=(24, fig_height)) - + # Top Row: Scatter (Physical Param) sc = plot_scatter_panel( - axes[0, 0], coords_astropt, values_astro, - r"\textbf{AstroPT} (Spectra + Images)", vmin, vmax, - use_hexbin=args.hexbin + axes[0, 0], + coords_astropt, + values_astro, + r"\textbf{AstroPT} (Spectra + Images)", + vmin, + vmax, + use_hexbin=args.hexbin, ) plot_scatter_panel( - axes[0, 1], coords_aion, values_aion, - r"\textbf{AION} (Spectra + Images)", vmin, vmax, - use_hexbin=args.hexbin + axes[0, 1], + coords_aion, + values_aion, + r"\textbf{AION} (Spectra + Images)", + vmin, + vmax, + use_hexbin=args.hexbin, ) plot_scatter_panel( - axes[0, 2], coords_astroclip, values_clip, - r"\textbf{AstroCLIP} (Spectra + Images)", vmin, vmax, - use_hexbin=args.hexbin + axes[0, 2], + coords_astroclip, + values_clip, + r"\textbf{AstroCLIP} (Spectra + Images)", + vmin, + vmax, + use_hexbin=args.hexbin, ) - - # Switch rows if similarity is added? - # Standard: Row 1=Param, Row 2=Thumbnails. + + # Switch rows if similarity is added? + # Standard: Row 1=Param, Row 2=Thumbnails. # Requested: Add 3rd line. # Let's put Thumbnails in Row 2 (index 1), Similarity in Row 3 (index 2). # Or Similarity in Row 2, Thumbnails in Row 3? # Keeping Thumbnails at bottom (last row) is often cleanest. - # But user asked to "add a 3rd line". + # But user asked to "add a 3rd line". # I will put Similarity in the middle (Row 2), Thumbnails at bottom (Row 3). # Wait, existing code puts Thumbnails in Row 2 (axes[1, ...]). # If I add similarity, I will put it as axes[2, ...] (Row 3). # This matches "Add a 3rd line". - - thumb_row_idx = 1 - sim_row_idx = 2 - + # ADD ROW TITLES # We place text relative to first axes of each row # Row 1: Latent Space - axes[0, 0].text(-0.15, 0.5, "Latent Space\n(Colored by Param)", transform=axes[0, 0].transAxes, - rotation=90, va='center', ha='right', fontsize=28, fontweight='bold') + axes[0, 0].text( + -0.15, + 0.5, + "Latent Space\n(Colored by Param)", + transform=axes[0, 0].transAxes, + rotation=90, + va="center", + ha="right", + fontsize=28, + fontweight="bold", + ) # Row 2: Thumbnails - axes[1, 0].text(-0.15, 0.5, "Representative\nThumbnails", transform=axes[1, 0].transAxes, - rotation=90, va='center', ha='right', fontsize=28, fontweight='bold') - + axes[1, 0].text( + -0.15, + 0.5, + "Representative\nThumbnails", + transform=axes[1, 0].transAxes, + rotation=90, + va="center", + ha="right", + fontsize=28, + fontweight="bold", + ) + # Middle Row: Thumbnails (Original Row 2) # If we want Similarity in the middle, we'd change indices. # Let's append Similarity at the bottom to follow "Add 3rd line" literally. - + plot_thumbnail_panel( - axes[1, 0], thumbs_astro, cells_astro, samples, - "", args.grid_rows, args.grid_cols + axes[1, 0], + thumbs_astro, + cells_astro, + samples, + "", + args.grid_rows, + args.grid_cols, ) plot_thumbnail_panel( - axes[1, 1], thumbs_aion, cells_aion, samples, - "", args.grid_rows, args.grid_cols + axes[1, 1], thumbs_aion, cells_aion, samples, "", args.grid_rows, args.grid_cols ) plot_thumbnail_panel( - axes[1, 2], thumbs_clip, cells_clip, samples, - "", args.grid_rows, args.grid_cols + axes[1, 2], thumbs_clip, cells_clip, samples, "", args.grid_rows, args.grid_cols ) - - sc_sim = None + if args.show_similarity: # Row 3: Similarity - axes[2, 0].text(-0.15, 0.5, "Cosine Similarity\n(Images vs Spectra)", transform=axes[2, 0].transAxes, - rotation=90, va='center', ha='right', fontsize=28, fontweight='bold') - + axes[2, 0].text( + -0.15, + 0.5, + "Cosine Similarity\n(Images vs Spectra)", + transform=axes[2, 0].transAxes, + rotation=90, + va="center", + ha="right", + fontsize=28, + fontweight="bold", + ) + # Bottom Row: Similarity Histograms # Use different colors for each model if desired, or uniform plot_similarity_histogram( - axes[2, 0], val_sim_astro, - "", color="C0" # Removed title to avoid overlap - ) - plot_similarity_histogram( - axes[2, 1], val_sim_aion, - "", color="C1" - ) - plot_similarity_histogram( - axes[2, 2], val_sim_clip, - "", color="C2" + axes[2, 0], val_sim_astro, "", color="C0" # Removed title to avoid overlap ) + plot_similarity_histogram(axes[2, 1], val_sim_aion, "", color="C1") + plot_similarity_histogram(axes[2, 2], val_sim_clip, "", color="C2") # Colorbars - fig.subplots_adjust(left=0.1, right=0.9, wspace=0.1, hspace=0.2) # Increased hspace for titles, added left margin - + fig.subplots_adjust( + left=0.1, right=0.9, wspace=0.1, hspace=0.2 + ) # Increased hspace for titles, added left margin + # 1. Colorbar for Param if sc: pos_top_right = axes[0, 2].get_position() cbar_ax = fig.add_axes([0.92, pos_top_right.ymin, 0.015, pos_top_right.height]) - + label = param if label == "Z" or label.lower() == "redshift": label = r"Redshift $z$" - + cbar = fig.colorbar(sc, cax=cbar_ax) cbar.set_label(label, fontsize=24) cbar.ax.tick_params(labelsize=18) cbar.solids.set_edgecolor("face") - - Path(args.save).parent.mkdir(parents=True, exist_ok=True) fig.savefig(args.save, dpi=600, bbox_inches="tight") print(f"Saved figure to {args.save}") plt.close(fig) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/displacement/plot_paper_displacement.py b/src/fmb/viz/displacement/plot_paper_displacement.py index f01b1b3..9ecacda 100644 --- a/src/fmb/viz/displacement/plot_paper_displacement.py +++ b/src/fmb/viz/displacement/plot_paper_displacement.py @@ -7,160 +7,209 @@ import argparse from pathlib import Path -from typing import Sequence, Optional, Dict, List +from typing import Sequence import matplotlib.pyplot as plt -import numpy as np import pandas as pd -import matplotlib # --- Publication Style Settings --- try: - plt.rcParams.update({ - "text.usetex": True, - "font.family": "serif", - "font.serif": ["Computer Modern Roman"], - }) + plt.rcParams.update( + { + "text.usetex": True, + "font.family": "serif", + "font.serif": ["Computer Modern Roman"], + } + ) except Exception: print("Warning: LaTeX not available, falling back to STIX fonts.") - plt.rcParams.update({ - "text.usetex": False, - "mathtext.fontset": "stix", - "font.family": "STIXGeneral", - }) - -plt.rcParams.update({ - "font.size": 10, - "axes.labelsize": 10, - "axes.titlesize": 11, - "xtick.labelsize": 8, - "ytick.labelsize": 8, - "legend.fontsize": 8, - "figure.titlesize": 14, - "axes.linewidth": 1.0, - "xtick.major.width": 1.0, - "ytick.major.width": 1.0, - "xtick.minor.width": 0.8, - "ytick.minor.width": 0.8, - "xtick.direction": "in", - "ytick.direction": "in", - "lines.linewidth": 1.0, - "savefig.bbox": "tight", - "savefig.pad_inches": 0.05, -}) + plt.rcParams.update( + { + "text.usetex": False, + "mathtext.fontset": "stix", + "font.family": "STIXGeneral", + } + ) + +plt.rcParams.update( + { + "font.size": 10, + "axes.labelsize": 10, + "axes.titlesize": 11, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "legend.fontsize": 8, + "figure.titlesize": 14, + "axes.linewidth": 1.0, + "xtick.major.width": 1.0, + "ytick.major.width": 1.0, + "xtick.minor.width": 0.8, + "ytick.minor.width": 0.8, + "xtick.direction": "in", + "ytick.direction": "in", + "lines.linewidth": 1.0, + "savefig.bbox": "tight", + "savefig.pad_inches": 0.05, + } +) # Mapping of embedding types to keys for each model KEYS = { "AION": { "Images": "embedding_hsc", "Spectra": "embedding_spectrum", - "Joint": "embedding_hsc_desi" + "Joint": "embedding_hsc_desi", }, "AstroPT": { "Images": "embedding_images", "Spectra": "embedding_spectra", - "Joint": "embedding_joint" + "Joint": "embedding_joint", }, "AstroCLIP": { "Images": "embedding_images", "Spectra": "embedding_spectra", - "Joint": "embedding_joint" - } + "Joint": "embedding_joint", + }, } COLORS = { - "AION": "#1f77b4", # Blue + "AION": "#1f77b4", # Blue "AstroPT": "#ff7f0e", # Orange - "AstroCLIP": "#2ca02c" # Green + "AstroCLIP": "#2ca02c", # Green } + def load_data(path: Path) -> pd.DataFrame: """Load anomaly scores from CSV.""" df = pd.read_csv(path) - df['object_id'] = df['object_id'].astype(str) + df["object_id"] = df["object_id"].astype(str) return df -def plot_panel(ax, source_name, target_name, source_ranks, target_ranks, n_total, color): + +def plot_panel( + ax, source_name, target_name, source_ranks, target_ranks, n_total, color +): """Plot a single displacement panel.""" # Normalize to percentile (0-100) top_1_threshold = n_total * 0.01 subset_mask = source_ranks <= top_1_threshold - + if subset_mask.sum() == 0: - ax.text(0.5, 0.5, "No data", transform=ax.transAxes, ha='center') + ax.text(0.5, 0.5, "No data", transform=ax.transAxes, ha="center") return ranks_pct = (target_ranks[subset_mask] / n_total) * 100 - + # Plot Histogram - ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, edgecolor='black', alpha=0.7, linewidth=0.5, zorder=3) - + ax.hist( + ranks_pct, + bins=40, + range=(0, 100), + color=color, + edgecolor="black", + alpha=0.7, + linewidth=0.5, + zorder=3, + ) + # Add markers for Top 1% and Top 10% - ax.axvline(x=1, color='red', linestyle='--', linewidth=1, zorder=4) - ax.axvline(x=10, color='orange', linestyle='--', linewidth=1, zorder=4) - + ax.axvline(x=1, color="red", linestyle="--", linewidth=1, zorder=4) + ax.axvline(x=10, color="orange", linestyle="--", linewidth=1, zorder=4) + # Shade regions - ax.axvspan(0, 1, color='red', alpha=0.1, zorder=1) - ax.axvspan(1, 10, color='orange', alpha=0.1, zorder=1) - + ax.axvspan(0, 1, color="red", alpha=0.1, zorder=1) + ax.axvspan(1, 10, color="orange", alpha=0.1, zorder=1) + # Calculate stats - retained_1 = (target_ranks[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 - retained_10 = (target_ranks[subset_mask] <= (n_total * 0.1)).sum() / subset_mask.sum() * 100 - + retained_1 = ( + (target_ranks[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 + ) + retained_10 = ( + (target_ranks[subset_mask] <= (n_total * 0.1)).sum() / subset_mask.sum() * 100 + ) + # Add text box stats_text = ( f"\\textbf{{Retained}}:\n" f"Top 1\\%: {retained_1:.1f}\\%\n" f"Top 10\\%: {retained_10:.1f}\\%" ) - ax.text(0.95, 0.92, stats_text, transform=ax.transAxes, ha='right', va='top', - bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'), zorder=5, fontsize=7) - - ax.set_title(rf"\textbf{{{source_name}}} $\rightarrow$ \textbf{{{target_name}}}", pad=5) - ax.grid(True, linestyle=':', alpha=0.4, zorder=0) + ax.text( + 0.95, + 0.92, + stats_text, + transform=ax.transAxes, + ha="right", + va="top", + bbox=dict( + facecolor="white", alpha=0.8, edgecolor="none", boxstyle="round,pad=0.2" + ), + zorder=5, + fontsize=7, + ) + + ax.set_title( + rf"\textbf{{{source_name}}} $\rightarrow$ \textbf{{{target_name}}}", pad=5 + ) + ax.grid(True, linestyle=":", alpha=0.4, zorder=0) + def main(argv: Sequence[str] | None = None) -> None: - parser = argparse.ArgumentParser(description="Generate multi-model/multi-type displacement analysis grid") + parser = argparse.ArgumentParser( + description="Generate multi-model/multi-type displacement analysis grid" + ) parser.add_argument("--aion-scores", required=True, help="AION scores CSV") parser.add_argument("--astropt-scores", required=True, help="AstroPT scores CSV") - parser.add_argument("--astroclip-scores", required=True, help="AstroCLIP scores CSV") - parser.add_argument("--save", default="paper_displacement_grid.png", help="Output filename") - + parser.add_argument( + "--astroclip-scores", required=True, help="AstroCLIP scores CSV" + ) + parser.add_argument( + "--save", default="paper_displacement_grid.png", help="Output filename" + ) + args = parser.parse_args(argv) - + # Load all datasets print("Loading scores...") dfs = { "AION": load_data(Path(args.aion_scores)), "AstroPT": load_data(Path(args.astropt_scores)), - "AstroCLIP": load_data(Path(args.astroclip_scores)) + "AstroCLIP": load_data(Path(args.astroclip_scores)), } - + embedding_types = ["Images", "Spectra", "Joint"] - comparisons = [ - ("AION", "AstroPT"), - ("AstroPT", "AstroCLIP"), - ("AstroCLIP", "AION") - ] - - fig, axes = plt.subplots(3, 3, figsize=(10, 9), sharex=True, sharey='row') - + comparisons = [("AION", "AstroPT"), ("AstroPT", "AstroCLIP"), ("AstroCLIP", "AION")] + + fig, axes = plt.subplots(3, 3, figsize=(10, 9), sharex=True, sharey="row") + for row_idx, emb_type in enumerate(embedding_types): print(f"Processing {emb_type}...") - + # Extract relevant subset for each model subsets = {} for model_name, df in dfs.items(): key = KEYS[model_name][emb_type] - subsets[model_name] = df[df['embedding_key'] == key].copy() + subsets[model_name] = df[df["embedding_key"] == key].copy() if len(subsets[model_name]) == 0: print(f"Warning: No data for {model_name} with key {key}") # Common merge to ensure we compare the same objects - merged = subsets["AION"][['object_id', 'rank']].rename(columns={'rank': 'rank_aion'}) - merged = merged.merge(subsets["AstroPT"][['object_id', 'rank']].rename(columns={'rank': 'rank_astropt'}), on='object_id') - merged = merged.merge(subsets["AstroCLIP"][['object_id', 'rank']].rename(columns={'rank': 'rank_astroclip'}), on='object_id') - + merged = subsets["AION"][["object_id", "rank"]].rename( + columns={"rank": "rank_aion"} + ) + merged = merged.merge( + subsets["AstroPT"][["object_id", "rank"]].rename( + columns={"rank": "rank_astropt"} + ), + on="object_id", + ) + merged = merged.merge( + subsets["AstroCLIP"][["object_id", "rank"]].rename( + columns={"rank": "rank_astroclip"} + ), + on="object_id", + ) + n_total = len(merged) print(f" Matched {n_total} objects for {emb_type}") @@ -168,15 +217,17 @@ def main(argv: Sequence[str] | None = None) -> None: ax = axes[row_idx, col_idx] src_col = f"rank_{src.lower()}" tgt_col = f"rank_{tgt.lower()}" - + plot_panel( - ax, src, tgt, - merged[src_col].values, - merged[tgt_col].values, + ax, + src, + tgt, + merged[src_col].values, + merged[tgt_col].values, n_total, - COLORS[src] + COLORS[src], ) - + # Labels if col_idx == 0: ax.set_ylabel(rf"\textbf{{{emb_type}}}" + "\nCount") @@ -184,12 +235,13 @@ def main(argv: Sequence[str] | None = None) -> None: ax.set_xlabel("Percentile Rank in Target") plt.tight_layout() - + save_path = Path(args.save) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches="tight") print(f"Saved figure to {save_path}") plt.close(fig) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/displacement/plot_paper_displacement_cross.py b/src/fmb/viz/displacement/plot_paper_displacement_cross.py index 604f6e6..bdb96ad 100644 --- a/src/fmb/viz/displacement/plot_paper_displacement_cross.py +++ b/src/fmb/viz/displacement/plot_paper_displacement_cross.py @@ -7,185 +7,224 @@ import argparse from pathlib import Path -from typing import Sequence, Optional, Dict, List +from typing import Sequence import matplotlib.pyplot as plt -import numpy as np import pandas as pd -import matplotlib # --- Publication Style Settings --- try: - plt.rcParams.update({ - "text.usetex": True, - "font.family": "serif", - "font.serif": ["Computer Modern Roman"], - }) + plt.rcParams.update( + { + "text.usetex": True, + "font.family": "serif", + "font.serif": ["Computer Modern Roman"], + } + ) except Exception: print("Warning: LaTeX not available, falling back to STIX fonts.") - plt.rcParams.update({ - "text.usetex": False, - "mathtext.fontset": "stix", - "font.family": "STIXGeneral", - }) - -plt.rcParams.update({ - "font.size": 10, - "axes.labelsize": 10, - "axes.titlesize": 11, - "xtick.labelsize": 8, - "ytick.labelsize": 8, - "legend.fontsize": 7, - "figure.titlesize": 14, - "axes.linewidth": 1.0, - "xtick.major.width": 1.0, - "ytick.major.width": 1.0, - "xtick.minor.width": 0.8, - "ytick.minor.width": 0.8, - "xtick.direction": "in", - "ytick.direction": "in", - "lines.linewidth": 1.0, - "savefig.bbox": "tight", - "savefig.pad_inches": 0.05, -}) + plt.rcParams.update( + { + "text.usetex": False, + "mathtext.fontset": "stix", + "font.family": "STIXGeneral", + } + ) + +plt.rcParams.update( + { + "font.size": 10, + "axes.labelsize": 10, + "axes.titlesize": 11, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "legend.fontsize": 7, + "figure.titlesize": 14, + "axes.linewidth": 1.0, + "xtick.major.width": 1.0, + "ytick.major.width": 1.0, + "xtick.minor.width": 0.8, + "ytick.minor.width": 0.8, + "xtick.direction": "in", + "ytick.direction": "in", + "lines.linewidth": 1.0, + "savefig.bbox": "tight", + "savefig.pad_inches": 0.05, + } +) # Mapping of embedding types to keys for each model KEYS = { "AION": { "Images": "embedding_hsc", "Spectra": "embedding_spectrum", - "Joint": "embedding_hsc_desi" + "Joint": "embedding_hsc_desi", }, "AstroPT": { "Images": "embedding_images", "Spectra": "embedding_spectra", - "Joint": "embedding_joint" + "Joint": "embedding_joint", }, "AstroCLIP": { "Images": "embedding_images", "Spectra": "embedding_spectra", - "Joint": "embedding_joint" - } + "Joint": "embedding_joint", + }, } COLORS = { - "AION": "#1f77b4", # Blue + "AION": "#1f77b4", # Blue "AstroPT": "#ff7f0e", # Orange - "AstroCLIP": "#2ca02c" # Green + "AstroCLIP": "#2ca02c", # Green } + def load_data(path: Path) -> pd.DataFrame: """Load anomaly scores from CSV.""" df = pd.read_csv(path) - df['object_id'] = df['object_id'].astype(str) + df["object_id"] = df["object_id"].astype(str) return df + def plot_cross_panel(ax, src_mod, tgt_mod, model_data_list, n_total): """Plot overlaid displacement histograms for multiple models in one panel.""" - + top_1_threshold = n_total * 0.01 stats_lines = [] - + for model_name, ranks_src, ranks_tgt in model_data_list: color = COLORS[model_name] subset_mask = ranks_src <= top_1_threshold - + if subset_mask.sum() == 0: continue ranks_pct = (ranks_tgt[subset_mask] / n_total) * 100 - + # Plot Histogram (step for better overlay) - ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, histtype='step', alpha=0.9, linewidth=1.2, zorder=3, label=model_name) + ax.hist( + ranks_pct, + bins=40, + range=(0, 100), + color=color, + histtype="step", + alpha=0.9, + linewidth=1.2, + zorder=3, + label=model_name, + ) # Optional: filled area with low alpha ax.hist(ranks_pct, bins=40, range=(0, 100), color=color, alpha=0.1, zorder=2) - + # Calculate stats - retained_1 = (ranks_tgt[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 + retained_1 = ( + (ranks_tgt[subset_mask] <= top_1_threshold).sum() / subset_mask.sum() * 100 + ) stats_lines.append(rf"{model_name}: {retained_1:.1f}\%") # Add markers for Top 1% and Top 10% - ax.axvline(x=1, color='red', linestyle='--', linewidth=0.8, alpha=0.5, zorder=4) - ax.axvspan(0, 1, color='red', alpha=0.05, zorder=1) - + ax.axvline(x=1, color="red", linestyle="--", linewidth=0.8, alpha=0.5, zorder=4) + ax.axvspan(0, 1, color="red", alpha=0.05, zorder=1) + # Add text box stats_text = "\\textbf{Retained Top 1\\%}:\n" + "\n".join(stats_lines) - ax.text(0.95, 0.92, stats_text, transform=ax.transAxes, ha='right', va='top', - bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', boxstyle='round,pad=0.2'), zorder=5, fontsize=6) - + ax.text( + 0.95, + 0.92, + stats_text, + transform=ax.transAxes, + ha="right", + va="top", + bbox=dict( + facecolor="white", alpha=0.8, edgecolor="none", boxstyle="round,pad=0.2" + ), + zorder=5, + fontsize=6, + ) + ax.set_title(rf"\textbf{{{src_mod}}} $\rightarrow$ \textbf{{{tgt_mod}}}", pad=5) - ax.grid(True, linestyle=':', alpha=0.4, zorder=0) + ax.grid(True, linestyle=":", alpha=0.4, zorder=0) + def main(argv: Sequence[str] | None = None) -> None: - parser = argparse.ArgumentParser(description="Generate cross-modality displacement analysis grid") + parser = argparse.ArgumentParser( + description="Generate cross-modality displacement analysis grid" + ) parser.add_argument("--aion-scores", required=True, help="AION scores CSV") parser.add_argument("--astropt-scores", required=True, help="AstroPT scores CSV") - parser.add_argument("--astroclip-scores", required=True, help="AstroCLIP scores CSV") - parser.add_argument("--save", default="paper_displacement_cross.png", help="Output filename") - + parser.add_argument( + "--astroclip-scores", required=True, help="AstroCLIP scores CSV" + ) + parser.add_argument( + "--save", default="paper_displacement_cross.png", help="Output filename" + ) + args = parser.parse_args(argv) - + # Load all datasets print("Loading scores...") dfs = { "AION": load_data(Path(args.aion_scores)), "AstroPT": load_data(Path(args.astropt_scores)), - "AstroCLIP": load_data(Path(args.astroclip_scores)) + "AstroCLIP": load_data(Path(args.astroclip_scores)), } - + modalities = ["Images", "Spectra", "Joint"] - - fig, axes = plt.subplots(3, 3, figsize=(11, 10), sharex=True, sharey='row') - + + fig, axes = plt.subplots(3, 3, figsize=(11, 10), sharex=True, sharey="row") + # We need to compute matched objects for *each* model across all *its* modalities # Actually, to keep it consistent across models, lets merge everything into one big DF print("Merging data for consistency...") big_merged = None - + for model_name, df in dfs.items(): model_subset = None for mod in modalities: key = KEYS[model_name][mod] - mod_data = df[df['embedding_key'] == key][['object_id', 'rank']].rename(columns={'rank': f'rank_{model_name}_{mod}'}) + mod_data = df[df["embedding_key"] == key][["object_id", "rank"]].rename( + columns={"rank": f"rank_{model_name}_{mod}"} + ) if model_subset is None: model_subset = mod_data else: - model_subset = model_subset.merge(mod_data, on='object_id') - + model_subset = model_subset.merge(mod_data, on="object_id") + if big_merged is None: big_merged = model_subset else: - big_merged = big_merged.merge(model_subset, on='object_id') - + big_merged = big_merged.merge(model_subset, on="object_id") + n_total = len(big_merged) print(f"Matched {n_total} objects across all models and modalities.") for row_idx, src_mod in enumerate(modalities): for col_idx, tgt_mod in enumerate(modalities): ax = axes[row_idx, col_idx] - + model_data_list = [] for model_name in ["AION", "AstroPT", "AstroCLIP"]: - ranks_src = big_merged[f'rank_{model_name}_{src_mod}'].values - ranks_tgt = big_merged[f'rank_{model_name}_{tgt_mod}'].values + ranks_src = big_merged[f"rank_{model_name}_{src_mod}"].values + ranks_tgt = big_merged[f"rank_{model_name}_{tgt_mod}"].values model_data_list.append((model_name, ranks_src, ranks_tgt)) - + plot_cross_panel(ax, src_mod, tgt_mod, model_data_list, n_total) - + if col_idx == 0: ax.set_ylabel(rf"\textbf{{{src_mod}}}" + "\nCount") if row_idx == 2: ax.set_xlabel(rf"Percentile Rank in \textbf{{{tgt_mod}}}") if row_idx == 0 and col_idx == 0: - ax.legend(loc='lower right', fontsize=6) + ax.legend(loc="lower right", fontsize=6) plt.tight_layout() - + save_path = Path(args.save) save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches="tight") print(f"Saved figure to {save_path}") plt.close(fig) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/displacement/plot_paper_displacement_extensive.py b/src/fmb/viz/displacement/plot_paper_displacement_extensive.py index d1ea95b..4ca6d13 100644 --- a/src/fmb/viz/displacement/plot_paper_displacement_extensive.py +++ b/src/fmb/viz/displacement/plot_paper_displacement_extensive.py @@ -7,7 +7,8 @@ import argparse from pathlib import Path -from typing import Sequence, Optional, Dict, List +from typing import Sequence + import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -15,51 +16,82 @@ # --- Publication Style Settings --- try: - plt.rcParams.update({ - "text.usetex": True, - "font.family": "serif", - "font.serif": ["Computer Modern Roman"], - }) + plt.rcParams.update( + { + "text.usetex": True, + "font.family": "serif", + "font.serif": ["Computer Modern Roman"], + } + ) except Exception: - plt.rcParams.update({ - "text.usetex": False, - "mathtext.fontset": "stix", - "font.family": "STIXGeneral", - }) - -plt.rcParams.update({ - "font.size": 9, - "axes.labelsize": 9, - "axes.titlesize": 10, - "xtick.labelsize": 7, - "ytick.labelsize": 7, - "savefig.bbox": "tight", -}) + plt.rcParams.update( + { + "text.usetex": False, + "mathtext.fontset": "stix", + "font.family": "STIXGeneral", + } + ) + +plt.rcParams.update( + { + "font.size": 9, + "axes.labelsize": 9, + "axes.titlesize": 10, + "xtick.labelsize": 7, + "ytick.labelsize": 7, + "savefig.bbox": "tight", + } +) KEYS = { - "AION": {"Images": "embedding_hsc", "Spectra": "embedding_spectrum", "Joint": "embedding_hsc_desi"}, - "AstroPT": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"}, - "AstroCLIP": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"} + "AION": { + "Images": "embedding_hsc", + "Spectra": "embedding_spectrum", + "Joint": "embedding_hsc_desi", + }, + "AstroPT": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, + "AstroCLIP": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, } COLORS = {"AION": "#1f77b4", "AstroPT": "#ff7f0e", "AstroCLIP": "#2ca02c"} + def load_data(path: Path) -> pd.DataFrame: df = pd.read_csv(path) - df['object_id'] = df['object_id'].astype(str) + df["object_id"] = df["object_id"].astype(str) return df + def main(argv: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser() parser.add_argument("--aion-scores", required=True) parser.add_argument("--astropt-scores", required=True) parser.add_argument("--astroclip-scores", required=True) - parser.add_argument("--save-prefix", default="paper/Final_results/paper_displacement_extensive") + parser.add_argument( + "--save-prefix", default="paper/Final_results/paper_displacement_extensive" + ) args = parser.parse_args(argv) print("Loading scores...") - dfs = {m: load_data(Path(args.aion_scores if m == "AION" else (args.astropt_scores if m == "AstroPT" else args.astroclip_scores))) for m in KEYS} - + dfs = { + m: load_data( + Path( + args.aion_scores + if m == "AION" + else (args.astropt_scores if m == "AstroPT" else args.astroclip_scores) + ) + ) + for m in KEYS + } + modalities = ["Images", "Spectra", "Joint"] systems = [] for model in ["AION", "AstroPT", "AstroCLIP"]: @@ -71,12 +103,14 @@ def main(argv: Sequence[str] | None = None) -> None: for model, mod in systems: key = KEYS[model][mod] df = dfs[model] - mod_data = df[df['embedding_key'] == key][['object_id', 'rank']].rename(columns={'rank': f'rank_{model}_{mod}'}) + mod_data = df[df["embedding_key"] == key][["object_id", "rank"]].rename( + columns={"rank": f"rank_{model}_{mod}"} + ) if big_merged is None: big_merged = mod_data else: - big_merged = big_merged.merge(mod_data, on='object_id') - + big_merged = big_merged.merge(mod_data, on="object_id") + n_total = len(big_merged) print(f"Matched {n_total} objects.") @@ -85,17 +119,27 @@ def main(argv: Sequence[str] | None = None) -> None: # --- 1. Generate Heatmap --- for i, (src_model, src_mod) in enumerate(systems): - src_col = f'rank_{src_model}_{src_mod}' + src_col = f"rank_{src_model}_{src_mod}" mask = big_merged[src_col] <= top_1_thr for j, (tgt_model, tgt_mod) in enumerate(systems): - tgt_col = f'rank_{tgt_model}_{tgt_mod}' - retained = (big_merged.loc[mask, tgt_col] <= top_1_thr).sum() / mask.sum() * 100 + tgt_col = f"rank_{tgt_model}_{tgt_mod}" + retained = ( + (big_merged.loc[mask, tgt_col] <= top_1_thr).sum() / mask.sum() * 100 + ) retention_matrix[i, j] = retained labels = [f"{m}-{mo}" for m, mo in systems] - + plt.figure(figsize=(10, 8)) - sns.heatmap(retention_matrix, annot=True, fmt=".1f", xticklabels=labels, yticklabels=labels, cmap="YlGnBu", cbar_kws={'label': 'Retention Top 1\% (\%)'}) + sns.heatmap( + retention_matrix, + annot=True, + fmt=".1f", + xticklabels=labels, + yticklabels=labels, + cmap="YlGnBu", + cbar_kws={"label": "Retention Top 1\% (\%)"}, + ) plt.title("Extensive Anomaly Retention Matrix (Top 1\% $\\rightarrow$ Top 1\%)") plt.xlabel("Target System") plt.ylabel("Source System") @@ -104,32 +148,52 @@ def main(argv: Sequence[str] | None = None) -> None: print(f"Saved heatmap to {args.save_prefix}_heatmap.png") # --- 2. Generate 9x9 Histogram Grid --- - fig, axes = plt.subplots(9, 9, figsize=(18, 18), sharex=True, sharey='row') + fig, axes = plt.subplots(9, 9, figsize=(18, 18), sharex=True, sharey="row") for i, (src_model, src_mod) in enumerate(systems): - src_col = f'rank_{src_model}_{src_mod}' + src_col = f"rank_{src_model}_{src_mod}" mask = big_merged[src_col] <= top_1_thr color = COLORS[src_model] - + for j, (tgt_model, tgt_mod) in enumerate(systems): ax = axes[i, j] - tgt_col = f'rank_{tgt_model}_{tgt_mod}' + tgt_col = f"rank_{tgt_model}_{tgt_mod}" ranks_pct = (big_merged.loc[mask, tgt_col] / n_total) * 100 - - ax.hist(ranks_pct, bins=30, range=(0, 100), color=color, alpha=0.7, edgecolor='none') - ax.axvline(x=1, color='red', linestyle='--', linewidth=0.5, alpha=0.5) - - if i == 0: ax.set_title(f"To: {tgt_model}\n{tgt_mod}", fontsize=8) - if j == 0: ax.set_ylabel(f"From: {src_model}\n{src_mod}\nCount", fontsize=8) - if i == 8: ax.set_xlabel("Rank \%", fontsize=8) - + + ax.hist( + ranks_pct, + bins=30, + range=(0, 100), + color=color, + alpha=0.7, + edgecolor="none", + ) + ax.axvline(x=1, color="red", linestyle="--", linewidth=0.5, alpha=0.5) + + if i == 0: + ax.set_title(f"To: {tgt_model}\n{tgt_mod}", fontsize=8) + if j == 0: + ax.set_ylabel(f"From: {src_model}\n{src_mod}\nCount", fontsize=8) + if i == 8: + ax.set_xlabel("Rank \%", fontsize=8) + # Retention text - ax.text(0.95, 0.95, f"{retention_matrix[i, j]:.1f}\%", transform=ax.transAxes, ha='right', va='top', fontsize=7, fontweight='bold') - ax.grid(True, linestyle=':', alpha=0.3) + ax.text( + 0.95, + 0.95, + f"{retention_matrix[i, j]:.1f}\%", + transform=ax.transAxes, + ha="right", + va="top", + fontsize=7, + fontweight="bold", + ) + ax.grid(True, linestyle=":", alpha=0.3) plt.tight_layout() plt.savefig(f"{args.save_prefix}_grid.png", dpi=200) print(f"Saved grid to {args.save_prefix}_grid.png") plt.close(fig) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/outliers/advanced_analysis.py b/src/fmb/viz/outliers/advanced_analysis.py index b01da56..c0bc01f 100644 --- a/src/fmb/viz/outliers/advanced_analysis.py +++ b/src/fmb/viz/outliers/advanced_analysis.py @@ -8,53 +8,66 @@ import argparse import sys from pathlib import Path -from typing import Sequence, Optional, Dict, List, Union +from typing import Sequence, Union import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -from scipy.stats import spearmanr from fmb.paths import load_paths from fmb.viz.style import apply_style KEYS = { - "AION": {"Images": "embedding_hsc", "Spectra": "embedding_spectrum", "Joint": "embedding_hsc_desi"}, - "AstroPT": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"}, - "AstroCLIP": {"Images": "embedding_images", "Spectra": "embedding_spectra", "Joint": "embedding_joint"} + "AION": { + "Images": "embedding_hsc", + "Spectra": "embedding_spectrum", + "Joint": "embedding_hsc_desi", + }, + "AstroPT": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, + "AstroCLIP": { + "Images": "embedding_images", + "Spectra": "embedding_spectra", + "Joint": "embedding_joint", + }, } + def load_data(path: Path) -> pd.DataFrame: if not path.exists(): raise FileNotFoundError(f"Score file not found: {path}") df = pd.read_csv(path) - df['object_id'] = df['object_id'].astype(str) + df["object_id"] = df["object_id"].astype(str) return df + def run_analysis( aion_scores: Union[str, Path, None] = None, astropt_scores: Union[str, Path, None] = None, astroclip_scores: Union[str, Path, None] = None, - save_prefix: Union[str, Path, None] = None + save_prefix: Union[str, Path, None] = None, ): """ Main entry point for advanced analysis visualization. """ # 1. Apply Style apply_style() - + # 2. Resolve Paths paths = load_paths() out_dir = paths.analysis / "advanced" out_dir.mkdir(parents=True, exist_ok=True) - + if save_prefix is None: save_prefix = out_dir / "paper_advanced" else: save_prefix = Path(save_prefix) if not save_prefix.parent.exists(): - save_prefix.parent.mkdir(parents=True, exist_ok=True) + save_prefix.parent.mkdir(parents=True, exist_ok=True) # Defaults for input files if not provided # We look for standard names in paths.outliers @@ -67,7 +80,7 @@ def run_analysis( astropt_scores = paths.outliers / "anomaly_scores_astropt.csv" else: astropt_scores = Path(astropt_scores) - + if astroclip_scores is None: astroclip_scores = paths.outliers / "anomaly_scores_astroclip.csv" else: @@ -78,18 +91,18 @@ def run_analysis( print(f" AION: {aion_scores}") print(f" AstroPT: {astropt_scores}") print(f" AstroCLIP: {astroclip_scores}") - + try: dfs = { "AION": load_data(aion_scores), "AstroPT": load_data(astropt_scores), - "AstroCLIP": load_data(astroclip_scores) + "AstroCLIP": load_data(astroclip_scores), } except FileNotFoundError as e: print(f"Error: {e}") print("Please ensure all score files exist or provide paths via arguments.") sys.exit(1) - + systems = [] modalities = ["Images", "Spectra", "Joint"] for model in ["AION", "AstroPT", "AstroCLIP"]: @@ -102,39 +115,49 @@ def run_analysis( key = KEYS[model][mod] df = dfs[model] # Check if key exists (some runs might be partial) - if key not in df['embedding_key'].unique(): - print(f"Warning: Key '{key}' not found in {model} scores. Skipping.") - continue - - mod_data = df[df['embedding_key'] == key][['object_id', 'rank']].rename(columns={'rank': f'{model}-{mod}'}) + if key not in df["embedding_key"].unique(): + print(f"Warning: Key '{key}' not found in {model} scores. Skipping.") + continue + + mod_data = df[df["embedding_key"] == key][["object_id", "rank"]].rename( + columns={"rank": f"{model}-{mod}"} + ) if big_merged is None: big_merged = mod_data else: - big_merged = big_merged.merge(mod_data, on='object_id') - + big_merged = big_merged.merge(mod_data, on="object_id") + if big_merged is None or big_merged.empty: - print("Error: No common objects found after merge. Check object_ids consistency.") + print( + "Error: No common objects found after merge. Check object_ids consistency." + ) return n_total = len(big_merged) print(f"Matched {n_total} objects.") # Update systems list based on what was actually merged - valid_cols = [c for c in big_merged.columns if c != 'object_id'] + valid_cols = [c for c in big_merged.columns if c != "object_id"] # Re-order to keep generic order if possible final_cols = [] for model, mod in systems: - col = f'{model}-{mod}' + col = f"{model}-{mod}" if col in valid_cols: final_cols.append(col) # --- 1. Spearman Clustermap (Rank Correlation) --- print("Computing Spearman Correlations...") - corr_matrix = big_merged[final_cols].corr(method='spearman') - + corr_matrix = big_merged[final_cols].corr(method="spearman") + plt.figure() - g = sns.clustermap(corr_matrix, annot=True, fmt=".2f", cmap="vlag", center=0, - dendrogram_ratio=(.1, .1)) #, cbar_pos=(.02, .32, .03, .2)) + g = sns.clustermap( + corr_matrix, + annot=True, + fmt=".2f", + cmap="vlag", + center=0, + dendrogram_ratio=(0.1, 0.1), + ) # , cbar_pos=(.02, .32, .03, .2)) g.fig.suptitle("Hierarchical Clustering of Systems (Spearman Correlation)", y=1.02) s_path = f"{save_prefix}_spearman_clustermap.png" g.savefig(s_path, dpi=300) @@ -144,22 +167,33 @@ def run_analysis( print("Computing Jaccard Indices...") top_1_thr = n_total * 0.01 jaccard_matrix = pd.DataFrame(index=final_cols, columns=final_cols, dtype=float) - + top_1_sets = {} for col in final_cols: - top_1_sets[col] = set(big_merged[big_merged[col] <= top_1_thr]['object_id']) + top_1_sets[col] = set(big_merged[big_merged[col] <= top_1_thr]["object_id"]) for col1 in final_cols: for col2 in final_cols: s1 = top_1_sets[col1] s2 = top_1_sets[col2] - iou = len(s1.intersection(s2)) / len(s1.union(s2)) if len(s1.union(s2)) > 0 else 0 + iou = ( + len(s1.intersection(s2)) / len(s1.union(s2)) + if len(s1.union(s2)) > 0 + else 0 + ) jaccard_matrix.loc[col1, col2] = iou plt.figure() - g2 = sns.clustermap(jaccard_matrix, annot=True, fmt=".2f", cmap="YlGnBu", - dendrogram_ratio=(.1, .1)) #, cbar_pos=(.02, .32, .03, .2)) - g2.fig.suptitle(r"Hierarchical Clustering of Top 1\% Anomalies (Jaccard Index)", y=1.02) + g2 = sns.clustermap( + jaccard_matrix, + annot=True, + fmt=".2f", + cmap="YlGnBu", + dendrogram_ratio=(0.1, 0.1), + ) # , cbar_pos=(.02, .32, .03, .2)) + g2.fig.suptitle( + r"Hierarchical Clustering of Top 1\% Anomalies (Jaccard Index)", y=1.02 + ) j_path = f"{save_prefix}_jaccard_clustermap.png" g2.savefig(j_path, dpi=300) print(f"Saved {j_path}") @@ -167,55 +201,74 @@ def run_analysis( # --- 3. Disagreement Analysis (AION-Joint vs AstroPT-Joint) --- sys1 = "AION-Joint" sys2 = "AstroPT-Joint" - + if sys1 in final_cols and sys2 in final_cols: print("Analyzing Disagreements...") - + x = big_merged[sys1].values y = big_merged[sys2].values - + # Normalize ranks to percentile (0-100) x_pct = (x / n_total) * 100 y_pct = (y / n_total) * 100 - + # Identify Controversial Zones # Case A: Anomaly in Sys1 (Top 1%), Normal in Sys2 (>50%) mask_a = (x_pct <= 1) & (y_pct > 50) # Case B: Anomaly in Sys2 (Top 1%), Normal in Sys1 (>50%) mask_b = (y_pct <= 1) & (x_pct > 50) - + fig, ax = plt.subplots(figsize=(7, 7)) # Hexbin for density - hb = ax.hexbin(x_pct, y_pct, gridsize=50, cmap='Greys', mincnt=1, bins='log', alpha=0.6) - + ax.hexbin( + x_pct, y_pct, gridsize=50, cmap="Greys", mincnt=1, bins="log", alpha=0.6 + ) + # Highlight Controversial - ax.scatter(x_pct[mask_a], y_pct[mask_a], s=10, c='red', label=f'Anomaly in {sys1}\nNormal in {sys2}', alpha=0.8) - ax.scatter(x_pct[mask_b], y_pct[mask_b], s=10, c='blue', label=f'Anomaly in {sys2}\nNormal in {sys1}', alpha=0.8) - + ax.scatter( + x_pct[mask_a], + y_pct[mask_a], + s=10, + c="red", + label=f"Anomaly in {sys1}\nNormal in {sys2}", + alpha=0.8, + ) + ax.scatter( + x_pct[mask_b], + y_pct[mask_b], + s=10, + c="blue", + label=f"Anomaly in {sys2}\nNormal in {sys1}", + alpha=0.8, + ) + ax.set_xlabel(f"{sys1} Rank (\%)") ax.set_ylabel(f"{sys2} Rank (\%)") ax.set_title(f"Disagreement Analysis: {sys1} vs {sys2}") - ax.legend(loc='upper right', fontsize=8) - ax.grid(True, linestyle=':', alpha=0.5) - + ax.legend(loc="upper right", fontsize=8) + ax.grid(True, linestyle=":", alpha=0.5) + # Annotate regions - ax.axvline(x=1, color='red', linestyle=':', alpha=0.5) - ax.axhline(y=1, color='blue', linestyle=':', alpha=0.5) - + ax.axvline(x=1, color="red", linestyle=":", alpha=0.5) + ax.axhline(y=1, color="blue", linestyle=":", alpha=0.5) + plt.tight_layout() d_path = f"{save_prefix}_disagreement_scatter.png" fig.savefig(d_path, dpi=300) print(f"Saved {d_path}") # Export Disagreements - disagreements = big_merged[mask_a | mask_b][['object_id', sys1, sys2]].copy() - disagreements['type'] = np.where(mask_a[mask_a | mask_b], f'{sys1}_Anomaly', f'{sys2}_Anomaly') + disagreements = big_merged[mask_a | mask_b][["object_id", sys1, sys2]].copy() + disagreements["type"] = np.where( + mask_a[mask_a | mask_b], f"{sys1}_Anomaly", f"{sys2}_Anomaly" + ) csv_path = f"{save_prefix}_disagreement_objects.csv" disagreements.to_csv(csv_path, index=False) print(f"Exported {len(disagreements)} controversial objects to {csv_path}") else: print(f"Skipping disagreement analysis: {sys1} or {sys2} missing.") + def main(argv: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser() parser.add_argument("--aion-scores", default=None) @@ -228,8 +281,9 @@ def main(argv: Sequence[str] | None = None) -> None: aion_scores=args.aion_scores, astropt_scores=args.astropt_scores, astroclip_scores=args.astroclip_scores, - save_prefix=args.save_prefix + save_prefix=args.save_prefix, ) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/outliers/outlier_grid.py b/src/fmb/viz/outliers/outlier_grid.py index 1415193..0918e0b 100644 --- a/src/fmb/viz/outliers/outlier_grid.py +++ b/src/fmb/viz/outliers/outlier_grid.py @@ -7,25 +7,23 @@ import argparse from pathlib import Path -from typing import Sequence, Optional, List, Union +from typing import List, Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np import scipy.ndimage -import torch -from tqdm import tqdm -from fmb.paths import load_paths from fmb.data.load_display_data import EuclidDESIDataset +from fmb.data.utils import read_object_ids +from fmb.paths import load_paths +from fmb.viz.spectrum import REST_LINES, extract_spectrum +from fmb.viz.style import apply_style from fmb.viz.utils import ( - load_index, collect_samples, collect_samples_with_index, + load_index, prepare_rgb_image, ) -from fmb.viz.spectrum import extract_spectrum, REST_LINES -from fmb.data.utils import read_object_ids -from fmb.viz.style import apply_style def plot_publication_grid( @@ -38,95 +36,96 @@ def plot_publication_grid( if count == 0: print("No samples to display.") return - + rows = int(np.ceil(count / cols)) - + # Grid: 1 row per object (Image | Spectrum) # 2 cols per object total_grid_rows = rows total_grid_cols = cols * 2 - + # Figure size # Width: ~5 inches per object (1.5" img + 3.5" spec) -> ~10 inches for 2 cols # Height: ~1.5 inches per object row fig_width = 5.0 * cols - fig_height = 1.3 * rows - + fig_height = 1.3 * rows + fig = plt.figure(figsize=(fig_width, fig_height)) - + # GridSpec # Alternating widths: [1, 2.5] for [Image, Spectrum] gs = fig.add_gridspec( - total_grid_rows, - total_grid_cols, + total_grid_rows, + total_grid_cols, width_ratios=[1, 2.5] * cols, - hspace=0.03, + hspace=0.03, wspace=0.03, top=0.98, bottom=0.05, left=0.02, - right=0.98 + right=0.98, ) for idx, sample in enumerate(samples): row = idx // cols col = idx % cols - + # Grid indices c_img = 2 * col c_spec = 2 * col + 1 - + # 1. Image (Left) ax_img = fig.add_subplot(gs[row, c_img]) - + image = prepare_rgb_image(sample) if image.ndim == 3 and image.shape[2] == 1: ax_img.imshow(image[..., 0], cmap="gray", origin="lower") else: ax_img.imshow(image, origin="lower") - + ax_img.axis("off") - + # Overlay Text obj_id = str(sample.get("object_id") or sample.get("targetid", "N/A")) redshift = sample.get("redshift") - + text_content = f"{obj_id}" if redshift is not None and not np.isnan(redshift): try: text_content += f"\n$z={float(redshift):.3f}$" except (TypeError, ValueError): pass - + # Top-left of image ax_img.text( - 0.05, 0.95, - text_content, - transform=ax_img.transAxes, - fontsize=7, - va="top", + 0.05, + 0.95, + text_content, + transform=ax_img.transAxes, + fontsize=7, + va="top", ha="left", color="white", - fontweight='bold', - bbox=dict(facecolor='black', alpha=0.5, edgecolor='none', pad=1.0) + fontweight="bold", + bbox=dict(facecolor="black", alpha=0.5, edgecolor="none", pad=1.0), ) # 2. Spectrum (Right) ax_spec = fig.add_subplot(gs[row, c_spec]) wavelength, flux = extract_spectrum(sample) - + if flux is not None and wavelength is not None: sort_idx = np.argsort(wavelength) wave_sorted = wavelength[sort_idx] flux_sorted = flux[sort_idx] - + smoothed_flux = scipy.ndimage.gaussian_filter1d(flux_sorted, sigma=3) - + redshift = sample.get("redshift") rest_wave = wave_sorted z_val = None - + if redshift is not None and not np.isnan(redshift) and redshift > -1: try: z_val = float(redshift) @@ -138,16 +137,22 @@ def plot_publication_grid( x_label = r"Wavelength [\AA]" ax_spec.plot(rest_wave, smoothed_flux, linewidth=0.6, color="black") - + # Emission lines if z_val is not None: ymin, ymax = ax_spec.get_ylim() # Auto-scale Y to ignore extreme outliers? # For now keep as is. - + for name, line_rest in REST_LINES.items(): if rest_wave.min() <= line_rest <= rest_wave.max(): - ax_spec.axvline(line_rest, color="red", linestyle=":", alpha=0.6, linewidth=0.8) + ax_spec.axvline( + line_rest, + color="red", + linestyle=":", + alpha=0.6, + linewidth=0.8, + ) ax_spec.text( line_rest, ymax * 0.98, @@ -157,36 +162,45 @@ def plot_publication_grid( ha="center", fontsize=5, color="#cc0000", - bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=0.5) + bbox=dict( + facecolor="white", alpha=0.7, edgecolor="none", pad=0.5 + ), ) ax_spec.set_xlim(rest_wave.min(), rest_wave.max()) - ax_spec.spines['top'].set_visible(False) - ax_spec.spines['right'].set_visible(False) - + ax_spec.spines["top"].set_visible(False) + ax_spec.spines["right"].set_visible(False) + # X-label only for bottom row of OBJECTS if row == rows - 1: ax_spec.set_xlabel(x_label, fontsize=7, labelpad=2) else: ax_spec.set_xticklabels([]) - - ax_spec.set_yticks([]) # Hide Y ticks + + ax_spec.set_yticks([]) # Hide Y ticks ax_spec.minorticks_on() - ax_spec.tick_params(axis='both', which='major', labelsize=6, length=3) - ax_spec.tick_params(axis='both', which='minor', length=1.5) - + ax_spec.tick_params(axis="both", which="major", labelsize=6, length=3) + ax_spec.tick_params(axis="both", which="minor", length=1.5) + else: - ax_spec.text(0.5, 0.5, "No spectrum", ha="center", va="center", transform=ax_spec.transAxes) + ax_spec.text( + 0.5, + 0.5, + "No spectrum", + ha="center", + va="center", + transform=ax_spec.transAxes, + ) ax_spec.axis("off") if save_path: save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=300, bbox_inches="tight") print(f"Saved publication grid to {save_path}") - + if show: plt.show() - + plt.close(fig) @@ -199,23 +213,23 @@ def run_grid_plot( save_path: Optional[Union[str, Path]] = "outliers_paper_grid.png", show: bool = True, index_path: Optional[Union[str, Path]] = None, - verbose: bool = False + verbose: bool = False, ): """ Load data and plot the outlier grid. """ # Apply style apply_style() - + # Resolve paths paths = load_paths() if not cache_dir: cache_dir = paths.dataset - + if index_path is None: index_path = paths.dataset_index if index_path and not index_path.exists(): - # Just warn or allow it to be ignored? + # Just warn or allow it to be ignored? # `collect_samples_with_index` needs it. pass @@ -223,13 +237,15 @@ def run_grid_plot( out_dir = paths.analysis / "outliers" out_dir.mkdir(parents=True, exist_ok=True) save_path = out_dir / "outliers_grid.png" - + csv_files = [Path(p) for p in csv_paths] - object_ids = read_object_ids(csv_files, verbose=verbose) # removed limit here to shuffle/sample later or use max - # Note: original read_object_ids took limit, but we pass full list then slice here? + object_ids = read_object_ids( + csv_files, verbose=verbose + ) # removed limit here to shuffle/sample later or use max + # Note: original read_object_ids took limit, but we pass full list then slice here? # Actually utils.read_object_ids doesn't have limit arg in all versions? - # Let's check signature... - # User's utils.py has `read_object_ids`? I saw it in `src/fmb/data/utils.py`? + # Let's check signature... + # User's utils.py has `read_object_ids`? I saw it in `src/fmb/data/utils.py`? # I didn't check that file content. Assuming it works or I should limit list manually. if len(object_ids) > max_count: object_ids = object_ids[:max_count] @@ -241,13 +257,15 @@ def run_grid_plot( if index_path: index_map = load_index(Path(index_path)) samples = collect_samples_with_index( - cache_dir=str(cache_dir), # func expects str + cache_dir=str(cache_dir), # func expects str object_ids=object_ids, index_map=index_map, verbose=verbose, ) else: - dataset = EuclidDESIDataset(split=split, cache_dir=str(cache_dir), verbose=verbose) + dataset = EuclidDESIDataset( + split=split, cache_dir=str(cache_dir), verbose=verbose + ) samples = collect_samples(dataset, object_ids, verbose=verbose) if not samples: @@ -266,15 +284,33 @@ def main(argv: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser( description="Display publication-ready grid of outliers with spectra", ) - parser.add_argument("--csv", nargs="+", required=True, help="CSV file(s) with object_id column") + parser.add_argument( + "--csv", nargs="+", required=True, help="CSV file(s) with object_id column" + ) parser.add_argument("--split", type=str, default="all", help="Dataset split(s)") parser.add_argument("--cache-dir", type=str, default=None) - parser.add_argument("--max", type=int, default=12, help="Maximum number of images to display") - parser.add_argument("--cols", type=int, default=3, help="Number of columns in the grid") - parser.add_argument("--save", type=str, default="outliers_paper_grid.png", help="Path to save the figure") - parser.add_argument("--no-show", action="store_true", help="Disable interactive display") + parser.add_argument( + "--max", type=int, default=12, help="Maximum number of images to display" + ) + parser.add_argument( + "--cols", type=int, default=3, help="Number of columns in the grid" + ) + parser.add_argument( + "--save", + type=str, + default="outliers_paper_grid.png", + help="Path to save the figure", + ) + parser.add_argument( + "--no-show", action="store_true", help="Disable interactive display" + ) parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") - parser.add_argument("--index", type=str, default=None, help="Optional CSV mapping object_id -> split/index") + parser.add_argument( + "--index", + type=str, + default=None, + help="Optional CSV mapping object_id -> split/index", + ) args = parser.parse_args(argv) run_grid_plot( @@ -286,8 +322,9 @@ def main(argv: Sequence[str] | None = None) -> None: save_path=args.save, show=not args.no_show, index_path=args.index, - verbose=args.verbose + verbose=args.verbose, ) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/outliers/single_object.py b/src/fmb/viz/outliers/single_object.py index 50f907d..cd0c34f 100644 --- a/src/fmb/viz/outliers/single_object.py +++ b/src/fmb/viz/outliers/single_object.py @@ -6,22 +6,24 @@ """ import argparse -import sys from pathlib import Path -from typing import Optional, Tuple, Union, Sequence +from typing import Optional, Sequence, Union import matplotlib.pyplot as plt import numpy as np import torch from scipy.ndimage import gaussian_filter1d -from fmb.paths import load_paths from fmb.data.load_display_data import EuclidDESIDataset -from fmb.viz.utils import load_index, prepare_rgb_image -from fmb.viz.spectrum import extract_spectrum, LATEX_REST_LINES +from fmb.paths import load_paths +from fmb.viz.spectrum import LATEX_REST_LINES, extract_spectrum from fmb.viz.style import apply_style +from fmb.viz.utils import load_index + -def get_sample_by_id(object_id: str, index_path: Optional[Path], cache_dir: str, verbose: bool = True) -> dict: +def get_sample_by_id( + object_id: str, index_path: Optional[Path], cache_dir: str, verbose: bool = True +) -> dict: if index_path and index_path.exists(): index_map = load_index(index_path) if object_id in index_map: @@ -29,13 +31,15 @@ def get_sample_by_id(object_id: str, index_path: Optional[Path], cache_dir: str, # Use offset logic if needed? load_display_data handles dataset loading. # We just load the specific sample. # Ideally we reuse valid offsets, but simply loading dataset and using generic access is safest. - dataset = EuclidDESIDataset(split=split, cache_dir=cache_dir, verbose=verbose) - - # Since index might be global/unsafe, checking ID is better, + dataset = EuclidDESIDataset( + split=split, cache_dir=cache_dir, verbose=verbose + ) + + # Since index might be global/unsafe, checking ID is better, # but EuclidDESIDataset doesn't support ID lookup. # We blindly trust index for now, but handle IndexError try: - # Correct index if strictly sequential? + # Correct index if strictly sequential? # See my previous logic in utils.py. # Here we will assume the index in index_map is usable or relative. # If it fails, we fail. @@ -43,14 +47,16 @@ def get_sample_by_id(object_id: str, index_path: Optional[Path], cache_dir: str, return sample except IndexError: # Fallback: scan whole split - if verbose: print(f"Index {idx} invalid, scanning split '{split}'...") + if verbose: + print(f"Index {idx} invalid, scanning split '{split}'...") for s in dataset: if str(s.get("object_id") or "") == object_id: return s raise ValueError(f"Object {object_id} not found in split {split}") - + # Fallback if no index: scan both splits - if verbose: print("No index provided or object not found. Scanning 'train' and 'test'...") + if verbose: + print("No index provided or object not found. Scanning 'train' and 'test'...") for split in ["train", "test"]: try: ds = EuclidDESIDataset(split=split, cache_dir=cache_dir, verbose=False) @@ -59,61 +65,77 @@ def get_sample_by_id(object_id: str, index_path: Optional[Path], cache_dir: str, return s except Exception: pass - + raise ValueError(f"Object {object_id} not found in dataset.") + def plot_spectrum(ax: plt.Axes, sample: dict, smooth_sigma: float = 2.0): wavelength, flux = extract_spectrum(sample) redshift = sample.get("redshift", 0.0) - - if hasattr(redshift, "item"): redshift = redshift.item() - if redshift is None or np.isnan(redshift): redshift = 0.0 - + + if hasattr(redshift, "item"): + redshift = redshift.item() + if redshift is None or np.isnan(redshift): + redshift = 0.0 + if flux is None or wavelength is None: - ax.text(0.5, 0.5, "No Spectrum", ha="center", va="center", transform=ax.transAxes) + ax.text( + 0.5, 0.5, "No Spectrum", ha="center", va="center", transform=ax.transAxes + ) return # Sort and Smooth order = np.argsort(wavelength) w = wavelength[order] f = flux[order] - + if smooth_sigma > 0: f = gaussian_filter1d(f, sigma=smooth_sigma) - + # Restframe conversion w_rest = w / (1.0 + redshift) - + ax.plot(w_rest, f, color="black", linewidth=0.8) - + # Overlay emission lines y_min, y_max = ax.get_ylim() for name, line_rest in LATEX_REST_LINES.items(): if w_rest.min() <= line_rest <= w_rest.max(): - ax.axvline(line_rest, color="gray", linestyle="--", linewidth=0.5, alpha=0.5) + ax.axvline( + line_rest, color="gray", linestyle="--", linewidth=0.5, alpha=0.5 + ) # Label at top - ax.text(line_rest, y_max * 0.95, name, rotation=90, fontsize=6, - ha="center", va="top", alpha=0.7) + ax.text( + line_rest, + y_max * 0.95, + name, + rotation=90, + fontsize=6, + ha="center", + va="top", + alpha=0.7, + ) ax.set_xlabel(r"Rest-frame Wavelength [\AA]") ax.set_ylabel("Flux [arb.]") ax.set_xlim(w_rest.min(), w_rest.max()) - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + def plot_band(ax: plt.Axes, image_data, title: str): if image_data is None: ax.text(0.5, 0.5, "N/A", ha="center", va="center", transform=ax.transAxes) ax.axis("off") return - + if isinstance(image_data, torch.Tensor): image_data = image_data.detach().cpu().numpy() - + im = np.nan_to_num(image_data) v_min = np.percentile(im, 1) v_max = np.percentile(im, 99.5) - + ax.imshow(im, cmap="magma", origin="lower", vmin=v_min, vmax=v_max) ax.set_title(title, fontsize=12) ax.set_xticks([]) @@ -124,76 +146,85 @@ def plot_band(ax: plt.Axes, image_data, title: str): spine.set_linewidth(1.0) spine.set_color("black") + def run_single_object_plot( object_id: str, index_path: Union[str, Path, None] = None, cache_dir: Union[str, Path, None] = None, save_path: Union[str, Path, None] = None, smooth: float = 2.0, - dpi: int = 300 + dpi: int = 300, ): apply_style() paths = load_paths() if not cache_dir: cache_dir = paths.dataset - + if index_path is None: index_path = paths.dataset_index - + if save_path is None: out_dir = paths.analysis / "objects" out_dir.mkdir(parents=True, exist_ok=True) save_path = out_dir / f"object_{object_id}.png" - + print(f"Loading object {object_id}...") # Ensure index path is resolved to Path if it exists idx_p = Path(index_path) if index_path else None sample = get_sample_by_id(object_id, idx_p, str(cache_dir)) - + # 5 panels: Spectrum, VIS, Y, J, H fig = plt.figure(figsize=(15, 3)) from matplotlib.gridspec import GridSpec + gs = GridSpec(1, 5, width_ratios=[2.5, 1, 1, 1, 1]) # 1. Spectrum ax_spec = fig.add_subplot(gs[0]) plot_spectrum(ax_spec, sample, smooth) - + # 2. VIS ax_vis = fig.add_subplot(gs[1]) plot_band(ax_vis, sample.get("vis_image"), "VIS") - + # 3. Y ax_y = fig.add_subplot(gs[2]) plot_band(ax_y, sample.get("nisp_y_image"), "Y") - + # 4. J ax_j = fig.add_subplot(gs[3]) plot_band(ax_j, sample.get("nisp_j_image"), "J") - + # 5. H ax_h = fig.add_subplot(gs[4]) plot_band(ax_h, sample.get("nisp_h_image"), "H") - + z = sample.get("redshift", 0.0) - if hasattr(z, "item"): z = z.item() - if z is None or np.isnan(z): z = 0.0 - + if hasattr(z, "item"): + z = z.item() + if z is None or np.isnan(z): + z = 0.0 + fig.suptitle(f"Object {object_id} ($z = {z:.3f}$)", fontsize=14, y=1.05) - + plt.tight_layout() sp = Path(save_path) sp.parent.mkdir(parents=True, exist_ok=True) fig.savefig(sp, dpi=dpi, bbox_inches="tight") print(f"Saved figure to {sp}") + def main(argv: Sequence[str] | None = None) -> None: - parser = argparse.ArgumentParser(description="Paper-ready single object visualization") + parser = argparse.ArgumentParser( + description="Paper-ready single object visualization" + ) parser.add_argument("--object-id", required=True, help="ID of the object to plot") parser.add_argument("--index", default=None, help="Path to index CSV") parser.add_argument("--cache-dir", default=None, help="Data cache directory") parser.add_argument("--save", default="object_viz.pdf", help="Output filename") - parser.add_argument("--smooth", type=float, default=2.0, help="Smoothing for spectrum") + parser.add_argument( + "--smooth", type=float, default=2.0, help="Smoothing for spectrum" + ) parser.add_argument("--dpi", type=int, default=300, help="DPI for saving") args = parser.parse_args(argv) @@ -203,8 +234,9 @@ def main(argv: Sequence[str] | None = None) -> None: cache_dir=args.cache_dir, save_path=args.save, smooth=args.smooth, - dpi=args.dpi + dpi=args.dpi, ) + if __name__ == "__main__": main() diff --git a/src/fmb/viz/plot_anomaly_scores.py b/src/fmb/viz/plot_anomaly_scores.py index 9ffd8c3..87f8b2a 100644 --- a/src/fmb/viz/plot_anomaly_scores.py +++ b/src/fmb/viz/plot_anomaly_scores.py @@ -12,15 +12,16 @@ import matplotlib.pyplot as plt import numpy as np -import torch try: import umap except ImportError as exc: # pragma: no cover - raise SystemExit("The 'umap-learn' package is required. Install it with 'pip install umap-learn'.") from exc - -from .detect_outliers import EMBEDDING_KEYS, load_records, stack_embeddings # type: ignore[import] + raise SystemExit( + "The 'umap-learn' package is required. Install it with 'pip install umap-learn'." + ) from exc +from .detect_outliers import EMBEDDING_KEYS # type: ignore[import] +from .detect_outliers import load_records, stack_embeddings METRICS: list[tuple[str, str, str]] = [ ("log_prob", "Log Probability", "viridis"), @@ -33,10 +34,19 @@ def parse_scores_csv(path: Path) -> dict[str, dict[str, dict[str, float]]]: by_key: dict[str, dict[str, dict[str, float]]] = {} with path.open("r", newline="") as csvfile: reader = csv.DictReader(csvfile) - required = {"object_id", "embedding_key", "log_prob", "neg_log_prob", "anomaly_sigma", "rank"} + required = { + "object_id", + "embedding_key", + "log_prob", + "neg_log_prob", + "anomaly_sigma", + "rank", + } missing_cols = required - set(reader.fieldnames or []) if missing_cols: - raise SystemExit(f"Scores CSV is missing required columns: {sorted(missing_cols)}") + raise SystemExit( + f"Scores CSV is missing required columns: {sorted(missing_cols)}" + ) for row in reader: key = row["embedding_key"] object_id = row["object_id"] @@ -85,7 +95,9 @@ def plot_metric_umap( output_path: Path, ) -> None: fig, ax = plt.subplots(figsize=(7, 6)) - sc = ax.scatter(coords[:, 0], coords[:, 1], c=values, cmap=cmap, s=12, linewidths=0, alpha=0.9) + sc = ax.scatter( + coords[:, 0], coords[:, 1], c=values, cmap=cmap, s=12, linewidths=0, alpha=0.9 + ) ax.set_title(f"{embedding_key.replace('embedding_', '').upper()} – {metric_label}") ax.set_xlabel("UMAP-1") ax.set_ylabel("UMAP-2") @@ -102,18 +114,35 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description="Plot UMAP projections of embeddings coloured by flow-based anomaly scores.", ) - parser.add_argument("--embeddings", required=True, help="Path to embeddings .pt file") - parser.add_argument("--scores-csv", required=True, help="CSV produced by scratch.detect_outliers_NFs") - parser.add_argument("--output-dir", required=True, help="Directory to store generated figures") + parser.add_argument( + "--embeddings", required=True, help="Path to embeddings .pt file" + ) + parser.add_argument( + "--scores-csv", + required=True, + help="CSV produced by scratch.detect_outliers_NFs", + ) + parser.add_argument( + "--output-dir", required=True, help="Directory to store generated figures" + ) parser.add_argument( "--embedding-key", choices=EMBEDDING_KEYS, nargs="+", help="Subset of embedding keys to plot. Defaults to keys present in the CSV.", ) - parser.add_argument("--n-neighbors", type=int, default=30, help="UMAP number of neighbours") - parser.add_argument("--min-dist", type=float, default=0.05, help="UMAP minimum distance") - parser.add_argument("--random-state", type=int, default=42, help="Random state for UMAP reproducibility") + parser.add_argument( + "--n-neighbors", type=int, default=30, help="UMAP number of neighbours" + ) + parser.add_argument( + "--min-dist", type=float, default=0.05, help="UMAP minimum distance" + ) + parser.add_argument( + "--random-state", + type=int, + default=42, + help="Random state for UMAP reproducibility", + ) parser.add_argument( "--no-standardize", action="store_true", diff --git a/src/fmb/viz/similarity.py b/src/fmb/viz/similarity.py index 15da789..4a2ffd1 100644 --- a/src/fmb/viz/similarity.py +++ b/src/fmb/viz/similarity.py @@ -6,7 +6,7 @@ """ from pathlib import Path -from typing import Sequence, Optional, Dict +from typing import Dict, Optional, Sequence import matplotlib.pyplot as plt import numpy as np @@ -34,6 +34,7 @@ "[S II]": 6731.0, } + def extract_spectrum(sample: Dict) -> tuple[np.ndarray | None, np.ndarray | None]: """Helper to extract flux and wavelength from sample dict.""" spec = sample.get("spectrum") @@ -42,13 +43,13 @@ def extract_spectrum(sample: Dict) -> tuple[np.ndarray | None, np.ndarray | None flux = spec.get("flux") if flux is None: return None, None - + if isinstance(flux, torch.Tensor): flux_np = flux.detach().cpu().numpy() else: flux_np = np.asarray(flux) flux_np = np.squeeze(flux_np) - + wavelength = spec.get("wavelength") if wavelength is None: wavelength_np = np.arange(len(flux_np)) @@ -58,16 +59,17 @@ def extract_spectrum(sample: Dict) -> tuple[np.ndarray | None, np.ndarray | None else: wavelength_np = np.asarray(wavelength) wavelength_np = np.squeeze(wavelength_np) - + return wavelength_np, flux_np + def plot_vertical_panels( samples: Sequence[Dict], cols: int, save_path: Optional[Path], show: bool, row_labels: Optional[Sequence[str]] = None, - smooth_sigma: float = 3.0 + smooth_sigma: float = 3.0, ) -> None: """ Plots a grid where each cell contains a vertical stack of (Image, Spectrum). @@ -77,23 +79,23 @@ def plot_vertical_panels( if count == 0: print("No samples to display.") return - + rows = int(np.ceil(count / cols)) - + # Figure size estimation # Width: ~3 inches per col # Height: ~4 inches per row (2 inches img, 2 inches spec) fig_w = 3.5 * cols fig_h = 4.5 * rows - + fig, axes = plt.subplots(rows * 2, cols, figsize=(fig_w, fig_h)) - + # Ensure axes is 2D array [2*rows, cols] if rows * 2 == 1 and cols == 1: axes = np.array([[axes]]) elif (rows * 2 == 1) or (cols == 1): axes = axes.reshape(rows * 2, cols) - + # Add row labels if provided (e.g. Model Name) # We place them to the left of the First column of each Row-Block if row_labels: @@ -105,20 +107,21 @@ def plot_vertical_panels( # Coordinates in axes fraction? Or Figure? # Using text with negative x ax_ref.text( - -0.2, 0.5, - lbl, - transform=ax_ref.transAxes, - rotation=90, - va='center', - ha='right', - fontsize=12, - fontweight='bold' + -0.2, + 0.5, + lbl, + transform=ax_ref.transAxes, + rotation=90, + va="center", + ha="right", + fontsize=12, + fontweight="bold", ) for idx, sample in enumerate(samples): row = idx // cols col = idx % cols - + img_ax = axes[2 * row, col] spec_ax = axes[2 * row + 1, col] @@ -128,69 +131,77 @@ def plot_vertical_panels( img_ax.imshow(image[..., 0], cmap="gray", origin="lower") else: img_ax.imshow(image, origin="lower") - + img_ax.axis("off") - + # Title/ID obj_id = str(sample.get("object_id", "N/A")) # Clean up ID if it has prefixes like [QUERY] # Maybe split lines if too long? - + redshift = sample.get("redshift") title_str = f"{obj_id}" if redshift is not None: - try: - title_str += f"\nz={float(redshift):.3f}" - except: - pass - + try: + title_str += f"\nz={float(redshift):.3f}" + except: + pass + img_ax.set_title(title_str, fontsize=8) # 2. Spectrum wavelength, flux = extract_spectrum(sample) spec_ax.clear() - + if flux is not None and wavelength is not None: sort_idx = np.argsort(wavelength) wave_sorted = wavelength[sort_idx] flux_sorted = flux[sort_idx] - + if smooth_sigma > 0: - smoothed_flux = scipy.ndimage.gaussian_filter1d(flux_sorted, sigma=smooth_sigma) + smoothed_flux = scipy.ndimage.gaussian_filter1d( + flux_sorted, sigma=smooth_sigma + ) else: smoothed_flux = flux_sorted - + redshift = sample.get("redshift") rest_wave = wave_sorted z_val = None - + if redshift is not None: try: z_val = float(redshift) rest_wave = wave_sorted / (1.0 + z_val) except (TypeError, ValueError): pass - + spec_ax.plot(rest_wave, smoothed_flux, linewidth=0.8, color="black") - + # Lines if z_val is not None: for name, line_rest in REST_LINES.items(): if rest_wave.min() <= line_rest <= rest_wave.max(): - spec_ax.axvline(line_rest, color="red", linestyle=":", alpha=0.5, linewidth=0.8) + spec_ax.axvline( + line_rest, + color="red", + linestyle=":", + alpha=0.5, + linewidth=0.8, + ) # Label? Maybe too crowded for small plots. - + spec_ax.set_xlim(rest_wave.min(), rest_wave.max()) - spec_ax.set_yticks([]) # Clean look - + spec_ax.set_yticks([]) # Clean look + # Only label X axis on bottom row? Or all? # All is safer for now. # spec_ax.set_xlabel("Wavelength [Å]", fontsize=7) - + else: spec_ax.text(0.5, 0.5, "No Spectrum", ha="center", va="center", fontsize=8) spec_ax.axis("off") - + # clear unused total_cells = rows * cols for idx in range(count, total_cells): @@ -200,13 +211,13 @@ def plot_vertical_panels( axes[2 * row + 1, col].axis("off") plt.tight_layout() - + if save_path: save_path.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(save_path, dpi=200, bbox_inches='tight') + plt.savefig(save_path, dpi=200, bbox_inches="tight") print(f"Saved visualization to {save_path}") - + if show: plt.show() - + plt.close(fig) diff --git a/src/fmb/viz/spectrum.py b/src/fmb/viz/spectrum.py index 4543c3b..909cbb3 100644 --- a/src/fmb/viz/spectrum.py +++ b/src/fmb/viz/spectrum.py @@ -5,9 +5,10 @@ Description: Spectrum extraction and visualization utilities """ +from typing import Optional, Tuple + import numpy as np import torch -from typing import Optional, Tuple REST_LINES = { r"Ly$\alpha$": 1216.0, @@ -47,6 +48,7 @@ "[S II]": 6731.0, } + def extract_spectrum(sample: dict) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: """ Extract wavelength and flux from a sample dictionary. @@ -55,17 +57,17 @@ def extract_spectrum(sample: dict) -> Tuple[Optional[np.ndarray], Optional[np.nd spec = sample.get("spectrum") if spec is None: return None, None - + flux = spec.get("flux") if flux is None: return None, None - + if isinstance(flux, torch.Tensor): flux_np = flux.detach().cpu().numpy() else: flux_np = np.asarray(flux) flux_np = np.squeeze(flux_np) - + wavelength = spec.get("wavelength") if wavelength is None: wavelength_np = np.arange(len(flux_np)) @@ -75,5 +77,5 @@ def extract_spectrum(sample: dict) -> Tuple[Optional[np.ndarray], Optional[np.nd else: wavelength_np = np.asarray(wavelength) wavelength_np = np.squeeze(wavelength_np) - + return wavelength_np, flux_np diff --git a/src/fmb/viz/style.py b/src/fmb/viz/style.py index 945ec69..11dcf85 100644 --- a/src/fmb/viz/style.py +++ b/src/fmb/viz/style.py @@ -5,10 +5,12 @@ Description: Centralized visualization styling """ -from typing import Optional from pathlib import Path +from typing import Optional + from .utils import load_viz_style + def apply_style(config_path: Optional[Path] = None): """ Apply the FMB publication style. @@ -16,5 +18,6 @@ def apply_style(config_path: Optional[Path] = None): """ load_viz_style() + # Alias for backward compatibility or preference set_style = apply_style diff --git a/src/fmb/viz/utils.py b/src/fmb/viz/utils.py index 183db24..54b53a1 100644 --- a/src/fmb/viz/utils.py +++ b/src/fmb/viz/utils.py @@ -6,9 +6,8 @@ """ import csv -import math from pathlib import Path -from typing import Sequence, Dict, Tuple, Optional +from typing import Dict, Sequence, Tuple import numpy as np import torch @@ -17,18 +16,19 @@ from fmb.data.load_display_data import EuclidDESIDataset from fmb.paths import load_paths + def load_viz_style(): """Load matplotlib style from centralized YAML config.""" - import yaml import matplotlib.pyplot as plt - + import yaml + style_path = load_paths().repo_root / "src/fmb/configs/viz_style.yaml" if style_path.exists(): with open(style_path, "r") as f: config = yaml.safe_load(f) - + # Flatten and update rcParams - def flatten(d, parent_key='', sep='.'): + def flatten(d, parent_key="", sep="."): items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k @@ -37,7 +37,7 @@ def flatten(d, parent_key='', sep='.'): else: items.append((new_key, v)) return dict(items) - + plt.rcParams.update(flatten(config)) print(f"Loaded visualization style from {style_path}") else: @@ -58,7 +58,9 @@ def load_index(index_path: Path) -> Dict[str, Tuple[str, int]]: try: idx = int(row["index"]) except (ValueError, TypeError) as exc: - raise ValueError(f"Invalid index for object_id={oid}: {row['index']}") from exc + raise ValueError( + f"Invalid index for object_id={oid}: {row['index']}" + ) from exc if oid: mapping[oid] = (split, idx) return mapping @@ -85,7 +87,9 @@ def collect_samples( break missing = set(wanted.keys()) - {str(s.get("object_id")) for s in collected} if missing and verbose: - print(f"Warning: {len(missing)} object IDs not found: {sorted(list(missing))[:5]}...") + print( + f"Warning: {len(missing)} object IDs not found: {sorted(list(missing))[:5]}..." + ) return collected @@ -109,45 +113,48 @@ def collect_samples_with_index( if verbose: print(f"Loading split '{split}' to fetch {len(entries)} samples") dataset = EuclidDESIDataset(split=split, cache_dir=cache_dir, verbose=False) - + # Indices are global 0..N across train+test min_idx = min(idx for _, idx in entries) - offset = 0 if min_idx >= len(dataset): - pass + pass - sorted_entries = sorted(entries, key=lambda x: x[1]) - base_offset = 0 - if split == 'test' and sorted_entries[0][1] >= len(dataset): - # For 'test' split, subtract the starting index to convert to local indices - base_offset = sorted_entries[0][1] + if split == "test" and sorted_entries[0][1] >= len(dataset): + # For 'test' split, subtract the starting index to convert to local indices + sorted_entries[0][1] - first_idx = sorted_entries[0][1] + sorted_entries[0][1] use_offset = 0 if sorted_entries[-1][1] >= len(dataset): - if split == 'test': - if split == 'test': - ds_train = EuclidDESIDataset(split='train', cache_dir=cache_dir, verbose=False) - use_offset = len(ds_train) - + if split == "test": + if split == "test": + ds_train = EuclidDESIDataset( + split="train", cache_dir=cache_dir, verbose=False + ) + use_offset = len(ds_train) + for oid, idx in entries: local_idx = idx - use_offset if local_idx < 0 or local_idx >= len(dataset): # Only check fallback if strictly needed - if use_offset == 0 and split == 'test': - pass - + if use_offset == 0 and split == "test": + pass + if verbose: - print(f"Warning: index {local_idx} (orig {idx}) out of range for split '{split}'") + print( + f"Warning: index {local_idx} (orig {idx}) out of range for split '{split}'" + ) continue - + sample = dataset[local_idx] # Verify ID match to be sure sample_id = str(sample.get("object_id") or sample.get("targetid")) if sample_id != str(oid): if verbose: - print(f"ID Mismatch at {local_idx}: expected {oid}, got {sample_id}. Fallback scanning...") + print( + f"ID Mismatch at {local_idx}: expected {oid}, got {sample_id}. Fallback scanning..." + ) found = False for i, s in enumerate(dataset): @@ -162,7 +169,9 @@ def collect_samples_with_index( missing = [oid for oid in object_ids if oid not in samples_by_id] if missing and verbose: - print(f"Warning: {len(missing)} object IDs missing in index/dataset: {missing[:5]}...") + print( + f"Warning: {len(missing)} object IDs missing in index/dataset: {missing[:5]}..." + ) return [samples_by_id[oid] for oid in object_ids if oid in samples_by_id] @@ -172,7 +181,7 @@ def prepare_rgb_image(sample: dict) -> np.ndarray: rgb = sample.get("rgb_image") if rgb is None: raise ValueError("Sample missing 'rgb_image'") - + if isinstance(rgb, torch.Tensor): tensor = rgb.detach().cpu() if tensor.dim() == 3: @@ -184,17 +193,19 @@ def prepare_rgb_image(sample: dict) -> np.ndarray: elif tensor.dim() == 2: array = tensor.numpy() else: - raise ValueError(f"Unexpected tensor shape for rgb_image: {tuple(tensor.shape)}") + raise ValueError( + f"Unexpected tensor shape for rgb_image: {tuple(tensor.shape)}" + ) else: # assume numpy array array = np.asarray(rgb) # Handle (C, H, W) -> (H, W, C) if needed if array.ndim == 3 and array.shape[0] in (1, 3): array = np.moveaxis(array, 0, -1) - + # Normalize if needed (assume float [0, 1] or int [0, 255]) - if array.dtype.kind in ('f', 'u'): + if array.dtype.kind in ("f", "u"): if array.max() > 1.01: array = array.astype(float) / 255.0 - + array = np.clip(array, 0.0, 1.0) return array diff --git a/src/fmb/viz/visualize_embedding_umap.py b/src/fmb/viz/visualize_embedding_umap.py index afefbaf..efb05f9 100644 --- a/src/fmb/viz/visualize_embedding_umap.py +++ b/src/fmb/viz/visualize_embedding_umap.py @@ -12,19 +12,15 @@ import matplotlib.pyplot as plt import numpy as np import torch -from tqdm import tqdm - from scratch.display_outlier_images import ( - load_index, collect_samples, collect_samples_with_index, + load_index, prepare_rgb_image, ) -from scratch.display_outlier_images_spectrum import ( - extract_spectrum, - REST_LINES, -) +from scratch.display_outlier_images_spectrum import REST_LINES, extract_spectrum from scratch.load_display_data import EuclidDESIDataset +from tqdm import tqdm try: import umap @@ -141,7 +137,9 @@ def add_thumbnails( zorder=3, ) except Exception as exc: # pragma: no cover - print(f"Failed to attach thumbnail for object {sample.get('object_id')}: {exc}") + print( + f"Failed to attach thumbnail for object {sample.get('object_id')}: {exc}" + ) def render_spectrum(ax: plt.Axes, sample: dict) -> None: @@ -160,7 +158,9 @@ def render_spectrum(ax: plt.Axes, sample: dict) -> None: if z is not None: for name, line_rest in REST_LINES.items(): if rest_wave.min() <= line_rest <= rest_wave.max(): - ax.axvline(line_rest, color="red", linestyle="--", alpha=0.6, linewidth=0.6) + ax.axvline( + line_rest, color="red", linestyle="--", alpha=0.6, linewidth=0.6 + ) ymax = ax.get_ylim()[1] ax.text( line_rest, @@ -171,11 +171,21 @@ def render_spectrum(ax: plt.Axes, sample: dict) -> None: ha="center", fontsize=6, color="black", - bbox=dict(facecolor="white", alpha=0.8, edgecolor="none", pad=1.0), + bbox=dict( + facecolor="white", alpha=0.8, edgecolor="none", pad=1.0 + ), ) ax.set_xlim(rest_wave.min(), rest_wave.max()) else: - ax.text(0.5, 0.5, "No spectrum", ha="center", va="center", fontsize=6, transform=ax.transAxes) + ax.text( + 0.5, + 0.5, + "No spectrum", + ha="center", + va="center", + fontsize=6, + transform=ax.transAxes, + ) ax.set_xticks([]) ax.set_yticks([]) @@ -191,8 +201,12 @@ def main(argv: Sequence[str] | None = None) -> None: choices=["embedding_hsc_desi", "embedding_hsc", "embedding_spectrum"], help="Embedding field to visualize", ) - parser.add_argument("--figure", required=True, help="Output image path for thumbnails") - parser.add_argument("--figure-spectrum", default=None, help="Optional path for spectrum-grid figure") + parser.add_argument( + "--figure", required=True, help="Output image path for thumbnails" + ) + parser.add_argument( + "--figure-spectrum", default=None, help="Optional path for spectrum-grid figure" + ) parser.add_argument( "--split", type=str, @@ -204,17 +218,38 @@ def main(argv: Sequence[str] | None = None) -> None: type=str, default="/n03data/ronceray/datasets", ) - parser.add_argument("--index", type=str, default=None, help="Optional CSV mapping object_id -> split/index") - parser.add_argument("--grid-rows", type=int, default=12, help="Number of rows in the thumbnail grid") - parser.add_argument("--grid-cols", type=int, default=12, help="Number of columns in the thumbnail grid") - parser.add_argument("--random-state", type=int, default=42, help="Random state for UMAP and sampling") + parser.add_argument( + "--index", + type=str, + default=None, + help="Optional CSV mapping object_id -> split/index", + ) + parser.add_argument( + "--grid-rows", type=int, default=12, help="Number of rows in the thumbnail grid" + ) + parser.add_argument( + "--grid-cols", + type=int, + default=12, + help="Number of columns in the thumbnail grid", + ) + parser.add_argument( + "--random-state", + type=int, + default=42, + help="Random state for UMAP and sampling", + ) parser.add_argument("--dpi", type=int, default=450, help="Output resolution in DPI") - parser.add_argument("--point-size", type=float, default=6.0, help="Scatter point size") + parser.add_argument( + "--point-size", type=float, default=6.0, help="Scatter point size" + ) parser.add_argument("--alpha", type=float, default=0.35, help="Scatter alpha") args = parser.parse_args(argv) records = load_records(Path(args.input)) - coords = compute_umap(stack_embeddings(records, args.embedding_key), args.random_state) + coords = compute_umap( + stack_embeddings(records, args.embedding_key), args.random_state + ) object_ids = [_to_str_id(rec.get("object_id", "")) for rec in records] @@ -246,7 +281,14 @@ def main(argv: Sequence[str] | None = None) -> None: fig, ax = plt.subplots(figsize=(args.grid_cols * 1.5, args.grid_rows * 1.5)) if np.isnan(redshifts).all(): - ax.scatter(coords_norm[:, 0] * args.grid_cols, coords_norm[:, 1] * args.grid_rows, s=args.point_size, alpha=args.alpha, color="slateblue", zorder=1) + ax.scatter( + coords_norm[:, 0] * args.grid_cols, + coords_norm[:, 1] * args.grid_rows, + s=args.point_size, + alpha=args.alpha, + color="slateblue", + zorder=1, + ) else: mask = ~np.isnan(redshifts) scatter = ax.scatter( @@ -294,12 +336,18 @@ def main(argv: Sequence[str] | None = None) -> None: retrieved_ids = {_to_str_id(sample.get("object_id")) for sample in samples} missing = [oid for oid in thumb_ids if oid not in retrieved_ids] if missing: - dataset = EuclidDESIDataset(split=args.split, cache_dir=args.cache_dir, verbose=True) + dataset = EuclidDESIDataset( + split=args.split, cache_dir=args.cache_dir, verbose=True + ) samples.extend(collect_samples(dataset, missing, verbose=True)) else: - dataset = EuclidDESIDataset(split=args.split, cache_dir=args.cache_dir, verbose=True) + dataset = EuclidDESIDataset( + split=args.split, cache_dir=args.cache_dir, verbose=True + ) samples = collect_samples(dataset, thumb_ids, verbose=True) - id_to_sample = {_to_str_id(sample.get("object_id")): sample for sample in samples} + id_to_sample = { + _to_str_id(sample.get("object_id")): sample for sample in samples + } for oid, cell in zip(thumb_ids, cell_positions): sample = id_to_sample.get(oid) if sample is not None: @@ -316,7 +364,10 @@ def main(argv: Sequence[str] | None = None) -> None: if args.figure_spectrum: if ordered_samples: - grid_pairs = [(gx, gy, sample) for (gx, gy), sample in zip(cell_positions, ordered_samples)] + grid_pairs = [ + (gx, gy, sample) + for (gx, gy), sample in zip(cell_positions, ordered_samples) + ] fig_spec, axes = plt.subplots( args.grid_rows, args.grid_cols, @@ -325,7 +376,9 @@ def main(argv: Sequence[str] | None = None) -> None: sharey=True, ) axes = np.atleast_2d(axes) - for gx, gy, sample in tqdm(grid_pairs, desc="Rendering spectra", unit="spec"): + for gx, gy, sample in tqdm( + grid_pairs, desc="Rendering spectra", unit="spec" + ): ax_spec = axes[gy, gx] render_spectrum(ax_spec, sample) for gy in range(args.grid_rows): diff --git a/src/fmb/viz/visualize_multimodel_embeddings.py b/src/fmb/viz/visualize_multimodel_embeddings.py index 6643039..638c4f7 100644 --- a/src/fmb/viz/visualize_multimodel_embeddings.py +++ b/src/fmb/viz/visualize_multimodel_embeddings.py @@ -28,13 +28,13 @@ --random-state 42 """ import argparse +import math from pathlib import Path from typing import Sequence import matplotlib.pyplot as plt import numpy as np import torch -import math try: from astropy.io import fits @@ -113,35 +113,41 @@ def load_fits_catalog(path: Path) -> tuple[dict, list[str], str]: columns = hdul[1].columns.names catalog_dict = {} id_column = None - for priority_col in ['TARGETID', 'targetid', 'TargetID']: + for priority_col in ["TARGETID", "targetid", "TargetID"]: if priority_col in columns: id_column = priority_col break if id_column is None: for col in columns: - if col.lower() in ['object_id', 'objid', 'id']: + if col.lower() in ["object_id", "objid", "id"]: id_column = col break if id_column is None: - raise ValueError(f"Could not find object ID column in FITS. Available columns: {columns}") - + raise ValueError( + f"Could not find object ID column in FITS. Available columns: {columns}" + ) + print(f"Using '{id_column}' as object ID column") - + for row in data: obj_id = str(row[id_column]) catalog_dict[obj_id] = {col: row[col] for col in columns} - + numeric_columns = [] for col in columns: - if col == id_column: continue + if col == id_column: + continue try: col_format = hdul[1].columns[col].format - if any(fmt in col_format.upper() for fmt in ['E', 'D', 'I', 'J', 'K', 'F']): + if any( + fmt in col_format.upper() for fmt in ["E", "D", "I", "J", "K", "F"] + ): numeric_columns.append(col) continue # sample check can be skipped for brevity or kept if robustness needed - except: pass - + except: + pass + print(f"Found {len(numeric_columns)} numeric physical parameters") return catalog_dict, numeric_columns, id_column @@ -152,39 +158,43 @@ def stack_embeddings_with_joint(records: Sequence[dict], key: str) -> np.ndarray # Key might be prefixed, e.g. "astropt_embedding_joint". # But records have keys "astropt_embedding_joint" directly merged? # Yes, we merge with prefixes. - + for rec in records: tensor = rec.get(key) # Fallback: if key is "astropt_embedding_joint" but it's missing, try to construct from existing fields - # BUT `merge_embedding_records` handles the merging and prefixing. + # BUT `merge_embedding_records` handles the merging and prefixing. # So we should rely on keys simply being present. - - # Exception: "joint" construction from components. + + # Exception: "joint" construction from components. # If the merged record has "astropt_embedding_images" and "astropt_embedding_spectra", we can built joint. if tensor is None and "embedding_joint" in key: - prefix = key.replace("embedding_joint", "") # "astropt_" + prefix = key.replace("embedding_joint", "") # "astropt_" img_key = f"{prefix}embedding_images" spec_key = f"{prefix}embedding_spectra" - + img = rec.get(img_key) spec = rec.get(spec_key) - + if img is not None and spec is not None: - if isinstance(img, torch.Tensor): img = img.detach().cpu().numpy() - else: img = np.asarray(img) - if isinstance(spec, torch.Tensor): spec = spec.detach().cpu().numpy() - else: spec = np.asarray(spec) + if isinstance(img, torch.Tensor): + img = img.detach().cpu().numpy() + else: + img = np.asarray(img) + if isinstance(spec, torch.Tensor): + spec = spec.detach().cpu().numpy() + else: + spec = np.asarray(spec) tensor = np.concatenate([img, spec]) - + if tensor is None: continue - + if isinstance(tensor, torch.Tensor): vectors.append(tensor.detach().cpu().numpy()) else: vectors.append(np.asarray(tensor)) - + if not vectors: raise ValueError(f"No embeddings found for key '{key}'") return np.stack(vectors, axis=0) @@ -205,7 +215,9 @@ def load_umap_coordinates(load_path: Path) -> dict[str, np.ndarray]: return coords_map -def compute_umap(embeddings: np.ndarray, random_state: int, preset: str = "balanced") -> np.ndarray: +def compute_umap( + embeddings: np.ndarray, random_state: int, preset: str = "balanced" +) -> np.ndarray: config = UMAP_PRESETS[preset] reducer = umap.UMAP( random_state=random_state, @@ -231,60 +243,66 @@ def merge_embedding_records( aion_dict = {str(r.get("object_id", "")): r for r in aion_records} astropt_dict = {str(r.get("object_id", "")): r for r in astropt_records} astroclip_dict = {str(r.get("object_id", "")): r for r in astroclip_records} - - all_ids = set(aion_dict.keys()) | set(astropt_dict.keys()) | set(astroclip_dict.keys()) + + all_ids = ( + set(aion_dict.keys()) | set(astropt_dict.keys()) | set(astroclip_dict.keys()) + ) all_ids.discard("") - + merged_records = [] physical_values = [] valid_ids = [] - + for obj_id in sorted(all_ids): # We start with empty dict, NOT reusing aion_dict content directly to strictly control keys. # But for AION we can keep original keys or prefix them? - # AION keys (hsc_desi) are generally unique. + # AION keys (hsc_desi) are generally unique. # But let's standardize: "aion_" vs "astropt_" vs "astroclip_" # To maintain compatibility with existing cache consumer `plot_paper_combined_umap.py` lines: # KEY_AION_CACHE = "aion_embedding_hsc_desi" # KEY_ASTROPT_CACHE = "astropt_embedding_joint" # KEY_ASTROCLIP_CACHE = "astroclip_embedding_joint" - + merged_rec = {"object_id": obj_id} - + # Merge AION if obj_id in aion_dict: for k in AION_EMBEDDING_KEYS: if k in aion_dict[obj_id]: # Standardize prefix "aion_" if strictly following updated plan - # OR check if keys already overlap. + # OR check if keys already overlap. merged_rec[f"aion_{k}"] = aion_dict[obj_id][k] - + # Merge AstroPT if obj_id in astropt_dict: for k in ASTROPT_EMBEDDING_KEYS: if k in astropt_dict[obj_id]: merged_rec[f"astropt_{k}"] = astropt_dict[obj_id][k] - # Merge AstroCLIP + # Merge AstroCLIP if obj_id in astroclip_dict: for k in ASTROCLIP_EMBEDDING_KEYS: if k in astroclip_dict[obj_id]: merged_rec[f"astroclip_{k}"] = astroclip_dict[obj_id][k] - + # Get physical parameter phys_val = np.nan if obj_id in catalog: try: raw_val = catalog[obj_id][physical_param] - if hasattr(raw_val, 'item'): phys_val = float(raw_val.item()) - else: phys_val = float(raw_val) - if np.isnan(phys_val) or np.isinf(phys_val): phys_val = np.nan - except: pass - + if hasattr(raw_val, "item"): + phys_val = float(raw_val.item()) + else: + phys_val = float(raw_val) + if np.isnan(phys_val) or np.isinf(phys_val): + phys_val = np.nan + except: + pass + merged_records.append(merged_rec) physical_values.append(phys_val) valid_ids.append(obj_id) - + return merged_records, np.array(physical_values), valid_ids @@ -297,30 +315,38 @@ def plot_umap_grid( # Sort keys for consistent display # Group by model # Prefer order: AION -> AstroPT -> AstroCLIP and within that: Image -> Spec -> Joint - + # helper for sorting def sort_key(k): score = 0 - if "aion" in k: score += 100 - elif "astropt" in k: score += 200 - elif "astroclip" in k: score += 300 - - if "image" in k or "hsc" in k: score += 1 - elif "spectr" in k: score += 2 # spectrum or spectra - elif "joint" in k or "hsc_desi" in k: score += 3 + if "aion" in k: + score += 100 + elif "astropt" in k: + score += 200 + elif "astroclip" in k: + score += 300 + + if "image" in k or "hsc" in k: + score += 1 + elif "spectr" in k: + score += 2 # spectrum or spectra + elif "joint" in k or "hsc_desi" in k: + score += 3 return score names = sorted(list(coords_map.keys()), key=sort_key) n_plots = len(names) - + # Flexible Grid - cols = 3 + cols = 3 rows = math.ceil(n_plots / cols) - + fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) - if n_plots == 1: axes = [axes] - else: axes = axes.flatten() - + if n_plots == 1: + axes = [axes] + else: + axes = axes.flatten() + # Colors valid_mask = ~np.isnan(colors) vmin, vmax = 0, 1 @@ -328,13 +354,14 @@ def sort_key(k): valid_colors = colors[valid_mask] vmin = np.percentile(valid_colors, 2) vmax = np.percentile(valid_colors, 98) - if vmax - vmin < 1e-6: vmin, vmax = valid_colors.min(), valid_colors.max() + if vmax - vmin < 1e-6: + vmin, vmax = valid_colors.min(), valid_colors.max() print(f" 🎨 Color scale: [{vmin:.3f}, {vmax:.3f}]") for i, name in enumerate(names): ax = axes[i] coords = coords_map[name] - + if valid_mask.any(): scatter = ax.scatter( coords[valid_mask, 0], @@ -344,27 +371,36 @@ def sort_key(k): s=10, alpha=0.7, edgecolors="none", - vmin=vmin, vmax=vmax, + vmin=vmin, + vmax=vmax, ) cbar = plt.colorbar(scatter, ax=ax) cbar.set_label(param_name, fontsize=8) - + if (~valid_mask).any(): - ax.scatter(coords[~valid_mask, 0], coords[~valid_mask, 1], s=10, - color="lightgray", alpha=0.3, label="N/A") - + ax.scatter( + coords[~valid_mask, 0], + coords[~valid_mask, 1], + s=10, + color="lightgray", + alpha=0.3, + label="N/A", + ) + pretty = name.replace("embedding_", "").replace("_", " ").title() ax.set_title(pretty, fontsize=11, fontweight="bold") ax.set_xticks([]) ax.set_yticks([]) ax.grid(True, alpha=0.2) - + # Hide unused for j in range(i + 1, len(axes)): axes[j].axis("off") - - fig.suptitle(f"Multi-Model UMAP Projections: {param_name}", fontsize=16, fontweight="bold") - fig.tight_layout(rect=[0, 0, 1, 0.97]) # space for tile + + fig.suptitle( + f"Multi-Model UMAP Projections: {param_name}", fontsize=16, fontweight="bold" + ) + fig.tight_layout(rect=[0, 0, 1, 0.97]) # space for tile save_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(save_path, dpi=200, bbox_inches="tight") plt.close(fig) @@ -372,60 +408,87 @@ def sort_key(k): def main(argv: Sequence[str] | None = None) -> None: - parser = argparse.ArgumentParser(description="Multi-Model UMAP Visualization (AION/AstroPT/AstroCLIP)") - parser.add_argument("--aion-embeddings", required=True, help="Path to AION embeddings .pt") - parser.add_argument("--astropt-embeddings", required=True, help="Path to AstroPT embeddings .pt") - parser.add_argument("--astroclip-embeddings", required=True, help="Path to AstroCLIP embeddings .pt") + parser = argparse.ArgumentParser( + description="Multi-Model UMAP Visualization (AION/AstroPT/AstroCLIP)" + ) + parser.add_argument( + "--aion-embeddings", required=True, help="Path to AION embeddings .pt" + ) + parser.add_argument( + "--astropt-embeddings", required=True, help="Path to AstroPT embeddings .pt" + ) + parser.add_argument( + "--astroclip-embeddings", required=True, help="Path to AstroCLIP embeddings .pt" + ) parser.add_argument("--catalog", required=True, help="Path to FITS catalog") parser.add_argument("--output-dir", required=True, help="Output directory") parser.add_argument("--physical-param", help="Physical parameter to color by") - parser.add_argument("--all-params", action="store_true", help="Visualize all numeric params") + parser.add_argument( + "--all-params", action="store_true", help="Visualize all numeric params" + ) parser.add_argument("--random-state", type=int, default=42) parser.add_argument("--umap-cache", help="Path to cache file") - parser.add_argument("--umap-preset", default="balanced", choices=list(UMAP_PRESETS.keys())) - + parser.add_argument( + "--umap-preset", default="balanced", choices=list(UMAP_PRESETS.keys()) + ) + args = parser.parse_args(argv) - + print("Loading data...") aion_recs = load_embeddings(Path(args.aion_embeddings)) astropt_recs = load_embeddings(Path(args.astropt_embeddings)) astroclip_recs = load_embeddings(Path(args.astroclip_embeddings)) catalog, numeric_cols, _ = load_fits_catalog(Path(args.catalog)) - - params_to_viz = numeric_cols if args.all_params else ([args.physical_param] if args.physical_param else []) + + params_to_viz = ( + numeric_cols + if args.all_params + else ([args.physical_param] if args.physical_param else []) + ) if not params_to_viz: - print("No parameters selected to visualize (use --physical-param or --all-params).") + print( + "No parameters selected to visualize (use --physical-param or --all-params)." + ) return # Compute UMAPs (Merged dummy first) print("Merging records and computing UMAPs...") - all_recs, _, _ = merge_embedding_records(aion_recs, astropt_recs, astroclip_recs, catalog, params_to_viz[0]) - + all_recs, _, _ = merge_embedding_records( + aion_recs, astropt_recs, astroclip_recs, catalog, params_to_viz[0] + ) + coords_map = {} if args.umap_cache and Path(args.umap_cache).exists(): - coords_map = load_umap_coordinates(Path(args.umap_cache)) + coords_map = load_umap_coordinates(Path(args.umap_cache)) else: - # Need to find all valid keys in merged records (that are embeddings) - sample_keys = [k for k in all_recs[0].keys() if k not in ["object_id", "redshift"]] - for key in sample_keys: - try: + # Need to find all valid keys in merged records (that are embeddings) + sample_keys = [ + k for k in all_recs[0].keys() if k not in ["object_id", "redshift"] + ] + for key in sample_keys: + try: embeddings = stack_embeddings_with_joint(all_recs, key) print(f" Computing UMAP for {key}...") - coords_map[key] = compute_umap(embeddings, args.random_state, args.umap_preset) - except ValueError: - print(f" Skipping {key} (no data)") - - if args.umap_cache: - save_umap_coordinates(coords_map, Path(args.umap_cache)) - + coords_map[key] = compute_umap( + embeddings, args.random_state, args.umap_preset + ) + except ValueError: + print(f" Skipping {key} (no data)") + + if args.umap_cache: + save_umap_coordinates(coords_map, Path(args.umap_cache)) + # Plotting for param in params_to_viz: print(f"Plotting {param}...") - _, values, _ = merge_embedding_records(aion_recs, astropt_recs, astroclip_recs, catalog, param) + _, values, _ = merge_embedding_records( + aion_recs, astropt_recs, astroclip_recs, catalog, param + ) save_path = Path(args.output_dir) / f"umap_grid_{param}.png" plot_umap_grid(coords_map, values, param, save_path) - + print("Done.") + if __name__ == "__main__": main()