From 41ab1599286521d9562e9077683efbeff1a2902d Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sat, 21 Jun 2025 19:57:16 +0200 Subject: [PATCH 1/4] add predicate selectors --- ablate/queries/selectors.py | 41 ++++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/ablate/queries/selectors.py b/ablate/queries/selectors.py index a00a604..89f7bb5 100644 --- a/ablate/queries/selectors.py +++ b/ablate/queries/selectors.py @@ -1,8 +1,29 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from operator import eq, ge, gt, le, lt, ne -from typing import Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal + + +if TYPE_CHECKING: # pragma: no cover + from ablate.core.types import Run + + +class Predicate: + def __init__(self, fn: Callable[[Run], bool]) -> None: + self._fn = fn + + def __call__(self, run: Run) -> bool: + return self._fn(run) + + def __and__(self, other: Predicate) -> Predicate: + return Predicate(lambda run: self(run) and other(run)) + + def __or__(self, other: Predicate) -> Predicate: + return Predicate(lambda run: self(run) or other(run)) -from ablate.core.types import Run + def __invert__(self) -> Predicate: + return Predicate(lambda run: not self(run)) class AbstractSelector(ABC): @@ -20,25 +41,25 @@ def __init__(self, name: str, label: str | None = None) -> None: @abstractmethod def __call__(self, run: Run) -> Any: ... - def _cmp(self, op: Callable[[Any, Any], bool], other: Any) -> Callable[[Run], bool]: - return lambda run: op(self.__call__(run), other) + def _cmp(self, op: Callable[[Any, Any], bool], other: Any) -> Predicate: + return Predicate(lambda run: op(self(run), other)) - def __eq__(self, other: object) -> Callable[[Run], bool]: # type: ignore[override] + def __eq__(self, other: object) -> Predicate: # type: ignore[override] return self._cmp(eq, other) - def __ne__(self, other: object) -> Callable[[Run], bool]: # type: ignore[override] + def __ne__(self, other: object) -> Predicate: # type: ignore[override] return self._cmp(ne, other) - def __lt__(self, other: Any) -> Callable[[Run], bool]: + def __lt__(self, other: Any) -> Predicate: return self._cmp(lt, other) - def __le__(self, other: Any) -> Callable[[Run], bool]: + def __le__(self, other: Any) -> Predicate: return self._cmp(le, other) - def __gt__(self, other: Any) -> Callable[[Run], bool]: + def __gt__(self, other: Any) -> Predicate: return self._cmp(gt, other) - def __ge__(self, other: Any) -> Callable[[Run], bool]: + def __ge__(self, other: Any) -> Predicate: return self._cmp(ge, other) From 4ed25ddd6c8924c2ac610901a9ca6258f1209d15 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sat, 21 Jun 2025 19:57:50 +0200 Subject: [PATCH 2/4] fix mock source typing --- ablate/sources/mock_source.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ablate/sources/mock_source.py b/ablate/sources/mock_source.py index 736a677..c6c182f 100644 --- a/ablate/sources/mock_source.py +++ b/ablate/sources/mock_source.py @@ -1,5 +1,5 @@ import itertools -from typing import Any, Dict, List +from typing import Dict, List import numpy as np @@ -11,7 +11,7 @@ class Mock(AbstractSource): def __init__( self, - grid: Dict[str, List[Any]], + grid: Dict[str, List[str | int | float | bool]], num_seeds: int = 1, steps: int = 25, ) -> None: @@ -30,7 +30,11 @@ def __init__( self.num_seeds = num_seeds self.steps = steps - def _generate_runs(self, param_dict: Dict[str, Any], idx: int) -> List[Run]: + def _generate_runs( + self, + param_dict: Dict[str, str | int | float | bool], + idx: int, + ) -> List[Run]: runs: List[Run] = [] for local_seed in range(self.num_seeds): global_seed = idx * self.num_seeds + local_seed From 69ccbfcf8a2be634d9033432e24a2f3bbb1bb309 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sat, 21 Jun 2025 19:58:13 +0200 Subject: [PATCH 3/4] add tests --- tests/queries/test_selectors.py | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/queries/test_selectors.py b/tests/queries/test_selectors.py index 60db694..08afda0 100644 --- a/tests/queries/test_selectors.py +++ b/tests/queries/test_selectors.py @@ -83,3 +83,47 @@ def test_temporal_metric_missing_returns_nan(example_run: Run) -> None: def test_temporal_metric_invalid_reduction() -> None: with pytest.raises(ValueError, match="Invalid reduction method"): TemporalMetric("accuracy", direction="max", reduction="median") # type: ignore[arg-type] + + +def test_predicate_and(example_run: Run) -> None: + acc = Metric("accuracy", direction="max") + loss = Metric("loss", direction="min") + + pred = (acc > 0.8) & (loss < 0.2) + assert pred(example_run) is True + + pred = (acc > 0.95) & (loss < 0.2) + assert pred(example_run) is False + + +def test_predicate_or(example_run: Run) -> None: + acc = Metric("accuracy", direction="max") + loss = Metric("loss", direction="min") + + pred = (acc > 0.95) | (loss < 0.2) + assert pred(example_run) is True + + pred = (acc > 0.95) | (loss > 0.2) + assert pred(example_run) is False + + +def test_predicate_not(example_run: Run) -> None: + acc = Metric("accuracy", direction="max") + + pred = ~(acc > 0.95) + assert pred(example_run) is True + + pred = ~(acc < 0.95) + assert pred(example_run) is False + + +def test_chained_predicates(example_run: Run) -> None: + acc = Metric("accuracy", direction="max") + loss = Metric("loss", direction="min") + lr = Param("lr") + + pred = ((acc > 0.8) & (loss < 0.2)) | (lr == 0.01) + assert pred(example_run) is True + + pred = ((acc > 0.95) & (loss < 0.05)) | (lr == 0.02) + assert pred(example_run) is False From 3eb5d4c05162bbbd298c4291713acfcc6b5d6467 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sat, 21 Jun 2025 19:59:43 +0200 Subject: [PATCH 4/4] update docs --- README.md | 100 ++++++++++++++++++++++++++++++++++---------------- README.rst | 105 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 141 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index 0439c57..c680e28 100644 --- a/README.md +++ b/README.md @@ -59,9 +59,9 @@ To create your first [Report](https://ramppdev.github.io/ablate/modules/report.h For example, the built in [Mock](https://ramppdev.github.io/ablate/modules/sources.html#mock-source) can be used to simulate runs: ```python -import ablate +from ablate import sources -source = ablate.sources.Mock( +source = Mock( grid={"model": ["vgg", "resnet"], "lr": [0.01, 0.001]}, num_seeds=2, ) @@ -73,10 +73,12 @@ Next, the runs can be loaded and processed using functional-style queries to e.g group by seed, aggregate the results by mean, and finally collect all results into a single list: ```python +from ablate.queries import Query, Metric, Param + runs = ( - ablate.queries.Query(source.load()) - .sort(ablate.queries.Metric("accuracy", direction="max")) - .groupdiff(ablate.queries.Param("seed")) + Query(source.load()) + .sort(Metric("accuracy", direction="max")) + .groupdiff(Param("seed")) .aggregate("mean") .all() ) @@ -87,16 +89,19 @@ Now that the runs are loaded and processed, a [Report](https://ramppdev.github.i comprising multiple blocks can be created to structure the content: ```python -report = ablate.Report(runs) -report.add(ablate.blocks.H1("Model Performance")) +from ablate import Report +from ablate.blocks import H1, Table + +report = Report(runs) +report.add(H1("Model Performance")) report.add( - ablate.blocks.Table( + Table( columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), - ablate.queries.Metric("f1", direction="max", label="F1 Score"), - ablate.queries.Metric("loss", direction="min", label="Loss"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), + Metric("f1", direction="max", label="F1 Score"), + Metric("loss", direction="min", label="Loss"), ] ) ) @@ -105,7 +110,9 @@ report.add( Finally, the report can be exported to a desired format such as [Markdown](https://ramppdev.github.io/ablate/modules/exporters.html#ablate.exporters.Markdown): ```python -ablate.exporters.Markdown().export(report) +from ablate.exporters import Markdown + +Markdown().export(report) ``` This will produce a `report.md` file with the following content: @@ -127,24 +134,53 @@ To compose multiple sources, they can be added together using the `+` operator as they represent lists of [Run](https://ramppdev.github.io/ablate/modules/core.html#ablate.core.types.Run) objects: ```python -runs1 = ablate.sources.Mock(...).load() -runs2 = ablate.sources.Mock(...).load() +runs1 = Mock(...).load() +runs2 = Mock(...).load() all_runs = runs1 + runs2 # combines both sources into a single list of runs ``` +### Selector Expressions + +_ablate_ selectors are lightweight expressions that access attributes of experiment runs, such as parameters, metrics, or IDs. +They support standard Python comparison operators and can be composed using logical operators to define complex query logic: + +```python +accuracy = Metric("accuracy", direction="max") +loss = Metric("loss", direction="min") + +runs = ( + Query(source.load()) + .filter((accuracy > 0.9) & (loss < 0.1)) + .all() +) +``` + +Selectors return callable predicates, so they can be used in any query operation that requires a condition. +All standard comparisons are supported: `==`, `!=`, `<`, `<=`, `>`, `>=`. +Logical operators `&` (and), `|` (or), and `~` (not) can be used to combine expressions: + +```python +from ablate.queries import Id + +select = (Param("model") == "resnet") | (Param("lr") < 0.001) # select resnet or LR below 0.001 + +exclude = ~(Id() == "run-42") # exclude a specific run by ID + +runs = Query(source.load()).filter(select & exclude).all() + +``` + ### Functional Queries _ablate_ queries are functionally pure such that intermediate results are not modified and can be reused: ```python -runs = ablate.sources.Mock(...).load() +runs = Mock(...).load() -sorted_runs = Query(runs).sort(ablate.queries.Metric("accuracy", direction="max")) +sorted_runs = Query(runs).sort(Metric("accuracy", direction="max")) -filtered_runs = sorted_runs.filter( - ablate.queries.Metric("accuracy", direction="max") > 0.9 -) +filtered_runs = sorted_runs.filter(Metric("accuracy", direction="max") > 0.9) sorted_runs.all() # still contains all runs sorted by accuracy filtered_runs.all() # only contains runs with accuracy > 0.9 @@ -157,25 +193,25 @@ To create more complex reports, blocks can be populated with a custom list of ru ```python report = ablate.Report(sorted_runs.all()) -report.add(ablate.blocks.H1("Report with Sorted Runs and Filtered Runs")) -report.add(ablate.blocks.H2("Sorted Runs")) +report.add(H1("Report with Sorted Runs and Filtered Runs")) +report.add(H2("Sorted Runs")) report.add( - ablate.blocks.Table( + Table( columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), ] ) ) -report.add(ablate.blocks.H2("Filtered Runs")) +report.add(H2("Filtered Runs")) report.add( - ablate.blocks.Table( + Table( runs = filtered_runs.all(), # use filtered runs only for this block columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), ] ) ) diff --git a/README.rst b/README.rst index 824494d..f6a7010 100644 --- a/README.rst +++ b/README.rst @@ -80,9 +80,9 @@ For example, the built in :class:`~ablate.sources.Mock` can be used to simulate .. code-block:: python :linenos: - import ablate + from ablate.sources import Mock - source = ablate.sources.Mock( + source = Mock( grid={"model": ["vgg", "resnet"], "lr": [0.01, 0.001]}, num_seeds=2, ) @@ -95,10 +95,12 @@ group by seed, aggregate the results by mean, and finally collect all results in .. code-block:: python :linenos: + from ablate.queries import Metric, Param, Query + runs = ( - ablate.queries.Query(source.load()) - .sort(ablate.queries.Metric("accuracy", direction="max")) - .groupdiff(ablate.queries.Param("seed")) + Query(source.load()) + .sort(Metric("accuracy", direction="max")) + .groupdiff(Param("seed")) .aggregate("mean") .all() ) @@ -109,16 +111,19 @@ can be created to structure the content: .. code-block:: python :linenos: - report = ablate.Report(runs) - report.add(ablate.blocks.H1("Model Performance")) + from ablate import Report + from ablate.blocks import H1, Table + + report = Report(runs) + report.add(H1("Model Performance")) report.add( - ablate.blocks.Table( + Table( columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), - ablate.queries.Metric("f1", direction="max", label="F1 Score"), - ablate.queries.Metric("loss", direction="min", label="Loss"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), + Metric("f1", direction="max", label="F1 Score"), + Metric("loss", direction="min", label="Loss"), ] ) ) @@ -128,7 +133,9 @@ Finally, the report can be exported to a desired format such as :class:`~ablate. .. code-block:: python :linenos: - ablate.exporters.Markdown().export(report) + from ablate.exporters import Markdown + + Markdown().export(report) This will produce a :file:`report.md` file with the following content: @@ -153,12 +160,47 @@ as they represent lists of :class:`~ablate.core.types.Run` objects: .. code-block:: python :linenos: - runs1 = ablate.sources.Mock(...).load() - runs2 = ablate.sources.Mock(...).load() + runs1 = Mock(...).load() + runs2 = Mock(...).load() all_runs = runs1 + runs2 # combines both sources into a single list of runs +Selector Expressions +~~~~~~~~~~~~~~~~~~~~ + +`ablate` selectors are lightweight expressions that access attributes of experiment runs, such as parameters, metrics, or IDs. +They support standard Python comparison operators and can be composed using logical operators to define complex query logic: + +.. code-block:: python + :linenos: + + accuracy = Metric("accuracy", direction="max") + loss = Metric("loss", direction="min") + + runs = ( + Query(source.load()) + .filter((accuracy > 0.9) & (loss < 0.1)) + .all() + ) + + +Selectors return callable predicates, so they can be used in any query operation that requires a condition. +All standard comparisons are supported: :attr:`==`, :attr:`!=`, :attr:`<`, :attr:`<=`, :attr:`>`, :attr:`>=`. +Logical operators :attr:`&` (and), :attr:`|` (or), and :attr:`~~` (not) can be used to combine expressions: + +.. code-block:: python + :linenos: + + from ablate.queries import Id + + select = (Param("model") == "resnet") | (Param("lr") < 0.001) # select resnet or LR below 0.001 + + exclude = ~(Id() == "run-42") # exclude a specific run by ID + + runs = Query(source.load()).filter(select & exclude).all() + + Functional Queries ~~~~~~~~~~~~~~~~~~ @@ -167,13 +209,11 @@ Functional Queries .. code-block:: python :linenos: - runs = ablate.sources.Mock(...).load() + runs = Mock(...).load() - sorted_runs = Query(runs).sort(ablate.queries.Metric("accuracy", direction="max")) + sorted_runs = Query(runs).sort(Metric("accuracy", direction="max")) - filtered_runs = sorted_runs.filter( - ablate.queries.Metric("accuracy", direction="max") > 0.9 - ) + filtered_runs = sorted_runs.filter(Metric("accuracy", direction="max") > 0.9) sorted_runs.all() # still contains all runs sorted by accuracy filtered_runs.all() # only contains runs with accuracy > 0.9 @@ -189,29 +229,30 @@ To create more complex reports, blocks can be populated with a custom list of ru :linenos: report = ablate.Report(sorted_runs.all()) - report.add(ablate.blocks.H1("Report with Sorted Runs and Filtered Runs")) - report.add(ablate.blocks.H2("Sorted Runs")) + report.add(H1("Report with Sorted Runs and Filtered Runs")) + report.add(H2("Sorted Runs")) report.add( - ablate.blocks.Table( + Table( columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), ] ) ) - report.add(ablate.blocks.H2("Filtered Runs")) + report.add(H2("Filtered Runs")) report.add( - ablate.blocks.Table( + Table( runs = filtered_runs.all(), # use filtered runs only for this block columns=[ - ablate.queries.Param("model", label="Model"), - ablate.queries.Param("lr", label="Learning Rate"), - ablate.queries.Metric("accuracy", direction="max", label="Accuracy"), + Param("model", label="Model"), + Param("lr", label="Learning Rate"), + Metric("accuracy", direction="max", label="Accuracy"), ] ) ) + Extending `ablate` ------------------