diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index a431c5bddfd..d530f9b679d 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -1,16 +1,7 @@ -try: - import pandas -except ImportError: - SHUFFLE_AVAILABLE = False -else: - del pandas - SHUFFLE_AVAILABLE = True - - from .shuffle import rearrange_by_column_p2p - from .shuffle_extension import ShuffleId, ShuffleMetadata, ShuffleWorkerExtension +from .shuffle import rearrange_by_column_p2p +from .shuffle_extension import ShuffleId, ShuffleMetadata, ShuffleWorkerExtension __all__ = [ - "SHUFFLE_AVAILABLE", "rearrange_by_column_p2p", "ShuffleId", "ShuffleMetadata", diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/shuffle.py index 5811afd732f..33fe1189a09 100644 --- a/distributed/shuffle/shuffle.py +++ b/distributed/shuffle/shuffle.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from dask.base import tokenize -from dask.dataframe import DataFrame from dask.delayed import Delayed, delayed from dask.highlevelgraph import HighLevelGraph @@ -12,6 +11,8 @@ if TYPE_CHECKING: import pandas as pd + from dask.dataframe import DataFrame + def get_ext() -> ShuffleWorkerExtension: from distributed import get_worker @@ -53,6 +54,8 @@ def rearrange_by_column_p2p( column: str, npartitions: int | None = None, ): + from dask.dataframe import DataFrame + npartitions = npartitions or df.npartitions token = tokenize(df, column, npartitions) diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index e5e0baaf7bc..8f13480b91d 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -6,12 +6,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, NewType -import pandas as pd - from distributed.protocol import to_serialize from distributed.utils import sync if TYPE_CHECKING: + import pandas as pd + from distributed.worker import Worker ShuffleId = NewType("ShuffleId", str) @@ -103,6 +103,8 @@ async def add_partition(self, data: pd.DataFrame) -> None: await asyncio.gather(*tasks) def get_output_partition(self, i: int) -> pd.DataFrame: + import pandas as pd + assert self.transferred, "`get_output_partition` called before barrier task" assert self.metadata.worker_for(i) == self.worker.address, ( diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index dc3eba86e58..6dcf72c5214 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -4,6 +4,7 @@ import multiprocessing as mp import os import random +import sys from contextlib import suppress from time import sleep from unittest import mock @@ -610,3 +611,31 @@ async def test_environ_plugin(c, s, a, b): assert results[a.worker_address] == "123" assert results[b.worker_address] == "123" assert results[n.worker_address] == "123" + + +@pytest.mark.parametrize( + "modname", + [ + pytest.param( + "numpy", + marks=pytest.mark.xfail(reason="distributed#5723, distributed#5729"), + ), + "scipy", + pytest.param("pandas", marks=pytest.mark.xfail(reason="distributed#5723")), + ], +) +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_no_unnecessary_imports_on_worker(c, s, a, modname): + """ + Regression test against accidentally importing unnecessary modules at worker startup. + + Importing modules like pandas slows down worker startup, especially if workers are + loading their software environment from NFS or other non-local filesystems. + It also slightly increases memory footprint. + """ + + def assert_no_import(dask_worker): + assert modname not in sys.modules + + await c.wait_for_workers(1) + await c.run(assert_no_import) diff --git a/distributed/worker.py b/distributed/worker.py index 022b076054a..5b51aa8d2ef 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -51,7 +51,7 @@ typename, ) -from . import comm, preloading, profile, shuffle, system, utils +from . import comm, preloading, profile, system, utils from .batched import BatchedSend from .comm import Comm, connect, get_address_host from .comm.addressing import address_from_user_args, parse_address @@ -74,6 +74,7 @@ from .protocol import pickle, to_serialize from .pubsub import PubSubWorkerExtension from .security import Security +from .shuffle import ShuffleWorkerExtension from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor from .threadpoolexecutor import secede as tpe_secede @@ -121,9 +122,7 @@ # Worker.status subsets RUNNING = {Status.running, Status.paused, Status.closing_gracefully} -DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension] -if shuffle.SHUFFLE_AVAILABLE: - DEFAULT_EXTENSIONS.append(shuffle.ShuffleWorkerExtension) +DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension, ShuffleWorkerExtension] DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}