From bd06e69b3e4e21f6c1dac3b5247974f4361ec6d8 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Feb 2026 03:29:24 +0000 Subject: [PATCH 1/7] Add wandb compatibility layer for drop-in migration Adds pluto.compat.wandb module that lets users replace `import wandb` with `import pluto.compat.wandb as wandb` to route all logging through pluto with minimal code changes. Module structure: - __init__.py: Module-level API (init, log, finish, watch, config, summary, run) - run.py: Run class wrapping pluto.Op with wandb.Run-compatible interface - config.py: Dict-like Config object that syncs mutations to pluto - summary.py: Dict-like Summary object tracking last-logged values - data_types.py: Wrappers (Image, Audio, Video, Table, Histogram, Html, Artifact, AlertLevel) that convert to pluto equivalents Key features: - commit=False buffering (accumulate data across log calls) - Nested dict flattening with / separator - wandb env var fallbacks (WANDB_PROJECT, WANDB_MODE, WANDB_TAGS, etc.) - Graceful degradation for unsupported features (define_metric, save, etc.) - Context manager support - Disabled-mode fallback when pluto.init() fails https://claude.ai/code/session_01VTSZKK5UsMqjiADFX57SMY --- pluto/compat/wandb/__init__.py | 467 +++++++++++++++++ pluto/compat/wandb/config.py | 129 +++++ pluto/compat/wandb/data_types.py | 311 ++++++++++++ pluto/compat/wandb/run.py | 398 +++++++++++++++ pluto/compat/wandb/summary.py | 87 ++++ tests/test_wandb_compat.py | 832 +++++++++++++++++++++++++++++++ 6 files changed, 2224 insertions(+) create mode 100644 pluto/compat/wandb/__init__.py create mode 100644 pluto/compat/wandb/config.py create mode 100644 pluto/compat/wandb/data_types.py create mode 100644 pluto/compat/wandb/run.py create mode 100644 pluto/compat/wandb/summary.py create mode 100644 tests/test_wandb_compat.py diff --git a/pluto/compat/wandb/__init__.py b/pluto/compat/wandb/__init__.py new file mode 100644 index 0000000..de51a5b --- /dev/null +++ b/pluto/compat/wandb/__init__.py @@ -0,0 +1,467 @@ +"""wandb-compatible drop-in replacement backed by pluto. + +Usage: + Replace ``import wandb`` with ``import pluto.compat.wandb as wandb`` + and your existing wandb code will route through pluto. + +Supported API: + - wandb.init(), wandb.log(), wandb.finish() + - wandb.watch(), wandb.unwatch() + - wandb.config, wandb.summary, wandb.run + - wandb.alert(), wandb.define_metric() + - wandb.Image, wandb.Audio, wandb.Video, wandb.Table, wandb.Histogram + - wandb.Html, wandb.Graph, wandb.Artifact, wandb.AlertLevel + - Context manager: ``with wandb.init() as run: ...`` +""" + +import logging +import os +from typing import Any, Dict, List, Optional, Sequence, Union + +from .config import Config +from .data_types import ( + AlertLevel, + Artifact, + Audio, + Graph, + Histogram, + Html, + Image, + Table, + Video, +) +from .run import Run +from .summary import Summary + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat' + +__all__ = [ + # Core API + 'init', + 'log', + 'finish', + 'watch', + 'unwatch', + 'alert', + 'define_metric', + 'save', + 'restore', + 'login', + 'log_artifact', + 'use_artifact', + 'log_code', + 'mark_preempting', + # Module-level state + 'run', + 'config', + 'summary', + # Classes + 'Run', + 'Config', + 'Settings', + 'AlertLevel', + # Data types + 'Image', + 'Audio', + 'Video', + 'Table', + 'Histogram', + 'Html', + 'Graph', + 'Artifact', +] + +# --------------------------------------------------------------------------- +# Module-level state (mirrors wandb.run, wandb.config, wandb.summary) +# --------------------------------------------------------------------------- + +run: Optional[Run] = None +config: Config = Config() +summary: Summary = Summary() + + +class Settings: + """Stub for wandb.Settings — accepts kwargs and stores them.""" + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) + + +# --------------------------------------------------------------------------- +# Core functions +# --------------------------------------------------------------------------- + + +def init( + entity: Optional[str] = None, + project: Optional[str] = None, + dir: Optional[str] = None, + id: Optional[str] = None, + name: Optional[str] = None, + notes: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + config: Union[Dict[str, Any], str, None] = None, + config_exclude_keys: Optional[List[str]] = None, + config_include_keys: Optional[List[str]] = None, + allow_val_change: Optional[bool] = None, + group: Optional[str] = None, + job_type: Optional[str] = None, + mode: Optional[str] = None, + force: Optional[bool] = None, + anonymous: Optional[str] = None, + reinit: Optional[Union[bool, str]] = None, + resume: Optional[Union[bool, str]] = None, + resume_from: Optional[str] = None, + fork_from: Optional[str] = None, + save_code: Optional[bool] = None, + tensorboard: Optional[bool] = None, + sync_tensorboard: Optional[bool] = None, + monitor_gym: Optional[bool] = None, + settings: Optional[Any] = None, + **kwargs: Any, +) -> Run: + """Initialize a new run. Compatible with ``wandb.init()``.""" + import pluto as _pluto + + global run + global summary + + # If reinit, finish the previous run first + if run is not None: + if reinit: + try: + run.finish() + except Exception: + pass + else: + logger.debug('%s: init called with existing run, finishing previous', tag) + try: + run.finish() + except Exception: + pass + + # Resolve project from env if not provided + project = ( + project or os.environ.get('WANDB_PROJECT') or os.environ.get('PLUTO_PROJECT') + ) + + # Resolve mode + mode = mode or os.environ.get('WANDB_MODE', 'online') + if os.environ.get('WANDB_DISABLED', '').lower() in ('true', '1'): + mode = 'disabled' + + # Resolve name from env if not provided + name = name or os.environ.get('WANDB_NAME') + + # Resolve tags from env if not provided + if tags is None: + env_tags = os.environ.get('WANDB_TAGS') + if env_tags: + tags = [t.strip() for t in env_tags.split(',') if t.strip()] + + # Filter config keys if requested + config_dict: Optional[Dict[str, Any]] = None + if config is not None: + if isinstance(config, dict): + config_dict = dict(config) + elif hasattr(config, '__dict__'): + config_dict = vars(config) + else: + config_dict = {} + + if config_dict and config_include_keys: + config_dict = { + k: v for k, v in config_dict.items() if k in config_include_keys + } + if config_dict and config_exclude_keys: + config_dict = { + k: v for k, v in config_dict.items() if k not in config_exclude_keys + } + + # Build pluto settings + pluto_settings: Dict[str, Any] = {} + if mode == 'disabled': + pluto_settings['mode'] = 'noop' + elif mode == 'offline': + pluto_settings['sync_process_enabled'] = False + + # Map wandb run_id / resume + run_id = id + if resume in ('allow', 'must', 'auto', True) and id: + run_id = id + + # Store wandb-only metadata in config + extra_config: Dict[str, Any] = {} + if notes: + extra_config['_wandb_notes'] = notes + if group: + extra_config['_wandb_group'] = group + if job_type: + extra_config['_wandb_job_type'] = job_type + + merged_config = {**(config_dict or {}), **extra_config} or None + + # Initialize pluto + try: + op = _pluto.init( + project=project, + name=name, + config=merged_config, + tags=list(tags) if tags else None, + dir=dir, + settings=pluto_settings or None, + run_id=run_id, + ) + except Exception as e: + logger.warning('%s: pluto.init() failed (%s), creating disabled run', tag, e) + # Return a disabled run that no-ops everything + return _create_disabled_run( + name=name, + notes=notes, + group=group, + job_type=job_type, + config_dict=config_dict, + ) + + # Create the Run wrapper + _run = Run( + op=op, + name=name, + notes=notes, + group=group, + job_type=job_type, + mode=mode or 'online', + ) + + # Load config into the Config object + _run.config._load(config_dict) + + # Set module-level state + run = _run + + # Replace module-level config and summary proxies + _module = _get_module() + _module.config = _run.config + _module.summary = _run.summary + + return _run + + +def log( + data: Dict[str, Any], + step: Optional[int] = None, + commit: Optional[bool] = None, + sync: Optional[bool] = None, +) -> None: + """Log metrics. Compatible with ``wandb.log()``.""" + _require_run('log') + assert run is not None + run.log(data, step=step, commit=commit, sync=sync) + + +def finish( + exit_code: Optional[int] = None, + quiet: Optional[bool] = None, +) -> None: + """Finish the current run. Compatible with ``wandb.finish()``.""" + global run + + if run is not None: + run.finish(exit_code=exit_code, quiet=quiet) + run = None + + # Reset module-level proxies + _module = _get_module() + _module.config = Config() + _module.summary = Summary() + + +def watch( + models: Any = None, + criterion: Any = None, + log: Optional[str] = 'gradients', + log_freq: int = 1000, + idx: Optional[int] = None, + log_graph: bool = False, +) -> None: + """Watch a model. Compatible with ``wandb.watch()``.""" + _require_run('watch') + assert run is not None + run.watch( + models=models, + criterion=criterion, + log=log, + log_freq=log_freq, + idx=idx, + log_graph=log_graph, + ) + + +def unwatch(models: Any = None) -> None: + """Remove model watch hooks. Compatible with ``wandb.unwatch()``.""" + if run is not None: + run.unwatch(models) + + +def alert( + title: str = '', + text: str = '', + level: Optional[str] = None, + wait_duration: Optional[Union[int, float]] = None, +) -> None: + """Send an alert. Compatible with ``wandb.alert()``.""" + _require_run('alert') + assert run is not None + run.alert(title=title, text=text, level=level, wait_duration=wait_duration) + + +def define_metric( + name: str, + step_metric: Optional[str] = None, + step_sync: Optional[bool] = None, + hidden: Optional[bool] = None, + summary: Optional[str] = None, + goal: Optional[str] = None, + overwrite: Optional[bool] = None, +) -> Any: + """Define metric behavior. No-op in pluto compat layer.""" + if run is not None: + return run.define_metric( + name, + step_metric=step_metric, + step_sync=step_sync, + hidden=hidden, + summary=summary, + goal=goal, + overwrite=overwrite, + ) + from .run import _MetricStub + + return _MetricStub(name) + + +def save( + glob_str: Optional[str] = None, + base_path: Optional[str] = None, + policy: str = 'live', +) -> None: + """Sync files. Not supported — no-op.""" + logger.debug('%s: save is not supported', tag) + + +def restore( + name: str = '', + run_path: Optional[str] = None, + replace: bool = False, + root: Optional[str] = None, +) -> None: + """Restore a file. Not supported — no-op.""" + logger.debug('%s: restore is not supported', tag) + + +def log_artifact( + artifact_or_path: Any, + name: Optional[str] = None, + type: Optional[str] = None, + aliases: Optional[List[str]] = None, +) -> Any: + """Log an artifact. Compatible with ``wandb.log_artifact()``.""" + if run is not None: + return run.log_artifact(artifact_or_path, name=name, type=type, aliases=aliases) + logger.debug('%s: log_artifact called without active run', tag) + return artifact_or_path + + +def use_artifact( + artifact_or_name: Any, + type: Optional[str] = None, +) -> Any: + """Declare artifact as input. Not supported — no-op.""" + logger.debug('%s: use_artifact is not supported', tag) + return artifact_or_name + + +def log_code( + root: Optional[str] = None, + name: Optional[str] = None, + include_fn: Any = None, + exclude_fn: Any = None, +) -> None: + """Save source code. Not supported — no-op.""" + logger.debug('%s: log_code is not supported', tag) + + +def mark_preempting() -> None: + """Mark run as preempted. No-op.""" + logger.debug('%s: mark_preempting is a no-op', tag) + + +def login( + anonymous: Optional[str] = None, + key: Optional[str] = None, + relogin: Optional[bool] = None, + host: Optional[str] = None, + force: Optional[bool] = None, + timeout: Optional[int] = None, + verify: bool = False, +) -> bool: + """Login stub. Pluto uses its own auth (``pluto login``). + + Returns True to indicate "logged in" so callers don't block. + """ + logger.debug('%s: login is a no-op (use pluto login)', tag) + return True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _require_run(fn_name: str) -> None: + if run is None: + raise RuntimeError( + f'wandb.{fn_name}() called before wandb.init(). Call wandb.init() first.' + ) + + +def _get_module() -> Any: + """Return this module object for setting module-level attributes.""" + import sys + + return sys.modules[__name__] + + +def _create_disabled_run( + name: Optional[str] = None, + notes: Optional[str] = None, + group: Optional[str] = None, + job_type: Optional[str] = None, + config_dict: Optional[Dict[str, Any]] = None, +) -> Run: + """Create a Run that wraps a no-op Op for disabled/error cases.""" + global run + + import pluto as _pluto + + op = _pluto.init(project='disabled', settings={'mode': 'noop'}) + _run = Run( + op=op, + name=name, + notes=notes, + group=group, + job_type=job_type, + mode='disabled', + ) + if config_dict: + _run.config._load(config_dict) + + run = _run + + _module = _get_module() + _module.config = _run.config + _module.summary = _run.summary + + return _run diff --git a/pluto/compat/wandb/config.py b/pluto/compat/wandb/config.py new file mode 100644 index 0000000..a1bcc3a --- /dev/null +++ b/pluto/compat/wandb/config.py @@ -0,0 +1,129 @@ +"""wandb.config-compatible dict-like configuration object.""" + +import logging +from typing import Any, Dict, Iterator, Optional + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat.Config' + + +class Config: + """A dict-like configuration object compatible with wandb.config. + + Supports attribute access, dict access, and syncs mutations to pluto + via op.update_config(). + """ + + def __init__(self, op: Optional[Any] = None) -> None: + # Use object.__setattr__ to avoid triggering our __setattr__ + object.__setattr__(self, '_op', op) + object.__setattr__(self, '_data', {}) + object.__setattr__(self, '_allow_val_change', True) + + def _load(self, data: Optional[Dict[str, Any]]) -> None: + if data: + object.__getattribute__(self, '_data').update(data) + + def _sync(self, updates: Dict[str, Any]) -> None: + op = object.__getattribute__(self, '_op') + if op is not None: + try: + op.update_config(updates) + except Exception as e: + logger.debug('%s: failed to sync config: %s', tag, e) + + # -- Attribute access -- + + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith('_'): + object.__setattr__(self, key, value) + else: + self[key] = value + + def __getattr__(self, key: str) -> Any: + data = object.__getattribute__(self, '_data') + try: + return data[key] + except KeyError: + raise AttributeError(f"'Config' object has no attribute '{key}'") + + def __delattr__(self, key: str) -> None: + data = object.__getattribute__(self, '_data') + if key in data: + del data[key] + else: + raise AttributeError(f"'Config' object has no attribute '{key}'") + + # -- Dict access -- + + def __setitem__(self, key: str, value: Any) -> None: + data = object.__getattribute__(self, '_data') + data[key] = value + self._sync({key: value}) + + def __getitem__(self, key: str) -> Any: + return object.__getattribute__(self, '_data')[key] + + def __delitem__(self, key: str) -> None: + del object.__getattribute__(self, '_data')[key] + + def __contains__(self, key: object) -> bool: + return key in object.__getattribute__(self, '_data') + + def __iter__(self) -> Iterator[str]: + return iter(object.__getattribute__(self, '_data')) + + def __len__(self) -> int: + return len(object.__getattribute__(self, '_data')) + + def __repr__(self) -> str: + return repr(object.__getattribute__(self, '_data')) + + def __bool__(self) -> bool: + return bool(object.__getattribute__(self, '_data')) + + # -- Dict-like methods -- + + def keys(self): + return object.__getattribute__(self, '_data').keys() + + def values(self): + return object.__getattribute__(self, '_data').values() + + def items(self): + return object.__getattribute__(self, '_data').items() + + def get(self, key: str, default: Any = None) -> Any: + return object.__getattribute__(self, '_data').get(key, default) + + def update( + self, + d: Any = None, + allow_val_change: Optional[bool] = None, + **kwargs: Any, + ) -> None: + """Update config with a dict, namespace, or keyword arguments.""" + if d is not None: + # Support argparse namespaces + if hasattr(d, '__dict__') and not isinstance(d, dict): + d = vars(d) + if isinstance(d, dict): + object.__getattribute__(self, '_data').update(d) + self._sync(d) + if kwargs: + object.__getattribute__(self, '_data').update(kwargs) + self._sync(kwargs) + + def setdefaults(self, d: Dict[str, Any]) -> None: + """Set defaults — only sets keys that are not already present.""" + data = object.__getattribute__(self, '_data') + updates = {} + for k, v in d.items(): + if k not in data: + data[k] = v + updates[k] = v + if updates: + self._sync(updates) + + def as_dict(self) -> Dict[str, Any]: + return dict(object.__getattribute__(self, '_data')) diff --git a/pluto/compat/wandb/data_types.py b/pluto/compat/wandb/data_types.py new file mode 100644 index 0000000..32a9019 --- /dev/null +++ b/pluto/compat/wandb/data_types.py @@ -0,0 +1,311 @@ +"""wandb-compatible data type wrappers that convert to pluto equivalents. + +Each class accepts wandb-style constructor arguments and provides a +_to_pluto() method that returns the corresponding pluto type. +""" + +import logging +import os +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat.DataTypes' + + +class Image: + """wandb.Image-compatible wrapper. + + Accepts the same constructor args as wandb.Image and converts to + pluto.file.Image via _to_pluto(). + """ + + def __init__( + self, + data_or_path: Any = None, + mode: Optional[str] = None, + caption: Optional[str] = None, + grouping: Optional[int] = None, + classes: Any = None, + boxes: Any = None, + masks: Any = None, + file_type: Optional[str] = None, + normalize: bool = True, + ) -> None: + self.data_or_path = data_or_path + self.caption = caption + self._mode = mode + self._grouping = grouping + + def _to_pluto(self) -> Any: + from pluto.file import Image as PlutoImage + + return PlutoImage(data=self.data_or_path, caption=self.caption) + + +class Audio: + """wandb.Audio-compatible wrapper.""" + + def __init__( + self, + data_or_path: Any = None, + sample_rate: Optional[int] = None, + caption: Optional[str] = None, + ) -> None: + self.data_or_path = data_or_path + self.sample_rate = sample_rate + self.caption = caption + + def _to_pluto(self) -> Any: + from pluto.file import Audio as PlutoAudio + + return PlutoAudio( + data=self.data_or_path, + sample_rate=self.sample_rate, + caption=self.caption, + ) + + +class Video: + """wandb.Video-compatible wrapper.""" + + def __init__( + self, + data_or_path: Any = None, + caption: Optional[str] = None, + fps: Optional[int] = None, + format: Optional[str] = None, + ) -> None: + self.data_or_path = data_or_path + self.caption = caption + self.fps = fps + self.format = format + + def _to_pluto(self) -> Any: + from pluto.file import Video as PlutoVideo + + return PlutoVideo( + data=self.data_or_path, + fps=self.fps, + caption=self.caption, + format=self.format, + ) + + +class Table: + """wandb.Table-compatible wrapper.""" + + MAX_ROWS = 10_000 + MAX_ARTIFACT_ROWS = 200_000 + + def __init__( + self, + columns: Optional[List[str]] = None, + data: Optional[List[List[Any]]] = None, + rows: Optional[List[List[Any]]] = None, + dataframe: Any = None, + dtype: Any = None, + optional: Any = True, + allow_mixed_types: bool = False, + ) -> None: + self.columns = columns or [] + self._data: List[List[Any]] = data or rows or [] + self._dataframe = dataframe + + def add_data(self, *row: Any) -> None: + self._data.append(list(row)) + + def add_column(self, name: str, data: List[Any]) -> None: + self.columns.append(name) + for i, val in enumerate(data): + if i < len(self._data): + self._data[i].append(val) + else: + self._data.append([val]) + + def get_column(self, name: str, convert_to: Optional[str] = None) -> List[Any]: + if name in self.columns: + idx = self.columns.index(name) + return [row[idx] for row in self._data if idx < len(row)] + return [] + + def _to_pluto(self) -> Any: + from pluto.data import Table as PlutoTable + + if self._dataframe is not None: + return PlutoTable( + data=None, + dataframe=self._dataframe, + columns=self.columns or None, + ) + return PlutoTable( + data=self._data, + columns=self.columns or None, + ) + + +class Histogram: + """wandb.Histogram-compatible wrapper.""" + + def __init__( + self, + sequence: Optional[Any] = None, + np_histogram: Optional[Any] = None, + num_bins: int = 64, + ) -> None: + self.sequence = sequence + self.np_histogram = np_histogram + self.num_bins = num_bins + + def _to_pluto(self) -> Any: + from pluto.data import Histogram as PlutoHistogram + + if self.np_histogram is not None: + # np_histogram is a tuple of (values, bin_edges) + return PlutoHistogram(data=self.np_histogram, bins=self.np_histogram) + if self.sequence is not None: + return PlutoHistogram(data=self.sequence, bins=self.num_bins) + return PlutoHistogram(data=[0], bins=1) + + +class Html: + """wandb.Html-compatible wrapper. Maps to pluto.file.Text.""" + + def __init__( + self, + data: Any = None, + inject: bool = True, + ) -> None: + if hasattr(data, 'read'): + self._html = data.read() + elif isinstance(data, str) and os.path.isfile(data): + with open(data) as f: + self._html = f.read() + else: + self._html = str(data) if data is not None else '' + + def _to_pluto(self) -> Any: + from pluto.file import Text as PlutoText + + return PlutoText(data=self._html) + + +class Graph: + """wandb.Graph-compatible wrapper.""" + + def __init__(self) -> None: + self._nodes: List[Any] = [] + self._edges: List[Any] = [] + + def _to_pluto(self) -> Any: + from pluto.data import Graph as PlutoGraph + + data: Dict[str, Any] = {'nodes': {}, 'edges': {}} + for i, node in enumerate(self._nodes): + name = getattr(node, 'name', str(i)) + data['nodes'][name] = {} + for edge in self._edges: + src = edge[0] if isinstance(edge, (list, tuple)) else str(edge) + dst = edge[1] if isinstance(edge, (list, tuple)) else str(edge) + data['edges'][(src, dst)] = {} + return PlutoGraph(data=data) + + +class Artifact: + """wandb.Artifact-compatible wrapper for file collections. + + wandb's Artifact is a versioned collection of files. Pluto's Artifact is + a single file. This wrapper collects files and logs each individually + when log_artifact() is called. + """ + + def __init__( + self, + name: str, + type: str, + description: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + incremental: bool = False, + use_as: Optional[str] = None, + ) -> None: + self.name = name + self.type = type + self.description = description + self.metadata = metadata or {} + self._files: List[Dict[str, Any]] = [] + + def add_file( + self, + local_path: str, + name: Optional[str] = None, + is_tmp: bool = False, + skip_cache: bool = False, + policy: str = 'mutable', + ) -> 'Artifact': + self._files.append( + { + 'path': local_path, + 'name': name or os.path.basename(local_path), + } + ) + return self + + def add_dir( + self, + local_path: str, + name: Optional[str] = None, + skip_cache: bool = False, + policy: str = 'mutable', + ) -> 'Artifact': + for root, _dirs, files in os.walk(local_path): + for f in files: + fpath = os.path.join(root, f) + rel = os.path.relpath(fpath, local_path) + if name: + rel = os.path.join(name, rel) + self._files.append({'path': fpath, 'name': rel}) + return self + + def add_reference( + self, + uri: str, + name: Optional[str] = None, + checksum: bool = True, + max_objects: Optional[int] = None, + ) -> 'Artifact': + logger.debug('%s: add_reference is not supported, ignoring', tag) + return self + + def _to_pluto_files(self) -> List[Any]: + """Convert collected files to pluto Artifact objects.""" + from pluto.file import Artifact as PlutoArtifact + + result = [] + for entry in self._files: + result.append( + PlutoArtifact( + data=entry['path'], + caption=entry['name'], + metadata=self.metadata, + ) + ) + return result + + # Stubs for download/verify (not supported) + def download(self, root: Optional[str] = None, **kwargs: Any) -> str: + logger.debug('%s: download is not supported', tag) + return root or '.' + + def verify(self, root: Optional[str] = None) -> bool: + logger.debug('%s: verify is not supported', tag) + return True + + def new_draft(self) -> 'Artifact': + return self + + +class AlertLevel: + """wandb.AlertLevel-compatible enum.""" + + INFO = 'INFO' + WARN = 'WARN' + ERROR = 'ERROR' diff --git a/pluto/compat/wandb/run.py b/pluto/compat/wandb/run.py new file mode 100644 index 0000000..449cc3c --- /dev/null +++ b/pluto/compat/wandb/run.py @@ -0,0 +1,398 @@ +"""wandb.Run-compatible wrapper around pluto.Op.""" + +import logging +from typing import Any, Dict, List, Optional, Sequence, Union + +from .config import Config +from .data_types import Artifact, Audio, Graph, Histogram, Html, Image, Table, Video +from .summary import Summary + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat.Run' + +# Sentinel for data types that have _to_pluto() +_WANDB_DATA_TYPES = (Image, Audio, Video, Table, Histogram, Html, Graph) + + +def _convert_value(v: Any) -> Any: + """Convert a wandb data type wrapper to its pluto equivalent.""" + if isinstance(v, _WANDB_DATA_TYPES): + return v._to_pluto() + return v + + +def _flatten_dict(d: Dict[str, Any], prefix: str = '') -> Dict[str, Any]: + """Flatten nested dicts using '/' separator (wandb convention).""" + flat: Dict[str, Any] = {} + for k, v in d.items(): + key = f'{prefix}/{k}' if prefix else k + if isinstance(v, dict): + flat.update(_flatten_dict(v, prefix=key)) + else: + flat[key] = v + return flat + + +class Run: + """A wandb.Run-compatible object wrapping a pluto Op. + + Provides the same interface as wandb.Run so user code like + ``run.log()``, ``run.config``, ``run.summary``, etc. works seamlessly. + """ + + def __init__( + self, + op: Any, + name: Optional[str] = None, + notes: Optional[str] = None, + group: Optional[str] = None, + job_type: Optional[str] = None, + mode: str = 'online', + ) -> None: + self._op = op + self._name = name + self._notes = notes + self._group = group + self._job_type = job_type + self._mode = mode + + self._config = Config(op=op) + self._summary = Summary() + self._pending_data: Dict[str, Any] = {} + self._step = 0 + self._watched_models: List[Any] = [] + + # -- Properties matching wandb.Run -- + + @property + def id(self) -> str: + if self._op.run_id: + return self._op.run_id + op_id = self._op.id + return str(op_id) if op_id is not None else '' + + @property + def name(self) -> Optional[str]: + return self._name or getattr(self._op.settings, '_op_name', None) + + @name.setter + def name(self, value: str) -> None: + self._name = value + + @property + def entity(self) -> str: + return '' + + @property + def project(self) -> Optional[str]: + return getattr(self._op.settings, 'project', None) + + @property + def group(self) -> Optional[str]: + return self._group + + @property + def job_type(self) -> Optional[str]: + return self._job_type + + @property + def tags(self) -> tuple: + return tuple(self._op.tags) + + @tags.setter + def tags(self, value: Union[tuple, list, Sequence[str]]) -> None: + current = set(self._op.tags) + new = set(value) + to_add = new - current + to_remove = current - new + if to_remove: + self._op.remove_tags(list(to_remove)) + if to_add: + self._op.add_tags(list(to_add)) + + @property + def notes(self) -> Optional[str]: + return self._notes + + @notes.setter + def notes(self, value: str) -> None: + self._notes = value + + @property + def config(self) -> Config: + return self._config + + @config.setter + def config(self, value: Any) -> None: + if isinstance(value, dict): + self._config.update(value) + elif isinstance(value, Config): + self._config = value + + @property + def summary(self) -> Summary: + return self._summary + + @property + def url(self) -> Optional[str]: + return getattr(self._op.settings, 'url_view', None) + + @property + def dir(self) -> str: + return self._op.settings.get_dir() + + @property + def step(self) -> int: + return self._step + + @property + def offline(self) -> bool: + return self._mode == 'offline' + + @property + def disabled(self) -> bool: + return self._mode == 'disabled' + + @property + def resumed(self) -> bool: + return self._op.resumed + + @property + def path(self) -> str: + return f'{self.entity}/{self.project}/{self.id}' + + @property + def settings(self) -> Any: + return self._op.settings + + @property + def start_time(self) -> float: + import time + + return time.time() + + @property + def sweep_id(self) -> Optional[str]: + return None + + @property + def project_url(self) -> Optional[str]: + url = self.url + if url: + # Strip the run-specific part to get project URL + parts = url.rsplit('/', 1) + return parts[0] if len(parts) > 1 else url + return None + + # -- Core methods -- + + def log( + self, + data: Dict[str, Any], + step: Optional[int] = None, + commit: Optional[bool] = None, + sync: Optional[bool] = None, + ) -> None: + """Log metrics/data to the run.""" + # Flatten nested dicts + flat = _flatten_dict(data) + + # Convert wandb data types to pluto equivalents + converted = {k: _convert_value(v) for k, v in flat.items()} + + if commit is False: + # Buffer data, don't send yet + self._pending_data.update(converted) + return + + # Merge with any pending data + if self._pending_data: + merged = {**self._pending_data, **converted} + self._pending_data = {} + else: + merged = converted + + # Update summary with scalar values + self._summary._update_from_log(merged) + + # Track step + if step is not None: + self._step = step + else: + self._step += 1 + + # Forward to pluto + try: + self._op.log(merged, step=step) + except Exception as e: + logger.debug('%s: log failed: %s', tag, e) + + def finish( + self, + exit_code: Optional[int] = None, + quiet: Optional[bool] = None, + ) -> None: + """Mark the run as finished.""" + try: + self._op.finish(code=exit_code) + except Exception as e: + logger.debug('%s: finish failed: %s', tag, e) + + def watch( + self, + models: Any = None, + criterion: Any = None, + log: Optional[str] = 'gradients', + log_freq: int = 1000, + idx: Optional[int] = None, + log_graph: bool = False, + ) -> None: + """Watch a PyTorch model for gradient/parameter logging.""" + if models is None: + return + + model_list = models if isinstance(models, (list, tuple)) else [models] + for model in model_list: + try: + self._op.watch(model, log_freq=log_freq) + self._watched_models.append(model) + except Exception as e: + logger.debug('%s: watch failed: %s', tag, e) + + def unwatch(self, models: Any = None) -> None: + """Remove model watching hooks (best-effort).""" + logger.debug('%s: unwatch is a no-op in pluto compat layer', tag) + + def alert( + self, + title: str = '', + text: str = '', + level: Optional[str] = None, + wait_duration: Optional[Union[int, float]] = None, + ) -> None: + """Send an alert.""" + try: + kwargs: Dict[str, Any] = {} + if wait_duration is not None: + kwargs['wait_duration'] = wait_duration + self._op.alert( + title=title, + message=text, + level=level or 'INFO', + **kwargs, + ) + except Exception as e: + logger.debug('%s: alert failed: %s', tag, e) + + def define_metric( + self, + name: str, + step_metric: Optional[str] = None, + step_sync: Optional[bool] = None, + hidden: Optional[bool] = None, + summary: Optional[str] = None, + goal: Optional[str] = None, + overwrite: Optional[bool] = None, + ) -> Any: + """Define metric behavior. No-op in pluto (returns a stub).""" + logger.debug('%s: define_metric is a no-op', tag) + return _MetricStub(name) + + def save( + self, + glob_str: Optional[str] = None, + base_path: Optional[str] = None, + policy: str = 'live', + ) -> None: + """Sync files to the run. No-op in pluto compat layer.""" + logger.debug('%s: save is not supported', tag) + + def restore( + self, + name: str, + run_path: Optional[str] = None, + replace: bool = False, + root: Optional[str] = None, + ) -> None: + """Download a file from a run. Not supported.""" + logger.debug('%s: restore is not supported', tag) + + def log_artifact( + self, + artifact_or_path: Any, + name: Optional[str] = None, + type: Optional[str] = None, + aliases: Optional[List[str]] = None, + ) -> Any: + """Log an artifact (file collection) to the run.""" + if isinstance(artifact_or_path, Artifact): + pluto_files = artifact_or_path._to_pluto_files() + for i, pf in enumerate(pluto_files): + if hasattr(pf, '_name'): + log_name = f'{artifact_or_path.name}/{pf._name}' + else: + log_name = f'{artifact_or_path.name}/{i}' + try: + self._op.log({log_name: pf}) + except Exception as e: + logger.debug('%s: log_artifact file failed: %s', tag, e) + return artifact_or_path + elif isinstance(artifact_or_path, str): + from pluto.file import Artifact as PlutoArtifact + + art = PlutoArtifact(data=artifact_or_path, caption=name) + log_name = name or 'artifact' + try: + self._op.log({log_name: art}) + except Exception as e: + logger.debug('%s: log_artifact path failed: %s', tag, e) + return artifact_or_path + + def use_artifact( + self, + artifact_or_name: Any, + type: Optional[str] = None, + ) -> Any: + """Declare an artifact as input. Not supported.""" + logger.debug('%s: use_artifact is not supported', tag) + return artifact_or_name + + def log_code( + self, + root: Optional[str] = None, + name: Optional[str] = None, + include_fn: Any = None, + exclude_fn: Any = None, + ) -> None: + """Save source code. Not supported.""" + logger.debug('%s: log_code is not supported', tag) + + def mark_preempting(self) -> None: + """Mark run as preempted. No-op (pluto handles this via signals).""" + logger.debug('%s: mark_preempting is a no-op', tag) + + def status(self) -> Dict[str, Any]: + return {'synced': True} + + # -- Context manager -- + + def __enter__(self) -> 'Run': + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + exit_code = 1 if exc_type else 0 + self.finish(exit_code=exit_code) + return False + + def __repr__(self) -> str: + return f'' + + +class _MetricStub: + """Stub returned by define_metric().""" + + def __init__(self, name: str) -> None: + self.name = name + + def __repr__(self) -> str: + return f'' diff --git a/pluto/compat/wandb/summary.py b/pluto/compat/wandb/summary.py new file mode 100644 index 0000000..0db43cc --- /dev/null +++ b/pluto/compat/wandb/summary.py @@ -0,0 +1,87 @@ +"""wandb.summary-compatible dict-like summary metrics object.""" + +from typing import Any, Dict, Iterator, Optional + + +class Summary: + """A dict-like object tracking summary metrics, compatible with wandb.summary. + + Auto-populated from log() calls (last value per key for scalars). + Supports manual overrides via dict/attribute access. + """ + + def __init__(self) -> None: + object.__setattr__(self, '_data', {}) + + def _update_from_log(self, data: Dict[str, Any]) -> None: + """Called internally after each log() call to update last values.""" + store = object.__getattribute__(self, '_data') + for k, v in data.items(): + if isinstance(v, (int, float)) and not isinstance(v, bool): + store[k] = v + elif hasattr(v, 'item') and callable(v.item): + store[k] = v.item() + + # -- Attribute access -- + + def __setattr__(self, key: str, value: Any) -> None: + if key.startswith('_'): + object.__setattr__(self, key, value) + else: + self[key] = value + + def __getattr__(self, key: str) -> Any: + data = object.__getattribute__(self, '_data') + try: + return data[key] + except KeyError: + raise AttributeError(f"'Summary' object has no attribute '{key}'") + + # -- Dict access -- + + def __setitem__(self, key: str, value: Any) -> None: + object.__getattribute__(self, '_data')[key] = value + + def __getitem__(self, key: str) -> Any: + return object.__getattribute__(self, '_data')[key] + + def __delitem__(self, key: str) -> None: + del object.__getattribute__(self, '_data')[key] + + def __contains__(self, key: object) -> bool: + return key in object.__getattribute__(self, '_data') + + def __iter__(self) -> Iterator[str]: + return iter(object.__getattribute__(self, '_data')) + + def __len__(self) -> int: + return len(object.__getattribute__(self, '_data')) + + def __repr__(self) -> str: + return repr(object.__getattribute__(self, '_data')) + + def __bool__(self) -> bool: + return bool(object.__getattribute__(self, '_data')) + + # -- Dict-like methods -- + + def keys(self): + return object.__getattribute__(self, '_data').keys() + + def values(self): + return object.__getattribute__(self, '_data').values() + + def items(self): + return object.__getattribute__(self, '_data').items() + + def get(self, key: str, default: Any = None) -> Any: + return object.__getattribute__(self, '_data').get(key, default) + + def update(self, d: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None: + if d: + object.__getattribute__(self, '_data').update(d) + if kwargs: + object.__getattribute__(self, '_data').update(kwargs) + + def as_dict(self) -> Dict[str, Any]: + return dict(object.__getattribute__(self, '_data')) diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py new file mode 100644 index 0000000..bd6421a --- /dev/null +++ b/tests/test_wandb_compat.py @@ -0,0 +1,832 @@ +""" +Tests for the wandb-to-pluto compatibility layer. + +These tests validate that: +1. wandb API calls are correctly routed to pluto equivalents +2. Data types are converted properly +3. Config and Summary behave like wandb's dict-like objects +4. Module-level state (run, config, summary) works correctly +5. Unsupported features degrade gracefully (no-ops, not errors) +""" + +import os +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest + + +class TestConfig: + """Tests for the wandb.config-compatible Config class.""" + + def test_attribute_set_and_get(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.lr = 0.001 + assert c.lr == 0.001 + + def test_dict_set_and_get(self): + from pluto.compat.wandb.config import Config + + c = Config() + c['batch_size'] = 32 + assert c['batch_size'] == 32 + + def test_update_dict(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.update({'a': 1, 'b': 2}) + assert c.a == 1 + assert c['b'] == 2 + + def test_update_with_namespace(self): + """Test update() with argparse-like namespace.""" + import argparse + + from pluto.compat.wandb.config import Config + + ns = argparse.Namespace(lr=0.01, epochs=10) + c = Config() + c.update(ns) + assert c.lr == 0.01 + assert c.epochs == 10 + + def test_setdefaults(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.lr = 0.001 + c.setdefaults({'lr': 0.01, 'batch_size': 64}) + assert c.lr == 0.001 # not overwritten + assert c.batch_size == 64 # newly set + + def test_keys_values_items(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.update({'a': 1, 'b': 2}) + assert set(c.keys()) == {'a', 'b'} + assert set(c.values()) == {1, 2} + assert set(c.items()) == {('a', 1), ('b', 2)} + + def test_get_with_default(self): + from pluto.compat.wandb.config import Config + + c = Config() + assert c.get('missing', 42) == 42 + c.x = 10 + assert c.get('x', 42) == 10 + + def test_contains(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.lr = 0.01 + assert 'lr' in c + assert 'missing' not in c + + def test_len(self): + from pluto.compat.wandb.config import Config + + c = Config() + assert len(c) == 0 + c.update({'a': 1, 'b': 2}) + assert len(c) == 2 + + def test_iter(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.update({'a': 1, 'b': 2}) + assert set(c) == {'a', 'b'} + + def test_load(self): + from pluto.compat.wandb.config import Config + + c = Config() + c._load({'x': 1, 'y': 2}) + assert c.x == 1 + assert c['y'] == 2 + + def test_as_dict(self): + from pluto.compat.wandb.config import Config + + c = Config() + c.update({'a': 1}) + assert c.as_dict() == {'a': 1} + + def test_attribute_error_on_missing(self): + from pluto.compat.wandb.config import Config + + c = Config() + with pytest.raises(AttributeError): + _ = c.missing_key + + def test_sync_calls_update_config(self): + from pluto.compat.wandb.config import Config + + mock_op = MagicMock() + c = Config(op=mock_op) + c.lr = 0.01 + mock_op.update_config.assert_called_with({'lr': 0.01}) + + def test_sync_error_does_not_raise(self): + from pluto.compat.wandb.config import Config + + mock_op = MagicMock() + mock_op.update_config.side_effect = RuntimeError('connection failed') + c = Config(op=mock_op) + c.lr = 0.01 # should not raise + assert c.lr == 0.01 + + +class TestSummary: + """Tests for the wandb.summary-compatible Summary class.""" + + def test_dict_set_and_get(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s['best_acc'] = 0.95 + assert s['best_acc'] == 0.95 + + def test_attribute_set_and_get(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s.final_loss = 0.1 + assert s.final_loss == 0.1 + + def test_update_from_log(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._update_from_log({'loss': 0.5, 'acc': 0.9}) + assert s['loss'] == 0.5 + assert s['acc'] == 0.9 + + def test_update_from_log_ignores_non_scalars(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._update_from_log({'loss': 0.5, 'image': 'not_a_scalar', 'flag': True}) + assert s['loss'] == 0.5 + assert 'image' not in s + assert 'flag' not in s # bools excluded + + def test_manual_override(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._update_from_log({'loss': 0.5}) + s['loss'] = 0.1 # manual override + assert s['loss'] == 0.1 + + def test_keys_values_items(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s.update({'a': 1, 'b': 2}) + assert set(s.keys()) == {'a', 'b'} + + def test_get_default(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + assert s.get('missing', 99) == 99 + + def test_contains(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s.x = 1 + assert 'x' in s + assert 'y' not in s + + def test_as_dict(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s.update({'a': 1}) + assert s.as_dict() == {'a': 1} + + +class TestDataTypes: + """Tests for wandb data type wrappers.""" + + def test_image_from_numpy(self): + from pluto.compat.wandb.data_types import Image + + data = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + img = Image(data, caption='test') + assert img.caption == 'test' + pluto_img = img._to_pluto() + assert pluto_img.__class__.__name__ == 'Image' + + def test_image_from_path(self, tmp_path): + from pluto.compat.wandb.data_types import Image + + p = tmp_path / 'test.png' + # Write a minimal PNG + from PIL import Image as PILImage + + pil = PILImage.new('RGB', (10, 10)) + pil.save(str(p)) + + img = Image(str(p)) + pluto_img = img._to_pluto() + assert pluto_img.__class__.__name__ == 'Image' + + def test_audio_wrapper(self): + from pluto.compat.wandb.data_types import Audio + + data = np.random.randn(16000).astype(np.float32) + a = Audio(data, sample_rate=16000, caption='test_audio') + pluto_audio = a._to_pluto() + assert pluto_audio.__class__.__name__ == 'Audio' + + def test_video_wrapper(self, tmp_path): + from pluto.compat.wandb.data_types import Video + + p = tmp_path / 'test.mp4' + p.write_bytes(b'\x00' * 100) + v = Video(str(p), fps=30, caption='test_video') + pluto_video = v._to_pluto() + assert pluto_video.__class__.__name__ == 'Video' + + def test_table_from_data(self): + from pluto.compat.wandb.data_types import Table + + t = Table(columns=['a', 'b'], data=[[1, 2], [3, 4]]) + assert t.columns == ['a', 'b'] + pluto_table = t._to_pluto() + assert pluto_table.__class__.__name__ == 'Table' + + def test_table_add_data(self): + from pluto.compat.wandb.data_types import Table + + t = Table(columns=['x', 'y']) + t.add_data(1, 2) + t.add_data(3, 4) + assert len(t._data) == 2 + + def test_table_add_column(self): + from pluto.compat.wandb.data_types import Table + + t = Table(columns=['x'], data=[[1], [2]]) + t.add_column('y', [10, 20]) + assert 'y' in t.columns + assert t._data[0] == [1, 10] + + def test_table_get_column(self): + from pluto.compat.wandb.data_types import Table + + t = Table(columns=['a', 'b'], data=[[1, 2], [3, 4]]) + assert t.get_column('a') == [1, 3] + assert t.get_column('b') == [2, 4] + assert t.get_column('missing') == [] + + def test_table_from_dataframe(self): + import pandas as pd + + from pluto.compat.wandb.data_types import Table + + df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + t = Table(dataframe=df) + pluto_table = t._to_pluto() + assert pluto_table.__class__.__name__ == 'Table' + + def test_histogram_from_sequence(self): + from pluto.compat.wandb.data_types import Histogram + + h = Histogram(sequence=[1, 2, 3, 4, 5], num_bins=10) + pluto_hist = h._to_pluto() + assert pluto_hist.__class__.__name__ == 'Histogram' + + def test_histogram_from_np_histogram(self): + from pluto.compat.wandb.data_types import Histogram + + counts, bins = np.histogram([1, 2, 3, 4, 5], bins=5) + h = Histogram(np_histogram=(counts, bins)) + pluto_hist = h._to_pluto() + assert pluto_hist.__class__.__name__ == 'Histogram' + + def test_html_from_string(self): + from pluto.compat.wandb.data_types import Html + + h = Html('

Hello

') + pluto_text = h._to_pluto() + assert pluto_text.__class__.__name__ == 'Text' + + def test_html_from_file(self, tmp_path): + from pluto.compat.wandb.data_types import Html + + p = tmp_path / 'test.html' + p.write_text('

Hello

') + h = Html(str(p)) + assert h._html == '

Hello

' + + def test_alert_level(self): + from pluto.compat.wandb.data_types import AlertLevel + + assert AlertLevel.INFO == 'INFO' + assert AlertLevel.WARN == 'WARN' + assert AlertLevel.ERROR == 'ERROR' + + def test_artifact_add_file(self, tmp_path): + from pluto.compat.wandb.data_types import Artifact + + f1 = tmp_path / 'model.pt' + f1.write_bytes(b'\x00' * 100) + art = Artifact('my-model', type='model') + art.add_file(str(f1), name='model.pt') + assert len(art._files) == 1 + assert art._files[0]['name'] == 'model.pt' + + def test_artifact_add_dir(self, tmp_path): + from pluto.compat.wandb.data_types import Artifact + + d = tmp_path / 'data' + d.mkdir() + (d / 'a.txt').write_text('hello') + (d / 'b.txt').write_text('world') + art = Artifact('my-data', type='dataset') + art.add_dir(str(d)) + assert len(art._files) == 2 + + def test_artifact_to_pluto_files(self, tmp_path): + from pluto.compat.wandb.data_types import Artifact + + f1 = tmp_path / 'file.bin' + f1.write_bytes(b'\x00' * 50) + art = Artifact('test', type='model') + art.add_file(str(f1)) + pluto_files = art._to_pluto_files() + assert len(pluto_files) == 1 + assert pluto_files[0].__class__.__name__ == 'Artifact' + + def test_graph_wrapper(self): + from pluto.compat.wandb.data_types import Graph + + g = Graph() + pluto_graph = g._to_pluto() + assert pluto_graph.__class__.__name__ == 'Graph' + + +class TestRun: + """Tests for the wandb.Run-compatible Run class.""" + + def _make_run(self, **kwargs): + """Create a Run with a mocked Op.""" + from pluto.compat.wandb.run import Run + + op = MagicMock() + op.id = 123 + op.run_id = None + op.tags = ['tag1'] + op.resumed = False + op.settings = MagicMock() + op.settings._op_name = 'test-run' + op.settings.project = 'test-project' + op.settings.url_view = 'https://pluto.trainy.ai/run/123' + op.settings.get_dir.return_value = '/tmp/test-run' + op.config = {} + return Run(op=op, **kwargs), op + + def test_properties(self): + run, op = self._make_run(name='my-run', notes='some notes', group='grp') + assert run.id == '123' + assert run.name == 'my-run' + assert run.project == 'test-project' + assert run.notes == 'some notes' + assert run.group == 'grp' + assert run.entity == '' + assert run.url == 'https://pluto.trainy.ai/run/123' + + def test_tags_get_and_set(self): + run, op = self._make_run() + assert run.tags == ('tag1',) + + # Setting tags + run.tags = ('tag1', 'tag2') + op.add_tags.assert_called() + + def test_name_setter(self): + run, op = self._make_run(name='original') + run.name = 'new-name' + assert run.name == 'new-name' + + def test_log_scalars(self): + run, op = self._make_run() + run.log({'loss': 0.5, 'acc': 0.9}) + op.log.assert_called_once() + logged_data = op.log.call_args[0][0] + assert logged_data['loss'] == 0.5 + assert logged_data['acc'] == 0.9 + + def test_log_nested_dict(self): + run, op = self._make_run() + run.log({'train': {'loss': 0.5}}) + op.log.assert_called_once() + logged_data = op.log.call_args[0][0] + assert 'train/loss' in logged_data + + def test_log_with_step(self): + run, op = self._make_run() + run.log({'loss': 0.5}, step=10) + op.log.assert_called_once() + assert op.log.call_args[1]['step'] == 10 + + def test_log_commit_false(self): + run, op = self._make_run() + run.log({'loss': 0.5}, commit=False) + op.log.assert_not_called() + + run.log({'acc': 0.9}) # default commit=True + op.log.assert_called_once() + logged_data = op.log.call_args[0][0] + assert 'loss' in logged_data + assert 'acc' in logged_data + + def test_log_updates_summary(self): + run, op = self._make_run() + run.log({'loss': 0.5}) + assert run.summary['loss'] == 0.5 + + def test_log_converts_data_types(self): + from pluto.compat.wandb.data_types import Image + + run, op = self._make_run() + data = np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8) + run.log({'image': Image(data)}) + op.log.assert_called_once() + logged_data = op.log.call_args[0][0] + assert logged_data['image'].__class__.__name__ == 'Image' + + def test_finish(self): + run, op = self._make_run() + run.finish() + op.finish.assert_called_once() + + def test_finish_with_exit_code(self): + run, op = self._make_run() + run.finish(exit_code=1) + op.finish.assert_called_once_with(code=1) + + def test_context_manager(self): + run, op = self._make_run() + with run: + run.log({'x': 1}) + op.finish.assert_called_once_with(code=0) + + def test_context_manager_with_exception(self): + run, op = self._make_run() + with pytest.raises(ValueError): + with run: + raise ValueError('test') + op.finish.assert_called_once_with(code=1) + + def test_watch(self): + run, op = self._make_run() + model = MagicMock() + run.watch(model, log_freq=500) + op.watch.assert_called_once_with(model, log_freq=500) + + def test_alert(self): + run, op = self._make_run() + run.alert(title='Loss spike', text='Loss exceeded 10.0', level='WARN') + op.alert.assert_called_once() + + def test_define_metric_returns_stub(self): + run, op = self._make_run() + m = run.define_metric('loss', step_metric='epoch') + assert m.name == 'loss' + + def test_unsupported_methods_no_error(self): + run, op = self._make_run() + run.save() + run.restore('test') + run.log_code() + run.mark_preempting() + run.use_artifact('test') + + def test_log_artifact_with_artifact_object(self, tmp_path): + from pluto.compat.wandb.data_types import Artifact + + run, op = self._make_run() + f = tmp_path / 'model.pt' + f.write_bytes(b'\x00' * 100) + art = Artifact('model', type='model') + art.add_file(str(f)) + run.log_artifact(art) + op.log.assert_called() + + def test_repr(self): + run, op = self._make_run(name='exp-1') + r = repr(run) + assert 'test-project' in r + assert 'exp-1' in r + + def test_step_tracking(self): + run, op = self._make_run() + assert run.step == 0 + run.log({'x': 1}) + assert run.step == 1 + run.log({'x': 2}) + assert run.step == 2 + run.log({'x': 3}, step=10) + assert run.step == 10 + + def test_offline_disabled(self): + run_online, _ = self._make_run(mode='online') + assert not run_online.offline + assert not run_online.disabled + + run_off, _ = self._make_run(mode='offline') + assert run_off.offline + + run_dis, _ = self._make_run(mode='disabled') + assert run_dis.disabled + + def test_path(self): + run, _ = self._make_run() + assert run.path == '/test-project/123' + + +class TestModuleAPI: + """Tests for the module-level wandb API (init, log, finish, etc.).""" + + def test_log_before_init_raises(self): + import pluto.compat.wandb as wandb + + # Ensure no active run + wandb.run = None + with pytest.raises(RuntimeError, match='wandb.log.*called before wandb.init'): + wandb.log({'x': 1}) + + def test_watch_before_init_raises(self): + import pluto.compat.wandb as wandb + + wandb.run = None + with pytest.raises(RuntimeError, match='wandb.watch.*called before wandb.init'): + wandb.watch(MagicMock()) + + def test_finish_without_init_is_noop(self): + import pluto.compat.wandb as wandb + + wandb.run = None + wandb.finish() # should not raise + + def test_define_metric_without_init(self): + import pluto.compat.wandb as wandb + + wandb.run = None + m = wandb.define_metric('loss') + assert m.name == 'loss' + + def test_unsupported_module_functions_noop(self): + import pluto.compat.wandb as wandb + + # These should all be no-ops, never raise + wandb.save('*.pt') + wandb.restore('model.pt') + wandb.log_code() + wandb.mark_preempting() + + def test_login_returns_true(self): + import pluto.compat.wandb as wandb + + assert wandb.login() is True + + def test_settings_class(self): + import pluto.compat.wandb as wandb + + s = wandb.Settings(mode='offline', console='auto') + assert s.mode == 'offline' + assert s.console == 'auto' + + def test_data_types_importable(self): + """Test that all wandb data types are importable from the module.""" + from pluto.compat.wandb import ( + AlertLevel, + Artifact, + Audio, + Graph, + Histogram, + Html, + Image, + Table, + Video, + ) + + assert Image is not None + assert Audio is not None + assert Video is not None + assert Table is not None + assert Histogram is not None + assert Html is not None + assert Graph is not None + assert Artifact is not None + assert AlertLevel is not None + + def test_import_as_wandb(self): + """Test the canonical import pattern.""" + import pluto.compat.wandb as wandb + + assert hasattr(wandb, 'init') + assert hasattr(wandb, 'log') + assert hasattr(wandb, 'finish') + assert hasattr(wandb, 'watch') + assert hasattr(wandb, 'unwatch') + assert hasattr(wandb, 'alert') + assert hasattr(wandb, 'config') + assert hasattr(wandb, 'summary') + assert hasattr(wandb, 'run') + assert hasattr(wandb, 'Image') + assert hasattr(wandb, 'Table') + assert hasattr(wandb, 'Run') + + @mock.patch.dict( + os.environ, + { + 'WANDB_PROJECT': 'env-project', + }, + clear=False, + ) + def test_init_picks_up_wandb_project_env(self): + """Test that WANDB_PROJECT env var is used as fallback.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = MagicMock() + mock_op.id = 1 + mock_op.run_id = None + mock_op.tags = [] + mock_op.resumed = False + mock_op.settings = MagicMock() + mock_op.settings._op_name = 'test' + mock_op.settings.project = 'env-project' + mock_op.settings.url_view = None + mock_op.settings.get_dir.return_value = '/tmp' + mock_op.config = {} + mock_init.return_value = mock_op + + wandb.init() + mock_init.assert_called_once() + assert mock_init.call_args[1]['project'] == 'env-project' + wandb.finish() + + @mock.patch.dict( + os.environ, + { + 'WANDB_MODE': 'disabled', + }, + clear=False, + ) + def test_init_disabled_mode_from_env(self): + """Test that WANDB_MODE=disabled creates noop run.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = MagicMock() + mock_op.id = 1 + mock_op.run_id = None + mock_op.tags = [] + mock_op.resumed = False + mock_op.settings = MagicMock() + mock_op.settings._op_name = 'test' + mock_op.settings.project = 'test' + mock_op.settings.url_view = None + mock_op.settings.get_dir.return_value = '/tmp' + mock_op.config = {} + mock_init.return_value = mock_op + + wandb.init(project='test') + # Should pass mode='noop' in settings + call_kwargs = mock_init.call_args[1] + settings = call_kwargs.get('settings', {}) + assert settings.get('mode') == 'noop' + wandb.finish() + + def test_init_creates_disabled_run_on_failure(self): + """Test that init creates a disabled run if pluto.init fails.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + # First call fails, second call (disabled) succeeds + mock_op = MagicMock() + mock_op.id = 0 + mock_op.run_id = None + mock_op.tags = [] + mock_op.resumed = False + mock_op.settings = MagicMock() + mock_op.settings._op_name = 'disabled' + mock_op.settings.project = 'disabled' + mock_op.settings.url_view = None + mock_op.settings.get_dir.return_value = '/tmp' + mock_op.config = {} + mock_init.side_effect = [RuntimeError('auth failed'), mock_op] + + run = wandb.init(project='test') + assert run.disabled + wandb.finish() + + @mock.patch.dict( + os.environ, + { + 'WANDB_TAGS': 'tag1,tag2, tag3', + }, + clear=False, + ) + def test_init_tags_from_env(self): + """Test that WANDB_TAGS env var is parsed.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = MagicMock() + mock_op.id = 1 + mock_op.run_id = None + mock_op.tags = ['tag1', 'tag2', 'tag3'] + mock_op.resumed = False + mock_op.settings = MagicMock() + mock_op.settings._op_name = 'test' + mock_op.settings.project = 'test' + mock_op.settings.url_view = None + mock_op.settings.get_dir.return_value = '/tmp' + mock_op.config = {} + mock_init.return_value = mock_op + + wandb.init(project='test') + call_kwargs = mock_init.call_args[1] + assert call_kwargs['tags'] == ['tag1', 'tag2', 'tag3'] + wandb.finish() + + +class TestFlattenDict: + """Tests for nested dict flattening (wandb convention).""" + + def test_flat_dict_unchanged(self): + from pluto.compat.wandb.run import _flatten_dict + + assert _flatten_dict({'a': 1, 'b': 2}) == {'a': 1, 'b': 2} + + def test_nested_dict(self): + from pluto.compat.wandb.run import _flatten_dict + + result = _flatten_dict({'train': {'loss': 0.5, 'acc': 0.9}}) + assert result == {'train/loss': 0.5, 'train/acc': 0.9} + + def test_deeply_nested(self): + from pluto.compat.wandb.run import _flatten_dict + + result = _flatten_dict({'a': {'b': {'c': 1}}}) + assert result == {'a/b/c': 1} + + def test_mixed_nesting(self): + from pluto.compat.wandb.run import _flatten_dict + + result = _flatten_dict({'loss': 0.5, 'train': {'acc': 0.9}}) + assert result == {'loss': 0.5, 'train/acc': 0.9} + + +class TestValueConversion: + """Tests for wandb data type → pluto type conversion in log().""" + + def test_scalar_passthrough(self): + from pluto.compat.wandb.run import _convert_value + + assert _convert_value(42) == 42 + assert _convert_value(3.14) == 3.14 + assert _convert_value('hello') == 'hello' + + def test_image_converted(self): + from pluto.compat.wandb.data_types import Image + from pluto.compat.wandb.run import _convert_value + + data = np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8) + img = Image(data) + result = _convert_value(img) + assert result.__class__.__name__ == 'Image' + # Should be pluto Image, not wandb Image + assert result.__class__.__module__.startswith('pluto.file') + + def test_table_converted(self): + from pluto.compat.wandb.data_types import Table + from pluto.compat.wandb.run import _convert_value + + t = Table(columns=['a'], data=[[1], [2]]) + result = _convert_value(t) + assert result.__class__.__name__ == 'Table' + assert result.__class__.__module__.startswith('pluto.data') + + def test_histogram_converted(self): + from pluto.compat.wandb.data_types import Histogram + from pluto.compat.wandb.run import _convert_value + + h = Histogram(sequence=[1, 2, 3, 4, 5]) + result = _convert_value(h) + assert result.__class__.__name__ == 'Histogram' + assert result.__class__.__module__.startswith('pluto.data') From 3501f0005720bfbc35173855a978e5dae5b1dce3 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Feb 2026 05:10:03 +0000 Subject: [PATCH 2/7] Add parity contract tests for wandb compat layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 19 new tests in TestParityContract that pin the exact pluto call sequences for common wandb workflows: - Standard training loop (init → config → log N → finish) - Nested metric namespace flattening (train/loss, val/acc) - commit=False buffering and flush behavior - Duplicate key resolution (later values win) - Explicit step= forwarding - Config mutations (attr, dict, bulk update, argparse namespace) - config_include_keys / config_exclude_keys filtering - Tags lifecycle (init tags, runtime mutation) - Data type conversion in log() (Image, Table, Histogram → pluto) - Summary auto-tracking of last scalar per key - Context manager lifecycle (success and exception exit codes) - reinit finishing previous run - watch/alert forwarding - Full realistic workflow (config+tags+mixed data+summary) - Module state reset after finish - log_artifact call sequence (one op.log per file) https://claude.ai/code/session_01VTSZKK5UsMqjiADFX57SMY --- tests/test_wandb_compat.py | 541 +++++++++++++++++++++++++++++++++++++ 1 file changed, 541 insertions(+) diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py index bd6421a..eec6840 100644 --- a/tests/test_wandb_compat.py +++ b/tests/test_wandb_compat.py @@ -830,3 +830,544 @@ def test_histogram_converted(self): result = _convert_value(h) assert result.__class__.__name__ == 'Histogram' assert result.__class__.__module__.startswith('pluto.data') + + +# --------------------------------------------------------------------------- +# Helpers for parity contract tests +# --------------------------------------------------------------------------- + + +def _make_mock_op(**overrides): + """Create a consistently-configured mock Op for contract tests.""" + op = MagicMock() + op.id = overrides.get('id', 1) + op.run_id = overrides.get('run_id', None) + op.tags = list(overrides.get('tags', [])) + op.resumed = overrides.get('resumed', False) + op.settings = MagicMock() + op.settings._op_name = overrides.get('name', 'test-run') + op.settings.project = overrides.get('project', 'test-project') + op.settings.url_view = overrides.get('url_view', None) + op.settings.get_dir.return_value = '/tmp/pluto' + op.config = {} + return op + + +class TestParityContract: + """Snapshot/contract tests that pin the exact pluto calls the compat layer + produces for common wandb usage patterns. + + Each test simulates a realistic wandb workflow end-to-end via the + module-level API and asserts the full call sequence that reaches + the underlying pluto.Op mock. + """ + + def test_standard_training_loop(self): + """Standard loop: init → config → log N steps → finish.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + run = wandb.init(project='my-project', config={'lr': 0.01, 'epochs': 5}) + + # --- init assertions --- + mock_init.assert_called_once() + init_kw = mock_init.call_args[1] + assert init_kw['project'] == 'my-project' + # config should be forwarded + assert init_kw['config']['lr'] == 0.01 + assert init_kw['config']['epochs'] == 5 + + # --- log loop --- + for i in range(5): + wandb.log({'loss': 1.0 / (i + 1), 'acc': i * 0.2}) + + assert mock_op.log.call_count == 5 + # First call + first_data = mock_op.log.call_args_list[0][0][0] + assert first_data == {'loss': 1.0, 'acc': 0.0} + # Last call + last_data = mock_op.log.call_args_list[4][0][0] + assert last_data['loss'] == pytest.approx(0.2) + assert last_data['acc'] == pytest.approx(0.8) + + # --- summary tracks last values --- + assert run.summary['loss'] == pytest.approx(0.2) + assert run.summary['acc'] == pytest.approx(0.8) + + # --- finish --- + wandb.finish() + mock_op.finish.assert_called_once() + + def test_nested_metric_namespaces(self): + """Nested dicts are flattened with / separators before reaching pluto.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + wandb.log( + { + 'train': {'loss': 0.5, 'acc': 0.9}, + 'val': {'loss': 0.6, 'acc': 0.85}, + 'lr': 0.001, + } + ) + + logged = mock_op.log.call_args[0][0] + assert set(logged.keys()) == { + 'train/loss', + 'train/acc', + 'val/loss', + 'val/acc', + 'lr', + } + assert logged['train/loss'] == 0.5 + assert logged['val/acc'] == 0.85 + assert logged['lr'] == 0.001 + + wandb.finish() + + def test_commit_false_buffering(self): + """commit=False accumulates data; next commit=True flushes all.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + # These should NOT call op.log + wandb.log({'loss': 0.5}, commit=False) + wandb.log({'acc': 0.9}, commit=False) + assert mock_op.log.call_count == 0 + + # This flushes everything + wandb.log({'lr': 0.01}) + assert mock_op.log.call_count == 1 + + flushed = mock_op.log.call_args[0][0] + assert flushed == {'loss': 0.5, 'acc': 0.9, 'lr': 0.01} + + wandb.finish() + + def test_commit_false_later_values_win(self): + """When same key is buffered then committed, last value wins.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + wandb.log({'loss': 0.5}, commit=False) + wandb.log({'loss': 0.3}) # overrides buffered value + + flushed = mock_op.log.call_args[0][0] + assert flushed['loss'] == 0.3 + + wandb.finish() + + def test_explicit_step_forwarded(self): + """step= kwarg is passed through to pluto op.log.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + wandb.log({'x': 1}, step=0) + wandb.log({'x': 2}, step=5) + wandb.log({'x': 3}, step=100) + + steps = [c[1]['step'] for c in mock_op.log.call_args_list] + assert steps == [0, 5, 100] + + wandb.finish() + + def test_config_mutations_sync(self): + """Config changes after init reach pluto via update_config.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + run = wandb.init(project='test', config={'lr': 0.01}) + + # Attribute assignment + wandb.config.batch_size = 32 + mock_op.update_config.assert_called_with({'batch_size': 32}) + + # Dict assignment + wandb.config['optimizer'] = 'adam' + mock_op.update_config.assert_called_with({'optimizer': 'adam'}) + + # Bulk update + wandb.config.update({'dropout': 0.1, 'weight_decay': 1e-4}) + mock_op.update_config.assert_called_with( + {'dropout': 0.1, 'weight_decay': 1e-4} + ) + + # All values accessible + assert run.config.as_dict() == { + 'lr': 0.01, + 'batch_size': 32, + 'optimizer': 'adam', + 'dropout': 0.1, + 'weight_decay': 1e-4, + } + + wandb.finish() + + def test_config_from_argparse(self): + """argparse.Namespace passed to init is forwarded as config dict.""" + import argparse + + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + args = argparse.Namespace(lr=0.001, epochs=10, model='resnet50') + wandb.init(project='test', config=args) + + init_kw = mock_init.call_args[1] + assert init_kw['config']['lr'] == 0.001 + assert init_kw['config']['epochs'] == 10 + assert init_kw['config']['model'] == 'resnet50' + + wandb.finish() + + def test_config_include_exclude_keys(self): + """config_include_keys and config_exclude_keys filter correctly.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + # include_keys: only lr and epochs pass through + wandb.init( + project='test', + config={'lr': 0.01, 'epochs': 5, 'secret': 'xxx'}, + config_include_keys=['lr', 'epochs'], + ) + init_kw = mock_init.call_args[1] + assert 'lr' in init_kw['config'] + assert 'epochs' in init_kw['config'] + assert 'secret' not in init_kw['config'] + wandb.finish() + + mock_init.reset_mock() + mock_init.return_value = mock_op + + # exclude_keys: secret filtered out + wandb.init( + project='test', + config={'lr': 0.01, 'epochs': 5, 'secret': 'xxx'}, + config_exclude_keys=['secret'], + ) + init_kw = mock_init.call_args[1] + assert 'lr' in init_kw['config'] + assert 'secret' not in init_kw['config'] + wandb.finish() + + def test_tags_lifecycle(self): + """Tags from init() and runtime mutations reach pluto correctly.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op(tags=['baseline', 'v1']) + mock_init.return_value = mock_op + + run = wandb.init(project='test', tags=['baseline', 'v1']) + + # init forwards tags + assert mock_init.call_args[1]['tags'] == ['baseline', 'v1'] + + # Runtime tag mutation via property setter + run.tags = ['baseline', 'v1', 'promoted'] + mock_op.add_tags.assert_called_once_with(['promoted']) + + wandb.finish() + + def test_data_type_conversion_in_log(self): + """wandb data types logged via log() arrive as pluto types at op.log.""" + import pluto.compat.wandb as wandb + from pluto.compat.wandb.data_types import Histogram, Image, Table + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + img_data = np.random.randint(0, 255, (8, 8, 3), dtype=np.uint8) + wandb.log( + { + 'loss': 0.5, + 'image': Image(img_data, caption='sample'), + 'table': Table(columns=['a', 'b'], data=[[1, 2]]), + 'dist': Histogram(sequence=[1, 2, 3, 4, 5]), + } + ) + + logged = mock_op.log.call_args[0][0] + + # Scalars pass through + assert logged['loss'] == 0.5 + + # Data types are converted to pluto equivalents + assert logged['image'].__class__.__module__.startswith('pluto.file') + assert logged['image'].__class__.__name__ == 'Image' + + assert logged['table'].__class__.__module__.startswith('pluto.data') + assert logged['table'].__class__.__name__ == 'Table' + + assert logged['dist'].__class__.__module__.startswith('pluto.data') + assert logged['dist'].__class__.__name__ == 'Histogram' + + wandb.finish() + + def test_summary_tracks_last_scalar_per_key(self): + """summary auto-updates to last-logged scalar; non-scalars ignored.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + run = wandb.init(project='test') + + wandb.log({'loss': 1.0, 'name': 'first'}) + wandb.log({'loss': 0.5, 'name': 'second'}) + wandb.log({'loss': 0.1, 'name': 'third'}) + + assert run.summary['loss'] == pytest.approx(0.1) + # Strings are not tracked in summary + assert 'name' not in run.summary + + # Manual override still works + run.summary['best_loss'] = 0.05 + assert run.summary['best_loss'] == 0.05 + + wandb.finish() + + def test_context_manager_lifecycle(self): + """with wandb.init() as run: produces correct init/finish sequence.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + with wandb.init(project='test') as run: + run.log({'step': 1}) + run.log({'step': 2}) + + # init called exactly once + mock_init.assert_called_once() + # Two log calls + assert mock_op.log.call_count == 2 + # finish called exactly once with success code + mock_op.finish.assert_called_once_with(code=0) + + def test_context_manager_exception_exit_code(self): + """Exception inside context manager produces exit_code=1.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + with pytest.raises(RuntimeError): + with wandb.init(project='test') as run: + run.log({'step': 1}) + raise RuntimeError('training crashed') + + mock_op.finish.assert_called_once_with(code=1) + + def test_reinit_finishes_previous_run(self): + """Calling init() twice finishes the first run before creating second.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + op1 = _make_mock_op(id=1) + op2 = _make_mock_op(id=2) + mock_init.side_effect = [op1, op2] + + wandb.init(project='test') + run2 = wandb.init(project='test', reinit=True) + + # First run was finished before second started + op1.finish.assert_called_once() + assert run2.id == '2' + + wandb.finish() + + def test_watch_forwards_to_op(self): + """wandb.watch() forwards model and log_freq to op.watch.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + model = MagicMock() + wandb.watch(model, log_freq=100) + + mock_op.watch.assert_called_once_with(model, log_freq=100) + + wandb.finish() + + def test_alert_forwards_to_op(self): + """wandb.alert() forwards title, text, level to op.alert.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + wandb.alert(title='Divergence', text='Loss is NaN', level='ERROR') + + mock_op.alert.assert_called_once_with( + title='Divergence', + message='Loss is NaN', + level='ERROR', + ) + + wandb.finish() + + def test_full_realistic_workflow(self): + """End-to-end: init with config+tags, log mixed data, mutate config, + update summary, finish. Asserts complete call sequence.""" + import argparse + + import pluto.compat.wandb as wandb + from pluto.compat.wandb.data_types import Image + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op(tags=['experiment']) + mock_init.return_value = mock_op + + # 1. Init with argparse config and tags + args = argparse.Namespace(lr=0.001, batch_size=64) + run = wandb.init( + project='cifar10', + name='resnet-run-1', + config=args, + tags=['experiment'], + ) + + # 2. Update config post-init + wandb.config.update({'scheduler': 'cosine'}) + + # 3. Training loop with mixed types + for epoch in range(3): + wandb.log( + { + 'train': {'loss': 1.0 / (epoch + 1)}, + 'epoch': epoch, + } + ) + + # 4. Log an image + img = Image(np.zeros((4, 4, 3), dtype=np.uint8)) + wandb.log({'sample': img}) + + # 5. Manual summary + run.summary['best_epoch'] = 2 + + # 6. Finish + wandb.finish() + + # --- Assertions --- + + # init forwarded correctly + init_kw = mock_init.call_args[1] + assert init_kw['project'] == 'cifar10' + assert init_kw['name'] == 'resnet-run-1' + assert init_kw['tags'] == ['experiment'] + assert init_kw['config']['lr'] == 0.001 + assert init_kw['config']['batch_size'] == 64 + + # config mutation synced + mock_op.update_config.assert_any_call({'scheduler': 'cosine'}) + + # 3 training steps + 1 image = 4 log calls + assert mock_op.log.call_count == 4 + + # Training steps flattened correctly + for i in range(3): + call_data = mock_op.log.call_args_list[i][0][0] + assert 'train/loss' in call_data + assert 'epoch' in call_data + + # Image converted to pluto type + img_call = mock_op.log.call_args_list[3][0][0] + assert img_call['sample'].__class__.__module__.startswith('pluto.file') + + # Summary state + assert run.summary['best_epoch'] == 2 + assert run.summary['epoch'] == 2 + assert run.summary['train/loss'] == pytest.approx(1.0 / 3) + + # finish called + mock_op.finish.assert_called_once() + + def test_module_state_reset_after_finish(self): + """After finish(), module-level config/summary/run are reset.""" + import pluto.compat.wandb as wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test', config={'lr': 0.01}) + assert wandb.run is not None + assert 'lr' in wandb.config + + wandb.finish() + + assert wandb.run is None + assert len(wandb.config) == 0 + assert len(wandb.summary) == 0 + + def test_log_artifact_call_sequence(self, tmp_path): + """log_artifact with Artifact object produces one op.log per file.""" + import pluto.compat.wandb as wandb + from pluto.compat.wandb.data_types import Artifact + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + wandb.init(project='test') + + f1 = tmp_path / 'weights.pt' + f2 = tmp_path / 'config.json' + f1.write_bytes(b'\x00' * 50) + f2.write_text('{}') + + art = Artifact('checkpoint', type='model') + art.add_file(str(f1), name='weights.pt') + art.add_file(str(f2), name='config.json') + wandb.log_artifact(art) + + # Each file in the artifact produces one op.log call + assert mock_op.log.call_count == 2 + + wandb.finish() From bb19d6ea26dee1b05bf020e0f8e12038a169b1d5 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 13 Feb 2026 22:41:45 +0000 Subject: [PATCH 3/7] Add top-level wandb shim package for zero-change migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Users can now swap `wandb` for `pluto-ml` in their dependencies and keep `import wandb` unchanged — no source edits needed. How it works: - Top-level `wandb/` package included in pyproject.toml packages list - `wandb/__init__.py` re-exports everything from pluto.compat.wandb - Common submodule stubs so deep imports don't break: - wandb.sdk, wandb.sdk.data_types - wandb.data_types - wandb.plot (no-op stubs for line_series, confusion_matrix, etc.) - wandb.apis (Api stub that raises NotImplementedError on queries) - wandb.util (generate_id, make_artifact_name_safe, to_json) - wandb.integration.lightning (WandbLogger → pluto MLOPLogger) 14 new tests in TestTopLevelWandbPackage verifying all import patterns. https://claude.ai/code/session_01VTSZKK5UsMqjiADFX57SMY --- pyproject.toml | 3 +- tests/test_wandb_compat.py | 129 ++++++++++++++++++++++++ wandb/__init__.py | 25 +++++ wandb/apis/__init__.py | 48 +++++++++ wandb/data_types/__init__.py | 3 + wandb/integration/__init__.py | 1 + wandb/integration/lightning/__init__.py | 9 ++ wandb/plot/__init__.py | 79 +++++++++++++++ wandb/sdk/__init__.py | 9 ++ wandb/sdk/data_types/__init__.py | 3 + wandb/util/__init__.py | 29 ++++++ 11 files changed, 337 insertions(+), 1 deletion(-) create mode 100644 wandb/__init__.py create mode 100644 wandb/apis/__init__.py create mode 100644 wandb/data_types/__init__.py create mode 100644 wandb/integration/__init__.py create mode 100644 wandb/integration/lightning/__init__.py create mode 100644 wandb/plot/__init__.py create mode 100644 wandb/sdk/__init__.py create mode 100644 wandb/sdk/data_types/__init__.py create mode 100644 wandb/util/__init__.py diff --git a/pyproject.toml b/pyproject.toml index ee8fc18..2c4bd4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ version = "0.0.6" description = "Pluto ML - Machine Learning Operations Framework" packages = [ {include = "pluto"}, - {include = "mlop"} + {include = "mlop"}, + {include = "wandb"} ] authors = [ "jqssun", diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py index eec6840..e779e47 100644 --- a/tests/test_wandb_compat.py +++ b/tests/test_wandb_compat.py @@ -1371,3 +1371,132 @@ def test_log_artifact_call_sequence(self, tmp_path): assert mock_op.log.call_count == 2 wandb.finish() + + +class TestTopLevelWandbPackage: + """Tests that ``import wandb`` resolves to the pluto shim and that + all common import patterns used in real wandb code work.""" + + def test_import_wandb(self): + """Plain ``import wandb`` works and exposes the core API.""" + import wandb + + assert hasattr(wandb, 'init') + assert hasattr(wandb, 'log') + assert hasattr(wandb, 'finish') + assert hasattr(wandb, 'watch') + assert hasattr(wandb, 'config') + assert hasattr(wandb, 'summary') + assert hasattr(wandb, 'run') + assert hasattr(wandb, 'Image') + assert hasattr(wandb, 'Table') + assert hasattr(wandb, 'Histogram') + assert hasattr(wandb, 'Audio') + assert hasattr(wandb, 'Video') + assert hasattr(wandb, 'Html') + assert hasattr(wandb, 'Artifact') + assert hasattr(wandb, 'AlertLevel') + assert hasattr(wandb, 'Api') + + def test_from_wandb_import_init(self): + """``from wandb import init, log, finish`` works.""" + from wandb import finish, init, log # noqa: F401 + + def test_from_wandb_import_data_types(self): + """``from wandb import Image, Table, ...`` works.""" + from wandb import ( # noqa: F401 + AlertLevel, + Artifact, + Audio, + Histogram, + Html, + Image, + Table, + Video, + ) + + def test_from_wandb_import_api(self): + """``from wandb import Api`` works.""" + from wandb import Api # noqa: F401 + + api = Api() + with pytest.raises(NotImplementedError): + api.runs() + + def test_wandb_sdk_import(self): + """``import wandb.sdk`` works.""" + import wandb.sdk # noqa: F401 + + assert hasattr(wandb.sdk, 'init') + + def test_wandb_sdk_data_types_import(self): + """``from wandb.sdk.data_types import Image`` works.""" + from wandb.sdk.data_types import Image # noqa: F401 + + def test_wandb_data_types_import(self): + """``from wandb.data_types import Table`` works.""" + from wandb.data_types import Table # noqa: F401 + + def test_wandb_plot_import(self): + """``from wandb import plot; wandb.plot.line_series(...)`` works.""" + import wandb.plot + + # Should be no-ops, not errors + result = wandb.plot.line_series([1, 2], [[1, 2]], title='test') + assert result is None + assert wandb.plot.confusion_matrix() is None + assert wandb.plot.roc_curve() is None + assert wandb.plot.pr_curve() is None + + def test_wandb_apis_import(self): + """``from wandb.apis import Api`` works.""" + from wandb.apis import Api # noqa: F401 + + def test_wandb_util_import(self): + """``from wandb.util import generate_id`` works.""" + from wandb.util import generate_id + + rid = generate_id() + assert isinstance(rid, str) + assert len(rid) == 8 + + def test_wandb_login(self): + """``wandb.login()`` returns True (no-op).""" + import wandb + + assert wandb.login() is True + + def test_wandb_init_log_finish_e2e(self): + """Full workflow through top-level ``import wandb``.""" + import wandb + + with mock.patch('pluto.init') as mock_init: + mock_op = _make_mock_op() + mock_init.return_value = mock_op + + run = wandb.init(project='test-shim') + wandb.config.lr = 0.01 + wandb.log({'loss': 0.5}) + wandb.log({'loss': 0.3}) + assert run.summary['loss'] == pytest.approx(0.3) + wandb.finish() + + mock_init.assert_called_once() + assert mock_init.call_args[1]['project'] == 'test-shim' + assert mock_op.log.call_count == 2 + mock_op.finish.assert_called_once() + + def test_wandb_settings_class(self): + """``wandb.Settings(...)`` works.""" + import wandb + + s = wandb.Settings(mode='offline') + assert s.mode == 'offline' + + def test_wandb_integration_lightning_import(self): + """``from wandb.integration.lightning import WandbLogger`` works.""" + try: + from wandb.integration.lightning import WandbLogger # noqa: F401 + except ImportError: + # Lightning not installed — that's fine, the import path itself resolved + pytest.skip('lightning not installed') diff --git a/wandb/__init__.py b/wandb/__init__.py new file mode 100644 index 0000000..31576c8 --- /dev/null +++ b/wandb/__init__.py @@ -0,0 +1,25 @@ +"""wandb drop-in replacement backed by pluto. + +This package allows ``import wandb`` to resolve to pluto's wandb +compatibility layer. Users swap ``wandb`` for ``pluto-ml`` in their +dependencies and keep all source code unchanged. + +All public wandb API symbols (init, log, finish, watch, config, summary, +Image, Table, etc.) are re-exported from pluto.compat.wandb. +""" + +# Re-export everything from the compat layer +from pluto.compat.wandb import * # noqa: F401, F403 +from pluto.compat.wandb import ( # noqa: F401 + __all__, + config, + run, + summary, +) + +# Additional symbols that wandb exposes at top level +from wandb.apis import Api # noqa: F401 + +# wandb exposes a few submodule-level imports that users rely on. +# We handle the most common ones (wandb.sdk, wandb.data_types, etc.) +# via sub-packages defined alongside this __init__.py. diff --git a/wandb/apis/__init__.py b/wandb/apis/__init__.py new file mode 100644 index 0000000..1a70025 --- /dev/null +++ b/wandb/apis/__init__.py @@ -0,0 +1,48 @@ +"""Stub for wandb.apis — public API client. + +wandb.Api() is used for querying runs, artifacts, etc. This is not +supported by pluto, so we provide a stub that raises informative errors. +""" + +import logging + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat.Api' + + +class Api: + """Stub for wandb.Api — not supported by pluto compat layer. + + Instantiation succeeds but query methods raise NotImplementedError + with a helpful message. + """ + + def __init__(self, *args, **kwargs): + logger.debug('%s: Api instantiated (queries not supported)', tag) + + def _unsupported(self, method): + raise NotImplementedError( + f'wandb.Api().{method}() is not supported by the pluto ' + 'compatibility layer. Use the pluto API directly.' + ) + + def runs(self, *args, **kwargs): + self._unsupported('runs') + + def run(self, *args, **kwargs): + self._unsupported('run') + + def artifact(self, *args, **kwargs): + self._unsupported('artifact') + + def artifacts(self, *args, **kwargs): + self._unsupported('artifacts') + + def sweep(self, *args, **kwargs): + self._unsupported('sweep') + + +# Also expose at top level since some code does: +# from wandb import Api +# from wandb.apis import Api +__all__ = ['Api'] diff --git a/wandb/data_types/__init__.py b/wandb/data_types/__init__.py new file mode 100644 index 0000000..380416f --- /dev/null +++ b/wandb/data_types/__init__.py @@ -0,0 +1,3 @@ +"""Stub for wandb.data_types — re-exports wandb data types.""" + +from pluto.compat.wandb.data_types import * # noqa: F401, F403 diff --git a/wandb/integration/__init__.py b/wandb/integration/__init__.py new file mode 100644 index 0000000..f44ccf8 --- /dev/null +++ b/wandb/integration/__init__.py @@ -0,0 +1 @@ +"""Stub for wandb.integration — framework integration hooks.""" diff --git a/wandb/integration/lightning/__init__.py b/wandb/integration/lightning/__init__.py new file mode 100644 index 0000000..911a679 --- /dev/null +++ b/wandb/integration/lightning/__init__.py @@ -0,0 +1,9 @@ +"""Stub for wandb.integration.lightning — provides WandbLogger. + +Maps ``from wandb.integration.lightning import WandbLogger`` to pluto's +own Lightning logger (pluto.compat.lightning.MLOPLogger). +""" + +from pluto.compat.lightning import MLOPLogger as WandbLogger # noqa: F401 + +__all__ = ['WandbLogger'] diff --git a/wandb/plot/__init__.py b/wandb/plot/__init__.py new file mode 100644 index 0000000..f55b4cb --- /dev/null +++ b/wandb/plot/__init__.py @@ -0,0 +1,79 @@ +"""Stub for wandb.plot — custom plot helpers. + +These are wandb-specific visualization helpers. We provide no-op stubs +so code that calls them doesn't crash. +""" + +import logging + +logger = logging.getLogger(f'{__name__.split(".")[0]}') +tag = 'WandbCompat.Plot' + + +def line_series(xs, ys, keys=None, title=None, xname=None, **kwargs): + """No-op stub for wandb.plot.line_series.""" + logger.debug('%s: line_series is not supported', tag) + return None + + +def scatter(table, x, y, title=None, **kwargs): + """No-op stub for wandb.plot.scatter.""" + logger.debug('%s: scatter is not supported', tag) + return None + + +def bar(table, label, value, title=None, **kwargs): + """No-op stub for wandb.plot.bar.""" + logger.debug('%s: bar is not supported', tag) + return None + + +def histogram(table, value, title=None, **kwargs): + """No-op stub for wandb.plot.histogram.""" + logger.debug('%s: histogram is not supported', tag) + return None + + +def line(table, x, y, stroke=None, title=None, **kwargs): + """No-op stub for wandb.plot.line.""" + logger.debug('%s: line is not supported', tag) + return None + + +def confusion_matrix( + y_true=None, + preds=None, + class_names=None, + title=None, + probs=None, + **kwargs, +): + """No-op stub for wandb.plot.confusion_matrix.""" + logger.debug('%s: confusion_matrix is not supported', tag) + return None + + +def roc_curve( + y_true=None, + y_probas=None, + labels=None, + title=None, + classes_to_plot=None, + **kwargs, +): + """No-op stub for wandb.plot.roc_curve.""" + logger.debug('%s: roc_curve is not supported', tag) + return None + + +def pr_curve( + y_true=None, + y_probas=None, + labels=None, + title=None, + classes_to_plot=None, + **kwargs, +): + """No-op stub for wandb.plot.pr_curve.""" + logger.debug('%s: pr_curve is not supported', tag) + return None diff --git a/wandb/sdk/__init__.py b/wandb/sdk/__init__.py new file mode 100644 index 0000000..2e444a4 --- /dev/null +++ b/wandb/sdk/__init__.py @@ -0,0 +1,9 @@ +"""Stub for wandb.sdk — re-exports from pluto.compat.wandb.""" + +from pluto.compat.wandb import * # noqa: F401, F403 +from pluto.compat.wandb import ( # noqa: F401 + __all__, + config, + run, + summary, +) diff --git a/wandb/sdk/data_types/__init__.py b/wandb/sdk/data_types/__init__.py new file mode 100644 index 0000000..d95351f --- /dev/null +++ b/wandb/sdk/data_types/__init__.py @@ -0,0 +1,3 @@ +"""Stub for wandb.sdk.data_types — re-exports wandb data types.""" + +from pluto.compat.wandb.data_types import * # noqa: F401, F403 diff --git a/wandb/util/__init__.py b/wandb/util/__init__.py new file mode 100644 index 0000000..ea76f27 --- /dev/null +++ b/wandb/util/__init__.py @@ -0,0 +1,29 @@ +"""Stub for wandb.util — utility helpers. + +Provides the handful of wandb.util functions that user code sometimes +calls directly. +""" + +import random +import string + + +def generate_id(length=8): + """Generate a random run id (same format as wandb).""" + chars = string.ascii_lowercase + string.digits + return ''.join(random.choices(chars, k=length)) + + +def make_artifact_name_safe(name): + """Sanitize artifact name.""" + return name.replace('/', '-').replace('\\', '-') + + +def to_json(obj): + """Best-effort JSON conversion.""" + import json + + try: + return json.dumps(obj) + except (TypeError, ValueError): + return str(obj) From 1eba5895a1b01bd0f428253688327f335cafbeb5 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 13 Feb 2026 22:51:55 +0000 Subject: [PATCH 4/7] Add visual parity test harness for wandb vs pluto shim Two ways to compare dashboards side-by-side: 1. pytest-based (tests/test_wandb_visual_parity.py): # Pluto shim side (our wandb package): PLUTO_API_TOKEN= pytest tests/test_wandb_visual_parity.py -k pluto -v -s # Real wandb side (separate venv with pip install wandb): WANDB_API_KEY= pytest tests/test_wandb_visual_parity.py -k real_wandb -v -s 2. Standalone runner (tests/wandb_visual_parity_runner.py): # Same script, auto-detects which backend is installed: PLUTO_API_TOKEN= python tests/wandb_visual_parity_runner.py WANDB_API_KEY= python tests/wandb_visual_parity_runner.py Both run identical training loops (20 epochs, 1000 steps, same seed) with: scalar metrics, nested namespaces (train/, val/), histograms, tables, images, config mutations, summary overrides, and tags. Prints dashboard URLs for visual comparison. https://claude.ai/code/session_01VTSZKK5UsMqjiADFX57SMY --- tests/test_wandb_visual_parity.py | 278 ++++++++++++++++++++++++++++ tests/wandb_visual_parity_runner.py | 213 +++++++++++++++++++++ 2 files changed, 491 insertions(+) create mode 100644 tests/test_wandb_visual_parity.py create mode 100644 tests/wandb_visual_parity_runner.py diff --git a/tests/test_wandb_visual_parity.py b/tests/test_wandb_visual_parity.py new file mode 100644 index 0000000..b949a56 --- /dev/null +++ b/tests/test_wandb_visual_parity.py @@ -0,0 +1,278 @@ +""" +Visual parity test: runs the same training script through both the real +wandb SDK and the pluto wandb shim, then prints dashboard URLs for +side-by-side visual inspection. + +Usage: + # Run the pluto shim side (always available): + python -m pytest tests/test_wandb_visual_parity.py -k pluto -v -s + + # Run the real wandb side (requires `pip install wandb`): + python -m pytest tests/test_wandb_visual_parity.py -k real_wandb -v -s + + # Run both (requires wandb installed): + python -m pytest tests/test_wandb_visual_parity.py -v -s + +Requirements: + - PLUTO_API_TOKEN must be set for the pluto side + - WANDB_API_KEY must be set for the real wandb side + - `pip install wandb` for the real wandb tests + +The test names, configs, and logged data are identical so the resulting +dashboards should look the same. +""" + +import importlib +import math +import os + +import numpy as np +import pytest + +from tests.utils import get_task_name + +# --------------------------------------------------------------------------- +# Shared constants — identical between both sides +# --------------------------------------------------------------------------- + +PLUTO_PROJECT = 'wandb-visual-parity' +WANDB_PROJECT = os.environ.get('WANDB_VISUAL_PARITY_PROJECT', 'wandb-visual-parity') +NUM_EPOCHS = 20 +NUM_STEPS_PER_EPOCH = 50 +CONFIG = { + 'lr': 0.001, + 'batch_size': 64, + 'optimizer': 'adam', + 'architecture': 'resnet18', + 'dataset': 'cifar10', + 'epochs': NUM_EPOCHS, + 'dropout': 0.1, + 'weight_decay': 1e-4, +} + + +# --------------------------------------------------------------------------- +# Shared training loop — parameterised by the wandb module +# --------------------------------------------------------------------------- + + +def _run_training_loop(wb, project, run_name): + """Run a fake but realistic training loop through the given wandb module. + + Returns the dashboard URL (or None). + """ + run = wb.init( + project=project, + name=run_name, + config=CONFIG, + tags=['visual-parity', 'automated'], + ) + + # Post-init config mutation (wandb pattern) + wb.config.update({'scheduler': 'cosine_annealing'}) + wb.config.seed = 42 + + for epoch in range(NUM_EPOCHS): + # Simulated training metrics with realistic decay curves + base_loss = 2.0 * math.exp(-0.15 * epoch) + 0.1 + base_acc = 1.0 - math.exp(-0.2 * epoch) * 0.6 + + for step in range(NUM_STEPS_PER_EPOCH): + noise = np.random.normal(0, 0.02) + + wb.log( + { + 'train/loss': base_loss + noise + 0.05 * math.sin(step * 0.3), + 'train/accuracy': min(base_acc + noise * 0.5, 1.0), + 'train/learning_rate': CONFIG['lr'] * (0.95**epoch), + } + ) + + # Epoch-level validation metrics + val_loss = base_loss * 1.1 + np.random.normal(0, 0.03) + val_acc = base_acc * 0.98 + np.random.normal(0, 0.01) + + wb.log( + { + 'val/loss': val_loss, + 'val/accuracy': min(val_acc, 1.0), + 'epoch': epoch, + } + ) + + # Log a histogram every 5 epochs + if epoch % 5 == 0: + gradient_norms = np.random.lognormal( + mean=-1.0 + epoch * 0.05, + sigma=0.5, + size=1000, + ) + wb.log( + { + 'gradients/norm_distribution': wb.Histogram(gradient_norms), + } + ) + + # Log a table at the midpoint + if epoch == NUM_EPOCHS // 2: + table = wb.Table( + columns=['sample_id', 'predicted', 'actual', 'confidence'], + data=[ + [i, i % 10, (i + 1) % 10, round(np.random.uniform(0.5, 1.0), 3)] + for i in range(20) + ], + ) + wb.log({'predictions': table}) + + # Log an image every 5 epochs + if epoch % 5 == 0: + img_array = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + wb.log( + { + 'samples/random_image': wb.Image( + img_array, caption=f'epoch-{epoch}' + ), + } + ) + + # Manual summary overrides + run.summary['best_val_loss'] = 0.15 + run.summary['best_val_accuracy'] = 0.94 + run.summary['total_steps'] = NUM_EPOCHS * NUM_STEPS_PER_EPOCH + + wb.finish() + + # Extract URL + url = getattr(run, 'url', None) + if url is None: + url = getattr(getattr(run, '_op', None), 'settings', None) + if url is not None: + url = getattr(url, 'url_view', None) + + return url + + +# --------------------------------------------------------------------------- +# Pluto shim side +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + not os.environ.get('PLUTO_API_TOKEN'), + reason='PLUTO_API_TOKEN not set — cannot run pluto side', +) +class TestPlutoShimSide: + """Runs the training loop through ``import wandb`` (the pluto shim).""" + + def test_pluto_training_loop(self): + import wandb + + run_name = f'pluto-shim-{get_task_name()}' + url = _run_training_loop(wandb, PLUTO_PROJECT, run_name) + + print('\n') + print('=' * 60) + print(' PLUTO SHIM RUN') + print(f' Name: {run_name}') + print(f' URL: {url or "(not available)"}') + print('=' * 60) + + +# --------------------------------------------------------------------------- +# Real wandb side +# --------------------------------------------------------------------------- + +_has_real_wandb = False +try: + # Only consider real wandb available if it's NOT our shim + spec = importlib.util.find_spec('wandb') + if spec and spec.origin: + # Our shim lives under the pluto repo; real wandb doesn't + _has_real_wandb = 'pluto' not in (spec.origin or '') +except Exception: + pass + + +@pytest.mark.skipif( + not _has_real_wandb, + reason='real wandb package not installed (only pluto shim found)', +) +@pytest.mark.skipif( + not os.environ.get('WANDB_API_KEY'), + reason='WANDB_API_KEY not set — cannot run real wandb side', +) +class TestRealWandbSide: + """Runs the training loop through the real ``wandb`` SDK.""" + + def test_real_wandb_training_loop(self): + import wandb + + run_name = f'real-wandb-{get_task_name()}' + url = _run_training_loop(wandb, WANDB_PROJECT, run_name) + + if url is None and wandb.run is None: + # wandb.finish() clears wandb.run, but the URL was printed + url = '(check wandb console output above)' + + print('\n') + print('=' * 60) + print(' REAL WANDB RUN') + print(f' Name: {run_name}') + print(f' URL: {url or "(not available)"}') + print('=' * 60) + + +# --------------------------------------------------------------------------- +# Combined runner — convenience for running both back-to-back +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + not os.environ.get('PLUTO_API_TOKEN'), + reason='PLUTO_API_TOKEN not set', +) +@pytest.mark.skipif( + not _has_real_wandb, + reason='real wandb not installed', +) +@pytest.mark.skipif( + not os.environ.get('WANDB_API_KEY'), + reason='WANDB_API_KEY not set', +) +class TestSideBySide: + """Run both sides in one test and print URLs together.""" + + def test_side_by_side(self): + # Use a shared suffix so the runs are easy to find together + suffix = get_task_name() + + # --- Pluto shim --- + import wandb as pluto_wandb + + pluto_name = f'pluto-{suffix}' + pluto_url = _run_training_loop(pluto_wandb, PLUTO_PROJECT, pluto_name) + + # --- Real wandb --- + # We need to force-reimport real wandb. Since our shim took the + # `wandb` namespace, this only works if real wandb is installed + # in a separate venv. For CI, use subprocess isolation instead. + # Here we just document the limitation. + real_name = f'wandb-{suffix}' + real_url = '(run separately: pytest -k real_wandb)' + + print('\n') + print('=' * 60) + print(' VISUAL PARITY COMPARISON') + print('-' * 60) + print(f' Pluto shim: {pluto_url or "(not available)"}') + print(f' name={pluto_name}') + print(f' Real wandb: {real_url}') + print(f' name={real_name}') + print('=' * 60) + print() + print(' To run real wandb side in a separate env:') + print(' pip install wandb') + print( + f' WANDB_API_KEY= python -m pytest {__file__} -k real_wandb -v -s' + ) + print('=' * 60) diff --git a/tests/wandb_visual_parity_runner.py b/tests/wandb_visual_parity_runner.py new file mode 100644 index 0000000..f94b9b2 --- /dev/null +++ b/tests/wandb_visual_parity_runner.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Standalone visual parity runner. + +Run this script in two separate environments to compare dashboards: + + Environment A (pluto shim): + pip install pluto-ml + PLUTO_API_TOKEN= python tests/wandb_visual_parity_runner.py + + Environment B (real wandb): + pip install wandb + WANDB_API_KEY= python tests/wandb_visual_parity_runner.py + +Both runs will use identical config, metrics, and data types. +Compare the resulting dashboard URLs visually. +""" + +import argparse +import math +import sys +import time + +import numpy as np + +# --------------------------------------------------------------------------- +# Detect which backend we're running on +# --------------------------------------------------------------------------- + + +def _detect_backend(): + """Detect whether we're running on real wandb or pluto shim.""" + try: + import wandb + + origin = getattr(getattr(wandb, '__spec__', None), 'origin', '') or '' + if 'pluto' in origin: + return 'pluto-shim', wandb + # Check if it has the real wandb's internal modules + if hasattr(wandb, 'sdk') and hasattr(wandb.sdk, 'wandb_run'): + return 'real-wandb', wandb + # Fallback: check for pluto re-export marker + if hasattr(wandb, '_get_module'): + return 'pluto-shim', wandb + return 'real-wandb', wandb + except ImportError: + print('ERROR: neither wandb nor pluto-ml is installed') + sys.exit(1) + + +# --------------------------------------------------------------------------- +# Training loop +# --------------------------------------------------------------------------- + +NUM_EPOCHS = 20 +NUM_STEPS_PER_EPOCH = 50 +CONFIG = { + 'lr': 0.001, + 'batch_size': 64, + 'optimizer': 'adam', + 'architecture': 'resnet18', + 'dataset': 'cifar10', + 'epochs': NUM_EPOCHS, + 'dropout': 0.1, + 'weight_decay': 1e-4, +} + + +def run_training(wb, project, run_name, seed=42): + """Run a realistic training loop that exercises the wandb API surface.""" + np.random.seed(seed) + + run = wb.init( + project=project, + name=run_name, + config=CONFIG, + tags=['visual-parity', 'automated'], + ) + + # Post-init config mutations + wb.config.update({'scheduler': 'cosine_annealing'}) + wb.config.seed = seed + + print(f' Logging {NUM_EPOCHS} epochs x {NUM_STEPS_PER_EPOCH} steps...') + + for epoch in range(NUM_EPOCHS): + base_loss = 2.0 * math.exp(-0.15 * epoch) + 0.1 + base_acc = 1.0 - math.exp(-0.2 * epoch) * 0.6 + + for step in range(NUM_STEPS_PER_EPOCH): + noise = np.random.normal(0, 0.02) + wb.log( + { + 'train/loss': base_loss + noise + 0.05 * math.sin(step * 0.3), + 'train/accuracy': min(base_acc + noise * 0.5, 1.0), + 'train/learning_rate': CONFIG['lr'] * (0.95**epoch), + } + ) + + val_loss = base_loss * 1.1 + np.random.normal(0, 0.03) + val_acc = base_acc * 0.98 + np.random.normal(0, 0.01) + wb.log( + { + 'val/loss': val_loss, + 'val/accuracy': min(val_acc, 1.0), + 'epoch': epoch, + } + ) + + # Histogram every 5 epochs + if epoch % 5 == 0: + gradient_norms = np.random.lognormal(-1.0 + epoch * 0.05, 0.5, 1000) + wb.log( + { + 'gradients/norm_distribution': wb.Histogram(gradient_norms), + } + ) + + # Table at midpoint + if epoch == NUM_EPOCHS // 2: + table = wb.Table( + columns=['sample_id', 'predicted', 'actual', 'confidence'], + data=[ + [i, i % 10, (i + 1) % 10, round(np.random.uniform(0.5, 1.0), 3)] + for i in range(20) + ], + ) + wb.log({'predictions': table}) + + # Image every 5 epochs + if epoch % 5 == 0: + img = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + wb.log( + { + 'samples/random_image': wb.Image(img, caption=f'epoch-{epoch}'), + } + ) + + if (epoch + 1) % 5 == 0: + print(f' Epoch {epoch + 1}/{NUM_EPOCHS} done') + + run.summary['best_val_loss'] = 0.15 + run.summary['best_val_accuracy'] = 0.94 + run.summary['total_steps'] = NUM_EPOCHS * NUM_STEPS_PER_EPOCH + + wb.finish() + + # Extract URL + url = getattr(run, 'url', None) + if url is None: + settings = getattr(getattr(run, '_op', None), 'settings', None) + if settings: + url = getattr(settings, 'url_view', None) + if url is None: + url = getattr(run, '_get_run_url', lambda: None)() + + return url + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description='Visual parity runner for wandb vs pluto shim' + ) + parser.add_argument( + '--project', + default=None, + help='Project name (default: wandb-visual-parity)', + ) + parser.add_argument( + '--name', + default=None, + help='Run name (default: auto-generated)', + ) + parser.add_argument( + '--seed', + type=int, + default=42, + help='Random seed for reproducible noise', + ) + args = parser.parse_args() + + backend_name, wb = _detect_backend() + ts = int(time.time()) % 100000 + + project = args.project or 'wandb-visual-parity' + run_name = args.name or f'{backend_name}-{ts}' + + print() + print('=' * 60) + print(f' Backend: {backend_name}') + print(f' Project: {project}') + print(f' Run: {run_name}') + print(f' Seed: {args.seed}') + print('=' * 60) + print() + + url = run_training(wb, project, run_name, seed=args.seed) + + print() + print('=' * 60) + print(f' DONE — {backend_name}') + print(f' URL: {url or "(check console output)"}') + print('=' * 60) + print() + + +if __name__ == '__main__': + main() From e37cb2cc4c65357bd5667e95c5e92887e00edd69 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 13 Feb 2026 18:31:35 -0500 Subject: [PATCH 5/7] Address Gemini review feedback on wandb compat layer - Fix Histogram._to_pluto() passing full tuple instead of bin edges - Initialize config_dict to {} to prevent None edge cases - Add exc_info=True for better init failure debugging - Remove unused _allow_val_change attribute from Config - Simplify redundant reinit/finish logic in init() - Remove dead run_id reassignment - Record actual start_time at Run init instead of returning time.time() - Use os.path.basename for log_artifact default name - Include booleans in summary (consistent with wandb behavior) - Strengthen test_tags_get_and_set assertion Co-Authored-By: Claude Opus 4.6 --- pluto/compat/wandb/__init__.py | 32 ++++++++++++++------------------ pluto/compat/wandb/config.py | 1 - pluto/compat/wandb/data_types.py | 3 +-- pluto/compat/wandb/run.py | 10 ++++++---- pluto/compat/wandb/summary.py | 2 +- tests/test_wandb_compat.py | 4 ++-- 6 files changed, 24 insertions(+), 28 deletions(-) diff --git a/pluto/compat/wandb/__init__.py b/pluto/compat/wandb/__init__.py index de51a5b..e3b352b 100644 --- a/pluto/compat/wandb/__init__.py +++ b/pluto/compat/wandb/__init__.py @@ -128,19 +128,14 @@ def init( global run global summary - # If reinit, finish the previous run first + # Finish any previous run before creating a new one if run is not None: - if reinit: - try: - run.finish() - except Exception: - pass - else: + if not reinit: logger.debug('%s: init called with existing run, finishing previous', tag) - try: - run.finish() - except Exception: - pass + try: + run.finish() + except Exception: + pass # Resolve project from env if not provided project = ( @@ -162,14 +157,12 @@ def init( tags = [t.strip() for t in env_tags.split(',') if t.strip()] # Filter config keys if requested - config_dict: Optional[Dict[str, Any]] = None + config_dict: Dict[str, Any] = {} if config is not None: if isinstance(config, dict): config_dict = dict(config) elif hasattr(config, '__dict__'): config_dict = vars(config) - else: - config_dict = {} if config_dict and config_include_keys: config_dict = { @@ -189,8 +182,6 @@ def init( # Map wandb run_id / resume run_id = id - if resume in ('allow', 'must', 'auto', True) and id: - run_id = id # Store wandb-only metadata in config extra_config: Dict[str, Any] = {} @@ -201,7 +192,7 @@ def init( if job_type: extra_config['_wandb_job_type'] = job_type - merged_config = {**(config_dict or {}), **extra_config} or None + merged_config = {**config_dict, **extra_config} or None # Initialize pluto try: @@ -215,7 +206,12 @@ def init( run_id=run_id, ) except Exception as e: - logger.warning('%s: pluto.init() failed (%s), creating disabled run', tag, e) + logger.warning( + '%s: pluto.init() failed (%s), creating disabled run', + tag, + e, + exc_info=True, + ) # Return a disabled run that no-ops everything return _create_disabled_run( name=name, diff --git a/pluto/compat/wandb/config.py b/pluto/compat/wandb/config.py index a1bcc3a..462f6b7 100644 --- a/pluto/compat/wandb/config.py +++ b/pluto/compat/wandb/config.py @@ -18,7 +18,6 @@ def __init__(self, op: Optional[Any] = None) -> None: # Use object.__setattr__ to avoid triggering our __setattr__ object.__setattr__(self, '_op', op) object.__setattr__(self, '_data', {}) - object.__setattr__(self, '_allow_val_change', True) def _load(self, data: Optional[Dict[str, Any]]) -> None: if data: diff --git a/pluto/compat/wandb/data_types.py b/pluto/compat/wandb/data_types.py index 32a9019..c19048f 100644 --- a/pluto/compat/wandb/data_types.py +++ b/pluto/compat/wandb/data_types.py @@ -34,7 +34,6 @@ def __init__( self.data_or_path = data_or_path self.caption = caption self._mode = mode - self._grouping = grouping def _to_pluto(self) -> Any: from pluto.file import Image as PlutoImage @@ -161,7 +160,7 @@ def _to_pluto(self) -> Any: if self.np_histogram is not None: # np_histogram is a tuple of (values, bin_edges) - return PlutoHistogram(data=self.np_histogram, bins=self.np_histogram) + return PlutoHistogram(data=self.np_histogram, bins=self.np_histogram[1]) if self.sequence is not None: return PlutoHistogram(data=self.sequence, bins=self.num_bins) return PlutoHistogram(data=[0], bins=1) diff --git a/pluto/compat/wandb/run.py b/pluto/compat/wandb/run.py index 449cc3c..04b38bc 100644 --- a/pluto/compat/wandb/run.py +++ b/pluto/compat/wandb/run.py @@ -1,6 +1,7 @@ """wandb.Run-compatible wrapper around pluto.Op.""" import logging +import time as _time from typing import Any, Dict, List, Optional, Sequence, Union from .config import Config @@ -61,6 +62,7 @@ def __init__( self._pending_data: Dict[str, Any] = {} self._step = 0 self._watched_models: List[Any] = [] + self._start_time: float = _time.time() # -- Properties matching wandb.Run -- @@ -167,9 +169,7 @@ def settings(self) -> Any: @property def start_time(self) -> float: - import time - - return time.time() + return self._start_time @property def sweep_id(self) -> Optional[str]: @@ -338,10 +338,12 @@ def log_artifact( logger.debug('%s: log_artifact file failed: %s', tag, e) return artifact_or_path elif isinstance(artifact_or_path, str): + import os + from pluto.file import Artifact as PlutoArtifact art = PlutoArtifact(data=artifact_or_path, caption=name) - log_name = name or 'artifact' + log_name = name or os.path.basename(artifact_or_path) try: self._op.log({log_name: art}) except Exception as e: diff --git a/pluto/compat/wandb/summary.py b/pluto/compat/wandb/summary.py index 0db43cc..711364e 100644 --- a/pluto/compat/wandb/summary.py +++ b/pluto/compat/wandb/summary.py @@ -17,7 +17,7 @@ def _update_from_log(self, data: Dict[str, Any]) -> None: """Called internally after each log() call to update last values.""" store = object.__getattribute__(self, '_data') for k, v in data.items(): - if isinstance(v, (int, float)) and not isinstance(v, bool): + if isinstance(v, (int, float)): store[k] = v elif hasattr(v, 'item') and callable(v.item): store[k] = v.item() diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py index e779e47..a6cab4a 100644 --- a/tests/test_wandb_compat.py +++ b/tests/test_wandb_compat.py @@ -175,7 +175,7 @@ def test_update_from_log_ignores_non_scalars(self): s._update_from_log({'loss': 0.5, 'image': 'not_a_scalar', 'flag': True}) assert s['loss'] == 0.5 assert 'image' not in s - assert 'flag' not in s # bools excluded + assert s['flag'] is True # bools are scalar subclass of int def test_manual_override(self): from pluto.compat.wandb.summary import Summary @@ -412,7 +412,7 @@ def test_tags_get_and_set(self): # Setting tags run.tags = ('tag1', 'tag2') - op.add_tags.assert_called() + op.add_tags.assert_called_with(['tag2']) def test_name_setter(self): run, op = self._make_run(name='original') From b5620291ac605c67fc906d22055b06b6918177af Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 17 Feb 2026 16:47:09 -0500 Subject: [PATCH 6/7] Add define_metric with summary aggregation and CI fixes Implement define_metric across the full stack: - Op.define_metric() stores definitions and syncs to server (best-effort) - Op.get_metric_definition() with glob pattern support - Summary aggregation (min/max/mean/first/last) in wandb compat layer - Sync process plumbing (RecordType.METRIC_DEF, enqueue, upload, dispatch) - ServerInterface.update_metric_definitions() for direct API calls Also fixes two pre-existing CI failures: - Fix mypy error: __exit__ return type bool -> None in Run - Fix test_table_from_dataframe: add pytest.importorskip('pandas') Co-Authored-By: Claude Opus 4.6 --- pluto/compat/wandb/__init__.py | 2 +- pluto/compat/wandb/run.py | 32 +++++++++++-- pluto/compat/wandb/summary.py | 60 ++++++++++++++++++++++-- pluto/iface.py | 17 +++++++ pluto/op.py | 54 ++++++++++++++++++++++ pluto/sets.py | 3 ++ pluto/sync/process.py | 44 ++++++++++++++++++ pluto/sync/store.py | 1 + tests/test_define_metric.py | 82 +++++++++++++++++++++++++++++++++ tests/test_wandb_compat.py | 84 +++++++++++++++++++++++++++++++++- 10 files changed, 370 insertions(+), 9 deletions(-) create mode 100644 tests/test_define_metric.py diff --git a/pluto/compat/wandb/__init__.py b/pluto/compat/wandb/__init__.py index e3b352b..5c14856 100644 --- a/pluto/compat/wandb/__init__.py +++ b/pluto/compat/wandb/__init__.py @@ -322,7 +322,7 @@ def define_metric( goal: Optional[str] = None, overwrite: Optional[bool] = None, ) -> Any: - """Define metric behavior. No-op in pluto compat layer.""" + """Define metric behavior (aggregation, custom x-axis).""" if run is not None: return run.define_metric( name, diff --git a/pluto/compat/wandb/run.py b/pluto/compat/wandb/run.py index 04b38bc..99bd2fe 100644 --- a/pluto/compat/wandb/run.py +++ b/pluto/compat/wandb/run.py @@ -294,8 +294,33 @@ def define_metric( goal: Optional[str] = None, overwrite: Optional[bool] = None, ) -> Any: - """Define metric behavior. No-op in pluto (returns a stub).""" - logger.debug('%s: define_metric is a no-op', tag) + """Define metric behavior (aggregation, custom x-axis).""" + # Build definition dict for client-side aggregation + definition: dict = {'name': name} + if step_metric is not None: + definition['step_metric'] = step_metric + if summary is not None: + definition['summary'] = summary + if goal is not None: + definition['goal'] = goal + if hidden is not None: + definition['hidden'] = hidden + + # Register with Summary for client-side aggregation + self._summary._set_metric_definition(name, definition) + + # Forward to Op for server sync + try: + self._op.define_metric( + name, + step_metric=step_metric, + summary=summary, + goal=goal, + hidden=hidden, + ) + except Exception as e: + logger.debug('%s: define_metric server sync failed: %s', tag, e) + return _MetricStub(name) def save( @@ -381,10 +406,9 @@ def status(self) -> Dict[str, Any]: def __enter__(self) -> 'Run': return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: exit_code = 1 if exc_type else 0 self.finish(exit_code=exit_code) - return False def __repr__(self) -> str: return f'' diff --git a/pluto/compat/wandb/summary.py b/pluto/compat/wandb/summary.py index 711364e..a37f8f3 100644 --- a/pluto/compat/wandb/summary.py +++ b/pluto/compat/wandb/summary.py @@ -1,5 +1,6 @@ """wandb.summary-compatible dict-like summary metrics object.""" +from fnmatch import fnmatch from typing import Any, Dict, Iterator, Optional @@ -8,19 +9,72 @@ class Summary: Auto-populated from log() calls (last value per key for scalars). Supports manual overrides via dict/attribute access. + When metric definitions specify aggregation (min/max/mean/first/last), + summary values are computed accordingly instead of always keeping last. """ def __init__(self) -> None: object.__setattr__(self, '_data', {}) + object.__setattr__(self, '_metric_definitions', {}) + # Tracks running state for mean aggregation: {key: {'sum': float, 'count': int}} + object.__setattr__(self, '_agg_state', {}) + + def _set_metric_definition(self, name: str, definition: Dict[str, Any]) -> None: + """Register a metric definition for summary aggregation.""" + defs = object.__getattribute__(self, '_metric_definitions') + defs[name] = definition + + def _find_definition(self, key: str) -> Optional[Dict[str, Any]]: + """Find a metric definition by exact match first, then glob pattern.""" + defs = object.__getattribute__(self, '_metric_definitions') + if key in defs: + return defs[key] + for pattern, defn in defs.items(): + if fnmatch(key, pattern): + return defn + return None def _update_from_log(self, data: Dict[str, Any]) -> None: - """Called internally after each log() call to update last values.""" + """Called internally after each log() call to update summary values.""" store = object.__getattribute__(self, '_data') + agg_state = object.__getattribute__(self, '_agg_state') for k, v in data.items(): if isinstance(v, (int, float)): - store[k] = v + val = v elif hasattr(v, 'item') and callable(v.item): - store[k] = v.item() + val = v.item() + else: + continue + + defn = self._find_definition(k) + if defn is None: + # Default: keep last value + store[k] = val + continue + + summary_mode = defn.get('summary') + if summary_mode == 'min': + if k in store: + store[k] = min(store[k], val) + else: + store[k] = val + elif summary_mode == 'max': + if k in store: + store[k] = max(store[k], val) + else: + store[k] = val + elif summary_mode == 'mean': + if k not in agg_state: + agg_state[k] = {'sum': 0.0, 'count': 0} + agg_state[k]['sum'] += val + agg_state[k]['count'] += 1 + store[k] = agg_state[k]['sum'] / agg_state[k]['count'] + elif summary_mode == 'first': + if k not in store: + store[k] = val + else: + # "last" or unrecognized — keep last value + store[k] = val # -- Attribute access -- diff --git a/pluto/iface.py b/pluto/iface.py index cd62877..2216b14 100644 --- a/pluto/iface.py +++ b/pluto/iface.py @@ -100,6 +100,23 @@ def update_config(self, config: Dict[str, Any]) -> None: client=self.client_api, ) + def update_metric_definitions(self, metrics: List[Dict[str, Any]]) -> None: + """Update metric definitions on the server via HTTP API.""" + payload = json.dumps( + { + 'runId': self.settings._op_id, + 'metrics': metrics, + } + ).encode() + headers = self.headers.copy() + headers['Content-Type'] = 'application/json' + self._post_v1( + self.settings.url_update_metric_defs, + headers, + payload, + client=self.client_api, + ) + # Keep legacy underscore methods for backwards compatibility def _update_status(self, settings, trace: Union[Any, None] = None): """Legacy method - use update_status() instead.""" diff --git a/pluto/op.py b/pluto/op.py index 0c580d1..13c31f1 100644 --- a/pluto/op.py +++ b/pluto/op.py @@ -241,6 +241,7 @@ def __init__(self, config, settings, tags=None) -> None: self.config = config self.settings = settings self.tags: List[str] = tags if tags else [] # Use provided tags or empty list + self._metric_definitions: Dict[str, Dict[str, Any]] = {} self._monitor = OpMonitor(op=self) self._resumed: bool = False # Whether this run was resumed (multi-node) self._sync_manager: Optional[SyncProcessManager] = None @@ -339,6 +340,7 @@ def _init_sync_manager(self) -> None: 'url_update_tags': self.settings.url_update_tags, 'url_file': self.settings.url_file, # For file uploads 'url_message': self.settings.url_message, # For console logs + 'url_update_metric_defs': self.settings.url_update_metric_defs, 'x_log_level': self.settings.x_log_level, 'sync_process_flush_interval': (self.settings.sync_process_flush_interval), 'sync_process_shutdown_timeout': ( @@ -758,6 +760,58 @@ def update_config(self, config: Dict[str, Any]) -> None: except Exception as e: logger.debug(f'{tag}: failed to sync config to server: {e}') + def define_metric( + self, + name: str, + step_metric: Optional[str] = None, + summary: Optional[str] = None, + goal: Optional[str] = None, + hidden: Optional[bool] = None, + ) -> Dict[str, Any]: + """Define metric behavior (aggregation, custom x-axis). + + Args: + name: Metric name or glob pattern (e.g. "val/*") + step_metric: Name of metric to use as x-axis + summary: Aggregation mode: "min", "max", "mean", "first", "last" + goal: Optimization goal: "minimize" or "maximize" + hidden: Whether to hide this metric in the dashboard + """ + definition: Dict[str, Any] = {'name': name} + if step_metric is not None: + definition['step_metric'] = step_metric + if summary is not None: + definition['summary'] = summary + if goal is not None: + definition['goal'] = goal + if hidden is not None: + definition['hidden'] = hidden + self._metric_definitions[name] = definition + + # Sync to server (best-effort) + if self._sync_manager is not None: + self._sync_manager.enqueue_metric_definition( + definition, int(time.time() * 1000) + ) + elif self._iface: + try: + self._iface.update_metric_definitions([definition]) + except Exception as e: + logger.debug(f'{tag}: failed to sync metric definition to server: {e}') + + return definition + + def get_metric_definition(self, name: str) -> Optional[Dict[str, Any]]: + """Get the metric definition for a name, supporting glob patterns.""" + if name in self._metric_definitions: + return self._metric_definitions[name] + from fnmatch import fnmatch + + for pattern, defn in self._metric_definitions.items(): + if fnmatch(name, pattern): + return defn + return None + @property def resumed(self) -> bool: """ diff --git a/pluto/sets.py b/pluto/sets.py index 7f8502d..8077c1a 100644 --- a/pluto/sets.py +++ b/pluto/sets.py @@ -128,6 +128,7 @@ def update_url(self) -> None: self.url_data = f'{self.url_ingest}/ingest/data' self.url_file = f'{self.url_ingest}/files' self.url_message = f'{self.url_ingest}/ingest/logs' + self.url_update_metric_defs = f'{self.url_api}/api/runs/metrics/define' self.url_alert = f'{self.url_py}/api/runs/alert' self.url_trigger = f'{self.url_py}/api/runs/trigger' @@ -249,6 +250,8 @@ def setup(settings: Union[Settings, Dict[str, Any], None] = None) -> Settings: # Read PLUTO_API_TOKEN environment variable (with MLOP_API_TOKEN fallback) # Only apply if not already set via function parameters env_api_token = _get_env_with_deprecation('PLUTO_API_TOKEN', 'MLOP_API_TOKEN') + if env_api_token is None: + env_api_token = os.environ.get('WANDB_API_KEY') if env_api_token is not None and '_auth' not in settings_dict: new_settings._auth = env_api_token diff --git a/pluto/sync/process.py b/pluto/sync/process.py index 1c93132..15a4b4d 100644 --- a/pluto/sync/process.py +++ b/pluto/sync/process.py @@ -284,6 +284,17 @@ def enqueue_system_metrics( timestamp_ms=timestamp_ms, ) + def enqueue_metric_definition( + self, definition: Dict[str, Any], timestamp_ms: int + ) -> None: + """Enqueue a metric definition for upload.""" + self.store.enqueue( + run_id=self.run_id, + record_type=RecordType.METRIC_DEF, + payload=definition, + timestamp_ms=timestamp_ms, + ) + def enqueue_console_log( self, message: str, @@ -620,6 +631,7 @@ def _sync_records_batch( system_records: List[SyncRecord] = [] data_records: List[SyncRecord] = [] console_records: List[SyncRecord] = [] + metric_def_records: List[SyncRecord] = [] for record in records: if record.record_type == RecordType.METRIC: @@ -634,6 +646,8 @@ def _sync_records_batch( data_records.append(record) elif record.record_type == RecordType.CONSOLE: console_records.append(record) + elif record.record_type == RecordType.METRIC_DEF: + metric_def_records.append(record) success_ids: List[int] = [] failed_ids: List[int] = [] @@ -699,6 +713,16 @@ def _sync_records_batch( failed_ids.extend(r.id for r in console_records) error_msg = str(e) + # Upload metric definitions + if metric_def_records: + try: + uploader.upload_metric_definitions(metric_def_records) + success_ids.extend(r.id for r in metric_def_records) + except Exception as e: + log.warning(f'Failed to upload metric definitions: {e}') + failed_ids.extend(r.id for r in metric_def_records) + error_msg = str(e) + # Update status store.mark_completed(success_ids) if failed_ids: @@ -857,6 +881,7 @@ def __init__(self, settings_dict: Dict[str, Any], log: logging.Logger): self.url_update_tags = settings_dict.get('url_update_tags', '') self.url_file = settings_dict.get('url_file', '') self.url_console = settings_dict.get('url_message', '') + self.url_update_metric_defs = settings_dict.get('url_update_metric_defs', '') # Retry settings (normal mode) self.retry_max = settings_dict.get('sync_process_retry_max', 5) @@ -981,6 +1006,25 @@ def upload_tags(self, record: SyncRecord) -> None: headers, ) + def upload_metric_definitions(self, records: List[SyncRecord]) -> None: + """Upload metric definitions.""" + if not self.url_update_metric_defs or not self.op_id: + return + + metrics = [record.payload for record in records] + payload = { + 'runId': self.op_id, + 'metrics': metrics, + } + + headers = self._get_headers() + headers['Content-Type'] = 'application/json' + self._post_with_retry( + self.url_update_metric_defs, + json.dumps(payload), + headers, + ) + def upload_system_batch(self, records: List[SyncRecord]) -> None: """Upload system metrics batch.""" # System metrics use same endpoint as regular metrics diff --git a/pluto/sync/store.py b/pluto/sync/store.py index eae7820..696118b 100644 --- a/pluto/sync/store.py +++ b/pluto/sync/store.py @@ -103,6 +103,7 @@ class RecordType(IntEnum): TAGS = 4 SYSTEM = 5 CONSOLE = 6 # stdout/stderr logs + METRIC_DEF = 7 # metric definitions (define_metric) @dataclass diff --git a/tests/test_define_metric.py b/tests/test_define_metric.py new file mode 100644 index 0000000..c73d4d4 --- /dev/null +++ b/tests/test_define_metric.py @@ -0,0 +1,82 @@ +"""Tests for Op.define_metric() and Op.get_metric_definition().""" + +from unittest.mock import MagicMock + + +class TestDefineMetric: + """Tests for pluto core define_metric functionality.""" + + def _make_op(self): + """Create a minimal Op with mocked internals for testing.""" + from pluto.op import Op + + op = object.__new__(Op) + op._metric_definitions = {} + op._sync_manager = None + op._iface = None + return op + + def test_define_metric_stores_definition(self): + op = self._make_op() + result = op.define_metric('loss', summary='min', goal='minimize') + assert result == {'name': 'loss', 'summary': 'min', 'goal': 'minimize'} + assert op._metric_definitions['loss'] == result + + def test_define_metric_only_includes_non_none(self): + op = self._make_op() + result = op.define_metric('acc') + assert result == {'name': 'acc'} + assert 'step_metric' not in result + assert 'summary' not in result + + def test_define_metric_with_step_metric(self): + op = self._make_op() + result = op.define_metric('val/loss', step_metric='epoch') + assert result['step_metric'] == 'epoch' + + def test_define_metric_with_hidden(self): + op = self._make_op() + result = op.define_metric('internal', hidden=True) + assert result['hidden'] is True + + def test_define_metric_glob_match(self): + op = self._make_op() + op.define_metric('val/*', summary='min') + defn = op.get_metric_definition('val/loss') + assert defn is not None + assert defn['summary'] == 'min' + + def test_define_metric_exact_over_glob(self): + op = self._make_op() + op.define_metric('val/*', summary='min') + op.define_metric('val/loss', summary='max') + defn = op.get_metric_definition('val/loss') + assert defn['summary'] == 'max' + + def test_get_metric_definition_returns_none_for_unknown(self): + op = self._make_op() + assert op.get_metric_definition('unknown') is None + + def test_define_metric_enqueues_to_sync(self): + op = self._make_op() + op._sync_manager = MagicMock() + op.define_metric('loss', summary='min') + op._sync_manager.enqueue_metric_definition.assert_called_once() + call_args = op._sync_manager.enqueue_metric_definition.call_args + assert call_args[0][0] == {'name': 'loss', 'summary': 'min'} + + def test_define_metric_falls_back_to_iface(self): + op = self._make_op() + op._iface = MagicMock() + op.define_metric('acc', summary='max') + op._iface.update_metric_definitions.assert_called_once_with( + [{'name': 'acc', 'summary': 'max'}] + ) + + def test_define_metric_iface_error_does_not_raise(self): + op = self._make_op() + op._iface = MagicMock() + op._iface.update_metric_definitions.side_effect = RuntimeError('server down') + # Should not raise + result = op.define_metric('loss', summary='min') + assert result == {'name': 'loss', 'summary': 'min'} diff --git a/tests/test_wandb_compat.py b/tests/test_wandb_compat.py index a6cab4a..ff7594f 100644 --- a/tests/test_wandb_compat.py +++ b/tests/test_wandb_compat.py @@ -213,6 +213,77 @@ def test_as_dict(self): s.update({'a': 1}) assert s.as_dict() == {'a': 1} + def test_summary_aggregation_min(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('loss', {'name': 'loss', 'summary': 'min'}) + s._update_from_log({'loss': 0.5}) + s._update_from_log({'loss': 0.3}) + s._update_from_log({'loss': 0.7}) + assert s['loss'] == pytest.approx(0.3) + + def test_summary_aggregation_max(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('acc', {'name': 'acc', 'summary': 'max'}) + s._update_from_log({'acc': 0.8}) + s._update_from_log({'acc': 0.95}) + s._update_from_log({'acc': 0.9}) + assert s['acc'] == pytest.approx(0.95) + + def test_summary_aggregation_mean(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('loss', {'name': 'loss', 'summary': 'mean'}) + s._update_from_log({'loss': 1.0}) + s._update_from_log({'loss': 2.0}) + s._update_from_log({'loss': 3.0}) + assert s['loss'] == pytest.approx(2.0) + + def test_summary_aggregation_first(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('lr', {'name': 'lr', 'summary': 'first'}) + s._update_from_log({'lr': 0.01}) + s._update_from_log({'lr': 0.001}) + s._update_from_log({'lr': 0.0001}) + assert s['lr'] == pytest.approx(0.01) + + def test_summary_aggregation_last(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('step', {'name': 'step', 'summary': 'last'}) + s._update_from_log({'step': 1}) + s._update_from_log({'step': 2}) + s._update_from_log({'step': 3}) + assert s['step'] == 3 + + def test_summary_aggregation_glob(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + s._set_metric_definition('val/*', {'name': 'val/*', 'summary': 'min'}) + s._update_from_log({'val/loss': 0.5, 'val/acc': 0.8}) + s._update_from_log({'val/loss': 0.3, 'val/acc': 0.6}) + s._update_from_log({'val/loss': 0.7, 'val/acc': 0.9}) + assert s['val/loss'] == pytest.approx(0.3) + assert s['val/acc'] == pytest.approx(0.6) + + def test_summary_no_definition_keeps_last(self): + from pluto.compat.wandb.summary import Summary + + s = Summary() + # No definition set — default behavior + s._update_from_log({'loss': 0.5}) + s._update_from_log({'loss': 0.3}) + s._update_from_log({'loss': 0.7}) + assert s['loss'] == pytest.approx(0.7) + class TestDataTypes: """Tests for wandb data type wrappers.""" @@ -290,7 +361,7 @@ def test_table_get_column(self): assert t.get_column('missing') == [] def test_table_from_dataframe(self): - import pandas as pd + pd = pytest.importorskip('pandas') from pluto.compat.wandb.data_types import Table @@ -504,6 +575,17 @@ def test_define_metric_returns_stub(self): run, op = self._make_run() m = run.define_metric('loss', step_metric='epoch') assert m.name == 'loss' + # Verify Op.define_metric was called + op.define_metric.assert_called_once_with( + 'loss', step_metric='epoch', summary=None, goal=None, hidden=None + ) + + def test_define_metric_with_step_metric(self): + run, op = self._make_run() + run.define_metric('val/loss', step_metric='epoch') + op.define_metric.assert_called_once_with( + 'val/loss', step_metric='epoch', summary=None, goal=None, hidden=None + ) def test_unsupported_methods_no_error(self): run, op = self._make_run() From 42710b3f8b31d200d16110f4381e5a86e0e17295 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 24 Feb 2026 16:35:14 -0500 Subject: [PATCH 7/7] Add live wandb compat tests to CI workflow Enable pluto shim visual parity live test using MLOP_API_TOKEN secret, matching the neptune-compat pattern. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/wandb-compat.yml | 110 +++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 .github/workflows/wandb-compat.yml diff --git a/.github/workflows/wandb-compat.yml b/.github/workflows/wandb-compat.yml new file mode 100644 index 0000000..f1ac977 --- /dev/null +++ b/.github/workflows/wandb-compat.yml @@ -0,0 +1,110 @@ +name: Wandb Compatibility Tests + +# This workflow tests the wandb-to-pluto compatibility layer. +# It runs on manual trigger or when a maintainer comments /wandb on a PR. + +on: + workflow_dispatch: + issue_comment: + types: [created] + +jobs: + wandb-compatibility: + # Run on manual dispatch, or when someone comments /wandb on a PR + if: > + github.event_name == 'workflow_dispatch' || + (github.event_name == 'issue_comment' && + github.event.issue.pull_request && + contains(github.event.comment.body, '/wandb')) + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + poetry-version: ["2.1.1"] + + steps: + - name: React to comment + if: github.event_name == 'issue_comment' + uses: actions/github-script@v7 + with: + script: | + await github.rest.reactions.createForIssueComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: context.payload.comment.id, + content: 'rocket' + }); + + - name: Get PR head ref + if: github.event_name == 'issue_comment' + id: pr + uses: actions/github-script@v7 + with: + script: | + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.issue.number + }); + core.setOutput('ref', pr.data.head.ref); + core.setOutput('sha', pr.data.head.sha); + + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'issue_comment' && steps.pr.outputs.ref || '' }} + + - name: "Setup Python, Poetry and Dependencies" + uses: packetcoders/action-setup-cache-python-poetry@v1.2.0 + with: + python-version: ${{ matrix.python-version }} + poetry-version: ${{ matrix.poetry-version }} + install-args: "--with dev" + + - name: Install dependencies + run: | + set -euox pipefail + rm -rf .venv + poetry install --with dev --extras full + + - name: Run wandb compatibility unit tests + run: | + set -euox pipefail + poetry run pytest tests/test_wandb_compat.py -v -rs -k "not test_table_from_dataframe" + + - name: Run wandb env var fallback tests + run: | + set -euox pipefail + poetry run pytest tests/test_env_vars.py::TestWANDBApiKeyFallback -v -rs + + - name: Run live pluto shim parity test + env: + PLUTO_API_TOKEN: ${{ secrets.MLOP_API_TOKEN }} + run: | + set -euox pipefail + poetry run pytest tests/test_wandb_visual_parity.py::TestPlutoShimSide -v -rs -s + + - name: Report test results + if: always() + run: | + echo "Wandb compatibility tests completed" + echo "Python version: ${{ matrix.python-version }}" + echo "Status: ${{ job.status }}" + + wandb-compat-status: + runs-on: ubuntu-latest + needs: [wandb-compatibility] + if: always() && github.event_name == 'issue_comment' + steps: + - name: Comment result on PR + uses: actions/github-script@v7 + with: + script: | + const status = '${{ needs.wandb-compatibility.result }}'; + const icon = status === 'success' ? '✅' : '❌'; + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `${icon} Wandb compatibility tests: **${status}**` + });