From 976be8ac1fecb6900b270cf01bd23211f9c7253b Mon Sep 17 00:00:00 2001 From: David Stansby Date: Fri, 23 Jan 2026 13:08:08 +0000 Subject: [PATCH] Allow passing a custom downsample function --- pyproject.toml | 3 ++- src/stack_to_chunk/__init__.py | 2 ++ src/stack_to_chunk/_array_helpers.py | 15 ++++++++--- src/stack_to_chunk/main.py | 35 ++++++++++++++++++++++--- src/stack_to_chunk/tests/test_main.py | 37 ++++++++++++++++++++++++++- 5 files changed, 83 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78a1f13..1b28e65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "scikit-image==0.26.0", "zarr==3.1.1", "numcodecs==0.15.1", - "tensorstore==0.1.80" + "tensorstore==0.1.80", + "scipy==1.16.2" ] optional-dependencies = { dev = [ diff --git a/src/stack_to_chunk/__init__.py b/src/stack_to_chunk/__init__.py index 732bc8a..15ebe87 100644 --- a/src/stack_to_chunk/__init__.py +++ b/src/stack_to_chunk/__init__.py @@ -6,6 +6,7 @@ "__version__", "memory_per_downsample_process", "memory_per_slab_process", + "mode", "open_multiscale_group", ] @@ -16,6 +17,7 @@ MultiScaleGroup, memory_per_downsample_process, memory_per_slab_process, + mode, open_multiscale_group, ) from .ome_ngff import SPATIAL_UNIT diff --git a/src/stack_to_chunk/_array_helpers.py b/src/stack_to_chunk/_array_helpers.py index fa22f6a..57c3e21 100644 --- a/src/stack_to_chunk/_array_helpers.py +++ b/src/stack_to_chunk/_array_helpers.py @@ -1,7 +1,9 @@ +from collections.abc import Callable from pathlib import Path import dask.array as da import numpy as np +import numpy.typing as npt import skimage.measure import tensorstore as ts from joblib import delayed @@ -38,7 +40,10 @@ def _copy_slab(arr_path: Path, slab: da.Array, zstart: int, zend: int) -> None: @delayed # type: ignore[misc] def _downsample_block( - arr_in_path: Path, arr_out_path: Path, block_idx: tuple[int, int, int] + arr_in_path: Path, + arr_out_path: Path, + block_idx: tuple[int, int, int], + downsample_func: Callable[[npt.ArrayLike], npt.NDArray] = np.mean, ) -> None: """ Copy a single block from one array to the next, downsampling by a factor of two. @@ -55,6 +60,8 @@ def _downsample_block( Path to output array. Must have the same chunk shape as `arr_in`. block_idx : Index of block to copy. Must be a multiple of the shard shape in `arr_out`. + downsample_func : + Function to use to downsample blocks of data. """ arr_in = _open_with_tensorstore(arr_in_path) @@ -83,9 +90,9 @@ def _downsample_block( pads = np.array(data.shape) % 2 pad_width = [(0, p) for p in pads] data = np.pad(data, pad_width, mode="edge") - data = skimage.measure.block_reduce(data, block_size=2, func=np.mean).astype( - data.dtype - ) + data = skimage.measure.block_reduce( + data, block_size=2, func=downsample_func + ).astype(data.dtype) out_slice = ( slice(block_idx[0], min((block_idx[0] + shard_shape[0]), arr_out.shape[0])), diff --git a/src/stack_to_chunk/main.py b/src/stack_to_chunk/main.py index 2906243..741d9c2 100644 --- a/src/stack_to_chunk/main.py +++ b/src/stack_to_chunk/main.py @@ -3,10 +3,13 @@ """ import math +from collections.abc import Callable from os import PathLike from pathlib import Path import numpy as np +import numpy.typing as npt +import scipy.stats import zarr import zarr.storage from dask.array.core import Array @@ -323,7 +326,13 @@ def add_full_res_data( blosc.use_threads = blosc_use_threads logger.info("Finished full resolution copy to zarr.") - def add_downsample_level(self, level: int, *, n_processes: int = 1) -> None: + def add_downsample_level( + self, + level: int, + *, + n_processes: int = 1, + downsample_func: Callable[[npt.ArrayLike], npt.NDArray] = np.mean, + ) -> None: """ Add a level of downsampling. @@ -337,6 +346,10 @@ def add_downsample_level(self, level: int, *, n_processes: int = 1) -> None: joblib.Parallel documentation for more info of allowed values. Running with one process (the default) will use about 5/8 the amount of memory of a single slab/shard. + downsample_func : + Function used to downsample data. It can be helpful to set this + to `stack_to_chunk.mode` for label data to calculate the most common label + when downsampling. Notes ----- @@ -394,8 +407,17 @@ def add_downsample_level(self, level: int, *, n_processes: int = 1) -> None: for z in range(0, sink_arr.shape[2], sink_arr.shards[2]) ] - all_args: list[tuple[Path, Path, tuple[int, int, int]]] = [ - (self._path / str(level_minus_one), self._path / level_str, idxs) + all_args: list[ + tuple[ + Path, Path, tuple[int, int, int], Callable[[npt.ArrayLike], npt.NDArray] + ] + ] = [ + ( + self._path / str(level_minus_one), + self._path / level_str, + idxs, + downsample_func, + ) for idxs in block_indices ] @@ -472,3 +494,10 @@ def open_multiscale_group(path: Path) -> MultiScaleGroup: return MultiScaleGroup( path, name=name, voxel_size=voxel_size, spatial_unit=spatial_unit ) + + +def mode(arr: npt.ArrayLike, axis: int) -> npt.NDArray: + """ + Get the modal value of an array. + """ + return scipy.stats.mode(arr, axis=axis)[0] diff --git a/src/stack_to_chunk/tests/test_main.py b/src/stack_to_chunk/tests/test_main.py index c3b5452..4bbcecd 100644 --- a/src/stack_to_chunk/tests/test_main.py +++ b/src/stack_to_chunk/tests/test_main.py @@ -2,11 +2,13 @@ import json import re +from collections.abc import Callable from pathlib import Path from typing import Any import dask.array as da import numpy as np +import numpy.typing as npt import ome_zarr_models.v05 import pytest import zarr @@ -15,10 +17,11 @@ from stack_to_chunk import ( MultiScaleGroup, + memory_per_downsample_process, memory_per_slab_process, + mode, open_multiscale_group, ) -from stack_to_chunk.main import memory_per_downsample_process def check_zattrs(zarr_path: Path, expected: dict[str, Any]) -> None: @@ -398,6 +401,38 @@ def test_known_data(tmp_path: Path) -> None: np.testing.assert_equal(arr_downsammpled[:], [[[3]]]) +@pytest.mark.parametrize( + ("downsample_func", "expected_value"), [(mode, 7), (np.mean, 6)] +) +def test_mode_downsample( + tmp_path: Path, + downsample_func: Callable[[npt.ArrayLike], npt.NDArray], + expected_value: float, +) -> None: + arr_npy = np.arange(8).reshape((2, 2, 2)).astype(np.uint8) + arr_npy[0] = 7 # Make sure there's two elements with 8 so the mode is well defined + arr = da.from_array(arr_npy) + arr = arr.rechunk(chunks=(2, 2, 1)) + + group = MultiScaleGroup( + tmp_path / "group.ome.zarr", + name="my_zarr_group", + spatial_unit="centimeter", + voxel_size=(3, 4, 5), + array_spec=ArraySpec.from_array( + arr, + chunk_grid=NamedConfig( + name="regular", + configuration={"chunk_shape": [1, 1, 1]}, + ), + ), + ) + group.add_full_res_data(arr, n_processes=1) + group.add_downsample_level(1, n_processes=1, downsample_func=downsample_func) + arr_downsammpled = group[1] + np.testing.assert_equal(arr_downsammpled[:], [[[expected_value]]]) + + def test_padding(tmp_path: Path) -> None: # Test data that doesn't fit exactly into (2, 2, 2) shaped chunks arr_npy = np.arange(8).reshape((2, 2, 2))