diff --git a/.github/workflows/ci-tests.yaml b/.github/workflows/ci-tests.yaml index 41394c03c..7753a8592 100644 --- a/.github/workflows/ci-tests.yaml +++ b/.github/workflows/ci-tests.yaml @@ -159,7 +159,7 @@ jobs: run: uv sync --locked --no-dev --group docs - name: "Build documentation and check for consistency" env: - CHECKSUM: "be2933a30a986a448b8b7dfbab602e8301744520f96abde1f8ae35539061c411" + CHECKSUM: "e3715d10625d439a159b0a3bbf527403692cb93f0555b454eda7d5201c008d8a" run: | cd docs HASH="$(uv run --no-sync make checksum | tail -n1)" diff --git a/docs/source/ext/connector.rst b/docs/source/ext/connector.rst index 4868e609b..ecd444ecb 100644 --- a/docs/source/ext/connector.rst +++ b/docs/source/ext/connector.rst @@ -48,6 +48,10 @@ The ``streamflow.core.deployment`` module defines the ``Connector`` interface, w ) -> MutableMapping[str, AvailableLocation]: ... + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: ... + async def get_stream_reader( self, command: MutableSequence[str], @@ -86,6 +90,8 @@ The ``deploy`` method instantiates the remote execution environment, making it r The ``undeploy`` method destroys the remote execution environment, potentially cleaning up all the temporary resources instantiated during the workflow execution (e.g., intermediate results). If a ``deployment`` object is marked as ``external``, the ``undeploy`` method should not destroy it but just close all the connections opened by the ``deploy`` method. +The ``get_shell`` method returns a ``Shell`` object, which is an abstraction of a persistent remote shell that can be used to execute commands remotely in an efficient way. The ``command`` parameter is used to obtain a shell instance (e.g., ``["sh"]`` for a standard POSIX shell), and the ``location`` parameter identifies the remote ``ExecutionLocation`` where the shell should be instantiated. + The ``get_available_locations`` method is used in the scheduling phase to obtain the locations available for job execution, identified by their unique name (see :ref:`here `). The method receives an optional input parameter to filter valid locations. The ``service`` parameter specifies a specific set of locations in a deployment, and its precise meaning differs for each deployment type (see :ref:`here `). The ``get_stream_reader`` and ``get_stream_writer`` methods return an `Asynchronous Context Manager `_ wrapping a ``StreamWrapper`` instance, allowing it to be used inside ``async with`` statements. The ``StreamWrapper`` instance is obtained by executing the ``command`` on the ``location``, and can be used to read or write data using a stream (see :ref:`here `). The streams must be read and written respecting the size of the available buffer, which is defined by the ``transferBufferSize`` attribute of the ``Connector`` instance. These methods improve performance of data copies between pairs of remote locations. diff --git a/streamflow/core/deployment.py b/streamflow/core/deployment.py index 092e643bb..445b49795 100644 --- a/streamflow/core/deployment.py +++ b/streamflow/core/deployment.py @@ -4,7 +4,7 @@ import os import posixpath import tempfile -from abc import abstractmethod +from abc import ABC, abstractmethod from collections.abc import MutableMapping, MutableSequence from typing import TYPE_CHECKING, AsyncContextManager, cast @@ -97,7 +97,6 @@ def __hash__(self) -> int: class BindingFilter(SchemaEntity): - __slots__ = "name" def __init__(self, name: str) -> None: @@ -170,6 +169,11 @@ async def run( @abstractmethod async def undeploy(self, external: bool) -> None: ... + @abstractmethod + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: ... + @abstractmethod async def get_stream_reader( self, command: MutableSequence[str], location: ExecutionLocation @@ -288,6 +292,41 @@ async def save(self, context: StreamFlowContext) -> None: ) +class FilterConfig(PersistableEntity): + __slots__ = ("name", "type", "config") + + def __init__(self, name: str, type: str, config: MutableMapping[str, Any]): + super().__init__() + self.name: str = name + self.type: str = type + self.config: MutableMapping[str, Any] = config or {} + + @classmethod + async def load( + cls, + context: StreamFlowContext, + persistent_id: int, + loading_context: DatabaseLoadingContext, + ) -> Self: + row = await context.database.get_filter(persistent_id) + obj = cls( + name=row["name"], + type=row["type"], + config=row["config"], + ) + loading_context.add_filter(persistent_id, obj) + return obj + + async def save(self, context: StreamFlowContext) -> None: + async with self.persistence_lock: + if not self.persistent_id: + self.persistent_id = await context.database.add_filter( + name=self.name, + type=self.type, + config=self.config, + ) + + class Target(PersistableEntity): def __init__( self, @@ -389,39 +428,32 @@ async def _load( return cls(workdir=row["workdir"]) -class FilterConfig(PersistableEntity): - __slots__ = ("name", "type", "config") +class Shell(ABC): + __slots__ = ("command", "buffer_size") - def __init__(self, name: str, type: str, config: MutableMapping[str, Any]): - super().__init__() - self.name: str = name - self.type: str = type - self.config: MutableMapping[str, Any] = config or {} + def __init__( + self, + command: MutableSequence[str], + buffer_size: int, + ) -> None: + self.command: MutableSequence[str] = command + self.buffer_size: int = buffer_size - @classmethod - async def load( - cls, - context: StreamFlowContext, - persistent_id: int, - loading_context: DatabaseLoadingContext, - ) -> Self: - row = await context.database.get_filter(persistent_id) - obj = cls( - name=row["name"], - type=row["type"], - config=row["config"], - ) - loading_context.add_filter(persistent_id, obj) - return obj + @abstractmethod + async def close(self) -> None: ... - async def save(self, context: StreamFlowContext) -> None: - async with self.persistence_lock: - if not self.persistent_id: - self.persistent_id = await context.database.add_filter( - name=self.name, - type=self.type, - config=self.config, - ) + @abstractmethod + async def closed(self) -> bool: ... + + @abstractmethod + async def execute( + self, + command: MutableSequence[str], + environment: MutableMapping[str, str] | None = ..., + workdir: str | None = ..., + capture_output: bool = ..., + timeout: int | None = ..., + ) -> tuple[str, int] | None: ... class WrapsConfig: diff --git a/streamflow/core/utils.py b/streamflow/core/utils.py index 080c4881d..54030a9f7 100644 --- a/streamflow/core/utils.py +++ b/streamflow/core/utils.py @@ -14,11 +14,12 @@ from streamflow.core.exception import ProcessorTypeError, WorkflowExecutionException from streamflow.core.persistence import PersistableEntity +from streamflow.log_handler import logger if TYPE_CHECKING: from typing import TypeVar - from streamflow.core.deployment import Connector, ExecutionLocation + from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.workflow import Token T = TypeVar("T") @@ -311,6 +312,32 @@ def random_name() -> str: return str(uuid.uuid4()) +async def run_in_shell( + shell: Shell, + location: ExecutionLocation, + command: MutableSequence[str], + environment: MutableMapping[str, str] | None = None, + workdir: str | None = None, + capture_output: bool = False, + timeout: int | None = None, +) -> tuple[str, int] | None: + try: + return await shell.execute( + command=command, + environment=environment, + workdir=workdir, + capture_output=capture_output, + timeout=timeout, + ) + except WorkflowExecutionException as e: + logger.warning( + f"Persistent shell failed for location {location.name} " + f"of deployment {location.deployment}: " + f"falling back to direct exec: {e}" + ) + raise e + + async def run_in_subprocess( location: ExecutionLocation, command: MutableSequence[str], diff --git a/streamflow/cwl/utils.py b/streamflow/cwl/utils.py index e26e18fff..929d58d08 100644 --- a/streamflow/cwl/utils.py +++ b/streamflow/cwl/utils.py @@ -1375,35 +1375,32 @@ async def update_file_token( load_contents: bool | None, load_listing: LoadListing | None = None, ) -> MutableMapping[str, Any]: - if path := get_path_from_token(token_value): + new_token_value = dict(token_value) + if path := get_path_from_token(new_token_value): filepath = StreamFlowPath(path, context=context, location=location) # Process contents - if get_token_class(token_value) == "File" and load_contents is not None: - if load_contents and "contents" not in token_value: - token_value |= { + if get_token_class(new_token_value) == "File" and load_contents is not None: + if load_contents and "contents" not in new_token_value: + new_token_value |= { "contents": await _get_contents( filepath, - token_value["size"], + new_token_value["size"], cwl_version, ) } - elif not load_contents and "contents" in token_value: - token_value = { - k: token_value[k] for k in token_value if k != "contents" - } + elif not load_contents and "contents" in new_token_value: + del new_token_value["contents"] # Process listings - if get_token_class(token_value) == "Directory" and load_listing is not None: + if get_token_class(new_token_value) == "Directory" and load_listing is not None: # If load listing is set to `no_listing`, remove the listing entries in present - if load_listing == LoadListing.no_listing: - if "listing" in token_value: - token_value = { - k: token_value[k] for k in token_value if k != "listing" - } + if load_listing == LoadListing.no_listing and "listing" in new_token_value: + del new_token_value["listing"] # If listing is not present or if the token needs a deep listing, process directory contents elif ( - "listing" not in token_value or load_listing == LoadListing.deep_listing + "listing" not in new_token_value + or load_listing == LoadListing.deep_listing ): - token_value |= { + new_token_value |= { "listing": await _get_listing( context=context, connector=connector, @@ -1416,13 +1413,13 @@ async def update_file_token( } # If load listing is set to `shallow_listing`, remove the deep listing entries if present elif load_listing == LoadListing.shallow_listing: - token_value |= { + new_token_value |= { "listing": [ {k: v[k] for k in v if k != "listing"} - for v in token_value["listing"] + for v in new_token_value["listing"] ] } - return token_value + return new_token_value async def write_remote_file( diff --git a/streamflow/data/remotepath.py b/streamflow/data/remotepath.py index 755e407b8..01cd5f653 100644 --- a/streamflow/data/remotepath.py +++ b/streamflow/data/remotepath.py @@ -625,18 +625,19 @@ async def glob( if not pattern: raise ValueError(f"Unacceptable pattern: {pattern!r}") command = [ + "set", + "--", + f"{shlex.quote(str(self))}/{pattern}", + ";", + "test", + "-e", + '"$1"', + "&&", "printf", - '"%s\\0"', - str(self / pattern), - "|", - "xargs", - "-0", - "-I{}", - "sh", - "-c", - '"if [ -e \\"{}\\" ]; then echo \\"{}\\"; fi"', - "|", - "sort", + "'%s\\n'", + '"$@"', + ";", + ":", ] result, status = await self.connector.run( location=self.location, command=command, capture_output=True @@ -726,7 +727,7 @@ async def rmtree(self) -> None: if (inner_path := await self._get_inner_path()) != self: await inner_path.rmtree() else: - command = ["rm", "-rf ", self.__str__()] + command = ["rm", "-rf", self.__str__()] result, status = await self.connector.run( location=self.location, command=command, capture_output=True ) diff --git a/streamflow/deployment/aiotarstream.py b/streamflow/deployment/aiotarstream.py index a4277e555..9a916f598 100644 --- a/streamflow/deployment/aiotarstream.py +++ b/streamflow/deployment/aiotarstream.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import grp import os @@ -76,10 +77,7 @@ def __init__( self.cmp: Compressor | Decompressor | None = cmp self.exception: type[Exception] | None = exception - async def close(self) -> None: - if self.closed: - return - self.closed = True + async def _close(self) -> None: try: if self.mode == "w": await self.stream.write(self.cmp.flush()) @@ -179,10 +177,7 @@ async def _init_write_gz(self) -> None: b"\037\213\010\010" + timestamp + b"\002\377" + tarfile.NUL ) - async def close(self) -> None: - if self.closed: - return - self.closed = True + async def _close(self) -> None: try: await self.stream.write(self.cmp.flush()) await self.stream.write(struct.pack(" uname self._gnames: dict[int, str] = {} # Cached mappings of gid -> gname + self._closing: asyncio.Event | None = None async def __aenter__(self) -> Self: self._check() @@ -665,6 +661,19 @@ def _check(self, mode: Literal["r", "a", "w", "x"] | None = None) -> None: if mode is not None and self.mode not in mode: raise OSError("bad operation for mode %r" % self.mode) + async def _close(self): + try: + if self.mode in ("a", "w", "x"): + await self.stream.write(tarfile.NUL * (tarfile.BLOCKSIZE * 2)) + self.offset += tarfile.BLOCKSIZE * 2 + blocks, remainder = divmod(self.offset, tarfile.RECORDSIZE) + if remainder > 0: + await self.stream.write( + tarfile.NUL * (tarfile.RECORDSIZE - remainder) + ) + finally: + await self.stream.close() + def _dbg(self, level: int, msg: str) -> None: if level <= self.debug: print(msg, file=sys.stderr) @@ -831,18 +840,13 @@ def chmod(self, tarinfo: AioTarInfo, targetpath: StrOrBytesPath) -> None: async def close(self) -> None: if self.closed: return - self.closed = True - try: - if self.mode in ("a", "w", "x"): - await self.stream.write(tarfile.NUL * (tarfile.BLOCKSIZE * 2)) - self.offset += tarfile.BLOCKSIZE * 2 - blocks, remainder = divmod(self.offset, tarfile.RECORDSIZE) - if remainder > 0: - await self.stream.write( - tarfile.NUL * (tarfile.RECORDSIZE - remainder) - ) - finally: - await self.stream.close() + if self._closing is not None: + await self._closing.wait() + else: + self._closing = asyncio.Event() + await self._close() + self.closed = True + self._closing.set() async def extract( self, diff --git a/streamflow/deployment/connector/base.py b/streamflow/deployment/connector/base.py index 23e16bb60..58134cf3f 100644 --- a/streamflow/deployment/connector/base.py +++ b/streamflow/deployment/connector/base.py @@ -9,19 +9,18 @@ import tarfile from abc import ABC from collections.abc import MutableMapping, MutableSequence -from typing import AsyncContextManager +from types import TracebackType +from typing import Any, AsyncContextManager from streamflow.core import utils from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import WorkflowExecutionException from streamflow.core.utils import get_local_to_remote_destination from streamflow.deployment import aiotarstream from streamflow.deployment.future import FutureAware -from streamflow.deployment.stream import ( - SubprocessStreamReaderWrapperContextManager, - SubprocessStreamWriterWrapperContextManager, -) +from streamflow.deployment.shell import BaseShell +from streamflow.deployment.stream import BaseStreamWrapper from streamflow.log_handler import logger FS_TYPES_TO_SKIP = { @@ -262,6 +261,105 @@ async def copy_same_connector( return locations +class SubprocessStreamWrapperContextManager(AsyncContextManager[StreamWrapper], ABC): + def __init__(self, coro) -> None: + self.coro = coro + self.proc: asyncio.subprocess.Process | None = None + self.stream: StreamWrapper | None = None + + +class SubprocessStreamReaderWrapper(BaseStreamWrapper): + async def close(self) -> None: + self.closed = True + + async def write(self, data: bytes) -> int: + raise NotImplementedError + + +class SubprocessStreamReaderWrapperContextManager( + SubprocessStreamWrapperContextManager +): + async def __aenter__(self) -> StreamWrapper: + self.proc = await self.coro + self.stream = SubprocessStreamReaderWrapper(self.proc.stdout) + return self.stream + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.proc.wait() + if self.stream: + await self.stream.close() + + +class SubprocessStreamWriterWrapper(BaseStreamWrapper): + async def _close(self) -> None: + self.stream.close() + await self.stream.wait_closed() + + async def read(self, n: int = -1) -> bytes: + raise NotImplementedError + + async def write(self, data: Any) -> None: + self.stream.write(data) + await self.stream.drain() + + +class SubprocessStreamWriterWrapperContextManager( + SubprocessStreamWrapperContextManager +): + async def __aenter__(self) -> StreamWrapper: + self.proc = await self.coro + self.stream = SubprocessStreamWriterWrapper(self.proc.stdin) + return self.stream + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.stream: + await self.stream.close() + await self.proc.wait() + + +class SubprocessShell(BaseShell): + __slots__ = "_proc" + + def __init__( + self, + command: MutableSequence[str], + buffer_size: int, + process: asyncio.subprocess.Process, + ) -> None: + super().__init__(command, buffer_size) + self._proc: asyncio.subprocess.Process = process + self._reader = SubprocessStreamReaderWrapper(self._proc.stdout) + self._writer = SubprocessStreamWriterWrapper(self._proc.stdin) + + async def _close(self) -> None: + try: + if self._proc.returncode is None: + await self._writer.write(b"exit\n") + try: + await asyncio.wait_for(self._proc.wait(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning( + f"Shell with command `{' '.join(self.command)}` " + "did not exit gracefully. Killing" + ) + self._proc.kill() + await self._proc.wait() + except Exception: + with contextlib.suppress(Exception): + self._proc.kill() + await self._proc.wait() + + class BaseConnector(Connector, FutureAware, ABC): def __init__(self, deployment_name: str, config_dir: str, transferBufferSize: int): super().__init__( @@ -269,6 +367,21 @@ def __init__(self, deployment_name: str, config_dir: str, transferBufferSize: in config_dir=config_dir, transferBufferSize=transferBufferSize, ) + self._shells: MutableMapping[str, Shell] = {} + self._shells_lock: asyncio.Lock = asyncio.Lock() + + async def _create_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ) + return SubprocessShell( + command=command, buffer_size=self.transferBufferSize, process=process + ) async def copy_local_to_remote( self, @@ -336,6 +449,24 @@ async def copy_remote_to_remote( source_location=source_location, ) + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + async with self._shells_lock: + if (key := str(hash("".join(command)))) in self._shells: + shell = self._shells[key] + if not await shell.closed(): + return shell + else: + del self._shells[key] + try: + self._shells[key] = await self._create_shell(command, location) + return self._shells[key] + except Exception as e: + raise WorkflowExecutionException( + f"Failed to create shell with command `{' '.join(command)}`: {e}" + ) from e + async def get_stream_reader( self, command: MutableSequence[str], @@ -375,6 +506,19 @@ async def run( timeout: int | None = None, job_name: str | None = None, ) -> tuple[str, int] | None: + if job_name is None and stdin is None: + with contextlib.suppress(WorkflowExecutionException): + return await utils.run_in_shell( + shell=await self.get_shell( + command=["sh"], location=location + ), # nosec + location=location, + command=command, + environment=environment, + workdir=workdir, + capture_output=capture_output, + timeout=timeout, + ) command_str = utils.create_command( self.__class__.__name__, command, @@ -399,6 +543,13 @@ async def run( timeout=timeout, ) + async def undeploy(self, external: bool) -> None: + async with self._shells_lock: + await asyncio.gather( + *(asyncio.create_task(shell.close()) for shell in self._shells.values()) + ) + self._shells.clear() + class BatchConnector(Connector, ABC): pass diff --git a/streamflow/deployment/connector/container.py b/streamflow/deployment/connector/container.py index 751fb8efc..2a218e550 100644 --- a/streamflow/deployment/connector/container.py +++ b/streamflow/deployment/connector/container.py @@ -9,6 +9,7 @@ import shlex from abc import ABC, abstractmethod from collections.abc import MutableMapping, MutableSequence +from contextlib import suppress from importlib.resources import files from shutil import which from typing import Any, AsyncContextManager, cast @@ -17,7 +18,7 @@ from streamflow.core import utils from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import ( WorkflowDefinitionException, WorkflowExecutionException, @@ -613,6 +614,19 @@ async def run( timeout: int | None = None, job_name: str | None = None, ) -> tuple[str, int] | None: + if job_name is None and stdin is None: + with suppress(WorkflowExecutionException): + return await utils.run_in_shell( + shell=await self.get_shell( + command=["sh"], location=location + ), # nosec + location=location, + command=command, + environment=environment, + workdir=workdir, + capture_output=capture_output, + timeout=timeout, + ) command = utils.create_command( self.__class__.__name__, command, @@ -869,6 +883,21 @@ async def _populate_instance(self, name: str) -> None: volumes=volumes, ) + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + inner_location = get_inner_location(location) + return await self.connector.get_shell( + command=[ + "docker", + "exec", + "--interactive", + location.name, + *command, + ], + location=inner_location, + ) + class DockerConnector(DockerBaseConnector): def __init__( @@ -1972,6 +2001,19 @@ def get_schema(cls) -> str: .read_text("utf-8") ) + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + return await self.connector.get_shell( + command=[ + "singularity", + "exec", + f"instance://{location.name}", + *command, + ], + location=get_inner_location(location), + ) + async def undeploy(self, external: bool) -> None: if not external: stdout, returncode = await self.connector.run( diff --git a/streamflow/deployment/connector/kubernetes.py b/streamflow/deployment/connector/kubernetes.py index 96c0de3f9..6c8a4cb65 100644 --- a/streamflow/deployment/connector/kubernetes.py +++ b/streamflow/deployment/connector/kubernetes.py @@ -2,21 +2,25 @@ import asyncio import io +import json import logging import os import re import shlex import uuid from abc import ABC, abstractmethod -from collections.abc import Awaitable, Coroutine, MutableMapping, MutableSequence +from collections.abc import Awaitable, MutableMapping, MutableSequence +from contextlib import suppress from importlib.resources import files from math import ceil, floor from pathlib import Path from shutil import which from types import TracebackType -from typing import Any, AsyncContextManager, cast +from typing import Any, AsyncContextManager, Final, cast import yaml +from aiohttp import ClientWebSocketResponse, WSMsgType +from aiohttp.client import _BaseRequestContextManager from cachebox import BaseCacheImpl, TTLCache, cached from kubernetes_asyncio import client from kubernetes_asyncio.client import ApiClient, Configuration, V1Container, V1PodList @@ -30,7 +34,7 @@ from streamflow.core import utils from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import ( WorkflowDefinitionException, WorkflowExecutionException, @@ -43,9 +47,15 @@ copy_remote_to_remote, copy_same_connector, ) +from streamflow.deployment.shell import BaseShell from streamflow.log_handler import logger -SERVICE_NAMESPACE_FILENAME = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" +HEADER_SEC_WEBSOCKET_PROTOCOL: Final[str] = "sec-websocket-protocol" +SERVICE_NAMESPACE_FILENAME: Final[str] = ( + "/var/run/secrets/kubernetes.io/serviceaccount/namespace" +) +STREAM_CLOSE: Final[bytes] = b"\xff" +STREAM_PROTOCOL_V5: Final[str] = "v5.channel.k8s.io" def _check_helm_installed() -> None: @@ -95,62 +105,121 @@ def _selector_from_set(selector: MutableMapping[str, Any]) -> str: return ",".join(requirements) -class KubernetesResponseWrapper(BaseStreamWrapper): - def __init__(self, stream: Any) -> None: +class KubernetesResponseWrapperContextManager(AsyncContextManager[StreamWrapper], ABC): + def __init__(self, request: _BaseRequestContextManager) -> None: + self.request: _BaseRequestContextManager = request + self.response: StreamWrapper | None = None + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.response: + await self.response.close() + + +class KubernetesResponseReaderWrapper(BaseStreamWrapper): + def __init__(self, stream: Any, flush: bool = False) -> None: super().__init__(stream) - self.msg: bytes = b"" + self.flush: bool = flush + self._buffer: bytearray = bytearray() async def read(self, size: int | None = None) -> bytes | None: - if len(self.msg) > 0: - if len(self.msg) > size: - data = self.msg[0:size] - self.msg = self.msg[size:] - return data + if self._buffer: + if size is None: + data = bytes(self._buffer) + self._buffer.clear() else: - data = self.msg - size -= len(self.msg) - self.msg = b"" + data = bytes(self._buffer[:size]) + del self._buffer[:size] + size -= len(data) else: data = b"" - while size > 0 and not self.stream.closed: - async for msg in self.stream: - channel = msg.data[0] - self.msg = msg.data[1:] - if self.msg and channel == ws_client.STDOUT_CHANNEL: - if len(self.msg) > size: - data += self.msg[0:size] - self.msg = self.msg[size:] - return data + while (size is None or size > 0) and not self.stream.closed: + msg = await self.stream.receive() + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + self.closed = True + break + if msg.data[0] == ws_client.STDOUT_CHANNEL and (payload := msg.data[1:]): + if size is None: + data += payload + else: + if (n := len(payload)) <= size: + data += payload + size -= n else: - data += self.msg - size -= len(self.msg) - self.msg = b"" + data += payload[:size] + self._buffer.extend(payload[size:]) + break + if self.flush: + break return data if len(data) > 0 else None + async def write(self, data: Any) -> None: + raise NotImplementedError + + +class KubernetesResponseReaderWrapperContextManager( + KubernetesResponseWrapperContextManager +): + async def __aenter__(self) -> StreamWrapper: + response = await self.request + self.response = KubernetesResponseReaderWrapper(await response.__aenter__()) + return self.response + + +class KubernetesResponseWriterWrapper(BaseStreamWrapper): + async def _close(self) -> None: + await self.stream.send_bytes( + STREAM_CLOSE + bytes(chr(ws_client.STDIN_CHANNEL), "ascii") + ) + while True: + msg = await self.stream.receive() + if payload := msg.data[1:]: + result = json.loads(payload.decode("utf-8")) + if result["status"] == "Success": + await self.stream.close() + break + else: + raise WorkflowExecutionException( + f"Kubernetes connection terminated with status {result['status']}." + ) + + async def read(self, size: int | None = None) -> bytes | None: + raise NotImplementedError + async def write(self, data: Any) -> None: channel_prefix = bytes(chr(ws_client.STDIN_CHANNEL), "ascii") payload = channel_prefix + data await self.stream.send_bytes(payload) -class KubernetesResponseWrapperContextManager(AsyncContextManager[StreamWrapper]): - def __init__(self, coro) -> None: - self.coro = coro - self.response: KubernetesResponseWrapper | None = None - - async def __aenter__(self) -> KubernetesResponseWrapper: - response = await self.coro - self.response = KubernetesResponseWrapper(await response.__aenter__()) +class KubernetesResponseWriterWrapperContextManager( + KubernetesResponseWrapperContextManager +): + async def __aenter__(self) -> StreamWrapper: + response = await self.request + self.response = KubernetesResponseWriterWrapper(await response.__aenter__()) return self.response - async def __aexit__( + +class KubernetesShell(BaseShell): + def __init__( self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - if self.response: - await self.response.close() + command: MutableSequence[str], + buffer_size: int, + stream: ClientWebSocketResponse, + ): + super().__init__(command=command, buffer_size=buffer_size) + self._reader: StreamWrapper = KubernetesResponseReaderWrapper( + stream, flush=True + ) + self._writer: StreamWrapper = KubernetesResponseWriterWrapper(stream) + + async def _close(self) -> None: + await self._reader.close() class KubernetesBaseConnector(BaseConnector, ABC): @@ -222,60 +291,27 @@ def _configure_incluster_namespace(self) -> None: if not self.namespace: raise ConfigException("Namespace file exists but empty.") - async def copy_local_to_remote( - self, - src: str, - dst: str, - locations: MutableSequence[ExecutionLocation], - read_only: bool = False, - ) -> None: - await super().copy_local_to_remote( - src=src, - dst=dst, - locations=await self._get_effective_locations(locations, dst), - read_only=read_only, + async def _create_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + pod, container = location.name.split(":") + response = await self.client_ws.connect_get_namespaced_pod_exec( + name=pod, + namespace=self.namespace or "default", + container=container, + command=command, + stderr=True, + stdin=True, + stdout=True, + tty=False, + _headers={HEADER_SEC_WEBSOCKET_PROTOCOL: STREAM_PROTOCOL_V5}, + _preload_content=False, + ) + return KubernetesShell( + command=command, + buffer_size=self.transferBufferSize, + stream=await response.__aenter__(), ) - - async def copy_remote_to_remote( - self, - src: str, - dst: str, - locations: MutableSequence[ExecutionLocation], - source_location: ExecutionLocation, - source_connector: Connector | None = None, - read_only: bool = False, - ) -> None: - source_connector = source_connector or self - if locations := await copy_same_connector( - connector=self, - locations=await self._get_effective_locations(locations, dst), - source_location=source_location, - src=src, - dst=dst, - read_only=read_only, - ): - await copy_remote_to_remote( - connector=self, - locations=locations, - src=src, - dst=dst, - source_connector=source_connector, - source_location=source_location, - writer_command=[ - "sh", - "-c", - " ".join( - await utils.get_remote_to_remote_write_command( - src_connector=source_connector, - src_location=source_location, - src=src, - dst_connector=self, - dst_locations=locations, - dst=dst, - ) - ), - ], - ) async def _get_container( self, location: ExecutionLocation @@ -340,6 +376,61 @@ async def _get_effective_locations( @abstractmethod async def _get_running_pods(self) -> V1PodList: ... + async def copy_local_to_remote( + self, + src: str, + dst: str, + locations: MutableSequence[ExecutionLocation], + read_only: bool = False, + ) -> None: + await super().copy_local_to_remote( + src=src, + dst=dst, + locations=await self._get_effective_locations(locations, dst), + read_only=read_only, + ) + + async def copy_remote_to_remote( + self, + src: str, + dst: str, + locations: MutableSequence[ExecutionLocation], + source_location: ExecutionLocation, + source_connector: Connector | None = None, + read_only: bool = False, + ) -> None: + source_connector = source_connector or self + if locations := await copy_same_connector( + connector=self, + locations=await self._get_effective_locations(locations, dst), + source_location=source_location, + src=src, + dst=dst, + read_only=read_only, + ): + await copy_remote_to_remote( + connector=self, + locations=locations, + src=src, + dst=dst, + source_connector=source_connector, + source_location=source_location, + writer_command=[ + "sh", + "-c", + " ".join( + await utils.get_remote_to_remote_write_command( + src_connector=source_connector, + src_location=source_location, + src=src, + dst_connector=self, + dst_locations=locations, + dst=dst, + ) + ), + ], + ) + async def deploy(self, external: bool) -> None: # Init standard client configuration = await self._get_configuration() @@ -385,20 +476,18 @@ async def get_stream_reader( self, command: MutableSequence[int], location: ExecutionLocation ) -> AsyncContextManager[StreamWrapper]: pod, container = location.name.split(":") - return KubernetesResponseWrapperContextManager( - coro=cast( - Coroutine, - self.client_ws.connect_get_namespaced_pod_exec( - name=pod, - namespace=self.namespace or "default", - container=container, - command=command, - stderr=True, - stdin=False, - stdout=True, - tty=False, - _preload_content=False, - ), + return KubernetesResponseReaderWrapperContextManager( + request=self.client_ws.connect_get_namespaced_pod_exec( + name=pod, + namespace=self.namespace or "default", + container=container, + command=command, + stderr=True, + stdin=False, + stdout=True, + tty=False, + _headers={HEADER_SEC_WEBSOCKET_PROTOCOL: STREAM_PROTOCOL_V5}, + _preload_content=False, ) ) @@ -406,20 +495,18 @@ async def get_stream_writer( self, command: MutableSequence[str], location: ExecutionLocation ) -> AsyncContextManager[StreamWrapper]: pod, container = location.name.split(":") - return KubernetesResponseWrapperContextManager( - coro=cast( - Coroutine, - self.client_ws.connect_get_namespaced_pod_exec( - name=pod, - namespace=self.namespace or "default", - container=container, - command=command, - stderr=False, - stdin=True, - stdout=False, - tty=False, - _preload_content=False, - ), + return KubernetesResponseWriterWrapperContextManager( + request=self.client_ws.connect_get_namespaced_pod_exec( + name=pod, + namespace=self.namespace or "default", + container=container, + command=command, + stderr=False, + stdin=True, + stdout=False, + tty=False, + _headers={HEADER_SEC_WEBSOCKET_PROTOCOL: STREAM_PROTOCOL_V5}, + _preload_content=False, ) ) @@ -436,6 +523,19 @@ async def run( timeout: int | None = None, job_name: str | None = None, ) -> tuple[str, int] | None: + if job_name is None and stdin is None: + with suppress(WorkflowExecutionException): + return await utils.run_in_shell( + shell=await self.get_shell( + command=["sh"], location=location + ), # nosec + location=location, + command=command, + environment=environment, + workdir=workdir, + capture_output=capture_output, + timeout=timeout, + ) command = utils.create_command( self.__class__.__name__, command, @@ -506,6 +606,7 @@ async def run( return None async def undeploy(self, external: bool) -> None: + await super().undeploy(external) if self.client is not None: await self.client.api_client.close() self.client = None @@ -721,8 +822,8 @@ async def _undeploy(self, k8s_object: Any) -> Any: resp = await getattr(k8s_api, f"delete_{kind}")( name=k8s_object.metadata.name ) - if self.debug: - print(f"{kind} deleted. status='{str(resp.status)}'") + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"{kind} deleted. status='{str(resp.status)}'") return resp async def _wait(self, k8s_object: Any) -> None: diff --git a/streamflow/deployment/connector/ssh.py b/streamflow/deployment/connector/ssh.py index 6e3038b92..9f04e5af5 100644 --- a/streamflow/deployment/connector/ssh.py +++ b/streamflow/deployment/connector/ssh.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import logging import os from abc import ABC @@ -14,7 +15,7 @@ from streamflow.core import utils from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import WorkflowExecutionException from streamflow.core.scheduling import AvailableLocation, Hardware, Storage from streamflow.deployment.connector.base import ( @@ -23,7 +24,8 @@ copy_remote_to_remote, copy_same_connector, ) -from streamflow.deployment.stream import StreamReaderWrapper, StreamWriterWrapper +from streamflow.deployment.shell import BaseShell +from streamflow.deployment.stream import BaseStreamWrapper from streamflow.deployment.template import CommandTemplateMap from streamflow.log_handler import logger @@ -300,6 +302,25 @@ def get( ) +class SSHShell(BaseShell): + __slots__ = "_context" + + def __init__( + self, + command: MutableSequence[str], + buffer_size: int, + process: asyncssh.SSHClientProcess, + context: SSHContextManager, + ): + super().__init__(command=command, buffer_size=buffer_size) + self._context: SSHContextManager = context + self._reader = SSHStreamReaderWrapper(process.stdout) + self._writer = SSHStreamWriterWrapper(process.stdin) + + async def _close(self) -> None: + await self._context.__aexit__(None, None, None) + + class SSHStreamWrapperContextManager(AsyncContextManager[StreamWrapper], ABC): def __init__( self, @@ -325,8 +346,16 @@ async def __aexit__( await self.ssh_context.__aexit__(exc_type, exc_val, exc_tb) +class SSHStreamReaderWrapper(BaseStreamWrapper): + async def close(self) -> None: + self.closed = True + + async def write(self, data: Any) -> None: + raise NotImplementedError + + class SSHStreamReaderWrapperContextManager(SSHStreamWrapperContextManager): - async def __aenter__(self) -> StreamReaderWrapper: + async def __aenter__(self) -> StreamWrapper: self.ssh_context = self.ssh_context_factory.get( command=" ".join(self.command), environment=self.environment, @@ -334,12 +363,25 @@ async def __aenter__(self) -> StreamReaderWrapper: encoding=None, ) proc = await self.ssh_context.__aenter__() - self.stream = StreamReaderWrapper(proc.stdout) + self.stream = SSHStreamReaderWrapper(proc.stdout) return self.stream +class SSHStreamWriterWrapper(BaseStreamWrapper): + async def _close(self) -> None: + self.stream.close() + await self.stream.wait_closed() + + async def read(self, n: int = -1) -> bytes: + raise NotImplementedError + + async def write(self, data: Any) -> None: + self.stream.write(data) + await self.stream.drain() + + class SSHStreamWriterWrapperContextManager(SSHStreamWrapperContextManager): - async def __aenter__(self) -> StreamWriterWrapper: + async def __aenter__(self) -> StreamWrapper: self.ssh_context = self.ssh_context_factory.get( command=" ".join(self.command), environment=self.environment, @@ -348,7 +390,7 @@ async def __aenter__(self) -> StreamWriterWrapper: encoding=None, ) proc = await self.ssh_context.__aenter__() - self.stream = StreamWriterWrapper(proc.stdin) + self.stream = SSHStreamWriterWrapper(proc.stdin) return self.stream @@ -447,41 +489,21 @@ def __init__( self.hardware: MutableMapping[str, Hardware] = {} self._cls_context: type[SSHContext] = SSHContext - async def copy_remote_to_remote( - self, - src: str, - dst: str, - locations: MutableSequence[ExecutionLocation], - source_location: ExecutionLocation, - source_connector: Connector | None = None, - read_only: bool = False, - ) -> None: - source_connector = source_connector or self - if locations := await copy_same_connector( - connector=self, - locations=locations, - source_location=source_location, - src=src, - dst=dst, - read_only=read_only, - ): - conn_per_round = min(len(locations), self.maxConcurrentSessions) - rounds = self.maxConcurrentSessions // conn_per_round - if len(locations) % conn_per_round != 0: - rounds += 1 - location_groups = [ - locations[i : i + rounds] for i in range(0, len(locations), rounds) - ] - for location_group in location_groups: - # Perform remote to remote copy - await copy_remote_to_remote( - connector=self, - locations=location_group, - src=src, - dst=dst, - source_connector=source_connector, - source_location=source_location, - ) + async def _create_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + context = self._get_ssh_context_factory(location).get( + command=" ".join(command), + environment=location.environment, + stderr=asyncio.subprocess.STDOUT, + encoding=None, + ) + return SSHShell( + command=command, + buffer_size=self.transferBufferSize, + process=await context.__aenter__(), + context=context, + ) async def _get_available_location(self, location: str) -> Hardware: if location not in self.hardware.keys(): @@ -634,6 +656,42 @@ def _get_ssh_context_factory( ) return self.ssh_context_factories[location.name] + async def copy_remote_to_remote( + self, + src: str, + dst: str, + locations: MutableSequence[ExecutionLocation], + source_location: ExecutionLocation, + source_connector: Connector | None = None, + read_only: bool = False, + ) -> None: + source_connector = source_connector or self + if locations := await copy_same_connector( + connector=self, + locations=locations, + source_location=source_location, + src=src, + dst=dst, + read_only=read_only, + ): + conn_per_round = min(len(locations), self.maxConcurrentSessions) + rounds = self.maxConcurrentSessions // conn_per_round + if len(locations) % conn_per_round != 0: + rounds += 1 + location_groups = [ + locations[i : i + rounds] for i in range(0, len(locations), rounds) + ] + for location_group in location_groups: + # Perform remote to remote copy + await copy_remote_to_remote( + connector=self, + locations=location_group, + src=src, + dst=dst, + source_connector=source_connector, + source_location=source_location, + ) + async def deploy(self, external: bool) -> None: pass @@ -701,6 +759,19 @@ async def run( timeout: int | None = None, job_name: str | None = None, ) -> tuple[str, int] | None: + if job_name is None and stdin is None: + with contextlib.suppress(WorkflowExecutionException): + return await utils.run_in_shell( + shell=await self.get_shell( + command=["sh"], location=location + ), # nosec + location=location, + command=command, + environment=environment, + workdir=workdir, + capture_output=capture_output, + timeout=timeout, + ) command = self._get_command( location=location, command=command, @@ -735,6 +806,7 @@ async def run( return (result.stdout.strip(), result.returncode) if capture_output else None async def undeploy(self, external: bool) -> None: + await super().undeploy(external) await asyncio.gather( *( asyncio.create_task(ssh_context.close()) diff --git a/streamflow/deployment/future.py b/streamflow/deployment/future.py index ba9e87d24..d897e411c 100644 --- a/streamflow/deployment/future.py +++ b/streamflow/deployment/future.py @@ -7,7 +7,7 @@ from typing import Any, AsyncContextManager from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import WorkflowExecutionException from streamflow.core.scheduling import AvailableLocation from streamflow.log_handler import logger @@ -151,6 +151,17 @@ async def get_available_locations( await self._safe_deploy_event_wait() return await self._connector.get_available_locations(service=service) + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + if self._connector is None: + if not self.deploying: + self.deploying = True + await self.deploy(self.external) + else: + await self._safe_deploy_event_wait() + return await self._connector.get_shell(command=command, location=location) + async def get_stream_reader( self, command: MutableSequence[str], location: ExecutionLocation ) -> AsyncContextManager[StreamWrapper]: diff --git a/streamflow/deployment/shell.py b/streamflow/deployment/shell.py new file mode 100644 index 000000000..25ae56b8a --- /dev/null +++ b/streamflow/deployment/shell.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +import codecs +import logging +import shlex +from abc import ABC, abstractmethod +from collections.abc import MutableMapping, MutableSequence + +from streamflow.core.data import StreamWrapper +from streamflow.core.deployment import Shell +from streamflow.core.exception import WorkflowExecutionException +from streamflow.core.utils import random_name +from streamflow.log_handler import logger + + +def _build_shell_command( + end_marker: str, + command: MutableSequence[str], + shell_class: str, + shell_cmd: MutableSequence[str], + environment: MutableMapping[str, str] | None = None, + workdir: str | None = None, +) -> str: + if environment or workdir: + subshell_parts = [] + if workdir: + subshell_parts.append(f"cd {shlex.quote(workdir)}") + if environment: + for key, value in environment.items(): + subshell_parts.append(f"export {key}={shlex.quote(value)}") + subshell_parts.append(" ".join(command)) + cmd = f"sh -c {shlex.quote('; '.join(subshell_parts))} 2>&1" + else: + cmd = f"{' '.join(command)} 2>&1" + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"EXECUTING command {cmd} on {shell_class} " + f"with command `{' '.join(shell_cmd)}`" + ) + return f'{cmd}\necho "{end_marker}:$?"\n' + + +class BaseShell(Shell, ABC): + __slots__ = ("_closed", "_lock", "_reader", "_writer", "_decoder") + + def __init__( + self, + command: MutableSequence[str], + buffer_size: int, + ): + super().__init__(command=command, buffer_size=buffer_size) + self._closed: bool = False + self._lock: asyncio.Lock = asyncio.Lock() + self._reader: StreamWrapper | None = None + self._writer: StreamWrapper | None = None + self._decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") + + @abstractmethod + async def _close(self) -> None: ... + + async def _read_with_output( + self, end_marker: str, timeout: int | None + ) -> tuple[str, int]: + output = "" + while True: + try: + chunk = await asyncio.wait_for( + self._reader.read(self.buffer_size), timeout=timeout + ) + except asyncio.TimeoutError: + raise WorkflowExecutionException("Timeout waiting for command output") + if not chunk: + raise WorkflowExecutionException( + "Shell process terminated unexpectedly" + ) + decoded_chunk = self._decoder.decode(chunk, final=False) + output += decoded_chunk + if (marker_pos := output.find(f"{end_marker}:")) != -1 and ( + newline_pos := output.find("\n", marker_pos) + ) != -1: + returncode_str = output[marker_pos + len(end_marker) + 1 : newline_pos] + try: + returncode = int(returncode_str) + except ValueError: + raise WorkflowExecutionException( + f"Invalid return code: {returncode_str}" + ) + final_output = output[:marker_pos].strip() + self._decoder.reset() + return final_output, returncode + + async def _read_without_output(self, end_marker: str, timeout: int | None) -> None: + output = "" + while True: + try: + chunk = await asyncio.wait_for( + self._reader.read(self.buffer_size), timeout=timeout + ) + except asyncio.TimeoutError: + raise WorkflowExecutionException("Timeout discarding output") + if not chunk: + raise WorkflowExecutionException( + "Shell process terminated unexpectedly" + ) + output += self._decoder.decode(chunk, final=False) + if (marker_pos := output.find(f"{end_marker}:")) != -1 and output.find( + "\n", marker_pos + ) != -1: + self._decoder.reset() + return None + + async def close(self) -> None: + async with self._lock: + if self._closed: + return + await self._close() + self._closed = True + + async def closed(self) -> bool: + async with self._lock: + return self._closed + + async def execute( + self, + command: MutableSequence[str], + environment: MutableMapping[str, str] | None = None, + workdir: str | None = None, + capture_output: bool = False, + timeout: int | None = None, + ) -> tuple[str, int] | None: + async with self._lock: + if self._closed: + raise WorkflowExecutionException("Shell process is terminated") + end_marker = f"SF_CMD_END_{random_name()}" + shell_command = _build_shell_command( + end_marker=end_marker, + command=command, + shell_class=self.__class__.__name__, + shell_cmd=self.command, + environment=environment, + workdir=workdir, + ) + try: + await self._writer.write(shell_command.encode()) + if capture_output: + return await self._read_with_output(end_marker, timeout) + else: + return await self._read_without_output(end_marker, timeout) + except (BrokenPipeError, ConnectionResetError) as e: + raise WorkflowExecutionException(f"Shell pipe broken: {e}") from e + except asyncio.TimeoutError as e: + raise WorkflowExecutionException( + f"Command timeout after {timeout}s" + ) from e diff --git a/streamflow/deployment/stream.py b/streamflow/deployment/stream.py index 9b65b3d85..07a2363ac 100644 --- a/streamflow/deployment/stream.py +++ b/streamflow/deployment/stream.py @@ -1,9 +1,7 @@ from __future__ import annotations -import asyncio.subprocess -from abc import ABC -from types import TracebackType -from typing import Any, AsyncContextManager +import asyncio +from typing import Any from streamflow.core.data import StreamWrapper @@ -12,84 +10,24 @@ class BaseStreamWrapper(StreamWrapper): def __init__(self, stream) -> None: super().__init__(stream) self.closed = False + self._closing: asyncio.Event | None = None - async def close(self) -> None: - if self.closed: - return - self.closed = True + async def _close(self) -> None: await self.stream.close() - async def read(self, size: int | None = None) -> bytes: - return await self.stream.read(size) - - async def write(self, data: Any) -> None: - return await self.stream.write(data) - - -class StreamReaderWrapper(StreamWrapper): async def close(self) -> None: - pass + if self.closed: + return + if self._closing is not None: + await self._closing.wait() + else: + self._closing = asyncio.Event() + await self._close() + self.closed = True + self._closing.set() async def read(self, size: int | None = None) -> bytes: return await self.stream.read(size) async def write(self, data: Any) -> None: - raise NotImplementedError - - -class StreamWriterWrapper(StreamWrapper): - async def close(self) -> None: - self.stream.close() - await self.stream.wait_closed() - - async def read(self, size: int | None = None) -> bytes: - raise NotImplementedError - - async def write(self, data: Any) -> None: - self.stream.write(data) - await self.stream.drain() - - -class SubprocessStreamWrapperContextManager(AsyncContextManager[StreamWrapper], ABC): - def __init__(self, coro) -> None: - self.coro = coro - self.proc: asyncio.subprocess.Process | None = None - self.stream: StreamWrapper | None = None - - -class SubprocessStreamReaderWrapperContextManager( - SubprocessStreamWrapperContextManager -): - async def __aenter__(self) -> StreamReaderWrapper: - self.proc = await self.coro - self.stream = StreamReaderWrapper(self.proc.stdout) - return self.stream - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - await self.proc.wait() - if self.stream: - await self.stream.close() - - -class SubprocessStreamWriterWrapperContextManager( - SubprocessStreamWrapperContextManager -): - async def __aenter__(self) -> StreamWriterWrapper: - self.proc = await self.coro - self.stream = StreamWriterWrapper(self.proc.stdin) - return self.stream - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - if self.stream: - await self.stream.close() - await self.proc.wait() + await self.stream.write(data) diff --git a/streamflow/deployment/wrapper.py b/streamflow/deployment/wrapper.py index d9eeb6c4b..7141158af 100644 --- a/streamflow/deployment/wrapper.py +++ b/streamflow/deployment/wrapper.py @@ -6,7 +6,7 @@ from typing import AsyncContextManager from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.exception import WorkflowExecutionException from streamflow.core.scheduling import AvailableLocation from streamflow.deployment.future import FutureAware @@ -97,6 +97,11 @@ async def get_available_locations( ) -> MutableMapping[str, AvailableLocation]: return await self.connector.get_available_locations(service=service) + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + return await self.connector.get_shell(command, location) + async def get_stream_reader( self, command: MutableSequence[str], location: ExecutionLocation ) -> AsyncContextManager[StreamWrapper]: diff --git a/tests/test_connector.py b/tests/test_connector.py index 31bd37618..6ab86aff2 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -47,7 +47,7 @@ def _get_connector_method_params(method_name: str) -> MutableSequence[Any]: return ["test_src", "test_dst", [loc], loc] case "get_available_locations": return [] - case "get_stream_reader" | "get_stream_writer": + case "get_shell" | "get_stream_reader" | "get_stream_writer": return [["test_command"], loc] case "run": return [loc, ["ls"]] diff --git a/tests/utils/connector.py b/tests/utils/connector.py index 438ac3b38..d0aa6cb0f 100644 --- a/tests/utils/connector.py +++ b/tests/utils/connector.py @@ -12,7 +12,7 @@ from asyncssh import SSHClient, SSHClientConnection from streamflow.core.data import StreamWrapper -from streamflow.core.deployment import Connector, ExecutionLocation +from streamflow.core.deployment import Connector, ExecutionLocation, Shell from streamflow.core.scheduling import AvailableLocation, Hardware from streamflow.deployment.connector import LocalConnector, SSHConnector from streamflow.deployment.connector.base import BaseConnector @@ -22,7 +22,7 @@ get_param_from_file, parse_hostname, ) -from streamflow.deployment.stream import StreamReaderWrapper, StreamWriterWrapper +from streamflow.deployment.stream import BaseStreamWrapper from streamflow.log_handler import logger @@ -39,7 +39,7 @@ def _get_path_from_cmd(command: MutableSequence[str]) -> str: raise NotImplementedError(command) -class AioTarStreamReaderWrapper(StreamReaderWrapper): +class AioTarStreamReaderWrapper(BaseStreamWrapper): async def close(self) -> None: if self.stream: os.close(self.stream) @@ -48,11 +48,17 @@ async def close(self) -> None: async def read(self, size: int | None = None) -> bytes: return os.read(self.stream, size) + async def write(self, data: Any) -> None: + raise NotImplementedError -class AioTarStreamWriterWrapper(StreamWriterWrapper): - async def close(self) -> None: + +class AioTarStreamWriterWrapper(BaseStreamWrapper): + async def _close(self) -> None: self.stream.close() + async def read(self, size: int | None = None) -> bytes: + raise NotImplementedError + async def write(self, data: Any) -> None: self.stream.write(data) @@ -138,6 +144,11 @@ async def get_available_locations( ) } + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ) -> Shell: + raise NotImplementedError("AioTarConnector get_shell") + @classmethod def get_schema(cls) -> str: return json.dumps( @@ -267,6 +278,11 @@ async def run( async def undeploy(self, external: bool) -> None: raise FailureConnectorException("FailureConnector undeploy") + async def get_shell( + self, command: MutableSequence[str], location: ExecutionLocation + ): + raise FailureConnectorException("FailureConnector get_shell") + async def get_stream_reader( self, command: MutableSequence[str], location: ExecutionLocation ) -> AsyncContextManager[StreamWrapper]: