diff --git a/tests/utils/swanlab_test.py b/tests/utils/swanlab_test.py new file mode 100644 index 0000000000..6b7f6a9c1e --- /dev/null +++ b/tests/utils/swanlab_test.py @@ -0,0 +1,37 @@ +import os +import unittest + + +class TestSwanlabMonitor(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["SWANLAB_API_KEY"] = "xxxxxxxxxxxxxxxxxxxxx" + + @classmethod + def tearDownClass(cls): + # Restore original environment variables + for k, v in cls._original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + @unittest.skip("Requires swanlab package and network access") + def test_swanlab_monitor_smoke(self): + from trinity.utils.monitor import SwanlabMonitor + + # Try creating the monitor; if swanlab isn't installed, __init__ will assert + mon = SwanlabMonitor( + project="trinity-smoke", + group="cradle", + name="swanlab-env", + role="tester", + ) + + # Log a minimal metric to verify basic flow + mon.log({"smoke/metric": 1.0}, step=1) + mon.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 73b64229fb..21ef7726f1 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -16,6 +16,12 @@ import mlflow except ImportError: mlflow = None + +try: + import swanlab +except ImportError: + swanlab = None + from torch.utils.tensorboard import SummaryWriter from trinity.common.config import Config @@ -28,6 +34,7 @@ "tensorboard": "trinity.utils.monitor.TensorboardMonitor", "wandb": "trinity.utils.monitor.WandbMonitor", "mlflow": "trinity.utils.monitor.MlflowMonitor", + "swanlab": "trinity.utils.monitor.SwanlabMonitor", }, ) @@ -232,3 +239,120 @@ def default_args(cls) -> Dict: "username": None, "password": None, } + + +class SwanlabMonitor(Monitor): + """Monitor with SwanLab. + + This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments. + + Supported monitor_args in config.monitor.monitor_args: + - api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env + (SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login. + - workspace (Optional[str]): Organization/username workspace. + - mode (Optional[str]): "cloud" | "local" | "offline" | "disabled". + - logdir (Optional[str]): Local log directory when in local/offline modes. + - experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}". + - description (Optional[str]): Experiment description. + - tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically. + - id (Optional[str]): Resume target run id (21 chars) when using resume modes. + - resume (Optional[Literal['must','allow','never']|bool]): Resume policy. + - reinit (Optional[bool]): Whether to re-init on repeated init() calls. + """ + + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + assert ( + swanlab is not None + ), "swanlab is not installed. Please install it to use SwanlabMonitor." + + monitor_args = ( + (config.monitor.monitor_args or {}) + if config and getattr(config, "monitor", None) + else {} + ) + + # Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`. + api_key = os.environ.get("SWANLAB_API_KEY") + if api_key: + try: + swanlab.login(api_key=api_key, save=True) + except Exception: + # Best-effort login; continue to init which may still work if already logged in + pass + else: + raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.") + + # Compose tags (ensure list and include role/group markers) + tags = monitor_args.get("tags") or [] + if isinstance(tags, tuple): + tags = list(tags) + if role and role not in tags: + tags.append(role) + if group and group not in tags: + tags.append(group) + + # Determine experiment name + exp_name = monitor_args.get("experiment_name") or f"{name}_{role}" + self.exp_name = exp_name + + # Prepare init kwargs, passing only non-None values to respect library defaults + init_kwargs = { + "project": project, + "workspace": monitor_args.get("workspace"), + "experiment_name": exp_name, + "description": monitor_args.get("description"), + "tags": tags or None, + "logdir": monitor_args.get("logdir"), + "mode": monitor_args.get("mode") or "cloud", + "settings": monitor_args.get("settings"), + "id": monitor_args.get("id"), + "config": config.flatten() if config is not None else None, + "resume": monitor_args.get("resume"), + "reinit": monitor_args.get("reinit"), + } + # Strip None values to avoid overriding swanlab defaults + init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} + + self.logger = swanlab.init(**init_kwargs) + self.console_logger = get_logger(__name__, in_ray_actor=True) + + def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): + # Convert pandas DataFrame to SwanLab ECharts Table + headers: List[str] = list(experiences_table.columns) + # Ensure rows are native Python types + rows: List[List[object]] = experiences_table.astype(object).values.tolist() + try: + tbl = swanlab.echarts.Table() + tbl.add(headers, rows) + swanlab.log({table_name: tbl}, step=step) + except Exception as e: + self.console_logger.warning( + f"Failed to log table '{table_name}' as echarts, falling back to CSV. Error: {e}" + ) + # Fallback: log as CSV string if echarts table is unavailable + csv_str = experiences_table.to_csv(index=False) + swanlab.log({table_name: csv_str}, step=step) + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + # SwanLab doesn't use commit flag; keep signature for compatibility + swanlab.log(data, step=step) + self.console_logger.info(f"Step {step}: {data}") + + def close(self) -> None: + try: + # Prefer run.finish() if available + if hasattr(self, "logger") and hasattr(self.logger, "finish"): + self.logger.finish() + else: + # Fallback to global finish + swanlab.finish() + except Exception as e: + self.console_logger.warning(f"Failed to close SwanlabMonitor: {e}") + + @classmethod + def default_args(cls) -> Dict: + """Return default arguments for the monitor.""" + return {}