diff --git a/reproject/_array_utils.py b/reproject/_array_utils.py index 26a6d2686..7691f573c 100644 --- a/reproject/_array_utils.py +++ b/reproject/_array_utils.py @@ -340,6 +340,16 @@ def aligned_chunks(lo, hi, edges=edges): class ArrayWrapper: + """ + A minimal getitem-only wrapper hiding an array from dask's tokenizer. + + Passing a Numpy array (in particular a memmap) directly to + ``da.from_array`` can make dask hash the whole buffer to compute the array + name, which silently loads the entire file into memory (see + https://github.com/dask/dask/issues/11850). Wrapping the array so that + dask can only access it through ``__getitem__``, combined with an explicit + ``name=``, guarantees the data is only ever read chunk by chunk. + """ def __init__(self, array): self._array = array diff --git a/reproject/_common.py b/reproject/_common.py index f8a52b54c..6530c3016 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -1,4 +1,5 @@ import logging +import mmap import os import tempfile import uuid @@ -101,9 +102,10 @@ def _reproject_dispatcher( given as a tuple of sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). If `None` (the default), any leading dimensions for which the WCS has fewer dimensions than the data are treated this - way. Reprojecting fewer dimensions than the WCS currently requires a - ``block_size`` that matches the output shape along the reprojected - dimensions. + way. Reprojecting fewer dimensions than the WCS currently requires an + explicit ``block_size``; its entries along the reprojected dimensions + may either match the output shape or be smaller, in which case each + non-reprojected slice is reprojected in sub-tiles of that size. array_out : `~numpy.ndarray`, optional An array in which to store the reprojected data. This can be any numpy array including a memory map, which may be helpful when dealing with @@ -184,15 +186,15 @@ def _reproject_dispatcher( "non_reprojected_dims should leave at least one dimension to be " "reprojected" ) - # If we are reprojecting fewer dimensions than the input or output WCS has, - # the WCS needs to be sliced down to the reprojected dimensions for each - # non-reprojected slice. This is currently only done when parallelizing over - # the non-reprojected (broadcasted) dimensions, so any other code path would - # silently reproject the dimensions that should have been left untouched. - # This is gated on non_reprojected_dims being set since that is the only way - # to opt into reprojecting fewer dimensions than the WCS; a plain mismatch - # between the input and output WCS dimensionality is instead a validation - # error raised by the underlying reprojection function. + # ``wcs_slicing_required`` flags that we are reprojecting fewer dimensions than + # the input or output WCS describes, so the WCS must be sliced down to the + # reprojected dimensions for each non-reprojected slice. That slicing is only + # implemented on the path that parallelizes over the non-reprojected + # (broadcasted) dimensions; the other code paths raise NotImplementedError below + # rather than attempting it. It is gated on non_reprojected_dims being set, the + # only way to opt into reprojecting fewer dimensions than the WCS; a plain + # mismatch between the input and output WCS dimensionality is instead a + # validation error raised by the underlying reprojection function. wcs_slicing_required = non_reprojected_dims is not None and ( n_dim_reproject < wcs_in.low_level_wcs.pixel_n_dim or n_dim_reproject < wcs_out.low_level_wcs.pixel_n_dim @@ -305,28 +307,59 @@ def _reproject_dispatcher( for i in range(len(block_size)) ) - # Check block size and determine whether block size indicates we should - # parallelize over broadcasted dimension. The logic is as follows: if - # the block size and output shape are the same size, then either the - # block size should match the output shape along the broadcasted - # dimensions or along the non-broadcasted dimensions. If it matches the - # non-broadcasted dimensions we can parallelize over the broadcasted - # dimensions. If the block size does not match the output shape, we - # don't make any assumptions for now and assume a single chunk in the - # missing dimensions. + # Decide whether the requested block size means we should parallelize over + # the broadcasted (non-reprojected, leading) dimensions. block_size has + # already been padded above to one entry per output dimension, so this is not + # about the number of entries but about which dimensions the block spans the + # full output extent along: + # - if the block spans the full extent along the reprojected (trailing) + # dimensions, each block is one whole non-reprojected slice, so we + # parallelize over the broadcasted dimensions (one slice per block); + # - if instead it spans the full extent along the broadcasted (leading) + # dimensions, the block tiles the reprojected dimensions and we do not + # parallelize over the broadcasted dimensions; + # - if it spans the full extent along neither, we raise, unless + # non_reprojected_dims requires slicing the WCS per non-reprojected slice, + # in which case a block smaller than the slice sub-tiles each slice. broadcasted_parallelization = False if broadcasting and block_size is not None and block_size != "auto": if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: - # TODO: maybe error if block_size was given in full and is wrong broadcasted_parallelization = True - block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ - -n_dim_reproject: - ] + elif wcs_slicing_required: + # A block smaller than the output along the reprojected dimensions + # is only meaningful when the WCS has to be sliced per broadcasted + # slice (i.e. non_reprojected_dims). We parallelize one broadcasted + # slice per block and let dask additionally tile the reprojected + # dimensions according to the block size, which bounds the + # coordinate-transform memory (it would otherwise scale with the + # full slice size). Each output tile is still reprojected from the + # whole input slice, since any output pixel can map anywhere within + # it. + broadcasted_parallelization = True elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]: raise ValueError( "block shape should either match output data shape along " "reprojected dimensions or non-reprojected dimensions" ) + if broadcasted_parallelization: + # One broadcasted slice per block; dask tiles the reprojected + # dimensions using whatever block size was requested along them. + # The block size along the non-reprojected dimensions must be 1 + # or span the full extent (equivalent here, since blocks are + # single slices either way); anything else would be silently + # reinterpreted, so raise instead. + if any( + entry not in (1, shape_out[idim]) + for idim, entry in enumerate(block_size[: len(shape_out) - n_dim_reproject]) + ): + raise ValueError( + f"block_size {block_size} should be 1 or match the output shape " + "along the non-reprojected dimensions (each block covers a " + "single non-reprojected slice)" + ) + block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ + -n_dim_reproject: + ] logger.info( f"{'P' if broadcasted_parallelization else 'Not p'}arallelizing along " @@ -341,9 +374,11 @@ def _reproject_dispatcher( raise NotImplementedError( "Reprojecting fewer dimensions than the input or output WCS " "(for example using non_reprojected_dims) currently requires " - "passing a block_size whose entries along the reprojected " - "dimensions match the output shape (optionally with parallel=True " - "to compute the blocks concurrently)" + "passing an explicit block_size whose entries along the reprojected " + "dimensions either match the output shape or are smaller (in which " + "case each non-reprojected slice is reprojected in sub-tiles of " + "that size), " + "optionally with parallel=True to compute the blocks concurrently" ) if output_footprint is None and return_footprint and return_type != "dask": @@ -359,8 +394,8 @@ def reproject_single_block(a, array_or_path, block_info=None): ): return np.array([a, a]) - if isinstance(array_or_path, str) and array_or_path == "from-dict": - array_or_path = dask_arrays["array"] + if isinstance(array_or_path, _ArrayContainer): + array_or_path = array_or_path._array shape_out = block_info[None]["chunk-shape"][1:] @@ -375,8 +410,16 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_in_cp = wcs_in.deepcopy() if isinstance(wcs_in, WCS) else wcs_in wcs_out_cp = wcs_out.deepcopy() if isinstance(wcs_out, WCS) else wcs_out + # Along the reprojected dimensions the input is always kept whole (any + # output pixel can map anywhere within it) while dask may tile the + # output; along the broadcasted dimensions each block is a single + # slice. slices_in/slices_out reduce the input/output WCS to this + # block; the matching broadcasted slice of the input either arrives as + # the aligned input block or, when the input was passed whole (lazy + # dask input), is read out below using slices_in_data. slices_in = [] slices_out = [] + slices_in_data = [] for idx in range(len(shape_out)): interval = block_info[None]["array-location"][idx + 1] if broadcasted_parallelization and idx < len(shape_out) - n_dim_reproject: @@ -387,9 +430,11 @@ def reproject_single_block(a, array_or_path, block_info=None): ) slices_in.append(interval[0]) slices_out.append(interval[0]) + slices_in_data.append(slice(*interval)) else: slices_in.append(slice(None)) slices_out.append(slice(*block_info[None]["array-location"][idx + 1])) + slices_in_data.append(slice(None)) slices_in = slices_in[-wcs_in.low_level_wcs.pixel_n_dim :] slices_out = slices_out[-wcs_out.low_level_wcs.pixel_n_dim :] @@ -411,16 +456,27 @@ def reproject_single_block(a, array_or_path, block_info=None): wcs_out_sub = HighLevelWCSWrapper(low_level_wcs_out) - if isinstance(array_or_path, tuple): + if broadcasted_parallelization and input_aligned: + # The input was passed as an aligned dask array, so array_or_path + # is already this block's broadcasted slice of the input, kept + # whole along the reprojected dimensions (see above). + array_in = array_or_path + elif isinstance(array_or_path, tuple): array_in = np.memmap(array_or_path[0], **array_or_path[1], mode="r") elif isinstance(array_or_path, str): array_in = np.memmap(array_or_path, dtype=float, shape=shape_in, mode="r") else: array_in = array_or_path - if array_or_path is None: + if array_in is None: raise RuntimeError("array_or_path is not set") + if broadcasted_parallelization and not input_aligned: + # The input was passed whole as a lazy dask array; read out a lazy + # view of this block's broadcasted slice so a streaming + # reprojection core only computes the input chunks it touches. + array_in = array_in[tuple(slices_in_data)] + array, footprint = reproject_func( array_in, wcs_in_sub, @@ -432,98 +488,140 @@ def reproject_single_block(a, array_or_path, block_info=None): return np.array([array, footprint]) - if broadcasted_parallelization: - - array_out_dask = da.empty(shape_out, chunks=block_size) - - # The input is reprojected in full for each output block, so it must - # not be chunked along the reprojected dimensions (which can have a - # different size from the output); only the broadcasted dimensions are - # chunked, matching array_out_dask block for block. + input_aligned = False + if broadcasted_parallelization and ( + not isinstance(array_in, da.core.Array) + or dask_method != "none" + or all(len(chunks) == 1 for chunks in array_in.chunks[-n_dim_reproject:]) + ): + # Pass the input as a second dask array with one chunk per broadcasted + # slice, kept whole along the reprojected dimensions (any output pixel + # can map anywhere within its slice). map_blocks broadcasts the single + # chunk along the reprojected dimensions to every output tile of that + # slice, so each slice is computed exactly once and streamed to the + # tasks that need it: dask array inputs are never materialized in + # full, sub-tiled slices do not recompute their input per tile, and + # under a distributed scheduler each task depends only on its own + # slice rather than embedding the whole input. The exception is a dask + # input with dask_method='none' that is chunked below one slice along + # the reprojected dimensions: materializing it here would forgo the + # ability of streaming reprojection cores to work chunk by chunk + # without ever holding a whole slice, so it is kept lazy below. + input_aligned = True input_chunks = (1,) * (array_in.ndim - n_dim_reproject) + (-1,) * n_dim_reproject if isinstance(array_in, da.core.Array): - array_in = array_in.rechunk(input_chunks) + array_in_dask = array_in.rechunk(input_chunks) + # Blockwise fusion would fold the input graph into every output + # tile task, recomputing each broadcasted slice once per tile of + # that slice; routing each slice through a delayed task pins it + # as a single node in the graph that all of its tiles share. + delayed_blocks = array_in_dask.to_delayed() + pieces = np.empty(delayed_blocks.shape, dtype=object) + for index in np.ndindex(delayed_blocks.shape): + shape = tuple( + array_in_dask.chunks[idim][index[idim]] + for idim in range(array_in_dask.ndim) + ) + pieces[index] = da.from_delayed( + delayed_blocks[index], shape=shape, dtype=array_in_dask.dtype + ) + array_in_or_path = da.block(pieces.tolist()) else: - array_in = da.asarray( - ArrayWrapper(array_in), name=str(uuid.uuid4()), chunks=input_chunks + # ArrayWrapper (plus the explicit name) prevents dask from + # hashing the whole buffer to name the array, which for a memmap + # would silently load the entire file into memory (see + # https://github.com/dask/dask/issues/11850). + array_in_or_path = da.from_array( + ArrayWrapper(array_in), + name=f"reproject-input-{uuid.uuid4().hex}", + chunks=input_chunks, ) - result = da.map_blocks( - reproject_single_block, - array_out_dask, - array_in, - dtype="