From 92c8f60e0c5716ff121038918a766196796e55f2 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Thu, 2 Jul 2026 12:14:39 +0000 Subject: [PATCH 01/11] Support a block size smaller than the output along the reprojected dimensions for non_reprojected_dims reprojection, tiling each plane so the coordinate transform memory stays bounded, and route dask array inputs through map_blocks so each block is reprojected from the input directly --- reproject/_common.py | 465 +++++++++++-------- reproject/adaptive/_high_level.py | 8 +- reproject/interpolation/_high_level.py | 8 +- reproject/tests/test_non_reprojected_dims.py | 79 +++- 4 files changed, 363 insertions(+), 197 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index f8a52b54c..9633ba2f5 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -11,7 +11,6 @@ from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper from dask import delayed -from ._array_utils import ArrayWrapper from .utils import _dask_to_numpy_memmap __all__ = ["_reproject_dispatcher"] @@ -53,6 +52,171 @@ def as_delayed_memmap_path(array, tmp_dir): return array_path +def _reproject_region( + array_region, + *, + wcs_in, + wcs_out, + slices_in_wcs, + slices_out_wcs, + shape_out_region, + reproject_func, + reproject_func_kwargs, +): + """ + Reproject a single region of the input into a single output block. + + This is the core used by the map_blocks block function to reproject one block. + + Parameters + ---------- + array_region : `numpy.ndarray` or `dask.array.Array` + The region of the input array to reproject. + wcs_in : `~astropy.wcs.wcsapi.BaseHighLevelWCS` + The full input WCS, sliced down to ``slices_in_wcs`` before use. + wcs_out : `~astropy.wcs.wcsapi.BaseHighLevelWCS` + The full output WCS, sliced down to ``slices_out_wcs`` before use. + slices_in_wcs : tuple or None + Slices used to reduce ``wcs_in`` to the region being reprojected. If + `None`, the input WCS is used unchanged, for example when the + reprojection function broadcasts the extra dimensions itself. + slices_out_wcs : tuple + Slices used to reduce ``wcs_out`` to the output block. + shape_out_region : tuple + The shape of the output block. + reproject_func : callable + The low-level reprojection function to call. + reproject_func_kwargs : dict + Extra keyword arguments passed through to ``reproject_func``. + + Returns + ------- + `numpy.ndarray` + A stacked array containing the reprojected data and its footprint. + """ + + # The WCS class from astropy is not thread-safe, see e.g. + # https://github.com/astropy/astropy/issues/16244 + # https://github.com/astropy/astropy/issues/16245 + # To work around these issues, we make sure we do a deep copy of the WCS object + # in here when using FITS WCS. This is a very fast operation (<0.1ms) so should + # not be a concern in terms of performance. We only need to do this for FITS WCS. + wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in + wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out + + if slices_in_wcs is None: + wcs_in_sub = wcs_in_cp + else: + if isinstance(wcs_in_cp, BaseHighLevelWCS): + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in_wcs) + else: + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in_wcs) + wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in) + + if isinstance(wcs_out_cp, BaseHighLevelWCS): + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out_wcs) + else: + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out_wcs) + wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) + + array, footprint = reproject_func( + array_region, + wcs_in_sub, + wcs_out_sub, + shape_out=shape_out_region, + array_out=np.zeros(shape_out_region), + **reproject_func_kwargs, + ) + + return np.array([array, footprint]) + + +def _reproject_single_block( + a, + array_or_path, + block_info=None, + *, + wcs_in, + wcs_out, + shape_in, + broadcasted_parallelization, + n_dim_reproject, + reproject_func, + reproject_func_kwargs, +): + # Reproject a single output block for the map_blocks path. The input is passed + # as an opaque object (a memmap, a memmap path, or an ``_ArrayContainer`` wrapping + # a dask array) and the output block location comes from ``block_info``. + + if ( + a.ndim == 0 + or block_info is None + or block_info == [] + or (isinstance(block_info, np.ndarray) and block_info.tolist() == []) + ): + return np.array([a, a]) + + if isinstance(array_or_path, _ArrayContainer): + array_or_path = array_or_path._array + + shape_out = block_info[None]["chunk-shape"][1:] + + # Three sets of slices are derived from this output block: which region of the + # output WCS it covers, which broadcasted slice of the input WCS it corresponds + # to, and which broadcasted slice of the input data to read. Along the + # reprojected dimensions the input is always kept whole (any output pixel can map + # anywhere within it), while dask may tile the output; along the broadcasted + # dimensions each block is a single slice. + slices_out_wcs = [] + slices_in_wcs = [] + slices_in_data = [] + for idx in range(len(shape_out)): + interval = block_info[None]["array-location"][idx + 1] + if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: + if interval[1] - interval[0] != 1: + raise RuntimeError( + f"Expected a chunk of width 1 along dimension {idx} " + f"(got {interval[1] - interval[0]})" + ) + slices_out_wcs.append(interval[0]) + slices_in_wcs.append(interval[0]) + slices_in_data.append(slice(*interval)) + else: + slices_out_wcs.append(slice(*interval)) + slices_in_wcs.append(slice(None)) + slices_in_data.append(slice(None)) + + slices_out_wcs = slices_out_wcs[-wcs_out.low_level_wcs.pixel_n_dim :] + slices_in_wcs = slices_in_wcs[-wcs_in.low_level_wcs.pixel_n_dim :] + + if array_or_path is None: + raise RuntimeError("array_or_path is not set") + + if isinstance(array_or_path, tuple): + array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") + elif isinstance(array_or_path, str): + array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") + else: + array_in = array_or_path + + if broadcasted_parallelization: + # Read just this broadcasted slice out of the whole input; the reprojected + # dimensions are kept whole (see above). For a memmap this stays a lazy view, + # so only the touched pages are loaded. + array_in = array_in[tuple(slices_in_data)] + + return _reproject_region( + array_in, + wcs_in=wcs_in, + wcs_out=wcs_out, + slices_in_wcs=slices_in_wcs if broadcasted_parallelization else None, + slices_out_wcs=slices_out_wcs, + shape_out_region=shape_out, + reproject_func=reproject_func, + reproject_func_kwargs=reproject_func_kwargs, + ) + + def _reproject_dispatcher( reproject_func, *, @@ -101,9 +265,10 @@ def _reproject_dispatcher( given as a tuple of sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). If `None` (the default), any leading dimensions for which the WCS has fewer dimensions than the data are treated this - way. Reprojecting fewer dimensions than the WCS currently requires a - ``block_size`` that matches the output shape along the reprojected - dimensions. + way. Reprojecting fewer dimensions than the WCS currently requires an + explicit ``block_size``; its entries along the reprojected dimensions + may either match the output shape or be smaller, in which case each + plane is reprojected in sub-tiles of that size. array_out : `~numpy.ndarray`, optional An array in which to store the reprojected data. This can be any numpy array including a memory map, which may be helpful when dealing with @@ -184,15 +349,15 @@ def _reproject_dispatcher( "non_reprojected_dims should leave at least one dimension to be " "reprojected" ) - # If we are reprojecting fewer dimensions than the input or output WCS has, - # the WCS needs to be sliced down to the reprojected dimensions for each - # non-reprojected slice. This is currently only done when parallelizing over - # the non-reprojected (broadcasted) dimensions, so any other code path would - # silently reproject the dimensions that should have been left untouched. - # This is gated on non_reprojected_dims being set since that is the only way - # to opt into reprojecting fewer dimensions than the WCS; a plain mismatch - # between the input and output WCS dimensionality is instead a validation - # error raised by the underlying reprojection function. + # ``wcs_slicing_required`` flags that we are reprojecting fewer dimensions than + # the input or output WCS describes, so the WCS must be sliced down to the + # reprojected dimensions for each non-reprojected slice. That slicing is only + # implemented on the path that parallelizes over the non-reprojected + # (broadcasted) dimensions; the other code paths raise NotImplementedError below + # rather than attempting it. It is gated on non_reprojected_dims being set, the + # only way to opt into reprojecting fewer dimensions than the WCS; a plain + # mismatch between the input and output WCS dimensionality is instead a + # validation error raised by the underlying reprojection function. wcs_slicing_required = non_reprojected_dims is not None and ( n_dim_reproject < wcs_in.low_level_wcs.pixel_n_dim or n_dim_reproject < wcs_out.low_level_wcs.pixel_n_dim @@ -305,28 +470,47 @@ def _reproject_dispatcher( for i in range(len(block_size)) ) - # Check block size and determine whether block size indicates we should - # parallelize over broadcasted dimension. The logic is as follows: if - # the block size and output shape are the same size, then either the - # block size should match the output shape along the broadcasted - # dimensions or along the non-broadcasted dimensions. If it matches the - # non-broadcasted dimensions we can parallelize over the broadcasted - # dimensions. If the block size does not match the output shape, we - # don't make any assumptions for now and assume a single chunk in the - # missing dimensions. + # Decide whether the requested block size means we should parallelize over + # the broadcasted (non-reprojected, leading) dimensions. block_size has + # already been padded above to one entry per output dimension, so this is not + # about the number of entries but about which dimensions the block spans the + # full output extent along: + # - if the block spans the full extent along the reprojected (trailing) + # dimensions, each block is one whole reprojected plane, so we parallelize + # over the broadcasted dimensions (one broadcasted slice per block); + # - if instead it spans the full extent along the broadcasted (leading) + # dimensions, the block tiles the reprojected plane and we do not + # parallelize over the broadcasted dimensions; + # - if it spans the full extent along neither, we raise, unless + # non_reprojected_dims requires slicing the WCS per plane, in which case a + # block smaller than the plane sub-tiles each plane. broadcasted_parallelization = False if broadcasting and block_size is not None and block_size != "auto": if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: # TODO: maybe error if block_size was given in full and is wrong broadcasted_parallelization = True - block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ - -n_dim_reproject: - ] + elif wcs_slicing_required: + # A block smaller than the output along the reprojected dimensions + # is only meaningful when the WCS has to be sliced per broadcasted + # slice (i.e. non_reprojected_dims). We parallelize one broadcasted + # slice per block and let dask additionally tile the reprojected + # dimensions according to the block size, which bounds the + # coordinate-transform memory (it would otherwise scale with the + # full plane size). Each output tile is still reprojected from the + # whole input slice, since any output pixel can map anywhere within + # it. + broadcasted_parallelization = True elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]: raise ValueError( "block shape should either match output data shape along " "reprojected dimensions or non-reprojected dimensions" ) + if broadcasted_parallelization: + # One broadcasted slice per block; dask tiles the reprojected + # dimensions using whatever block size was requested along them. + block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ + -n_dim_reproject: + ] logger.info( f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along " @@ -341,182 +525,87 @@ def _reproject_dispatcher( raise NotImplementedError( "Reprojecting fewer dimensions than the input or output WCS " "(for example using non_reprojected_dims) currently requires " - "passing a block_size whose entries along the reprojected " - "dimensions match the output shape (optionally with parallel=True " - "to compute the blocks concurrently)" + "passing an explicit block_size whose entries along the reprojected " + "dimensions either match the output shape or are smaller (in which " + "case each plane is reprojected in sub-tiles of that size), " + "optionally with parallel=True to compute the blocks concurrently" ) if output_footprint is None and return_footprint and return_type != "dask": output_footprint = np.zeros(shape_out, dtype=float) - def reproject_single_block(a, array_or_path, block_info=None): - - if ( - a.ndim == 0 - or block_info is None - or block_info == [] - or (isinstance(block_info, np.ndarray) and block_info.tolist() == []) - ): - return np.array([a, a]) - - if isinstance(array_or_path, str) and array_or_path == "from-dict": - array_or_path = dask_arrays["array"] - - shape_out = block_info[None]["chunk-shape"][1:] - - # The WCS class from astropy is not thread-safe, see e.g. - # https://github.com/astropy/astropy/issues/16244 - # https://github.com/astropy/astropy/issues/16245 - # To work around these issues, we make sure we do a deep copy of - # the WCS object in here when using FITS WCS. This is a very fast - # operation (<0.1ms) so should not be a concern in terms of - # performance. We only need to do this for FITS WCS. - - wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in - wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out - - slices_in = [] - slices_out = [] - for idx in range(len(shape_out)): - interval = block_info[None]["array-location"][idx + 1] - if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: - if interval[1] - interval[0] != 1: - raise RuntimeError( - f"Expected a chunk of width 1 along dimension {idx} " - f"(got {interval[1] - interval[0]})" - ) - slices_in.append(interval[0]) - slices_out.append(interval[0]) - else: - slices_in.append(slice(None)) - slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) - - slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :] - slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :] - - if broadcasted_parallelization: - if isinstance(wcs_in_cp, BaseHighLevelWCS): - low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in) + # The input is passed to map_blocks as an opaque (non-dask) argument + # rather than as a second dask array to align with the output, so that + # dask is free to tile the output however the block size dictates + # (including along the reprojected dimensions) while every task still sees + # the whole input; the block function then reads out the broadcasted slice + # it needs. As we use the synchronous or threads scheduler, we don't need + # to worry about the data getting copied, so if the data is already a Numpy + # array (including a memory-mapped array) then we don't need to do anything + # special. However, if the input array is a dask array, we should convert + # it to a Numpy memory-mapped array so that it can be used by the various + # reprojection functions (which don't internally work with dask arrays). + + if isinstance(array_in, np.memmap) and array_in.flags.c_contiguous: + array_in_or_path = array_in.filename, { + "dtype": array_in.dtype, + "shape": array_in.shape, + "offset": array_in.offset, + } + elif isinstance(array_in, da.core.Array) or return_type == "dask": + if dask_method == "memmap": + if return_type == "dask": + # We should use a temporary directory that will persist beyond + # the call to the reproject function. + tmp_dir = tempfile.mkdtemp() else: - low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in) - - wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in) - else: - wcs_in_sub = wcs_in_cp - - if isinstance(wcs_out_cp, BaseHighLevelWCS): - low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out) + tmp_dir = local_tmp_dir + array_in_or_path = as_delayed_memmap_path(_ArrayContainer(array_in), tmp_dir) else: - low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out) - - wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) - - if isinstance(array_or_path, tuple): - array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") - elif isinstance(array_or_path, str): - array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") - else: - array_in = array_or_path - - if array_or_path is None: - raise RuntimeError("array_or_path is not set") - - array, footprint = reproject_func( - array_in, - wcs_in_sub, - wcs_out_sub, - shape_out=shape_out, - array_out=np.zeros(shape_out), - **reproject_func_kwargs, - ) - - return np.array([array, footprint]) - - if broadcasted_parallelization: + # Wrap the dask array in _ArrayContainer so dask treats it as an + # opaque constant (rather than a collection to compute/align) when + # it is passed through to the block function. + array_in_or_path = _ArrayContainer(array_in) + else: + # Here we could set array_in_or_path to array_in_path if it has + # been set previously, but in synchronous and threaded mode it is + # better to simply pass a reference to the memmap array itself to + # avoid having to load the memmap inside each + # _reproject_single_block call. + array_in_or_path = array_in + if block_size is not None and block_size != "auto": array_out_dask = da.empty(shape_out, chunks=block_size) - - # The input is reprojected in full for each output block, so it must - # not be chunked along the reprojected dimensions (which can have a - # different size from the output); only the broadcasted dimensions are - # chunked, matching array_out_dask block for block. - input_chunks = (1,) * (array_in.ndim - n_dim_reproject) + (-1,) * n_dim_reproject - if isinstance(array_in, da.core.Array): - array_in = array_in.rechunk(input_chunks) - else: - array_in = da.asarray( - ArrayWrapper(array_in), name=str(uuid.uuid4()), chunks=input_chunks - ) - - result = da.map_blocks( - reproject_single_block, - array_out_dask, - array_in, - dtype=" Date: Thu, 2 Jul 2026 12:38:09 +0000 Subject: [PATCH 02/11] style: pre-commit fixes --- reproject/_common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index 9633ba2f5..99d6992c7 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -584,9 +584,7 @@ def _reproject_dispatcher( else: rechunk_kwargs = {} array_out_dask = da.empty(shape_out) - array_out_dask = array_out_dask.rechunk( - block_size_limit=64 * 1024**2, **rechunk_kwargs - ) + array_out_dask = array_out_dask.rechunk(block_size_limit=64 * 1024**2, **rechunk_kwargs) logger.info("Setting up output dask array with map_blocks") From 0a648bc579ed45747fd791cc795f59b000136cdb Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Thu, 2 Jul 2026 22:47:10 +0000 Subject: [PATCH 03/11] Fold the per-block reprojection back into a single nested reproject_single_block function inside the dispatcher instead of two module-level helpers so the diff against main stays small and easy to review --- reproject/_common.py | 273 ++++++++++++++++--------------------------- 1 file changed, 99 insertions(+), 174 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index 99d6992c7..558b9101f 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -52,171 +52,6 @@ def as_delayed_memmap_path(array, tmp_dir): return array_path -def _reproject_region( - array_region, - *, - wcs_in, - wcs_out, - slices_in_wcs, - slices_out_wcs, - shape_out_region, - reproject_func, - reproject_func_kwargs, -): - """ - Reproject a single region of the input into a single output block. - - This is the core used by the map_blocks block function to reproject one block. - - Parameters - ---------- - array_region : `numpy.ndarray` or `dask.array.Array` - The region of the input array to reproject. - wcs_in : `~astropy.wcs.wcsapi.BaseHighLevelWCS` - The full input WCS, sliced down to ``slices_in_wcs`` before use. - wcs_out : `~astropy.wcs.wcsapi.BaseHighLevelWCS` - The full output WCS, sliced down to ``slices_out_wcs`` before use. - slices_in_wcs : tuple or None - Slices used to reduce ``wcs_in`` to the region being reprojected. If - `None`, the input WCS is used unchanged, for example when the - reprojection function broadcasts the extra dimensions itself. - slices_out_wcs : tuple - Slices used to reduce ``wcs_out`` to the output block. - shape_out_region : tuple - The shape of the output block. - reproject_func : callable - The low-level reprojection function to call. - reproject_func_kwargs : dict - Extra keyword arguments passed through to ``reproject_func``. - - Returns - ------- - `numpy.ndarray` - A stacked array containing the reprojected data and its footprint. - """ - - # The WCS class from astropy is not thread-safe, see e.g. - # https://github.com/astropy/astropy/issues/16244 - # https://github.com/astropy/astropy/issues/16245 - # To work around these issues, we make sure we do a deep copy of the WCS object - # in here when using FITS WCS. This is a very fast operation (<0.1ms) so should - # not be a concern in terms of performance. We only need to do this for FITS WCS. - wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in - wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out - - if slices_in_wcs is None: - wcs_in_sub = wcs_in_cp - else: - if isinstance(wcs_in_cp, BaseHighLevelWCS): - low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in_wcs) - else: - low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in_wcs) - wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in) - - if isinstance(wcs_out_cp, BaseHighLevelWCS): - low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out_wcs) - else: - low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out_wcs) - wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) - - array, footprint = reproject_func( - array_region, - wcs_in_sub, - wcs_out_sub, - shape_out=shape_out_region, - array_out=np.zeros(shape_out_region), - **reproject_func_kwargs, - ) - - return np.array([array, footprint]) - - -def _reproject_single_block( - a, - array_or_path, - block_info=None, - *, - wcs_in, - wcs_out, - shape_in, - broadcasted_parallelization, - n_dim_reproject, - reproject_func, - reproject_func_kwargs, -): - # Reproject a single output block for the map_blocks path. The input is passed - # as an opaque object (a memmap, a memmap path, or an ``_ArrayContainer`` wrapping - # a dask array) and the output block location comes from ``block_info``. - - if ( - a.ndim == 0 - or block_info is None - or block_info == [] - or (isinstance(block_info, np.ndarray) and block_info.tolist() == []) - ): - return np.array([a, a]) - - if isinstance(array_or_path, _ArrayContainer): - array_or_path = array_or_path._array - - shape_out = block_info[None]["chunk-shape"][1:] - - # Three sets of slices are derived from this output block: which region of the - # output WCS it covers, which broadcasted slice of the input WCS it corresponds - # to, and which broadcasted slice of the input data to read. Along the - # reprojected dimensions the input is always kept whole (any output pixel can map - # anywhere within it), while dask may tile the output; along the broadcasted - # dimensions each block is a single slice. - slices_out_wcs = [] - slices_in_wcs = [] - slices_in_data = [] - for idx in range(len(shape_out)): - interval = block_info[None]["array-location"][idx + 1] - if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: - if interval[1] - interval[0] != 1: - raise RuntimeError( - f"Expected a chunk of width 1 along dimension {idx} " - f"(got {interval[1] - interval[0]})" - ) - slices_out_wcs.append(interval[0]) - slices_in_wcs.append(interval[0]) - slices_in_data.append(slice(*interval)) - else: - slices_out_wcs.append(slice(*interval)) - slices_in_wcs.append(slice(None)) - slices_in_data.append(slice(None)) - - slices_out_wcs = slices_out_wcs[-wcs_out.low_level_wcs.pixel_n_dim :] - slices_in_wcs = slices_in_wcs[-wcs_in.low_level_wcs.pixel_n_dim :] - - if array_or_path is None: - raise RuntimeError("array_or_path is not set") - - if isinstance(array_or_path, tuple): - array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") - elif isinstance(array_or_path, str): - array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") - else: - array_in = array_or_path - - if broadcasted_parallelization: - # Read just this broadcasted slice out of the whole input; the reprojected - # dimensions are kept whole (see above). For a memmap this stays a lazy view, - # so only the touched pages are loaded. - array_in = array_in[tuple(slices_in_data)] - - return _reproject_region( - array_in, - wcs_in=wcs_in, - wcs_out=wcs_out, - slices_in_wcs=slices_in_wcs if broadcasted_parallelization else None, - slices_out_wcs=slices_out_wcs, - shape_out_region=shape_out, - reproject_func=reproject_func, - reproject_func_kwargs=reproject_func_kwargs, - ) - - def _reproject_dispatcher( reproject_func, *, @@ -534,6 +369,103 @@ def _reproject_dispatcher( if output_footprint is None and return_footprint and return_type != "dask": output_footprint = np.zeros(shape_out, dtype=float) + def reproject_single_block(a, array_or_path, block_info=None): + + if ( + a.ndim == 0 + or block_info is None + or block_info == [] + or (isinstance(block_info, np.ndarray) and block_info.tolist() == []) + ): + return np.array([a, a]) + + if isinstance(array_or_path, _ArrayContainer): + array_or_path = array_or_path._array + + shape_out = block_info[None]["chunk-shape"][1:] + + # The WCS class from astropy is not thread-safe, see e.g. + # https://github.com/astropy/astropy/issues/16244 + # https://github.com/astropy/astropy/issues/16245 + # To work around these issues, we make sure we do a deep copy of + # the WCS object in here when using FITS WCS. This is a very fast + # operation (<0.1ms) so should not be a concern in terms of + # performance. We only need to do this for FITS WCS. + + wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in + wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out + + # Along the reprojected dimensions the input is always kept whole (any + # output pixel can map anywhere within it) while dask may tile the + # output; along the broadcasted dimensions each block is a single slice. + # slices_in/slices_out reduce the input/output WCS to this block, and + # slices_in_data selects the matching broadcasted slice of the input. + slices_in = [] + slices_out = [] + slices_in_data = [] + for idx in range(len(shape_out)): + interval = block_info[None]["array-location"][idx + 1] + if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: + if interval[1] - interval[0] != 1: + raise RuntimeError( + f"Expected a chunk of width 1 along dimension {idx} " + f"(got {interval[1] - interval[0]})" + ) + slices_in.append(interval[0]) + slices_out.append(interval[0]) + slices_in_data.append(slice(*interval)) + else: + slices_in.append(slice(None)) + slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) + slices_in_data.append(slice(None)) + + slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :] + slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :] + + if broadcasted_parallelization: + if isinstance(wcs_in_cp, BaseHighLevelWCS): + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp.low_level_wcs, slices=slices_in) + else: + low_level_wcs_in = SlicedLowLevelWCS(wcs_in_cp, slices=slices_in) + + wcs_in_sub = HighLevelWCSWrapper(low_level_wcs_in) + else: + wcs_in_sub = wcs_in_cp + + if isinstance(wcs_out_cp, BaseHighLevelWCS): + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp.low_level_wcs, slices=slices_out) + else: + low_level_wcs_out = SlicedLowLevelWCS(wcs_out_cp, slices=slices_out) + + wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) + + if isinstance(array_or_path, tuple): + array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") + elif isinstance(array_or_path, str): + array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") + else: + array_in = array_or_path + + if array_or_path is None: + raise RuntimeError("array_or_path is not set") + + if broadcasted_parallelization: + # Read just this broadcasted slice out of the whole input; the + # reprojected dimensions are kept whole (see above). For a memmap + # this stays a lazy view, so only the touched pages are loaded. + array_in = array_in[tuple(slices_in_data)] + + array, footprint = reproject_func( + array_in, + wcs_in_sub, + wcs_out_sub, + shape_out=shape_out, + array_out=np.zeros(shape_out), + **reproject_func_kwargs, + ) + + return np.array([array, footprint]) + # The input is passed to map_blocks as an opaque (non-dask) argument # rather than as a second dask array to align with the output, so that # dask is free to tile the output however the block size dictates @@ -571,7 +503,7 @@ def _reproject_dispatcher( # been set previously, but in synchronous and threaded mode it is # better to simply pass a reference to the memmap array itself to # avoid having to load the memmap inside each - # _reproject_single_block call. + # reproject_single_block call. array_in_or_path = array_in if block_size is not None and block_size != "auto": @@ -589,19 +521,12 @@ def _reproject_dispatcher( logger.info("Setting up output dask array with map_blocks") result = da.map_blocks( - _reproject_single_block, + reproject_single_block, array_out_dask, array_in_or_path, dtype=" Date: Fri, 3 Jul 2026 11:21:26 +0000 Subject: [PATCH 04/11] Only reconstruct memmap inputs from filename and offset for base memmaps, since views keep the parent's unadjusted offset and were silently reprojected from the wrong file region, and pass views by reference instead --- reproject/_common.py | 11 +++++- reproject/tests/test_non_reprojected_dims.py | 37 ++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/reproject/_common.py b/reproject/_common.py index 558b9101f..f1faf41b9 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -1,4 +1,5 @@ import logging +import mmap import os import tempfile import uuid @@ -478,7 +479,15 @@ def reproject_single_block(a, array_or_path, block_info=None): # it to a Numpy memory-mapped array so that it can be used by the various # reprojection functions (which don't internally work with dask arrays). - if isinstance(array_in, np.memmap) and array_in.flags.c_contiguous: + # Only base memmaps can be reconstructed from filename and offset: views + # (e.g. a slice of a memmap) keep the parent's unadjusted .offset, so + # reconstructing them would silently read the wrong file region. Views + # fall through and are passed by reference like plain arrays. + if ( + isinstance(array_in, np.memmap) + and array_in.flags.c_contiguous + and isinstance(array_in.base, mmap.mmap) + ): array_in_or_path = array_in.filename, { "dtype": array_in.dtype, "shape": array_in.shape, diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index 30a55be5b..7c8097a5a 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -129,6 +129,43 @@ def test_non_reprojected_dims_dask_input(reproject_function, block_size): assert_allclose(array_out, reference, equal_nan=True) +def test_non_reprojected_dims_sliced_memmap(tmp_path, reproject_function): + # A sliced memmap view keeps the parent's unadjusted .offset, so it must not + # be reconstructed from filename and offset inside the block tasks (which + # would silently reproject the wrong planes); views are passed by reference + # instead. Slicing off the leading plane keeps the view c-contiguous, which + # is the case that used to take the reconstruction path. + + data = np.arange(5 * 20 * 20, dtype=float).reshape((5, 20, 20)) + mm = np.memmap(tmp_path / "cube.np", mode="w+", dtype=float, shape=(5, 20, 20)) + mm[:] = data + mm.flush() + + wcs_in = _spectral_cube_wcs(0.0, 1e9) + wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6) + shape_out = (4, 20, 20) + + reference, _ = reproject_function( + (data[1:], wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=(20, 20), + ) + + array_out, _ = reproject_function( + (mm[1:], wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=(20, 20), + ) + + assert_allclose(array_out, reference, equal_nan=True) + + def test_non_reprojected_dims_invalid_order(reproject_function): data = np.ones((4, 20, 20)) wcs = _spectral_cube_wcs(0.0, 1e9) From 2d40ab8dce16aab7d5ff976eb494e46035eea2e8 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 11:27:33 +0000 Subject: [PATCH 05/11] Pass the input to the broadcasted reprojection path as a dask array with one chunk per non-reprojected slice, each routed through a delayed task so blockwise fusion cannot recompute it per output tile, restoring per-slice streaming for dask inputs instead of materializing them to a temporary memmap and letting distributed schedulers ship each task only its own slice --- reproject/_common.py | 69 ++++++++++++++------ reproject/tests/test_non_reprojected_dims.py | 47 +++++++++++++ 2 files changed, 95 insertions(+), 21 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index f1faf41b9..9ea1f741d 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -398,12 +398,12 @@ def reproject_single_block(a, array_or_path, block_info=None): # Along the reprojected dimensions the input is always kept whole (any # output pixel can map anywhere within it) while dask may tile the - # output; along the broadcasted dimensions each block is a single slice. - # slices_in/slices_out reduce the input/output WCS to this block, and - # slices_in_data selects the matching broadcasted slice of the input. + # output; along the broadcasted dimensions each block is a single + # slice. slices_in/slices_out reduce the input/output WCS to this + # block; the matching broadcasted slice of the input arrives as the + # aligned input block. slices_in = [] slices_out = [] - slices_in_data = [] for idx in range(len(shape_out)): interval = block_info[None]["array-location"][idx + 1] if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: @@ -414,11 +414,9 @@ def reproject_single_block(a, array_or_path, block_info=None): ) slices_in.append(interval[0]) slices_out.append(interval[0]) - slices_in_data.append(slice(*interval)) else: slices_in.append(slice(None)) slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) - slices_in_data.append(slice(None)) slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :] slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :] @@ -440,22 +438,21 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) - if isinstance(array_or_path, tuple): + if broadcasted_parallelization: + # The input was passed as an aligned dask array, so array_or_path + # is already this block's broadcasted slice of the input, kept + # whole along the reprojected dimensions (see above). + array_in = array_or_path + elif isinstance(array_or_path, tuple): array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") elif isinstance(array_or_path, str): array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") else: array_in = array_or_path - if array_or_path is None: + if array_in is None: raise RuntimeError("array_or_path is not set") - if broadcasted_parallelization: - # Read just this broadcasted slice out of the whole input; the - # reprojected dimensions are kept whole (see above). For a memmap - # this stays a lazy view, so only the touched pages are loaded. - array_in = array_in[tuple(slices_in_data)] - array, footprint = reproject_func( array_in, wcs_in_sub, @@ -467,12 +464,42 @@ def reproject_single_block(a, array_or_path, block_info=None): return np.array([array, footprint]) - # The input is passed to map_blocks as an opaque (non-dask) argument - # rather than as a second dask array to align with the output, so that - # dask is free to tile the output however the block size dictates - # (including along the reprojected dimensions) while every task still sees - # the whole input; the block function then reads out the broadcasted slice - # it needs. As we use the synchronous or threads scheduler, we don't need + if broadcasted_parallelization: + # Pass the input as a second dask array with one chunk per broadcasted + # slice, kept whole along the reprojected dimensions (any output pixel + # can map anywhere within its slice). map_blocks broadcasts the single + # chunk along the reprojected dimensions to every output tile of that + # slice, so each slice is computed once and streamed to exactly the + # tasks that need it: dask array inputs are never materialized in + # full, sub-tiled planes do not recompute their input per tile, and + # under a distributed scheduler each task depends only on its own + # slice rather than embedding the whole input. + input_chunks = (1,) * (array_in.ndim - n_dim_reproject) + (-1,) * n_dim_reproject + if isinstance(array_in, da.core.Array): + array_in_dask = array_in.rechunk(input_chunks) + # Blockwise fusion would fold the input graph into every output + # tile task, recomputing each broadcasted slice once per tile of + # that slice; routing each slice through a delayed task pins it + # as a single node in the graph that all of its tiles share. + delayed_blocks = array_in_dask.to_delayed() + pieces = np.empty(delayed_blocks.shape, dtype=object) + for index in np.ndindex(delayed_blocks.shape): + shape = tuple( + array_in_dask.chunks[idim][index[idim]] + for idim in range(array_in_dask.ndim) + ) + pieces[index] = da.from_delayed( + delayed_blocks[index], shape=shape, dtype=array_in_dask.dtype + ) + array_in_or_path = da.block(pieces.tolist()) + else: + array_in_or_path = da.from_array( + array_in, name=f"reproject-input-{uuid.uuid4().hex}", chunks=input_chunks + ) + + # For the remaining (non-broadcasted) cases the input is passed to + # map_blocks as an opaque (non-dask) argument, so that every task sees the + # whole input. As we use the synchronous or threads scheduler, we don't need # to worry about the data getting copied, so if the data is already a Numpy # array (including a memory-mapped array) then we don't need to do anything # special. However, if the input array is a dask array, we should convert @@ -483,7 +510,7 @@ def reproject_single_block(a, array_or_path, block_info=None): # (e.g. a slice of a memmap) keep the parent's unadjusted .offset, so # reconstructing them would silently read the wrong file region. Views # fall through and are passed by reference like plain arrays. - if ( + elif ( isinstance(array_in, np.memmap) and array_in.flags.c_contiguous and isinstance(array_in.base, mmap.mmap) diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index 7c8097a5a..0500d0126 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -166,6 +166,53 @@ def test_non_reprojected_dims_sliced_memmap(tmp_path, reproject_function): assert_allclose(array_out, reference, equal_nan=True) +def test_non_reprojected_dims_dask_input_streams_planes(reproject_function): + # The input is passed as a dask array with one chunk per non-reprojected + # slice, so each input plane must be computed exactly once, including when + # the output is sub-tiled (every tile of a plane shares that plane's chunk + # rather than recomputing it), and the whole input must never be + # materialized at once. + import dask.array as da + + n_time = 5 + shape_out = (n_time, 30, 30) + wcs_in = _drifting_cube_wcs(drift=0.6) + wcs_out = _drifting_cube_wcs(drift=0.0) + + data = np.random.default_rng(0).random((n_time, 30, 30)) + + computed_planes = [] + + def record_plane(plane, block_info=None): + if block_info: + computed_planes.append(block_info[None]["chunk-location"][0]) + return plane + + lazy = da.from_array(data, chunks=(1, 30, 30)).map_blocks(record_plane) + + array_out, _ = reproject_function( + (lazy, wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=(7, 7), + dask_method="none", + ) + + reference, _ = reproject_function( + (data, wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=(30, 30), + ) + + assert_allclose(array_out, reference, equal_nan=True) + assert sorted(computed_planes) == list(range(n_time)) + + def test_non_reprojected_dims_invalid_order(reproject_function): data = np.ones((4, 20, 20)) wcs = _spectral_cube_wcs(0.0, 1e9) From ceea6065d43cedea3aa500cf962e96d7f46289e9 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 11:31:07 +0000 Subject: [PATCH 06/11] Keep a dask input lazy on the broadcasted path when dask_method='none' and it is chunked below one slice along the reprojected dimensions, so streaming reprojection cores only compute the input chunks each output tile touches and never need to hold a whole slice, while slice-chunked inputs are still materialized exactly once per slice --- reproject/_common.py | 42 +++++++++++++++++--- reproject/tests/test_non_reprojected_dims.py | 17 ++++---- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index 9ea1f741d..46dc59622 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -400,10 +400,12 @@ def reproject_single_block(a, array_or_path, block_info=None): # output pixel can map anywhere within it) while dask may tile the # output; along the broadcasted dimensions each block is a single # slice. slices_in/slices_out reduce the input/output WCS to this - # block; the matching broadcasted slice of the input arrives as the - # aligned input block. + # block; the matching broadcasted slice of the input either arrives as + # the aligned input block or, when the input was passed whole (lazy + # dask input), is read out below using slices_in_data. slices_in = [] slices_out = [] + slices_in_data = [] for idx in range(len(shape_out)): interval = block_info[None]["array-location"][idx + 1] if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: @@ -414,9 +416,11 @@ def reproject_single_block(a, array_or_path, block_info=None): ) slices_in.append(interval[0]) slices_out.append(interval[0]) + slices_in_data.append(slice(*interval)) else: slices_in.append(slice(None)) slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) + slices_in_data.append(slice(None)) slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :] slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :] @@ -438,7 +442,7 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) - if broadcasted_parallelization: + if broadcasted_parallelization and input_aligned: # The input was passed as an aligned dask array, so array_or_path # is already this block's broadcasted slice of the input, kept # whole along the reprojected dimensions (see above). @@ -453,6 +457,12 @@ def reproject_single_block(a, array_or_path, block_info=None): if array_in is None: raise RuntimeError("array_or_path is not set") + if broadcasted_parallelization and not input_aligned: + # The input was passed whole as a lazy dask array; read out a lazy + # view of this block's broadcasted slice so a streaming + # reprojection core only computes the input chunks it touches. + array_in = array_in[tuple(slices_in_data)] + array, footprint = reproject_func( array_in, wcs_in_sub, @@ -464,16 +474,26 @@ def reproject_single_block(a, array_or_path, block_info=None): return np.array([array, footprint]) - if broadcasted_parallelization: + input_aligned = False + if broadcasted_parallelization and ( + not isinstance(array_in, da.core.Array) + or dask_method != "none" + or all(len(chunks) == 1 for chunks in array_in.chunks[-n_dim_reproject:]) + ): # Pass the input as a second dask array with one chunk per broadcasted # slice, kept whole along the reprojected dimensions (any output pixel # can map anywhere within its slice). map_blocks broadcasts the single # chunk along the reprojected dimensions to every output tile of that - # slice, so each slice is computed once and streamed to exactly the + # slice, so each slice is computed exactly once and streamed to the # tasks that need it: dask array inputs are never materialized in # full, sub-tiled planes do not recompute their input per tile, and # under a distributed scheduler each task depends only on its own - # slice rather than embedding the whole input. + # slice rather than embedding the whole input. The exception is a dask + # input with dask_method='none' that is chunked below one slice along + # the reprojected dimensions: materializing it here would forgo the + # ability of streaming reprojection cores to work chunk by chunk + # without ever holding a whole slice, so it is kept lazy below. + input_aligned = True input_chunks = (1,) * (array_in.ndim - n_dim_reproject) + (-1,) * n_dim_reproject if isinstance(array_in, da.core.Array): array_in_dask = array_in.rechunk(input_chunks) @@ -497,6 +517,16 @@ def reproject_single_block(a, array_or_path, block_info=None): array_in, name=f"reproject-input-{uuid.uuid4().hex}", chunks=input_chunks ) + elif broadcasted_parallelization: + # A dask input with dask_method='none' chunked below one slice along + # the reprojected dimensions: pass it whole as an opaque constant and + # let each block read out a lazy view of its own slice, so that a + # streaming reprojection core (e.g. interpolation via dask-image) only + # ever computes the input chunks that each output tile touches and a + # full slice need never be materialized at once. The tradeoff is that + # input chunks touched by several tiles are computed once per tile. + array_in_or_path = _ArrayContainer(array_in) + # For the remaining (non-broadcasted) cases the input is passed to # map_blocks as an opaque (non-dask) argument, so that every task sees the # whole input. As we use the synchronous or threads scheduler, we don't need diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index 0500d0126..bf06f4c89 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -91,13 +91,16 @@ def test_non_reprojected_dims_subtiled(reproject_function, block_size): assert_allclose(footprint_sub, footprint_full, equal_nan=True) +@pytest.mark.parametrize("chunks", [(1, 30, 30), (1, 15, 15)]) @pytest.mark.parametrize("block_size", [(20, 20), (7, 7)]) -def test_non_reprojected_dims_dask_input(reproject_function, block_size): - # A dask-array input is passed through map_blocks and reprojected per block (for - # interp, via dask-image's map_coordinates, which streams the input chunks). The - # result must match the identical numpy input, both for full-plane and sub-tiled - # blocks. The WCS drifts along the non-reprojected axis so each slice really is - # reprojected with its own WCS. +def test_non_reprojected_dims_dask_input(reproject_function, block_size, chunks): + # A dask-array input must match the identical numpy input, both for + # full-plane and sub-tiled blocks. With dask_method='none', an input chunked + # one slice at a time is materialized per slice (exactly once), while an + # input chunked below one slice is kept lazy so streaming cores never need a + # whole slice at once; both must give the same answer. The WCS drifts along + # the non-reprojected axis so each slice really is reprojected with its own + # WCS. import dask.array as da n_time = 5 @@ -117,7 +120,7 @@ def test_non_reprojected_dims_dask_input(reproject_function, block_size): ) array_out, _ = reproject_function( - (da.from_array(data, chunks=(1, 30, 30)), wcs_in), + (da.from_array(data, chunks=chunks), wcs_in), wcs_out, shape_out=shape_out, non_reprojected_dims=(0,), From d819f1e79fc8244f211955128abbf4a5ee6ebf4d Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 11:36:10 +0000 Subject: [PATCH 07/11] Raise an error for block_size entries along the non-reprojected dimensions that are neither 1 nor the full extent instead of silently reinterpreting them as one slice per block --- reproject/_common.py | 14 +++++++++++++- reproject/tests/test_non_reprojected_dims.py | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/reproject/_common.py b/reproject/_common.py index 46dc59622..503511b05 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -323,7 +323,6 @@ def _reproject_dispatcher( broadcasted_parallelization = False if broadcasting and block_size is not None and block_size != "auto": if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: - # TODO: maybe error if block_size was given in full and is wrong broadcasted_parallelization = True elif wcs_slicing_required: # A block smaller than the output along the reprojected dimensions @@ -344,6 +343,19 @@ def _reproject_dispatcher( if broadcasted_parallelization: # One broadcasted slice per block; dask tiles the reprojected # dimensions using whatever block size was requested along them. + # The block size along the non-reprojected dimensions must be 1 + # or span the full extent (equivalent here, since blocks are + # single slices either way); anything else would be silently + # reinterpreted, so raise instead. + if any( + entry not in (1, shape_out[idim]) + for idim, entry in enumerate(block_size[: len(shape_out) - n_dim_reproject]) + ): + raise ValueError( + f"block_size {block_size} should be 1 or match the output shape " + "along the non-reprojected dimensions (each block covers a " + "single non-reprojected slice)" + ) block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ -n_dim_reproject: ] diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index bf06f4c89..366c2d679 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -216,6 +216,25 @@ def record_plane(plane, block_info=None): assert sorted(computed_planes) == list(range(n_time)) +def test_non_reprojected_dims_invalid_leading_block_size(reproject_function): + # Since each block covers a single non-reprojected slice, block_size entries + # along the non-reprojected dimensions must be 1 or the full extent; other + # values would be silently reinterpreted as 1 so they raise instead. + data = np.ones((4, 20, 20)) + wcs_in = _spectral_cube_wcs(0.0, 1e9) + wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6) + for block_size in [(2, 7, 7), (999, 7, 7), (2, 20, 20)]: + with pytest.raises(ValueError, match="single non-reprojected slice"): + reproject_function( + (data, wcs_in), + wcs_out, + shape_out=(4, 20, 20), + non_reprojected_dims=(0,), + parallel=True, + block_size=block_size, + ) + + def test_non_reprojected_dims_invalid_order(reproject_function): data = np.ones((4, 20, 20)) wcs = _spectral_cube_wcs(0.0, 1e9) From aa55de7449ab259a3ddaf58cba341c306783de70 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 11:38:38 +0000 Subject: [PATCH 08/11] Declare the exact possibly ragged output chunks in the dispatcher map_blocks call so edge blocks are reprojected at their true size instead of being computed at the full block size and truncated, which also avoids an irregular trailing chunk in dask results --- reproject/_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index 503511b05..2e8718098 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -598,22 +598,22 @@ def reproject_single_block(a, array_or_path, block_info=None): logger.info("Setting up output dask array with map_blocks") + # Declare the exact (possibly ragged) chunks of the output template so + # that edge blocks are computed at their true size rather than being + # reprojected at the full block size and truncated afterwards. result = da.map_blocks( reproject_single_block, array_out_dask, array_in_or_path, dtype=" Date: Fri, 3 Jul 2026 11:39:39 +0000 Subject: [PATCH 09/11] Move ArrayWrapper into the interpolation test that uses it as a stand-in for a custom lazy array, since the dispatcher no longer needs it --- reproject/_array_utils.py | 14 +------------- reproject/interpolation/tests/test_core.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/reproject/_array_utils.py b/reproject/_array_utils.py index 26a6d2686..f2b23f3f1 100644 --- a/reproject/_array_utils.py +++ b/reproject/_array_utils.py @@ -4,7 +4,7 @@ from dask_image.ndinterp import spline_filter from scipy.ndimage import spline_filter as scipy_spline_filter -__all__ = ["map_coordinates", "dask_map_coordinates", "sample_array_edges", "ArrayWrapper"] +__all__ = ["map_coordinates", "dask_map_coordinates", "sample_array_edges"] def find_chunk_shape(shape, max_chunk_size=None): @@ -337,15 +337,3 @@ def aligned_chunks(lo, hi, edges=edges): if len(pieces) > 1: array = da.concatenate(pieces, axis=idim) return array - - -class ArrayWrapper: - - def __init__(self, array): - self._array = array - self.ndim = array.ndim - self.shape = array.shape - self.dtype = array.dtype - - def __getitem__(self, item): - return self._array[item] diff --git a/reproject/interpolation/tests/test_core.py b/reproject/interpolation/tests/test_core.py index c2232ceb8..4ae9595c0 100644 --- a/reproject/interpolation/tests/test_core.py +++ b/reproject/interpolation/tests/test_core.py @@ -13,13 +13,25 @@ from astropy.wcs.wcsapi import HighLevelWCSWrapper, SlicedLowLevelWCS from numpy.testing import assert_allclose -from reproject._array_utils import ArrayWrapper from reproject.interpolation._high_level import reproject_interp from reproject.tests.helpers import array_footprint_to_hdulist # TODO: add reference comparisons +class ArrayWrapper: + # Minimal getitem-only array-like, standing in for a custom lazy array + + def __init__(self, array): + self._array = array + self.ndim = array.ndim + self.shape = array.shape + self.dtype = array.dtype + + def __getitem__(self, item): + return self._array[item] + + @pytest.fixture( params=[None, "memmap", "none"], ) From e3e1e1b606581a5dd03d19c94d3872ee18905d41 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 12:01:33 +0000 Subject: [PATCH 10/11] Restore ArrayWrapper around the broadcasted path's from_array call and document that it prevents dask from hashing the whole buffer to name the array, which for a memmap silently loads the entire file into memory --- reproject/_array_utils.py | 24 +++++++++++++++++++++- reproject/_common.py | 9 +++++++- reproject/interpolation/tests/test_core.py | 14 +------------ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/reproject/_array_utils.py b/reproject/_array_utils.py index f2b23f3f1..7691f573c 100644 --- a/reproject/_array_utils.py +++ b/reproject/_array_utils.py @@ -4,7 +4,7 @@ from dask_image.ndinterp import spline_filter from scipy.ndimage import spline_filter as scipy_spline_filter -__all__ = ["map_coordinates", "dask_map_coordinates", "sample_array_edges"] +__all__ = ["map_coordinates", "dask_map_coordinates", "sample_array_edges", "ArrayWrapper"] def find_chunk_shape(shape, max_chunk_size=None): @@ -337,3 +337,25 @@ def aligned_chunks(lo, hi, edges=edges): if len(pieces) > 1: array = da.concatenate(pieces, axis=idim) return array + + +class ArrayWrapper: + """ + A minimal getitem-only wrapper hiding an array from dask's tokenizer. + + Passing a Numpy array (in particular a memmap) directly to + ``da.from_array`` can make dask hash the whole buffer to compute the array + name, which silently loads the entire file into memory (see + https://github.com/dask/dask/issues/11850). Wrapping the array so that + dask can only access it through ``__getitem__``, combined with an explicit + ``name=``, guarantees the data is only ever read chunk by chunk. + """ + + def __init__(self, array): + self._array = array + self.ndim = array.ndim + self.shape = array.shape + self.dtype = array.dtype + + def __getitem__(self, item): + return self._array[item] diff --git a/reproject/_common.py b/reproject/_common.py index 2e8718098..69a91420a 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -12,6 +12,7 @@ from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper from dask import delayed +from ._array_utils import ArrayWrapper from .utils import _dask_to_numpy_memmap __all__ = ["_reproject_dispatcher"] @@ -525,8 +526,14 @@ def reproject_single_block(a, array_or_path, block_info=None): ) array_in_or_path = da.block(pieces.tolist()) else: + # ArrayWrapper (plus the explicit name) prevents dask from + # hashing the whole buffer to name the array, which for a memmap + # would silently load the entire file into memory (see + # https://github.com/dask/dask/issues/11850). array_in_or_path = da.from_array( - array_in, name=f"reproject-input-{uuid.uuid4().hex}", chunks=input_chunks + ArrayWrapper(array_in), + name=f"reproject-input-{uuid.uuid4().hex}", + chunks=input_chunks, ) elif broadcasted_parallelization: diff --git a/reproject/interpolation/tests/test_core.py b/reproject/interpolation/tests/test_core.py index 4ae9595c0..c2232ceb8 100644 --- a/reproject/interpolation/tests/test_core.py +++ b/reproject/interpolation/tests/test_core.py @@ -13,25 +13,13 @@ from astropy.wcs.wcsapi import HighLevelWCSWrapper, SlicedLowLevelWCS from numpy.testing import assert_allclose +from reproject._array_utils import ArrayWrapper from reproject.interpolation._high_level import reproject_interp from reproject.tests.helpers import array_footprint_to_hdulist # TODO: add reference comparisons -class ArrayWrapper: - # Minimal getitem-only array-like, standing in for a custom lazy array - - def __init__(self, array): - self._array = array - self.ndim = array.ndim - self.shape = array.shape - self.dtype = array.dtype - - def __getitem__(self, item): - return self._array[item] - - @pytest.fixture( params=[None, "memmap", "none"], ) From 67f27d70f8c4a14411761ed5e3162c092306e6ea Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Fri, 3 Jul 2026 12:13:03 +0000 Subject: [PATCH 11/11] Refer to non-reprojected slices instead of planes and celestial dimensions in the sub-tiling comments, docstrings and error message, since the reprojected dimensions are not necessarily two-dimensional or celestial --- reproject/_common.py | 19 ++++++++++--------- reproject/adaptive/_high_level.py | 5 +++-- reproject/interpolation/_high_level.py | 5 +++-- reproject/tests/test_non_reprojected_dims.py | 8 ++++---- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/reproject/_common.py b/reproject/_common.py index 69a91420a..6530c3016 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -105,7 +105,7 @@ def _reproject_dispatcher( way. Reprojecting fewer dimensions than the WCS currently requires an explicit ``block_size``; its entries along the reprojected dimensions may either match the output shape or be smaller, in which case each - plane is reprojected in sub-tiles of that size. + non-reprojected slice is reprojected in sub-tiles of that size. array_out : `~numpy.ndarray`, optional An array in which to store the reprojected data. This can be any numpy array including a memory map, which may be helpful when dealing with @@ -313,14 +313,14 @@ def _reproject_dispatcher( # about the number of entries but about which dimensions the block spans the # full output extent along: # - if the block spans the full extent along the reprojected (trailing) - # dimensions, each block is one whole reprojected plane, so we parallelize - # over the broadcasted dimensions (one broadcasted slice per block); + # dimensions, each block is one whole non-reprojected slice, so we + # parallelize over the broadcasted dimensions (one slice per block); # - if instead it spans the full extent along the broadcasted (leading) - # dimensions, the block tiles the reprojected plane and we do not + # dimensions, the block tiles the reprojected dimensions and we do not # parallelize over the broadcasted dimensions; # - if it spans the full extent along neither, we raise, unless - # non_reprojected_dims requires slicing the WCS per plane, in which case a - # block smaller than the plane sub-tiles each plane. + # non_reprojected_dims requires slicing the WCS per non-reprojected slice, + # in which case a block smaller than the slice sub-tiles each slice. broadcasted_parallelization = False if broadcasting and block_size is not None and block_size != "auto": if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: @@ -332,7 +332,7 @@ def _reproject_dispatcher( # slice per block and let dask additionally tile the reprojected # dimensions according to the block size, which bounds the # coordinate-transform memory (it would otherwise scale with the - # full plane size). Each output tile is still reprojected from the + # full slice size). Each output tile is still reprojected from the # whole input slice, since any output pixel can map anywhere within # it. broadcasted_parallelization = True @@ -376,7 +376,8 @@ def _reproject_dispatcher( "(for example using non_reprojected_dims) currently requires " "passing an explicit block_size whose entries along the reprojected " "dimensions either match the output shape or are smaller (in which " - "case each plane is reprojected in sub-tiles of that size), " + "case each non-reprojected slice is reprojected in sub-tiles of " + "that size), " "optionally with parallel=True to compute the blocks concurrently" ) @@ -499,7 +500,7 @@ def reproject_single_block(a, array_or_path, block_info=None): # chunk along the reprojected dimensions to every output tile of that # slice, so each slice is computed exactly once and streamed to the # tasks that need it: dask array inputs are never materialized in - # full, sub-tiled planes do not recompute their input per tile, and + # full, sub-tiled slices do not recompute their input per tile, and # under a distributed scheduler each task depends only on its own # slice rather than embedding the whole input. The exception is a dask # input with dask_method='none' that is chunked below one slice along diff --git a/reproject/adaptive/_high_level.py b/reproject/adaptive/_high_level.py index 463f89055..961a001d2 100644 --- a/reproject/adaptive/_high_level.py +++ b/reproject/adaptive/_high_level.py @@ -215,8 +215,9 @@ def reproject_adaptive( sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). This currently requires passing an explicit ``block_size``; its entries along the reprojected dimensions may either match ``shape_out`` or be - smaller, in which case each plane is reprojected in sub-tiles of that - size to keep the coordinate-transform memory bounded (optionally + smaller, in which case each non-reprojected slice is reprojected in + sub-tiles of that size to keep the coordinate-transform memory bounded + (optionally combined with ``parallel`` to compute the blocks concurrently). parallel : bool or int or str, optional If `True`, the reprojection is carried out in parallel, and if a diff --git a/reproject/interpolation/_high_level.py b/reproject/interpolation/_high_level.py index 88e7e7993..4fbef5aaf 100644 --- a/reproject/interpolation/_high_level.py +++ b/reproject/interpolation/_high_level.py @@ -111,8 +111,9 @@ def reproject_interp( sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). This currently requires passing an explicit ``block_size``; its entries along the reprojected dimensions may either match ``shape_out`` or be - smaller, in which case each plane is reprojected in sub-tiles of that - size to keep the coordinate-transform memory bounded (optionally + smaller, in which case each non-reprojected slice is reprojected in + sub-tiles of that size to keep the coordinate-transform memory bounded + (optionally combined with ``parallel`` to compute the blocks concurrently). parallel : bool or int or str, optional If `True`, the reprojection is carried out in parallel, and if a diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index 366c2d679..5b847227f 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -59,10 +59,10 @@ def test_non_reprojected_dims(reproject_function): @pytest.mark.parametrize("block_size", [(1, 7, 7), (7, 7), (1, 12, 20)]) def test_non_reprojected_dims_subtiled(reproject_function, block_size): - # A block_size smaller than the output along the reprojected (celestial) - # dimensions should reproject each plane in sub-tiles and give exactly the - # same result as reprojecting each full plane in one go. This is what keeps - # the coordinate-transform memory bounded for large planes. + # A block_size smaller than the output along the reprojected dimensions + # (the celestial ones here) should reproject each slice in sub-tiles and give + # exactly the same result as reprojecting each full slice in one go. This is + # what keeps the coordinate-transform memory bounded for large slices. data = np.arange(4 * 20 * 20, dtype=float).reshape((4, 20, 20)) wcs_in = _spectral_cube_wcs(0.0, 1e9)