Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion agentune/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from agentune.core.llmcache.sqlite_lru import ConnectionProviderFactory, SqliteLru
from agentune.core.progress.reporters.base import ProgressReporter, progress_setup
from agentune.core.progress.reporters.log import LogReporter
from agentune.core.progress.reporters.rich_console import RichConsoleReporter
from agentune.core.sercontext import SerializationContext
from agentune.core.util import asyncref
from agentune.core.util.lrucache import LRUCache
Expand All @@ -38,6 +39,8 @@
from .llm import BoundLlm
from .ops import BoundOps

_logger = logging.getLogger(__name__)

# These classes need to be in the base module because we use them to define the default values of parameters to create()
# and so we can't import them only when we need them

Expand Down Expand Up @@ -74,6 +77,13 @@ class WriteProgressToLog(ProgressReporterParams):
logger_name: str = 'agentune.progress'
log_level: int = logging.INFO

@frozen
class WriteProgressToConsole(ProgressReporterParams):
"""Writes progress updates in the console/terminal, shows interactive progress tree"""
poll_interval: timedelta = timedelta(milliseconds=500)
show_percentages: bool = True
show_colors: bool = True


@frozen
class RunContext:
Expand Down Expand Up @@ -128,7 +138,7 @@ async def create(duckdb: DuckdbDatabase | DuckdbManager = DuckdbInMemory(),
httpx_async_client: httpx.AsyncClient | None = None,
llm_providers: LLMProvider | Sequence[LLMProvider] | None = None,
llm_cache: LlmCacheInMemory | LlmCacheOnDisk | LLMCacheBackend | None = LlmCacheInMemory(1000),
progress_reporter: WriteProgressToLog | ProgressReporter | None = WriteProgressToLog()
progress_reporter: WriteProgressToLog | WriteProgressToConsole | ProgressReporter | None = WriteProgressToLog()
) -> RunContext:
"""Create a new context instance (see the class doc). Remember to close it when you are done, by using it as
a context manager or by calling the aclose() method explicitly.
Expand Down Expand Up @@ -196,6 +206,13 @@ async def create(duckdb: DuckdbDatabase | DuckdbManager = DuckdbInMemory(),
match progress_reporter:
case WriteProgressToLog(poll_interval, logger_name, log_level):
reporter_instance = LogReporter(poll_interval, logger_name, log_level)
case WriteProgressToConsole(poll_interval, show_percentages, show_colors):
from rich.console import Console
# Check whether the running environment supports Rich console
if Console().is_terminal:
reporter_instance = RichConsoleReporter(poll_interval, show_percentages, show_colors)
else:
_logger.warning('WriteProgressToConsole requested but it is not supported in this environment. Progress reporting will be disabled.')
case ProgressReporter() as reporter:
reporter_instance = reporter
owns_reporter = False
Expand Down
114 changes: 114 additions & 0 deletions agentune/core/progress/reporters/rich_console.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Rich console-based progress reporter for interactive display."""

from __future__ import annotations

from datetime import timedelta
from typing import override

from rich.console import Console
from rich.live import Live
from rich.tree import Tree

from agentune.core.progress.base import ProgressStage, root_stage
from agentune.core.progress.reporters.base import ProgressReporter


class RichConsoleReporter(ProgressReporter):
"""Progress reporter that displays interactive progress in the console using Rich.

Progress updates are displayed in a hierarchical tree visualization. Can be used in interactive terminals and Jupyter notebooks.

Args:
poll_interval: How often to poll for progress updates.
show_percentages: Whether to show percentage completion for counted stages.
show_colors: Whether to use colors in the display.
"""

def __init__(
self,
poll_interval: timedelta,
show_percentages: bool,
show_colors: bool,
) -> None:
self.poll_interval = poll_interval
self._show_percentages = show_percentages
self._show_colors = show_colors
self._live: Live | None = None
self._progress_tree: Tree | None = None

@override
async def start(self, root_stage: ProgressStage) -> None:
"""Start displaying progress for the given root stage."""
snapshot = root_stage.deepcopy()
self._progress_tree = self._build_tree(snapshot)
self._live = Live(
self._progress_tree,
console=Console(),
auto_refresh=False,
)
self._live.start(refresh=True)

@override
async def update(self, snapshot: ProgressStage) -> None:
"""Update the display with the latest progress snapshot."""
if self._live is None:
return
self._progress_tree = self._build_tree(snapshot)
self._live.update(self._progress_tree, refresh=True)

@override
async def stop(self) -> None:
"""Stop displaying and perform cleanup."""
if self._live is not None:
current_root = root_stage()
if current_root is not None:
await self.update(current_root.deepcopy())
self._live.stop()
self._live = None
self._progress_tree = None

def _build_tree(self, stage: ProgressStage) -> Tree:
"""Build a Rich Tree structure from a progress stage.

Args:
stage: The progress stage to convert.

Returns:
A Rich Tree object representing the progress hierarchy.
"""
label = self._format_stage_label(stage)
tree = Tree(label)

for child in stage.children:
tree.add(self._build_tree(child))

return tree

def _format_stage_label(self, stage: ProgressStage) -> str:
"""Format a stage label with optional Rich markup for colors."""
count = stage.count
total = stage.total

progress = ''
if count is not None and total is not None:
if self._show_percentages and total > 0:
percentage = (count / total) * 100
progress = f' [{count}/{total} ({percentage:.1f}%)]'
else:
progress = f' [{count}/{total}]'
elif count is not None:
progress = f' [{count}]'
elif total is not None:
progress = f' [0/{total}]'

if not self._show_colors:
name = stage.name + progress
return name + ' ✓' if stage.is_completed else name

if stage.is_completed:
return f'[green]{stage.name}[/green][dim]{progress}[/dim] [green]✓[/green]'
elif count is not None:
return f'[cyan]{stage.name}[/cyan][dim]{progress}[/dim]'
else:
return f'[white]{stage.name}[/white][dim]{progress}[/dim]'

59 changes: 57 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ more-itertools = "^10.7.0"
llama-index-core = "^0.14.7"
llama-index-llms-openai = "^0.6.7"
lightgbm = "^4.6.0"
rich = "^14.0.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.18.2"
Expand Down
115 changes: 115 additions & 0 deletions tests/agentune/core/progress/reporters/test_rich_console_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for RichConsoleReporter formatting and tree building."""
from datetime import timedelta

import pytest

from agentune.core.progress.base import ProgressStage
from agentune.core.progress.reporters.rich_console import RichConsoleReporter


@pytest.fixture
def reporter_with_colors() -> RichConsoleReporter:
return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=True, show_colors=True)


@pytest.fixture
def reporter_no_colors() -> RichConsoleReporter:
return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=True, show_colors=False)


@pytest.fixture
def reporter_no_percentages() -> RichConsoleReporter:
return RichConsoleReporter(timedelta(seconds=0.1), show_percentages=False, show_colors=True)


def test_label_in_progress_with_colors(reporter_with_colors: RichConsoleReporter) -> None:
"""In-progress stage with count shows cyan name."""
stage = ProgressStage(name='task', count=5, total=10)
label = reporter_with_colors._format_stage_label(stage)

assert '[cyan]task[/cyan]' in label
assert '5/10' in label
assert '50.0%' in label


def test_label_completed_with_colors(reporter_with_colors: RichConsoleReporter) -> None:
"""Completed stage shows green name and checkmark."""
stage = ProgressStage(name='done', count=10, total=10)
stage.complete()
label = reporter_with_colors._format_stage_label(stage)

assert '[green]done[/green]' in label
assert '[green]✓[/green]' in label


def test_label_started_no_count_with_colors(reporter_with_colors: RichConsoleReporter) -> None:
"""Stage without count shows white name."""
stage = ProgressStage(name='waiting')
label = reporter_with_colors._format_stage_label(stage)

assert '[white]waiting[/white]' in label


def test_label_without_colors(reporter_no_colors: RichConsoleReporter) -> None:
"""Without colors, labels have no Rich markup."""
stage = ProgressStage(name='task', count=5, total=10)
label = reporter_no_colors._format_stage_label(stage)

assert '[' not in label or label.startswith('task [') # Only progress brackets, no color markup
assert 'task' in label
assert '5/10' in label


def test_label_completed_without_colors(reporter_no_colors: RichConsoleReporter) -> None:
"""Completed stage without colors shows plain checkmark."""
stage = ProgressStage(name='done')
stage.complete()
label = reporter_no_colors._format_stage_label(stage)

assert label == 'done ✓'


def test_percentage_display(reporter_with_colors: RichConsoleReporter) -> None:
"""Percentage shown when show_percentages=True."""
stage = ProgressStage(name='work', count=3, total=12)
label = reporter_with_colors._format_stage_label(stage)

assert '25.0%' in label


def test_no_percentage_display(reporter_no_percentages: RichConsoleReporter) -> None:
"""Percentage hidden when show_percentages=False."""
stage = ProgressStage(name='work', count=3, total=12)
label = reporter_no_percentages._format_stage_label(stage)

assert '%' not in label
assert '3/12' in label


def test_count_only_no_total(reporter_with_colors: RichConsoleReporter) -> None:
"""Count without total shows just the count."""
stage = ProgressStage(name='items', count=42)
label = reporter_with_colors._format_stage_label(stage)

assert '[42]' in label
assert '42/' not in label


def test_total_only_no_count(reporter_with_colors: RichConsoleReporter) -> None:
"""Total without count shows 0/total."""
stage = ProgressStage(name='pending', total=100)
label = reporter_with_colors._format_stage_label(stage)

assert '[0/100]' in label


def test_tree_builds_hierarchy(reporter_with_colors: RichConsoleReporter) -> None:
"""Tree structure reflects stage hierarchy."""
root = ProgressStage(name='root')
root.add_child('child1')
root.add_child('child2')

tree = reporter_with_colors._build_tree(root)

assert len(tree.children) == 2