diff --git a/distributed/node.py b/distributed/node.py index 9917c434f7c..477311eafce 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -27,7 +27,7 @@ class ServerNode(Server): # XXX avoid inheriting from Server? there is some large potential for confusion # between base and derived attribute namespaces... - def versions(self, packages=None): + def versions(self, packages=()): return get_versions(packages=packages) def start_services(self, default_listen_ip): diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index b908236c4ab..dd4ebc0fb50 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -516,10 +516,9 @@ async def test_environ_plugin(c, s, a, b): @pytest.mark.parametrize( "modname", [ - # numpy is always imported, and for a good reason: - # https://github.com/dask/distributed/issues/5729 + pytest.param("numpy", marks=pytest.mark.xfail(reason="distributed#5729")), "scipy", - pytest.param("pandas", marks=pytest.mark.xfail(reason="distributed#5723")), + "pandas", ], ) @gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) diff --git a/distributed/tests/test_versions.py b/distributed/tests/test_versions.py index a24151dcf6e..15ec95dfcf7 100644 --- a/distributed/tests/test_versions.py +++ b/distributed/tests/test_versions.py @@ -3,6 +3,7 @@ import re import sys +import msgpack import pytest import tornado @@ -147,17 +148,15 @@ def test_python_version(): def test_version_custom_pkgs(): out = get_versions( [ - # Use custom function - ("distributed", lambda mod: "123"), - # Use version_of_package "notexist", - ("pytest", None), # has __version__ - "tornado", # has version - "math", # has nothing + "pytest", + "tornado", + "msgpack", + "math", ] )["packages"] - assert out["distributed"] == "123" assert out["notexist"] is None assert out["pytest"] == pytest.__version__ assert out["tornado"] == tornado.version + assert out["msgpack"] == ".".join(str(v) for v in msgpack.version) assert out["math"] is None diff --git a/distributed/versions.py b/distributed/versions.py index 1994cd2e7f5..98b3019946f 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -2,34 +2,33 @@ from __future__ import annotations -import importlib +import importlib.metadata import os import platform import struct import sys -from collections.abc import Callable, Iterable +from collections.abc import Iterable from itertools import chain -from types import ModuleType from typing import Any required_packages = [ - ("dask", lambda p: p.__version__), - ("distributed", lambda p: p.__version__), - ("msgpack", lambda p: ".".join([str(v) for v in p.version])), - ("cloudpickle", lambda p: p.__version__), - ("tornado", lambda p: p.version), - ("toolz", lambda p: p.__version__), + "dask", + "distributed", + "msgpack", + "cloudpickle", + "tornado", + "toolz", ] optional_packages = [ - ("numpy", lambda p: p.__version__), - ("pandas", lambda p: p.__version__), - ("lz4", lambda p: p.__version__), + "numpy", + "pandas", + "lz4", ] # only these scheduler packages will be checked for version mismatch -scheduler_relevant_packages = {pkg for pkg, _ in required_packages} | {"lz4"} +scheduler_relevant_packages = set(required_packages) | {"lz4"} # notes to be displayed for mismatch packages @@ -38,18 +37,16 @@ } -def get_versions( - packages: Iterable[str | tuple[str, Callable[[ModuleType], str | None]]] - | None = None -) -> dict[str, dict[str, Any]]: +def get_versions(packages: Iterable[str] = ()) -> dict[str, dict[str, Any]]: """Return basic information on our software installation, and our installed versions of packages """ return { "host": get_system_info(), - "packages": get_package_info( - chain(required_packages, optional_packages, packages or []) - ), + "packages": { + "python": ".".join(map(str, sys.version_info)), + **get_package_info(chain(required_packages, optional_packages, packages)), + }, } @@ -68,37 +65,13 @@ def get_system_info() -> dict[str, Any]: } -def version_of_package(pkg: ModuleType) -> str | None: - """Try a variety of common ways to get the version of a package""" - from contextlib import suppress - - with suppress(AttributeError): - return pkg.__version__ - with suppress(AttributeError): - return str(pkg.version) - with suppress(AttributeError): - return ".".join(map(str, pkg.version_info)) - return None - - -def get_package_info( - pkgs: Iterable[str | tuple[str, Callable[[ModuleType], str | None] | None]] -) -> dict[str, str | None]: +def get_package_info(pkgs: Iterable[str]) -> dict[str, str | None]: """get package versions for the passed required & optional packages""" - pversions: dict[str, str | None] = {"python": ".".join(map(str, sys.version_info))} - for pkg in pkgs: - if isinstance(pkg, (tuple, list)): - modname, ver_f = pkg - if ver_f is None: - ver_f = version_of_package - else: - modname = pkg - ver_f = version_of_package - + pversions: dict[str, str | None] = {} + for modname in pkgs: try: - mod = importlib.import_module(modname) - pversions[modname] = ver_f(mod) + pversions[modname] = importlib.metadata.version(modname) except Exception: pversions[modname] = None