From 3d42fc83b6de732cefe84098d593221e27cd9810 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Fri, 9 Jan 2026 10:43:08 +0000 Subject: [PATCH] Use tensorestore for downsampling --- src/stack_to_chunk/_array_helpers.py | 47 ++++++++++++++++------------ src/stack_to_chunk/main.py | 7 +++-- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/stack_to_chunk/_array_helpers.py b/src/stack_to_chunk/_array_helpers.py index 07c7c0c..fa22f6a 100644 --- a/src/stack_to_chunk/_array_helpers.py +++ b/src/stack_to_chunk/_array_helpers.py @@ -4,7 +4,6 @@ import numpy as np import skimage.measure import tensorstore as ts -import zarr from joblib import delayed from loguru import logger @@ -32,23 +31,14 @@ def _copy_slab(arr_path: Path, slab: da.Array, zstart: int, zend: int) -> None: logger.info(f"Writing z={zstart} -> {zend - 1}") # Write out data - arr_zarr = ts.open( - { - "driver": "zarr3", - "kvstore": { - "driver": "file", - "path": str(arr_path), - }, - "open": True, - } - ).result() + arr_zarr = _open_with_tensorstore(arr_path) arr_zarr[:, :, zstart:zend].write(data).result() logger.info(f"Finished copying z={zstart} -> {zend - 1}") @delayed # type: ignore[misc] def _downsample_block( - arr_in: zarr.Array, arr_out: zarr.Array, block_idx: tuple[int, int, int] + arr_in_path: Path, arr_out_path: Path, block_idx: tuple[int, int, int] ) -> None: """ Copy a single block from one array to the next, downsampling by a factor of two. @@ -59,15 +49,17 @@ def _downsample_block( Parameters ---------- - arr_in : - Input array. - arr_out : - Output array. Must have the same chunk shape as `arr_in`. + arr_in_path : + Path to input array. + arr_out_path : + Path to output array. Must have the same chunk shape as `arr_in`. block_idx : Index of block to copy. Must be a multiple of the shard shape in `arr_out`. """ - shard_shape: tuple[int, int, int] = arr_out.shards + arr_in = _open_with_tensorstore(arr_in_path) + arr_out = _open_with_tensorstore(arr_out_path) + shard_shape: tuple[int, int, int] = arr_out.chunk_layout.write_chunk.shape np.testing.assert_equal( np.array(block_idx) % np.array(shard_shape), np.array([0, 0, 0]), @@ -85,17 +77,32 @@ def _downsample_block( block_idx[2] * 2, min((block_idx[2] + shard_shape[2]) * 2, arr_in.shape[2]) ), ) - data = arr_in[in_slice] + data = arr_in[in_slice].read().result() # Pad to an even number pads = np.array(data.shape) % 2 pad_width = [(0, p) for p in pads] data = np.pad(data, pad_width, mode="edge") - data = skimage.measure.block_reduce(data, block_size=2, func=np.mean) + data = skimage.measure.block_reduce(data, block_size=2, func=np.mean).astype( + data.dtype + ) out_slice = ( slice(block_idx[0], min((block_idx[0] + shard_shape[0]), arr_out.shape[0])), slice(block_idx[1], min((block_idx[1] + shard_shape[1]), arr_out.shape[1])), slice(block_idx[2], min((block_idx[2] + shard_shape[2]), arr_out.shape[2])), ) - arr_out[out_slice] = data + arr_out[out_slice].write(data).result() + + +def _open_with_tensorstore(arr_path: Path) -> ts.TensorStore: + return ts.open( + { + "driver": "zarr3", + "kvstore": { + "driver": "file", + "path": str(arr_path), + }, + "open": True, + } + ).result() diff --git a/src/stack_to_chunk/main.py b/src/stack_to_chunk/main.py index fb7d08e..2906243 100644 --- a/src/stack_to_chunk/main.py +++ b/src/stack_to_chunk/main.py @@ -387,14 +387,17 @@ def add_downsample_level(self, level: int, *, n_processes: int = 1) -> None: assert sink_arr.shards is not None # Get slice of every shard in the sink array - block_indices = [ + block_indices: list[tuple[int, int, int]] = [ (x, y, z) for x in range(0, sink_arr.shape[0], sink_arr.shards[0]) for y in range(0, sink_arr.shape[1], sink_arr.shards[1]) for z in range(0, sink_arr.shape[2], sink_arr.shards[2]) ] - all_args = [(source_arr, sink_arr, idxs) for idxs in block_indices] + all_args: list[tuple[Path, Path, tuple[int, int, int]]] = [ + (self._path / str(level_minus_one), self._path / level_str, idxs) + for idxs in block_indices + ] logger.info(f"Starting downsampling from level {level_minus_one} > {level}...") blosc_use_threads = blosc.use_threads