Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions tests/utils/swanlab_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import unittest


class TestSwanlabMonitor(unittest.TestCase):
@classmethod
def setUpClass(cls):
os.environ["SWANLAB_API_KEY"] = "xxxxxxxxxxxxxxxxxxxxx"

@classmethod
def tearDownClass(cls):
# Restore original environment variables
for k, v in cls._original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v

@unittest.skip("Requires swanlab package and network access")
def test_swanlab_monitor_smoke(self):
from trinity.utils.monitor import SwanlabMonitor

# Try creating the monitor; if swanlab isn't installed, __init__ will assert
mon = SwanlabMonitor(
project="trinity-smoke",
group="cradle",
name="swanlab-env",
role="tester",
)

# Log a minimal metric to verify basic flow
mon.log({"smoke/metric": 1.0}, step=1)
mon.close()


if __name__ == "__main__":
unittest.main()
124 changes: 124 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import mlflow
except ImportError:
mlflow = None

try:
import swanlab
except ImportError:
swanlab = None

from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
Expand All @@ -28,6 +34,7 @@
"tensorboard": "trinity.utils.monitor.TensorboardMonitor",
"wandb": "trinity.utils.monitor.WandbMonitor",
"mlflow": "trinity.utils.monitor.MlflowMonitor",
"swanlab": "trinity.utils.monitor.SwanlabMonitor",
},
)

Expand Down Expand Up @@ -232,3 +239,120 @@ def default_args(cls) -> Dict:
"username": None,
"password": None,
}


class SwanlabMonitor(Monitor):
"""Monitor with SwanLab.

This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments.

Supported monitor_args in config.monitor.monitor_args:
- api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env
(SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login.
- workspace (Optional[str]): Organization/username workspace.
- mode (Optional[str]): "cloud" | "local" | "offline" | "disabled".
- logdir (Optional[str]): Local log directory when in local/offline modes.
- experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}".
- description (Optional[str]): Experiment description.
- tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically.
- id (Optional[str]): Resume target run id (21 chars) when using resume modes.
- resume (Optional[Literal['must','allow','never']|bool]): Resume policy.
- reinit (Optional[bool]): Whether to re-init on repeated init() calls.
"""

def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert (
swanlab is not None
), "swanlab is not installed. Please install it to use SwanlabMonitor."

monitor_args = (
(config.monitor.monitor_args or {})
if config and getattr(config, "monitor", None)
else {}
)

# Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
api_key = os.environ.get("SWANLAB_API_KEY")
if api_key:
try:
swanlab.login(api_key=api_key, save=True)
except Exception:
# Best-effort login; continue to init which may still work if already logged in
pass
else:
raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.")

# Compose tags (ensure list and include role/group markers)
tags = monitor_args.get("tags") or []
if isinstance(tags, tuple):
tags = list(tags)
if role and role not in tags:
tags.append(role)
if group and group not in tags:
tags.append(group)

# Determine experiment name
exp_name = monitor_args.get("experiment_name") or f"{name}_{role}"
self.exp_name = exp_name

# Prepare init kwargs, passing only non-None values to respect library defaults
init_kwargs = {
"project": project,
"workspace": monitor_args.get("workspace"),
"experiment_name": exp_name,
"description": monitor_args.get("description"),
"tags": tags or None,
"logdir": monitor_args.get("logdir"),
"mode": monitor_args.get("mode") or "cloud",
"settings": monitor_args.get("settings"),
"id": monitor_args.get("id"),
"config": config.flatten() if config is not None else None,
"resume": monitor_args.get("resume"),
"reinit": monitor_args.get("reinit"),
}
# Strip None values to avoid overriding swanlab defaults
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}

self.logger = swanlab.init(**init_kwargs)
self.console_logger = get_logger(__name__, in_ray_actor=True)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
# Convert pandas DataFrame to SwanLab ECharts Table
headers: List[str] = list(experiences_table.columns)
# Ensure rows are native Python types
rows: List[List[object]] = experiences_table.astype(object).values.tolist()
try:
tbl = swanlab.echarts.Table()
tbl.add(headers, rows)
swanlab.log({table_name: tbl}, step=step)
except Exception as e:
self.console_logger.warning(
f"Failed to log table '{table_name}' as echarts, falling back to CSV. Error: {e}"
)
# Fallback: log as CSV string if echarts table is unavailable
csv_str = experiences_table.to_csv(index=False)
swanlab.log({table_name: csv_str}, step=step)

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
# SwanLab doesn't use commit flag; keep signature for compatibility
swanlab.log(data, step=step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
try:
# Prefer run.finish() if available
if hasattr(self, "logger") and hasattr(self.logger, "finish"):
self.logger.finish()
else:
# Fallback to global finish
swanlab.finish()
except Exception as e:
self.console_logger.warning(f"Failed to close SwanlabMonitor: {e}")

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {}