diff --git a/agentune/api/base.py b/agentune/api/base.py index b1236c3f..9cf9bc6e 100644 --- a/agentune/api/base.py +++ b/agentune/api/base.py @@ -26,6 +26,7 @@ from agentune.core.llmcache.sqlite_lru import ConnectionProviderFactory, SqliteLru from agentune.core.progress.reporters.base import ProgressReporter, progress_setup from agentune.core.progress.reporters.log import LogReporter +from agentune.core.progress.reporters.rich_console import RichConsoleReporter from agentune.core.sercontext import SerializationContext from agentune.core.util import asyncref from agentune.core.util.lrucache import LRUCache @@ -38,6 +39,8 @@ from .llm import BoundLlm from .ops import BoundOps +_logger = logging.getLogger(__name__) + # These classes need to be in the base module because we use them to define the default values of parameters to create() # and so we can't import them only when we need them @@ -74,6 +77,13 @@ class WriteProgressToLog(ProgressReporterParams): logger_name: str = 'agentune.progress' log_level: int = logging.INFO +@frozen +class WriteProgressToConsole(ProgressReporterParams): + """Writes progress updates in the console/terminal, shows interactive progress tree""" + poll_interval: timedelta = timedelta(milliseconds=500) + show_percentages: bool = True + show_colors: bool = True + @frozen class RunContext: @@ -128,7 +138,7 @@ async def create(duckdb: DuckdbDatabase | DuckdbManager = DuckdbInMemory(), httpx_async_client: httpx.AsyncClient | None = None, llm_providers: LLMProvider | Sequence[LLMProvider] | None = None, llm_cache: LlmCacheInMemory | LlmCacheOnDisk | LLMCacheBackend | None = LlmCacheInMemory(1000), - progress_reporter: WriteProgressToLog | ProgressReporter | None = WriteProgressToLog() + progress_reporter: WriteProgressToLog | WriteProgressToConsole | ProgressReporter | None = WriteProgressToLog() ) -> RunContext: """Create a new context instance (see the class doc). Remember to close it when you are done, by using it as a context manager or by calling the aclose() method explicitly. @@ -196,6 +206,13 @@ async def create(duckdb: DuckdbDatabase | DuckdbManager = DuckdbInMemory(), match progress_reporter: case WriteProgressToLog(poll_interval, logger_name, log_level): reporter_instance = LogReporter(poll_interval, logger_name, log_level) + case WriteProgressToConsole(poll_interval, show_percentages, show_colors): + from rich.console import Console + # Check whether the running environment supports Rich console + if Console().is_terminal: + reporter_instance = RichConsoleReporter(poll_interval, show_percentages, show_colors) + else: + _logger.warning('WriteProgressToConsole requested but it is not supported in this environment. Progress reporting will be disabled.') case ProgressReporter() as reporter: reporter_instance = reporter owns_reporter = False diff --git a/agentune/core/progress/reporters/rich_console.py b/agentune/core/progress/reporters/rich_console.py new file mode 100644 index 00000000..8e2d806b --- /dev/null +++ b/agentune/core/progress/reporters/rich_console.py @@ -0,0 +1,114 @@ +"""Rich console-based progress reporter for interactive display.""" + +from __future__ import annotations + +from datetime import timedelta +from typing import override + +from rich.console import Console +from rich.live import Live +from rich.tree import Tree + +from agentune.core.progress.base import ProgressStage, root_stage +from agentune.core.progress.reporters.base import ProgressReporter + + +class RichConsoleReporter(ProgressReporter): + """Progress reporter that displays interactive progress in the console using Rich. + + Progress updates are displayed in a hierarchical tree visualization. Can be used in interactive terminals and Jupyter notebooks. + + Args: + poll_interval: How often to poll for progress updates. + show_percentages: Whether to show percentage completion for counted stages. + show_colors: Whether to use colors in the display. + """ + + def __init__( + self, + poll_interval: timedelta, + show_percentages: bool, + show_colors: bool, + ) -> None: + self.poll_interval = poll_interval + self._show_percentages = show_percentages + self._show_colors = show_colors + self._live: Live | None = None + self._progress_tree: Tree | None = None + + @override + async def start(self, root_stage: ProgressStage) -> None: + """Start displaying progress for the given root stage.""" + snapshot = root_stage.deepcopy() + self._progress_tree = self._build_tree(snapshot) + self._live = Live( + self._progress_tree, + console=Console(), + auto_refresh=False, + ) + self._live.start(refresh=True) + + @override + async def update(self, snapshot: ProgressStage) -> None: + """Update the display with the latest progress snapshot.""" + if self._live is None: + return + self._progress_tree = self._build_tree(snapshot) + self._live.update(self._progress_tree, refresh=True) + + @override + async def stop(self) -> None: + """Stop displaying and perform cleanup.""" + if self._live is not None: + current_root = root_stage() + if current_root is not None: + await self.update(current_root.deepcopy()) + self._live.stop() + self._live = None + self._progress_tree = None + + def _build_tree(self, stage: ProgressStage) -> Tree: + """Build a Rich Tree structure from a progress stage. + + Args: + stage: The progress stage to convert. + + Returns: + A Rich Tree object representing the progress hierarchy. + """ + label = self._format_stage_label(stage) + tree = Tree(label) + + for child in stage.children: + tree.add(self._build_tree(child)) + + return tree + + def _format_stage_label(self, stage: ProgressStage) -> str: + """Format a stage label with optional Rich markup for colors.""" + count = stage.count + total = stage.total + + progress = '' + if count is not None and total is not None: + if self._show_percentages and total > 0: + percentage = (count / total) * 100 + progress = f' [{count}/{total} ({percentage:.1f}%)]' + else: + progress = f' [{count}/{total}]' + elif count is not None: + progress = f' [{count}]' + elif total is not None: + progress = f' [0/{total}]' + + if not self._show_colors: + name = stage.name + progress + return name + ' ✓' if stage.is_completed else name + + if stage.is_completed: + return f'[green]{stage.name}[/green][dim]{progress}[/dim] [green]✓[/green]' + elif count is not None: + return f'[cyan]{stage.name}[/cyan][dim]{progress}[/dim]' + else: + return f'[white]{stage.name}[/white][dim]{progress}[/dim]' + diff --git a/poetry.lock b/poetry.lock index 890c23f8..40d90cb1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2476,6 +2476,30 @@ typing-extensions = ">=4.6.0" client = ["httpx (>=0.28.1,<1)"] server = ["starlette (>=0.39.0)", "uvicorn (>=0.32.0)"] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147"}, + {file = "markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "markdown-it-pyrs", "mistletoe (>=1.0,<2.0)", "mistune (>=3.0,<4.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins (>=0.5.0)"] +profiling = ["gprof2dot"] +rtd = ["ipykernel", "jupyter_sphinx", "mdit-py-plugins (>=0.5.0)", "myst-parser", "pyyaml", "sphinx", "sphinx-book-theme (>=1.0,<2.0)", "sphinx-copybutton", "sphinx-design"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions", "requests"] + [[package]] name = "markupsafe" version = "3.0.3" @@ -2692,6 +2716,18 @@ traitlets = "*" [package.extras] test = ["flake8", "nbdime", "nbval", "notebook", "pytest"] +[[package]] +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, +] + [[package]] name = "mistune" version = "3.1.4" @@ -4014,7 +4050,7 @@ version = "2.19.2" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["dev", "examples"] +groups = ["main", "dev", "examples"] files = [ {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"}, {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"}, @@ -4550,6 +4586,25 @@ lark = ">=1.2.2" [package.extras] testing = ["pytest (>=8.3.5)"] +[[package]] +name = "rich" +version = "14.2.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.8.0" +groups = ["main"] +files = [ + {file = "rich-14.2.0-py3-none-any.whl", hash = "sha256:76bc51fe2e57d2b1be1f96c524b890b816e334ab4c1e45888799bfaab0021edd"}, + {file = "rich-14.2.0.tar.gz", hash = "sha256:73ff50c7c0c1c77c8243079283f4edb376f0f6442433aecb8ce7e6d0b92d1fe4"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + [[package]] name = "rpds-py" version = "0.28.0" @@ -5683,4 +5738,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = "^3.12.6" -content-hash = "9d5906ba31f5c8d2ba58b96a8772c3d1a1deffa20344fe73c757c8163f7f806e" +content-hash = "a3b1781e697f9e164c18c57aec3538a17d6ecc6a66c367181e73addf6ba1515e" diff --git a/pyproject.toml b/pyproject.toml index 535c2e4c..a0685219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ more-itertools = "^10.7.0" llama-index-core = "^0.14.7" llama-index-llms-openai = "^0.6.7" lightgbm = "^4.6.0" +rich = "^14.0.0" [tool.poetry.group.dev.dependencies] mypy = "^1.18.2" diff --git a/tests/agentune/core/progress/reporters/test_rich_console_reporter.py b/tests/agentune/core/progress/reporters/test_rich_console_reporter.py new file mode 100644 index 00000000..79e831e2 --- /dev/null +++ b/tests/agentune/core/progress/reporters/test_rich_console_reporter.py @@ -0,0 +1,115 @@ +"""Tests for RichConsoleReporter formatting and tree building.""" +from datetime import timedelta + +import pytest + +from agentune.core.progress.base import ProgressStage +from agentune.core.progress.reporters.rich_console import RichConsoleReporter + + +@pytest.fixture +def reporter_with_colors() -> RichConsoleReporter: + return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=True, show_colors=True) + + +@pytest.fixture +def reporter_no_colors() -> RichConsoleReporter: + return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=True, show_colors=False) + + +@pytest.fixture +def reporter_no_percentages() -> RichConsoleReporter: + return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=False, show_colors=True) + + +def test_label_in_progress_with_colors(reporter_with_colors: RichConsoleReporter) -> None: + """In-progress stage with count shows cyan name.""" + stage = ProgressStage(name='task', count=5, total=10) + label = reporter_with_colors._format_stage_label(stage) + + assert '[cyan]task[/cyan]' in label + assert '5/10' in label + assert '50.0%' in label + + +def test_label_completed_with_colors(reporter_with_colors: RichConsoleReporter) -> None: + """Completed stage shows green name and checkmark.""" + stage = ProgressStage(name='done', count=10, total=10) + stage.complete() + label = reporter_with_colors._format_stage_label(stage) + + assert '[green]done[/green]' in label + assert '[green]✓[/green]' in label + + +def test_label_started_no_count_with_colors(reporter_with_colors: RichConsoleReporter) -> None: + """Stage without count shows white name.""" + stage = ProgressStage(name='waiting') + label = reporter_with_colors._format_stage_label(stage) + + assert '[white]waiting[/white]' in label + + +def test_label_without_colors(reporter_no_colors: RichConsoleReporter) -> None: + """Without colors, labels have no Rich markup.""" + stage = ProgressStage(name='task', count=5, total=10) + label = reporter_no_colors._format_stage_label(stage) + + assert '[' not in label or label.startswith('task [') # Only progress brackets, no color markup + assert 'task' in label + assert '5/10' in label + + +def test_label_completed_without_colors(reporter_no_colors: RichConsoleReporter) -> None: + """Completed stage without colors shows plain checkmark.""" + stage = ProgressStage(name='done') + stage.complete() + label = reporter_no_colors._format_stage_label(stage) + + assert label == 'done ✓' + + +def test_percentage_display(reporter_with_colors: RichConsoleReporter) -> None: + """Percentage shown when show_percentages=True.""" + stage = ProgressStage(name='work', count=3, total=12) + label = reporter_with_colors._format_stage_label(stage) + + assert '25.0%' in label + + +def test_no_percentage_display(reporter_no_percentages: RichConsoleReporter) -> None: + """Percentage hidden when show_percentages=False.""" + stage = ProgressStage(name='work', count=3, total=12) + label = reporter_no_percentages._format_stage_label(stage) + + assert '%' not in label + assert '3/12' in label + + +def test_count_only_no_total(reporter_with_colors: RichConsoleReporter) -> None: + """Count without total shows just the count.""" + stage = ProgressStage(name='items', count=42) + label = reporter_with_colors._format_stage_label(stage) + + assert '[42]' in label + assert '42/' not in label + + +def test_total_only_no_count(reporter_with_colors: RichConsoleReporter) -> None: + """Total without count shows 0/total.""" + stage = ProgressStage(name='pending', total=100) + label = reporter_with_colors._format_stage_label(stage) + + assert '[0/100]' in label + + +def test_tree_builds_hierarchy(reporter_with_colors: RichConsoleReporter) -> None: + """Tree structure reflects stage hierarchy.""" + root = ProgressStage(name='root') + root.add_child('child1') + root.add_child('child2') + + tree = reporter_with_colors._build_tree(root) + + assert len(tree.children) == 2 +