diff --git a/reproject/_common.py b/reproject/_common.py index bc0e45ef0..7df7a0500 100644 --- a/reproject/_common.py +++ b/reproject/_common.py @@ -11,7 +11,7 @@ from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper from dask import delayed -from ._array_utils import ArrayWrapper +from ._array_utils import ArrayWrapper, iterate_chunks from .utils import _dask_to_numpy_memmap __all__ = ["_reproject_dispatcher"] @@ -69,6 +69,7 @@ def _reproject_dispatcher( reproject_func_kwargs=None, return_type=None, dask_method=None, + zarr_path=None, ): """ Main function that handles either calling the core algorithms directly or @@ -101,9 +102,10 @@ def _reproject_dispatcher( given as a tuple of sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). If `None` (the default), any leading dimensions for which the WCS has fewer dimensions than the data are treated this - way. Reprojecting fewer dimensions than the WCS currently requires a - ``block_size`` that matches the output shape along the reprojected - dimensions. + way. Reprojecting fewer dimensions than the WCS currently requires an + explicit ``block_size``; its entries along the reprojected dimensions + may either match the output shape or be smaller, in which case each + plane is reprojected in sub-tiles of that size. array_out : `~numpy.ndarray`, optional An array in which to store the reprojected data. This can be any numpy array including a memory map, which may be helpful when dealing with @@ -123,8 +125,13 @@ def _reproject_dispatcher( dask.distributed), set this to ``'current-scheduler'``. reproject_func_kwargs : dict, optional Keyword arguments to pass through to ``reproject_func`` - return_type : {'numpy', 'dask' }, optional - Whether to return numpy or dask arrays. + return_type : {'numpy', 'dask', 'zarr'}, optional + Whether to return numpy or dask arrays or whether to dump the result to + a zarr array on disk. If 'zarr', then the ``zarr_path`` keyword has to + also be specified, and the function will return dask arrays constructed + from the zarr array. In this case the reprojection is always carried out + in blocks (using dask, on the synchronous scheduler when ``parallel`` is + `False`), and ``block_size`` defaults to ``'auto'`` when not specified. dask_method : {'memmap', 'none'}, optional Method to use when input array is a dask array. The methods are: * ``'memmap'``: write out the entire input dask array to a temporary @@ -137,14 +144,20 @@ def _reproject_dispatcher( fits into memory (as this will then be faster than ``'memmap'``), and when the data contains more dimensions than the input WCS and the block_size is chosen to iterate over the extra dimensions. + zarr_path : str, optional + If return_type is 'zarr', this specifies the path to use for the zarr + array. This should be a non-existent path. """ logger = logging.getLogger(__name__) if return_type is None: return_type = "numpy" - elif return_type not in ("numpy", "dask"): - raise ValueError("return_type should be set to 'numpy' or 'dask'") + elif return_type not in ("numpy", "dask", "zarr"): + raise ValueError("return_type should be set to 'numpy', 'dask', or 'zarr'") + + if return_type == "zarr" and block_size is None: + block_size = "auto" if dask_method is None: dask_method = "memmap" @@ -154,6 +167,12 @@ def _reproject_dispatcher( if reproject_func_kwargs is None: reproject_func_kwargs = {} + if return_type == "zarr": + if zarr_path is None: + raise ValueError("zarr_path needs to be set if return_type is 'zarr'") + elif os.path.exists(zarr_path): + raise ValueError(f"Path {zarr_path} already exists") + # For now, we are quite restrictive in what non_reprojected_dims can # be, but it is designed so that if we wanted we could support more use # cases in future. For now, it has to be a tuple where each element is @@ -203,7 +222,7 @@ def _reproject_dispatcher( with tempfile.TemporaryDirectory() as local_tmp_dir: if array_out is None: - if return_type != "dask": + if return_type == "numpy": array_out = np.zeros(shape_out, dtype=float) elif array_out.shape != tuple(shape_out): raise ValueError( @@ -226,9 +245,9 @@ def _reproject_dispatcher( # If a dask array was passed as input, we first convert this to a # Numpy memory mapped array - if return_type == "dask": + if return_type in ("dask", "zarr"): raise ValueError( - "Output cannot be returned as dask arrays " + "Output cannot be returned as dask or zarr arrays " "when parallel=False and no block size has " "been specified" ) @@ -315,6 +334,10 @@ def _reproject_dispatcher( # don't make any assumptions for now and assume a single chunk in the # missing dimensions. broadcasted_parallelization = False + # When sub-tiling the reprojected dimensions within each broadcasted block + # (see below), this holds the shape of each sub-tile; otherwise it is None + # and each block is reprojected in one go. + sub_tile_shape = None if broadcasting and block_size is not None and block_size != "auto": if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]: # TODO: maybe error if block_size was given in full and is wrong @@ -322,6 +345,21 @@ def _reproject_dispatcher( block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[ -n_dim_reproject: ] + elif wcs_slicing_required: + # A block smaller than the output along the reprojected dimensions + # is only meaningful when the WCS has to be sliced per broadcasted + # slice (i.e. non_reprojected_dims). The whole input slice has to + # be reprojected as one (it cannot be chunked along the reprojected + # dimensions, since any output pixel can map anywhere in the input), + # so we still parallelize one full reprojected plane per block, but + # additionally sub-tile the reprojection within each plane to bound + # the coordinate-transform memory, which would otherwise scale with + # the full plane size. + broadcasted_parallelization = True + sub_tile_shape = block_size[-n_dim_reproject:] + block_size = (1,) * (len(shape_out) - n_dim_reproject) + shape_out[ + -n_dim_reproject: + ] elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]: raise ValueError( "block shape should either match output data shape along " @@ -341,9 +379,10 @@ def _reproject_dispatcher( raise NotImplementedError( "Reprojecting fewer dimensions than the input or output WCS " "(for example using non_reprojected_dims) currently requires " - "passing a block_size whose entries along the reprojected " - "dimensions match the output shape (optionally with parallel=True " - "to compute the blocks concurrently)" + "passing an explicit block_size whose entries along the reprojected " + "dimensions either match the output shape or are smaller (in which " + "case each plane is reprojected in sub-tiles of that size), " + "optionally with parallel=True to compute the blocks concurrently" ) if output_footprint is None and return_footprint: @@ -421,14 +460,40 @@ def reproject_single_block(a, array_or_path, block_info=None): if array_or_path is None: raise RuntimeError("array_or_path is not set") - array, footprint = reproject_func( - array_in, - wcs_in_sub, - wcs_out_sub, - shape_out=shape_out, - array_out=np.zeros(shape_out), - **reproject_func_kwargs, - ) + if sub_tile_shape is None: + array, footprint = reproject_func( + array_in, + wcs_in_sub, + wcs_out_sub, + shape_out=shape_out, + array_out=np.zeros(shape_out), + **reproject_func_kwargs, + ) + else: + # Reproject the plane in sub-tiles along the reprojected + # dimensions so that the coordinate transform (which is computed + # over the whole output sub-tile at once) does not have to be + # evaluated for the full plane in one go. The input slice is left + # whole since any output pixel can map anywhere within it. + array = np.zeros(shape_out) + footprint = np.zeros(shape_out) + n_broadcast = len(shape_out) - n_dim_reproject + for sub_tile in iterate_chunks( + shape_out[n_broadcast:], max_chunk_size=int(np.prod(sub_tile_shape)) + ): + sub_wcs_out = HighLevelWCSWrapper( + SlicedLowLevelWCS(low_level_wcs_out, slices=sub_tile) + ) + full_slices = (slice(None),) * n_broadcast + sub_tile + sub_shape = shape_out[:n_broadcast] + tuple(s.stop - s.start for s in sub_tile) + array[full_slices], footprint[full_slices] = reproject_func( + array_in, + wcs_in_sub, + sub_wcs_out, + shape_out=sub_shape, + array_out=np.zeros(sub_shape), + **reproject_func_kwargs, + ) return np.array([array, footprint]) @@ -532,7 +597,7 @@ def reproject_single_block(a, array_or_path, block_info=None): # We now convert the dask arrays back to Numpy arrays - if parallel: + if parallel or return_type == "zarr": # As discussed in https://github.com/dask/dask/issues/9556, da.store # will not work well in parallel mode when the destination is a # Numpy array. Instead, in this case we save the dask array to a zarr @@ -541,11 +606,15 @@ def reproject_single_block(a, array_or_path, block_info=None): # 'synchronous' scheduler since that is I/O limited so does not need # to be done in parallel. - zarr_path = os.path.join(local_tmp_dir, f"{uuid.uuid4()}.zarr") + if return_type != "zarr": + zarr_path = os.path.join(local_tmp_dir, f"{uuid.uuid4()}.zarr") logger.info(f"Computing output array directly to zarr array at {zarr_path}") - if parallel == "current-scheduler": + if not parallel: + with dask.config.set(scheduler="synchronous"): + result.to_zarr(zarr_path) + elif parallel == "current-scheduler": # Just use whatever is the current active scheduler, which can # be used for e.g. dask.distributed result.to_zarr(zarr_path) @@ -565,6 +634,12 @@ def reproject_single_block(a, array_or_path, block_info=None): result = da.from_zarr(zarr_path) + if return_type == "zarr": + if return_footprint: + return result[0], result[1] + else: + return result[0] + logger.info("Copying output zarr array into output Numpy arrays") if return_footprint: diff --git a/reproject/adaptive/_high_level.py b/reproject/adaptive/_high_level.py index 2fede6f8e..bc5117afa 100644 --- a/reproject/adaptive/_high_level.py +++ b/reproject/adaptive/_high_level.py @@ -34,6 +34,7 @@ def reproject_adaptive( parallel=False, return_type=None, dask_method=None, + zarr_path=None, ): """ Reproject a 2D array from one WCS to another using the DeForest (2004) @@ -213,9 +214,11 @@ def reproject_adaptive( even when the input and output WCS have the same number of dimensions as the data. The dimensions must be the leading ones, given as a tuple of sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). - This currently requires passing a ``block_size`` whose entries along - the reprojected dimensions match ``shape_out`` (optionally combined - with ``parallel`` to compute the blocks concurrently). + This currently requires passing an explicit ``block_size``; its entries + along the reprojected dimensions may either match ``shape_out`` or be + smaller, in which case each plane is reprojected in sub-tiles of that + size to keep the coordinate-transform memory bounded (optionally + combined with ``parallel`` to compute the blocks concurrently). parallel : bool or int or str, optional If `True`, the reprojection is carried out in parallel, and if a positive integer, this specifies the number of threads to use. @@ -223,8 +226,13 @@ def reproject_adaptive( by ``block_size`` (if the block size is not set, it will be determined automatically). To use the currently active dask scheduler (e.g. dask.distributed), set this to ``'current-scheduler'``. - return_type : {'numpy', 'dask'}, optional - Whether to return numpy or dask arrays. + return_type : {'numpy', 'dask', 'zarr'}, optional + Whether to return numpy or dask arrays, or to write the output to a zarr + array on disk. If ``'zarr'``, ``zarr_path`` must also be given; the + output is then computed in blocks (using dask, on the synchronous + scheduler when ``parallel`` is `False`), ``block_size`` defaults to + ``'auto'`` when not specified, and dask arrays backed by the zarr array + are returned. dask_method : {'memmap', 'none'}, optional Method to use when input array is a dask array. The methods are: * ``'memmap'``: write out the entire input dask array to a temporary @@ -237,6 +245,9 @@ def reproject_adaptive( fits into memory (as this will then be faster than ``'memmap'``), and when the data contains more dimensions than the input WCS and the block_size is chosen to iterate over the extra dimensions. + zarr_path : str, optional + Path to use for the output zarr array when ``return_type='zarr'``. This + must be a path that does not already exist. Returns ------- @@ -284,4 +295,5 @@ def reproject_adaptive( bad_fill_value=bad_fill_value, ), return_type=return_type, + zarr_path=zarr_path, ) diff --git a/reproject/interpolation/_high_level.py b/reproject/interpolation/_high_level.py index 876b63605..ba4c9ac22 100644 --- a/reproject/interpolation/_high_level.py +++ b/reproject/interpolation/_high_level.py @@ -29,6 +29,7 @@ def reproject_interp( parallel=False, return_type=None, dask_method=None, + zarr_path=None, ): """ Reproject data to a new projection using interpolation (this is typically @@ -109,9 +110,11 @@ def reproject_interp( even when the input and output WCS have the same number of dimensions as the data. The dimensions must be the leading ones, given as a tuple of sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``). - This currently requires passing a ``block_size`` whose entries along - the reprojected dimensions match ``shape_out`` (optionally combined - with ``parallel`` to compute the blocks concurrently). + This currently requires passing an explicit ``block_size``; its entries + along the reprojected dimensions may either match ``shape_out`` or be + smaller, in which case each plane is reprojected in sub-tiles of that + size to keep the coordinate-transform memory bounded (optionally + combined with ``parallel`` to compute the blocks concurrently). parallel : bool or int or str, optional If `True`, the reprojection is carried out in parallel, and if a positive integer, this specifies the number of threads to use. @@ -119,8 +122,13 @@ def reproject_interp( by ``block_size`` (if the block size is not set, it will be determined automatically). To use the currently active dask scheduler (e.g. dask.distributed), set this to ``'current-scheduler'``. - return_type : {'numpy', 'dask'}, optional - Whether to return numpy or dask arrays. + return_type : {'numpy', 'dask', 'zarr'}, optional + Whether to return numpy or dask arrays, or to write the output to a zarr + array on disk. If ``'zarr'``, ``zarr_path`` must also be given; the + output is then computed in blocks (using dask, on the synchronous + scheduler when ``parallel`` is `False`), ``block_size`` defaults to + ``'auto'`` when not specified, and dask arrays backed by the zarr array + are returned. dask_method : {'memmap', 'none'}, optional Method to use when input array is a dask array. The methods are: * ``'memmap'``: write out the entire input dask array to a temporary @@ -128,6 +136,9 @@ def reproject_interp( the entire input array. * ``'none'`` (default): use native dask interpolation, which avoids having to write the array to disk. + zarr_path : str, optional + Path to use for the output zarr array when ``return_type='zarr'``. This + must be a path that does not already exist. Returns ------- @@ -170,4 +181,5 @@ def reproject_interp( ), return_type=return_type, dask_method=dask_method, + zarr_path=zarr_path, ) diff --git a/reproject/interpolation/tests/test_core.py b/reproject/interpolation/tests/test_core.py index c2232ceb8..67966c7e2 100644 --- a/reproject/interpolation/tests/test_core.py +++ b/reproject/interpolation/tests/test_core.py @@ -946,7 +946,7 @@ def test_auto_block_size(dask_method): wcs_out = WCS(naxis=2) # When block size and parallel aren't specified, can't return as dask arrays - with pytest.raises(ValueError, match="Output cannot be returned as dask arrays"): + with pytest.raises(ValueError, match="Output cannot be returned as dask or zarr arrays"): reproject_interp( (array_in, wcs_in), wcs_out, diff --git a/reproject/mosaicking/_coadd.py b/reproject/mosaicking/_coadd.py index 087856c4b..7c1469eaf 100644 --- a/reproject/mosaicking/_coadd.py +++ b/reproject/mosaicking/_coadd.py @@ -1,20 +1,22 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import os +import shutil import sys import tempfile import uuid from logging import getLogger +import dask.array as da import numpy as np from astropy.wcs import WCS -from astropy.wcs.utils import pixel_to_pixel from astropy.wcs.wcsapi import SlicedLowLevelWCS -from .._array_utils import iterate_chunks, sample_array_edges +from .._array_utils import iterate_chunks from ..interpolation._core import _validate_wcs from ..utils import parse_input_data, parse_input_weights, parse_output_projection from ._background import determine_offset_matrix, solve_corrections_sgd from ._subset_array import DEFAULT_MAX_CHUNK_SIZE, ReprojectedArraySubset +from ._wcs_helpers import sample_input_edges_in_output __all__ = ["reproject_and_coadd"] @@ -28,8 +30,12 @@ def _noop(iterable): def _safe_remove(path): try: - os.remove(path) - except PermissionError: + if os.path.isdir(path): + # zarr stores are directories rather than single files + shutil.rmtree(path, ignore_errors=True) + else: + os.remove(path) + except (PermissionError, FileNotFoundError): pass @@ -169,9 +175,12 @@ def reproject_and_coadd( blank_pixel_value : float, optional Value to use for areas of the resulting mosaic that do not have input data. - intermediate_memmap : bool, optional - If `True`, use `numpy.memmap` to store intermediate output arrays for - reprojected data. + intermediate_memmap : bool or {'zarr'}, optional + If `True`, use `numpy.memmap` to store intermediate reprojected arrays on + disk. If ``'zarr'``, store the intermediate arrays as zarr arrays on disk + instead, which is typically more efficient (each image is then reprojected + in blocks and the zarr store is removed once the image has been combined), + but cannot be used together with ``match_background=True``. **kwargs Keyword arguments to be passed to the reprojection function. @@ -215,6 +224,9 @@ def reproject_and_coadd( if progress_bar is None: progress_bar = _noop + if match_background and intermediate_memmap == "zarr": + raise ValueError("Cannot use intermediate_memmap='zarr' when match_background=True") + # Parse the output projection to avoid having to do it for each wcs_out, shape_out = parse_output_projection(output_projection, shape_out=shape_out) @@ -306,10 +318,7 @@ def reproject_and_coadd( # which provides a lot of redundant information. try: - edges = sample_array_edges( - array_in.shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=11 - )[::-1] - edges_out = pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1] + edges_out = sample_input_edges_in_output(array_in.shape, wcs_in, wcs_out) except Exception: # If the edge coordinates cannot be transformed (for example if # they fall outside the validity region of the WCS), fall back to @@ -402,7 +411,16 @@ def reproject_and_coadd( # able to handle weights, and make the footprint become the combined # footprint + weight map - if intermediate_memmap: + extra_kwargs = {} + array = footprint = None + + if intermediate_memmap == "zarr": + + array_zarr_path = os.path.join(local_tmp_dir, f"array_{uuid.uuid4()}.zarr") + extra_kwargs["return_type"] = "zarr" + extra_kwargs["zarr_path"] = array_zarr_path + + elif intermediate_memmap: array_path = os.path.join(local_tmp_dir, f"array_{uuid.uuid4()}.np") @@ -430,10 +448,6 @@ def reproject_and_coadd( dtype=float, ) - else: - - array = footprint = None - logger.info(f"Calling {reproject_function.__name__} with shape_out={shape_out_indiv}") array, footprint = reproject_function( @@ -445,11 +459,21 @@ def reproject_and_coadd( output_footprint=footprint, block_size=block_size, **kwargs, + **extra_kwargs, ) if weights_in is not None: - if intermediate_memmap: + extra_kwargs = {} + weights = None + + if intermediate_memmap == "zarr": + + weights_zarr_path = os.path.join(local_tmp_dir, f"weights_{uuid.uuid4()}.zarr") + extra_kwargs["return_type"] = "zarr" + extra_kwargs["zarr_path"] = weights_zarr_path + + elif intermediate_memmap: weights_path = os.path.join(local_tmp_dir, f"weights_{uuid.uuid4()}.np") @@ -464,10 +488,6 @@ def reproject_and_coadd( dtype=float, ) - else: - - weights = None - logger.info( f"Calling {reproject_function.__name__} with shape_out={shape_out_indiv} for weights" ) @@ -480,28 +500,43 @@ def reproject_and_coadd( output_array=weights, return_footprint=False, **kwargs, + **extra_kwargs, ) # For the purposes of mosaicking, we mask out NaN values from the array - # and set the footprint to 0 at these locations. We do this in chunks - # to avoid excessive memory usage. - for chunk in iterate_chunks(array.shape, max_chunk_size=DEFAULT_MAX_CHUNK_SIZE): - - # Determine location of NaNs - reset = np.isnan(array[chunk]) + # and set the footprint to 0 at these locations, and fold any weights + # into the footprint. + if isinstance(array, da.core.Array): + # When intermediate_memmap='zarr' the arrays are dask arrays which + # do not support in-place assignment, so build the masked arrays + # lazily instead; the masking is then applied chunk by chunk when + # the arrays are combined below. + reset = da.isnan(array) if weights_in is not None: - reset |= np.isnan(weights[chunk]) + reset = reset | da.isnan(weights) + footprint = da.where(reset, 0.0, footprint * weights) + else: + footprint = da.where(reset, 0.0, footprint) + array = da.where(reset, 0.0, array) + else: + # We do this in chunks to avoid excessive memory usage. + for chunk in iterate_chunks(array.shape, max_chunk_size=DEFAULT_MAX_CHUNK_SIZE): - # Mask them in-place in the arrays - array[chunk][reset] = 0.0 - footprint[chunk][reset] = 0.0 + # Determine location of NaNs + reset = np.isnan(array[chunk]) + if weights_in is not None: + reset |= np.isnan(weights[chunk]) - # Combine weights and footprint - if weights_in is not None: - weights[chunk][reset] = 0.0 - footprint[chunk] *= weights[chunk] + # Mask them in-place in the arrays + array[chunk][reset] = 0.0 + footprint[chunk][reset] = 0.0 + + # Combine weights and footprint + if weights_in is not None: + weights[chunk][reset] = 0.0 + footprint[chunk] *= weights[chunk] - if weights_in is not None and intermediate_memmap: + if weights_in is not None and intermediate_memmap is True: # Remove the reference to the memmap before trying to remove the file itself logger.info("Removing memory-mapped weight array") weights = None @@ -519,12 +554,22 @@ def reproject_and_coadd( logger.info("Adding reprojected array to final array in chunks") _combine_array_into_output(combine_function, array, output_array, output_footprint) - if intermediate_memmap: + if intermediate_memmap is True: logger.info("Removing memory-mapped array and footprint arrays") array = None footprint = None for path in (array_path, footprint_path): _safe_remove(path) + elif intermediate_memmap == "zarr": + # The array and footprint share a single zarr store, and the + # footprint may lazily reference the weights zarr, so these + # can only be removed now that the arrays have been combined. + logger.info("Removing intermediate zarr arrays") + array = None + footprint = None + _safe_remove(array_zarr_path) + if weights_in is not None: + _safe_remove(weights_zarr_path) else: diff --git a/reproject/mosaicking/_subset_array.py b/reproject/mosaicking/_subset_array.py index b13381155..db8709f4f 100644 --- a/reproject/mosaicking/_subset_array.py +++ b/reproject/mosaicking/_subset_array.py @@ -3,6 +3,10 @@ import operator from math import prod +import dask.array as da +import numpy as np +from dask.array.core import slices_from_chunks + from .._array_utils import iterate_chunks __all__ = ["ReprojectedArraySubset"] @@ -125,9 +129,19 @@ def _operation(self, other, op): def as_chunks(self, max_chunk_size=None): - for chunk in iterate_chunks( - self.shape, max_chunk_size=max_chunk_size or DEFAULT_MAX_CHUNK_SIZE - ): + if isinstance(self.array, da.core.Array): + # For dask-backed arrays (e.g. when reprojected output was written to + # a zarr array), iterate over the native chunks so that each on-disk + # chunk is read and decompressed exactly once, rather than slicing + # across chunk boundaries which would re-read chunks many times. The + # chunk size was already chosen so that a single block fits in memory. + chunks = slices_from_chunks(self.array.chunks) + else: + chunks = iterate_chunks( + self.shape, max_chunk_size=max_chunk_size or DEFAULT_MAX_CHUNK_SIZE + ) + + for chunk in chunks: bounds_chunk = tuple( (self.bounds[idim][0] + chunk[idim].start, self.bounds[idim][0] + chunk[idim].stop) @@ -135,7 +149,7 @@ def as_chunks(self, max_chunk_size=None): ) yield ReprojectedArraySubset( - array=self.array[chunk], - footprint=self.footprint[chunk], + array=np.asarray(self.array[chunk]), + footprint=np.asarray(self.footprint[chunk]), bounds=bounds_chunk, ) diff --git a/reproject/mosaicking/_wcs_helpers.py b/reproject/mosaicking/_wcs_helpers.py index fd8d8ef05..fe2bbf989 100644 --- a/reproject/mosaicking/_wcs_helpers.py +++ b/reproject/mosaicking/_wcs_helpers.py @@ -1,5 +1,6 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst +import itertools import warnings import numpy as np @@ -9,12 +10,15 @@ from astropy.wcs import WCS from astropy.wcs.utils import ( celestial_frame_to_wcs, + pixel_to_pixel, pixel_to_skycoord, skycoord_to_pixel, wcs_to_celestial_frame, ) -from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS +from astropy.wcs.wcsapi import BaseHighLevelWCS, BaseLowLevelWCS, SlicedLowLevelWCS +from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper +from .._array_utils import sample_array_edges from .._wcs_utils import pixel_scale from ..utils import parse_input_shape @@ -321,3 +325,67 @@ def find_optimal_celestial_wcs( naxis2 = int(round(ymax - ymin)) return wcs_final, (naxis2, naxis1) + + +def sample_input_edges_in_output(array_shape, wcs_in, wcs_out, n_samples=11): + """ + Sample the edges of an input array and return their pixel coordinates in the + output WCS, for the reprojected (trailing) dimensions. + + This is used to determine the minimal region of the output that an input + image covers. If the input WCS has more pixel dimensions than the output WCS + (for example when using ``non_reprojected_dims`` to reproject a cube into a + celestial-only output), the input WCS is sliced down to its reprojected + (trailing) dimensions before being related to the output, since + ``pixel_to_pixel`` requires the two WCS to describe the same number of world + coordinates. Because the reprojected WCS may vary along the non-reprojected + axes (for example a drifting pointing, possibly non-linear), the input WCS is + sliced at ``n_samples`` positions along each of those axes and the resulting + footprints are combined. + + Parameters + ---------- + array_shape : tuple + The shape of the input array. + wcs_in : `~astropy.wcs.wcsapi.BaseHighLevelWCS` + The WCS of the input array. + wcs_out : `~astropy.wcs.wcsapi.BaseHighLevelWCS` + The WCS of the output array. + n_samples : int, optional + The number of samples to take along each edge. + + Returns + ------- + list of `~numpy.ndarray` + The output pixel coordinates of the sampled edges, in array (numpy) + dimension order. + """ + n_extra_in = wcs_in.low_level_wcs.pixel_n_dim - wcs_out.low_level_wcs.pixel_n_dim + + if n_extra_in <= 0: + edges = sample_array_edges( + array_shape[-wcs_in.low_level_wcs.pixel_n_dim :], n_samples=n_samples + )[::-1] + return pixel_to_pixel(wcs_in, wcs_out, *edges)[::-1] + + n_reproject = wcs_out.low_level_wcs.pixel_n_dim + edges = sample_array_edges(array_shape[-n_reproject:], n_samples=n_samples)[::-1] + leading_shape = array_shape[:n_extra_in] + # Sample positions along each non-reprojected axis (not just its end points) + # so that non-linear variation of the reprojected WCS along that axis is + # captured. Use integer pixel indices and de-duplicate for short axes. + leading_samples = [ + sorted({int(round(idx)) for idx in np.linspace(0, size - 1, n_samples)}) + for size in leading_shape + ] + edges_out_corners = [] + for corner in itertools.product(*leading_samples): + slices = list(corner) + [slice(None)] * n_reproject + wcs_in_reproject = HighLevelWCSWrapper( + SlicedLowLevelWCS(wcs_in.low_level_wcs, slices=slices) + ) + edges_out_corners.append(pixel_to_pixel(wcs_in_reproject, wcs_out, *edges)[::-1]) + return [ + np.concatenate([corner[idim] for corner in edges_out_corners]) + for idim in range(n_reproject) + ] diff --git a/reproject/mosaicking/tests/test_coadd.py b/reproject/mosaicking/tests/test_coadd.py index 480998634..66829c177 100644 --- a/reproject/mosaicking/tests/test_coadd.py +++ b/reproject/mosaicking/tests/test_coadd.py @@ -25,11 +25,16 @@ def reproject_function(request): return request.param -@pytest.fixture(params=[False, True]) +@pytest.fixture(params=[False, True, "zarr"]) def intermediate_memmap(request): return request.param +@pytest.fixture(params=[False, True]) +def intermediate_memmap_nozarr(request): + return request.param + + class TestReprojectAndCoAdd: def setup_method(self, method): self.wcs = WCS(naxis=2) @@ -92,6 +97,7 @@ def test_coadd_no_overlap(self, combine_function, reproject_function, intermedia shape_out=self.array.shape, combine_function=combine_function, reproject_function=reproject_function, + intermediate_memmap=intermediate_memmap, ) assert_allclose(array, self.array, atol=ATOL) @@ -109,10 +115,39 @@ def test_coadd_with_overlap(self, reproject_function, intermediate_memmap): shape_out=self.array.shape, combine_function="mean", reproject_function=reproject_function, + intermediate_memmap=intermediate_memmap, ) assert_allclose(array, self.array, atol=ATOL) + def test_coadd_zarr_interior_nan(self, reproject_function, monkeypatch): + # Regression test: with intermediate_memmap='zarr' the reprojected arrays + # are dask arrays, which do not support in-place assignment, so NaN values + # inside the footprint have to be masked out lazily. We force a small + # chunk size so the masking spans more than one chunk, which previously + # silently failed for dask arrays and left NaNs in the output. + monkeypatch.setattr("reproject.mosaicking._coadd.DEFAULT_MAX_CHUNK_SIZE", 100) + + # Two identical full-frame tiles so every pixel is covered twice, with + # NaNs in one tile that must be ignored in favor of the other. + input_data = [ + (self.array.copy(), self.wcs.deepcopy()), + (self.array.copy(), self.wcs.deepcopy()), + ] + input_data[0][0][:20, :20] = np.nan + + array, footprint = reproject_and_coadd( + input_data, + self.wcs, + shape_out=self.array.shape, + combine_function="mean", + reproject_function=reproject_function, + intermediate_memmap="zarr", + ) + + assert not np.any(np.isnan(array)) + assert_allclose(array, self.array, atol=ATOL) + def test_coadd_with_outputs(self, tmp_path, reproject_function, intermediate_memmap): # Test the options to specify output array/footprint @@ -133,6 +168,7 @@ def test_coadd_with_outputs(self, tmp_path, reproject_function, intermediate_mem reproject_function=reproject_function, output_array=output_array, output_footprint=output_footprint, + intermediate_memmap=intermediate_memmap, ) assert_allclose(output_array, self.array, atol=ATOL) @@ -179,7 +215,7 @@ def test_coadd_with_overlap_first_last(self, reproject_function, combine_functio assert_allclose(output_values, (i + 7) % 20) array[view] = np.nan - def test_coadd_background_matching(self, reproject_function, intermediate_memmap): + def test_coadd_background_matching(self, reproject_function, intermediate_memmap_nozarr): # Test out the background matching input_data = self._get_tiles(self._overlapping_views) @@ -195,6 +231,7 @@ def test_coadd_background_matching(self, reproject_function, intermediate_memmap shape_out=self.array.shape, combine_function="mean", reproject_function=reproject_function, + intermediate_memmap=intermediate_memmap_nozarr, ) assert not np.allclose(array, self.array, atol=ATOL) @@ -208,6 +245,7 @@ def test_coadd_background_matching(self, reproject_function, intermediate_memmap combine_function="mean", reproject_function=reproject_function, match_background=True, + intermediate_memmap=intermediate_memmap_nozarr, ) # The absolute values of the two arrays will be offset since any @@ -215,7 +253,9 @@ def test_coadd_background_matching(self, reproject_function, intermediate_memmap assert_allclose(array - np.mean(array), self.array - np.mean(self.array), atol=ATOL) - def test_coadd_background_matching_one_array(self, reproject_function, intermediate_memmap): + def test_coadd_background_matching_one_array( + self, reproject_function, intermediate_memmap_nozarr + ): # Test that background matching doesn't affect the output when there's # only one input image. @@ -228,6 +268,7 @@ def test_coadd_background_matching_one_array(self, reproject_function, intermedi combine_function="mean", reproject_function=reproject_function, match_background=True, + intermediate_memmap=intermediate_memmap_nozarr, ) array, footprint = reproject_and_coadd( @@ -237,6 +278,7 @@ def test_coadd_background_matching_one_array(self, reproject_function, intermedi combine_function="mean", reproject_function=reproject_function, match_background=False, + intermediate_memmap=intermediate_memmap_nozarr, ) np.testing.assert_allclose(array, array_matched) np.testing.assert_allclose(footprint, footprint_matched) @@ -306,7 +348,9 @@ def test_background_matching_consistent_tiles(self, reproject_function, combine_ np.testing.assert_allclose(footprint_match, footprint_nomatch, atol=ATOL) np.testing.assert_allclose(array_match, array_nomatch, atol=ATOL) - def test_coadd_background_matching_with_nan(self, reproject_function, intermediate_memmap): + def test_coadd_background_matching_with_nan( + self, reproject_function, intermediate_memmap_nozarr + ): # Test out the background matching when NaN values are present. We do # this by using three arrays with the same footprint but with different # parts masked. @@ -329,6 +373,7 @@ def test_coadd_background_matching_with_nan(self, reproject_function, intermedia combine_function="mean", reproject_function=reproject_function, match_background=True, + intermediate_memmap=intermediate_memmap_nozarr, ) # The absolute values of the two arrays will be offset since any @@ -374,6 +419,7 @@ def test_coadd_with_weights(self, tmpdir, reproject_function, mode, intermediate input_weights=input_weights, reproject_function=reproject_function, match_background=False, + intermediate_memmap=intermediate_memmap, ) expected = self.array + (2 * (weight1 / weight1.max()) - 1) @@ -408,6 +454,7 @@ def test_coadd_with_weights_with_wcs(self, tmpdir, reproject_function, intermedi input_weights=input_weights, reproject_function=reproject_function, match_background=False, + intermediate_memmap=intermediate_memmap, ) weights1_reprojected = reproject_function( @@ -456,6 +503,7 @@ def test_coadd_with_broadcasting( shape_out=(3,) + self.array.shape, combine_function="mean", reproject_function=reproject_function, + intermediate_memmap=intermediate_memmap, **kwargs, ) @@ -569,3 +617,50 @@ def test_coadd_non_reprojected_dims(combine_function): assert_allclose(array, reference, atol=ATOL) assert_allclose(footprint, reference_footprint, atol=ATOL) + + +@pytest.mark.parametrize("combine_function", ["mean", "sum"]) +def test_coadd_non_reprojected_dims_celestial_output(combine_function): + # Co-add a drifting cube into a celestial-only (2D) output WCS, treating the + # leading axis as non-reprojected. Here the input WCS has more pixel + # dimensions than the output WCS, so computing each tile's footprint requires + # relating only the reprojected (celestial) sub-space of the input WCS to the + # output. The result should match co-adding each time slice independently + # with the input WCS sliced at that time. + n_time = 5 + shape_out = (n_time, 30, 30) + wcs_in = _drifting_cube_wcs(drift=0.6) + wcs_out = _drifting_cube_wcs(drift=0.0).celestial + + rng = np.random.default_rng(12345) + data1 = rng.random((n_time, 30, 30)) + data2 = rng.random((n_time, 30, 30)) + + array, footprint = reproject_and_coadd( + [(data1, wcs_in), (data2, wcs_in)], + wcs_out, + shape_out=shape_out, + reproject_function=reproject_interp, + combine_function=combine_function, + non_reprojected_dims=(0,), + parallel=True, + block_size=(1,) + shape_out[1:], + roundtrip_coords=False, + ) + + reference = np.zeros(shape_out) + reference_footprint = np.zeros(shape_out) + for itime in range(n_time): + ref, ref_fp = reproject_and_coadd( + [(data1[itime], wcs_in[itime]), (data2[itime], wcs_in[itime])], + wcs_out, + shape_out=shape_out[1:], + reproject_function=reproject_interp, + combine_function=combine_function, + roundtrip_coords=False, + ) + reference[itime] = ref + reference_footprint[itime] = ref_fp + + assert_allclose(array, reference, atol=ATOL) + assert_allclose(footprint, reference_footprint, atol=ATOL) diff --git a/reproject/spherical_intersect/_high_level.py b/reproject/spherical_intersect/_high_level.py index 33dafd9ea..d676fa2b9 100644 --- a/reproject/spherical_intersect/_high_level.py +++ b/reproject/spherical_intersect/_high_level.py @@ -20,6 +20,7 @@ def reproject_exact( parallel=False, return_type=None, dask_method=None, + zarr_path=None, ): """ Reproject data to a new projection using flux-conserving spherical @@ -86,8 +87,13 @@ def reproject_exact( by ``block_size`` (if the block size is not set, it will be determined automatically). To use the currently active dask scheduler (e.g. dask.distributed), set this to ``'current-scheduler'``. - return_type : {'numpy', 'dask'}, optional - Whether to return numpy or dask arrays + return_type : {'numpy', 'dask', 'zarr'}, optional + Whether to return numpy or dask arrays, or to write the output to a zarr + array on disk. If ``'zarr'``, ``zarr_path`` must also be given; the + output is then computed in blocks (using dask, on the synchronous + scheduler when ``parallel`` is `False`), ``block_size`` defaults to + ``'auto'`` when not specified, and dask arrays backed by the zarr array + are returned. dask_method : {'memmap', 'none'}, optional Method to use when input array is a dask array. The methods are: * ``'memmap'``: write out the entire input dask array to a temporary @@ -100,6 +106,9 @@ def reproject_exact( fits into memory (as this will then be faster than ``'memmap'``), and when the data contains more dimensions than the input WCS and the block_size is chosen to iterate over the extra dimensions. + zarr_path : str, optional + Path to use for the output zarr array when ``return_type='zarr'``. This + must be a path that does not already exist. Returns ------- @@ -129,6 +138,7 @@ def reproject_exact( return_footprint=return_footprint, output_footprint=output_footprint, return_type=return_type, + zarr_path=zarr_path, ) else: raise NotImplementedError( diff --git a/reproject/tests/test_non_reprojected_dims.py b/reproject/tests/test_non_reprojected_dims.py index 81d17f15d..a88dc77f5 100644 --- a/reproject/tests/test_non_reprojected_dims.py +++ b/reproject/tests/test_non_reprojected_dims.py @@ -57,6 +57,40 @@ def test_non_reprojected_dims(reproject_function): assert_allclose(array_out, reference, equal_nan=True) +@pytest.mark.parametrize("block_size", [(1, 7, 7), (7, 7), (1, 12, 20)]) +def test_non_reprojected_dims_subtiled(reproject_function, block_size): + # A block_size smaller than the output along the reprojected (celestial) + # dimensions should reproject each plane in sub-tiles and give exactly the + # same result as reprojecting each full plane in one go. This is what keeps + # the coordinate-transform memory bounded for large planes. + + data = np.arange(4 * 20 * 20, dtype=float).reshape((4, 20, 20)) + wcs_in = _spectral_cube_wcs(0.0, 1e9) + wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6) + shape_out = (4, 20, 20) + + array_full, footprint_full = reproject_function( + (data, wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=(20, 20), + ) + + array_sub, footprint_sub = reproject_function( + (data, wcs_in), + wcs_out, + shape_out=shape_out, + non_reprojected_dims=(0,), + parallel=True, + block_size=block_size, + ) + + assert_allclose(array_sub, array_full, equal_nan=True) + assert_allclose(footprint_sub, footprint_full, equal_nan=True) + + def test_non_reprojected_dims_invalid_order(reproject_function): data = np.ones((4, 20, 20)) wcs = _spectral_cube_wcs(0.0, 1e9) @@ -82,12 +116,13 @@ def test_non_reprojected_dims_inconsistent_with_wcs(reproject_function): @pytest.mark.parametrize( - "kwargs", [{}, {"parallel": True}, {"parallel": True, "block_size": (4, 10, 10)}] + "kwargs", [{}, {"parallel": True}, {"parallel": True, "block_size": "auto"}] ) def test_non_reprojected_dims_unsupported_mode(reproject_function, kwargs): # non_reprojected_dims with a full-dimensional WCS is only supported when - # parallelizing over the non-reprojected dimensions; other modes should - # raise rather than silently reprojecting the non-reprojected axis. + # parallelizing over the non-reprojected dimensions, which requires an + # explicit block_size; modes without one (including block_size='auto') + # should raise rather than silently reprojecting the non-reprojected axis. data = np.ones((4, 20, 20)) wcs_in = _spectral_cube_wcs(0.0, 1e9) wcs_out = _spectral_cube_wcs(0.02, 1e9 + 2e6)