diff --git a/chimerapy/engine/exceptions.py b/chimerapy/engine/exceptions.py index 2c68ac2..aa3956d 100644 --- a/chimerapy/engine/exceptions.py +++ b/chimerapy/engine/exceptions.py @@ -1,2 +1,6 @@ class CommitGraphError(Exception): ... + + +class TagError(Exception): + ... diff --git a/chimerapy/engine/manager/events.py b/chimerapy/engine/manager/events.py index e85bf8b..89d93e4 100644 --- a/chimerapy/engine/manager/events.py +++ b/chimerapy/engine/manager/events.py @@ -1,3 +1,4 @@ +from typing import Optional from dataclasses import dataclass from ..states import WorkerState @@ -38,3 +39,10 @@ class DeregisterEntityEvent: # entity_deregister @dataclass class MoveTransferredFilesEvent: # move_transferred_files worker_state: WorkerState + + +@dataclass +class TagEvent: + uuid: str + name: str + description: Optional[str] = None diff --git a/chimerapy/engine/manager/manager.py b/chimerapy/engine/manager/manager.py index 517412e..27fdc40 100644 --- a/chimerapy/engine/manager/manager.py +++ b/chimerapy/engine/manager/manager.py @@ -12,6 +12,7 @@ from ..networking.async_loop_thread import AsyncLoopThread from chimerapy.engine.states import ManagerState, WorkerState from chimerapy.engine.graph import Graph +from chimerapy.engine.exceptions import TagError # Eventbus from ..eventbus import EventBus, Event, make_evented @@ -24,6 +25,7 @@ from .zeroconf_service import ZeroconfService from .session_record_service import SessionRecordService from .distributed_logging_service import DistributedLoggingService +from .session_tag_service import SessionTagService logger = _logger.getLogger("chimerapy-engine") @@ -107,6 +109,11 @@ def __init__( eventbus=self.eventbus, state=self.state, ) + self.session_tags = SessionTagService( + name="session_tag", + eventbus=self.eventbus, + state=self.state, + ) self.distributed_logging = DistributedLoggingService( name="distributed_logging", publish_logs_via_zmq=publish_logs_via_zmq, @@ -338,6 +345,27 @@ async def async_shutdown(self) -> bool: return True + async def async_create_tag( + self, name: str, description: Optional[str] = None + ) -> str: + can, reason = self.session_tags.can_create_tag() + if can: + return await self.worker_handler.create_tag( + name=name, description=description + ) + else: + raise TagError(reason) + + async def async_update_tag_descr(self, tag_id, description) -> bool: + tag_name = self.session_tags.get_tag_name(tag_id) + + if tag_name is None: + raise TagError(f"Tag with id {tag_id} not found") + + return await self.worker_handler.update_tag_description( + tag_id, tag_name, description + ) + #################################################################### ## Front-facing Sync API #################################################################### @@ -476,6 +504,12 @@ def reset( return future + def create_tag(self, name: str, description: Optional[str] = None) -> Future[str]: + return self._exec_coro(self.async_create_tag(name, description)) + + def update_tag_descr(self, tag_id, description) -> Future[bool]: + return self._exec_coro(self.async_update_tag_descr(tag_id, description)) + def shutdown(self, blocking: bool = True) -> Union[bool, Future[bool]]: """Proper shutting down ChimeraPy-Engine cluster. diff --git a/chimerapy/engine/manager/session_tag_service.py b/chimerapy/engine/manager/session_tag_service.py new file mode 100644 index 0000000..1a3fc0c --- /dev/null +++ b/chimerapy/engine/manager/session_tag_service.py @@ -0,0 +1,104 @@ +import datetime +import json +from typing import Optional, Dict, Tuple + +from ..states import ManagerState +from ..eventbus import EventBus, TypedObserver +from ..service import Service +from .events import TagEvent + + +class SessionTagService(Service): + def __init__( + self, + name: str, + eventbus: EventBus, + state: ManagerState, + ): + super().__init__(name=name) + + # Input parameters + self.eventbus = eventbus + self.state = state + + # State information + self.start_time: Optional[datetime.datetime] = None + self.tags: Dict[str, Dict] = {} + + # Specify observers + self.observers: Dict[str, TypedObserver] = { + "start_recording": TypedObserver( + "start_recording", on_asend=self._record_start_time, handle_event="drop" + ), + "create_tag": TypedObserver( + "create_tag", + event_data_cls=TagEvent, + on_asend=self._create_tag, + handle_event="unpack", + ), + "update_tag": TypedObserver( + "update_tag", + event_data_cls=TagEvent, + on_asend=self._update_tag, + handle_event="unpack", + ), + "stop_recording": TypedObserver( + "stop_recording", on_asend=self._consolidate_tags, handle_event="drop" + ), + } + for ob in self.observers.values(): + self.eventbus.subscribe(ob).result(timeout=1) + + def can_create_tag(self) -> Tuple[bool, Optional[str]]: + """Check if a tag can be created""" + all_node_states = list( + node_state.fsm + for worker in self.state.workers.values() + for node_state in worker.nodes.values() + ) + + if len(all_node_states) == 0: + return False, "No nodes to tag" + + can_create = all(node_state == "RECORDING" for node_state in all_node_states) + reason = ( + "All nodes must be in RECORDING state to add a tag" + if not can_create + else None + ) + return can_create, reason + + def _record_start_time(self) -> None: + self.start_time = datetime.datetime.now() + + def _get_elapsed_time(self) -> float: + assert self.start_time is not None + delta = datetime.datetime.now() - self.start_time + return delta.total_seconds() + + def _create_tag(self, uuid: str, name: str, description: Optional[str] = None): + e_time = self._get_elapsed_time() + self.tags[uuid] = { + "uuid": uuid, + "name": name, + "description": description, + "timestamp": e_time, + "timestamp_str": f"{e_time // 3600} Hours, " + f"{(e_time // 60) % 60} Minutes, " + f"{e_time % 60} Seconds", + } + + def _update_tag(self, uuid, name, description): + if uuid in self.tags: + self.tags[uuid]["name"] = name + self.tags[uuid]["description"] = description + + def get_tag_name(self, uuid: str) -> Optional[str]: + if uuid in self.tags: + return self.tags[uuid]["name"] + else: + return None + + def _consolidate_tags(self) -> None: + with (self.state.logdir / "session_tags.json").open("w") as f: + json.dump(self.tags, f, indent=2) diff --git a/chimerapy/engine/manager/worker_handler_service.py b/chimerapy/engine/manager/worker_handler_service.py index 1b8f33f..8f55c46 100644 --- a/chimerapy/engine/manager/worker_handler_service.py +++ b/chimerapy/engine/manager/worker_handler_service.py @@ -11,6 +11,7 @@ import aiohttp import networkx as nx +from uuid import uuid4 from chimerapy.engine import config from chimerapy.engine import _logger from ..utils import async_waiting_for @@ -29,6 +30,7 @@ DeregisterEntityEvent, MoveTransferredFilesEvent, UpdateSendArchiveEvent, + TagEvent, ) logger = _logger.getLogger("chimerapy-engine") @@ -850,3 +852,15 @@ async def reset(self, keep_workers: bool = True): self._deregister_graph() return all(results) + + async def create_tag(self, name, description=None) -> str: + # TODO: Ping Nodes + tag_event = TagEvent(uuid=str(uuid4()), name=name, description=description) + await self.eventbus.asend(Event("create_tag", tag_event)) + return tag_event.uuid + + async def update_tag_description(self, uuid, name, description) -> bool: + tag_event = TagEvent(uuid=uuid, name=name, description=description) + + await self.eventbus.asend(Event("update_tag", tag_event)) + return True diff --git a/test/manager/test_manager.py b/test/manager/test_manager.py index 0b6318e..911caf3 100644 --- a/test/manager/test_manager.py +++ b/test/manager/test_manager.py @@ -1,6 +1,7 @@ import time import os import pathlib +import json import pytest @@ -156,3 +157,38 @@ def test_manager_recommit_graph(self, manager_with_worker): assert ((delta2 - delta) / (delta)) < 1 manager.reset() + + def test_manager_session_tags(self, manager_with_worker): + manager, worker = manager_with_worker + + # Define graph + gen_node = GenNode(name="Gen1") + con_node = ConsumeNode(name="Con1") + simple_graph = cpe.Graph() + simple_graph.add_nodes_from([gen_node, con_node]) + simple_graph.add_edge(src=gen_node, dst=con_node) + + mapping = {worker.id: [gen_node.id, con_node.id]} + + graph_info = {"graph": simple_graph, "mapping": mapping} + + assert manager.commit_graph(**graph_info).result(timeout=30) + + assert manager.start().result(timeout=30) + assert manager.record().result(timeout=30) + time.sleep(2) # Wait for the FSMs to change + tag_uuid = manager.create_tag("test_tag").result(timeout=30) + + manager.update_tag_descr(tag_uuid, "test_tag_descr").result(timeout=30) + + assert manager.stop().result(timeout=30) + assert manager.collect().result(timeout=30) + + assert (manager.logdir / "session_tags.json").exists() + + with open(manager.logdir / "session_tags.json", "r") as f: + session_tags = json.load(f) + assert len(session_tags) == 1 + tag_details = session_tags[tag_uuid] + assert tag_details["name"] == "test_tag" + assert tag_details["description"] == "test_tag_descr" diff --git a/test/manager/test_session_tag_service.py b/test/manager/test_session_tag_service.py new file mode 100644 index 0000000..bef472a --- /dev/null +++ b/test/manager/test_session_tag_service.py @@ -0,0 +1,69 @@ +import json +import tempfile +import time +from pathlib import Path + +import pytest + +from chimerapy.engine.eventbus import Event, EventBus, make_evented +from chimerapy.engine.manager.events import TagEvent +from chimerapy.engine.manager.session_tag_service import SessionTagService +from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread +from chimerapy.engine.states import ManagerState, NodeState, WorkerState + + +@pytest.fixture(scope="module") +def testbed_setup(): + thread = AsyncLoopThread() + thread.start() + eventbus = EventBus(thread=thread) + + state = ManagerState(logdir=Path(tempfile.mkdtemp())) + + make_evented(state, event_bus=eventbus) + + tag_service = SessionTagService(name="tag_service", eventbus=eventbus, state=state) + + return tag_service, state, eventbus + + +def test_tag_service_tag_error(testbed_setup): + tag_service, state, bus = testbed_setup + can_create_tag, reason = tag_service.can_create_tag() + + assert not can_create_tag + assert reason == "No nodes to tag" + + +def test_tag_service_tag(testbed_setup): + # Add workers and nodes to the state + tag_service, state, bus = testbed_setup + + state.workers["worker1"] = WorkerState( + name="worker1", + id="worker1", + ip="0.0.0.0", + port=0, + nodes={"node1": NodeState(name="node1", id="node1", fsm="NULL")}, + ) + + can, reason = tag_service.can_create_tag() + assert not can + assert reason == "All nodes must be in RECORDING state to add a tag" + + state.workers["worker1"].nodes["node1"].fsm = "RECORDING" + can, reason = tag_service.can_create_tag() + assert can + assert reason is None + + bus.send(Event("start_recording")).result() + time.sleep(2) + bus.send(Event("create_tag", TagEvent("A", "tag1"))).result() + bus.send(Event("update_tag", TagEvent("A", "tag2", "tag2_descr"))).result() + bus.send(Event("stop_recording")).result() + + with (state.logdir / "session_tags.json").open("r") as f: + tags = json.load(f) + assert tags["A"]["name"] == "tag2" + assert tags["A"]["description"] == "tag2_descr" + assert tags["A"]["timestamp"] > 2