Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 100 additions & 25 deletions reproject/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astropy.wcs.wcsapi.high_level_wcs_wrapper import HighLevelWCSWrapper
from dask import delayed

from ._array_utils import ArrayWrapper
from ._array_utils import ArrayWrapper, iterate_chunks
from .utils import _dask_to_numpy_memmap

__all__ = ["_reproject_dispatcher"]
Expand Down Expand Up @@ -69,6 +69,7 @@ def _reproject_dispatcher(
reproject_func_kwargs=None,
return_type=None,
dask_method=None,
zarr_path=None,
):
"""
Main function that handles either calling the core algorithms directly or
Expand Down Expand Up @@ -101,9 +102,10 @@ def _reproject_dispatcher(
given as a tuple of sequential integers starting from zero (e.g.
``(0,)`` or ``(0, 1)``). If `None` (the default), any leading dimensions
for which the WCS has fewer dimensions than the data are treated this
way. Reprojecting fewer dimensions than the WCS currently requires a
``block_size`` that matches the output shape along the reprojected
dimensions.
way. Reprojecting fewer dimensions than the WCS currently requires an
explicit ``block_size``; its entries along the reprojected dimensions
may either match the output shape or be smaller, in which case each
plane is reprojected in sub-tiles of that size.
array_out : `~numpy.ndarray`, optional
An array in which to store the reprojected data. This can be any numpy
array including a memory map, which may be helpful when dealing with
Expand All @@ -123,8 +125,13 @@ def _reproject_dispatcher(
dask.distributed), set this to ``'current-scheduler'``.
reproject_func_kwargs : dict, optional
Keyword arguments to pass through to ``reproject_func``
return_type : {'numpy', 'dask' }, optional
Whether to return numpy or dask arrays.
return_type : {'numpy', 'dask', 'zarr'}, optional
Whether to return numpy or dask arrays or whether to dump the result to
a zarr array on disk. If 'zarr', then the ``zarr_path`` keyword has to
also be specified, and the function will return dask arrays constructed
from the zarr array. In this case the reprojection is always carried out
in blocks (using dask, on the synchronous scheduler when ``parallel`` is
`False`), and ``block_size`` defaults to ``'auto'`` when not specified.
dask_method : {'memmap', 'none'}, optional
Method to use when input array is a dask array. The methods are:
* ``'memmap'``: write out the entire input dask array to a temporary
Expand All @@ -137,14 +144,20 @@ def _reproject_dispatcher(
fits into memory (as this will then be faster than ``'memmap'``),
and when the data contains more dimensions than the input WCS and
the block_size is chosen to iterate over the extra dimensions.
zarr_path : str, optional
If return_type is 'zarr', this specifies the path to use for the zarr
array. This should be a non-existent path.
"""

logger = logging.getLogger(__name__)

if return_type is None:
return_type = "numpy"
elif return_type not in ("numpy", "dask"):
raise ValueError("return_type should be set to 'numpy' or 'dask'")
elif return_type not in ("numpy", "dask", "zarr"):
raise ValueError("return_type should be set to 'numpy', 'dask', or 'zarr'")

if return_type == "zarr" and block_size is None:
block_size = "auto"

if dask_method is None:
dask_method = "memmap"
Expand All @@ -154,6 +167,12 @@ def _reproject_dispatcher(
if reproject_func_kwargs is None:
reproject_func_kwargs = {}

if return_type == "zarr":
if zarr_path is None:
raise ValueError("zarr_path needs to be set if return_type is 'zarr'")
elif os.path.exists(zarr_path):
raise ValueError(f"Path {zarr_path} already exists")

# For now, we are quite restrictive in what non_reprojected_dims can
# be, but it is designed so that if we wanted we could support more use
# cases in future. For now, it has to be a tuple where each element is
Expand Down Expand Up @@ -203,7 +222,7 @@ def _reproject_dispatcher(

with tempfile.TemporaryDirectory() as local_tmp_dir:
if array_out is None:
if return_type != "dask":
if return_type == "numpy":
array_out = np.zeros(shape_out, dtype=float)
elif array_out.shape != tuple(shape_out):
raise ValueError(
Expand All @@ -226,9 +245,9 @@ def _reproject_dispatcher(
# If a dask array was passed as input, we first convert this to a
# Numpy memory mapped array

if return_type == "dask":
if return_type in ("dask", "zarr"):
raise ValueError(
"Output cannot be returned as dask arrays "
"Output cannot be returned as dask or zarr arrays "
"when parallel=False and no block size has "
"been specified"
)
Expand Down Expand Up @@ -315,13 +334,32 @@ def _reproject_dispatcher(
# don't make any assumptions for now and assume a single chunk in the
# missing dimensions.
broadcasted_parallelization = False
# When sub-tiling the reprojected dimensions within each broadcasted block
# (see below), this holds the shape of each sub-tile; otherwise it is None
# and each block is reprojected in one go.
sub_tile_shape = None
if broadcasting and block_size is not None and block_size != "auto":
if block_size[-n_dim_reproject:] == shape_out[-n_dim_reproject:]:
# TODO: maybe error if block_size was given in full and is wrong
broadcasted_parallelization = True
block_size = (1,) * (len(shape_out) - n_dim_reproject) + block_size[
-n_dim_reproject:
]
elif wcs_slicing_required:
# A block smaller than the output along the reprojected dimensions
# is only meaningful when the WCS has to be sliced per broadcasted
# slice (i.e. non_reprojected_dims). The whole input slice has to
# be reprojected as one (it cannot be chunked along the reprojected
# dimensions, since any output pixel can map anywhere in the input),
# so we still parallelize one full reprojected plane per block, but
# additionally sub-tile the reprojection within each plane to bound
# the coordinate-transform memory, which would otherwise scale with
# the full plane size.
broadcasted_parallelization = True
sub_tile_shape = block_size[-n_dim_reproject:]
block_size = (1,) * (len(shape_out) - n_dim_reproject) + shape_out[
-n_dim_reproject:
]
elif block_size[:-n_dim_reproject] != shape_out[:-n_dim_reproject]:
raise ValueError(
"block shape should either match output data shape along "
Expand All @@ -341,9 +379,10 @@ def _reproject_dispatcher(
raise NotImplementedError(
"Reprojecting fewer dimensions than the input or output WCS "
"(for example using non_reprojected_dims) currently requires "
"passing a block_size whose entries along the reprojected "
"dimensions match the output shape (optionally with parallel=True "
"to compute the blocks concurrently)"
"passing an explicit block_size whose entries along the reprojected "
"dimensions either match the output shape or are smaller (in which "
"case each plane is reprojected in sub-tiles of that size), "
"optionally with parallel=True to compute the blocks concurrently"
)

if output_footprint is None and return_footprint:
Expand Down Expand Up @@ -421,14 +460,40 @@ def reproject_single_block(a, array_or_path, block_info=None):
if array_or_path is None:
raise RuntimeError("array_or_path is not set")

array, footprint = reproject_func(
array_in,
wcs_in_sub,
wcs_out_sub,
shape_out=shape_out,
array_out=np.zeros(shape_out),
**reproject_func_kwargs,
)
if sub_tile_shape is None:
array, footprint = reproject_func(
array_in,
wcs_in_sub,
wcs_out_sub,
shape_out=shape_out,
array_out=np.zeros(shape_out),
**reproject_func_kwargs,
)
else:
# Reproject the plane in sub-tiles along the reprojected
# dimensions so that the coordinate transform (which is computed
# over the whole output sub-tile at once) does not have to be
# evaluated for the full plane in one go. The input slice is left
# whole since any output pixel can map anywhere within it.
array = np.zeros(shape_out)
footprint = np.zeros(shape_out)
n_broadcast = len(shape_out) - n_dim_reproject
for sub_tile in iterate_chunks(
shape_out[n_broadcast:], max_chunk_size=int(np.prod(sub_tile_shape))
):
sub_wcs_out = HighLevelWCSWrapper(
SlicedLowLevelWCS(low_level_wcs_out, slices=sub_tile)
)
full_slices = (slice(None),) * n_broadcast + sub_tile
sub_shape = shape_out[:n_broadcast] + tuple(s.stop - s.start for s in sub_tile)
array[full_slices], footprint[full_slices] = reproject_func(
array_in,
wcs_in_sub,
sub_wcs_out,
shape_out=sub_shape,
array_out=np.zeros(sub_shape),
**reproject_func_kwargs,
)

return np.array([array, footprint])

Expand Down Expand Up @@ -532,7 +597,7 @@ def reproject_single_block(a, array_or_path, block_info=None):

# We now convert the dask arrays back to Numpy arrays

if parallel:
if parallel or return_type == "zarr":
# As discussed in https://github.com/dask/dask/issues/9556, da.store
# will not work well in parallel mode when the destination is a
# Numpy array. Instead, in this case we save the dask array to a zarr
Expand All @@ -541,11 +606,15 @@ def reproject_single_block(a, array_or_path, block_info=None):
# 'synchronous' scheduler since that is I/O limited so does not need
# to be done in parallel.

zarr_path = os.path.join(local_tmp_dir, f"{uuid.uuid4()}.zarr")
if return_type != "zarr":
zarr_path = os.path.join(local_tmp_dir, f"{uuid.uuid4()}.zarr")

logger.info(f"Computing output array directly to zarr array at {zarr_path}")

if parallel == "current-scheduler":
if not parallel:
with dask.config.set(scheduler="synchronous"):
result.to_zarr(zarr_path)
elif parallel == "current-scheduler":
# Just use whatever is the current active scheduler, which can
# be used for e.g. dask.distributed
result.to_zarr(zarr_path)
Expand All @@ -565,6 +634,12 @@ def reproject_single_block(a, array_or_path, block_info=None):

result = da.from_zarr(zarr_path)

if return_type == "zarr":
if return_footprint:
return result[0], result[1]
else:
return result[0]

logger.info("Copying output zarr array into output Numpy arrays")

if return_footprint:
Expand Down
22 changes: 17 additions & 5 deletions reproject/adaptive/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def reproject_adaptive(
parallel=False,
return_type=None,
dask_method=None,
zarr_path=None,
):
"""
Reproject a 2D array from one WCS to another using the DeForest (2004)
Expand Down Expand Up @@ -213,18 +214,25 @@ def reproject_adaptive(
even when the input and output WCS have the same number of dimensions as
the data. The dimensions must be the leading ones, given as a tuple of
sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``).
This currently requires passing a ``block_size`` whose entries along
the reprojected dimensions match ``shape_out`` (optionally combined
with ``parallel`` to compute the blocks concurrently).
This currently requires passing an explicit ``block_size``; its entries
along the reprojected dimensions may either match ``shape_out`` or be
smaller, in which case each plane is reprojected in sub-tiles of that
size to keep the coordinate-transform memory bounded (optionally
combined with ``parallel`` to compute the blocks concurrently).
parallel : bool or int or str, optional
If `True`, the reprojection is carried out in parallel, and if a
positive integer, this specifies the number of threads to use.
The reprojection will be parallelized over output array blocks specified
by ``block_size`` (if the block size is not set, it will be determined
automatically). To use the currently active dask scheduler (e.g.
dask.distributed), set this to ``'current-scheduler'``.
return_type : {'numpy', 'dask'}, optional
Whether to return numpy or dask arrays.
return_type : {'numpy', 'dask', 'zarr'}, optional
Whether to return numpy or dask arrays, or to write the output to a zarr
array on disk. If ``'zarr'``, ``zarr_path`` must also be given; the
output is then computed in blocks (using dask, on the synchronous
scheduler when ``parallel`` is `False`), ``block_size`` defaults to
``'auto'`` when not specified, and dask arrays backed by the zarr array
are returned.
dask_method : {'memmap', 'none'}, optional
Method to use when input array is a dask array. The methods are:
* ``'memmap'``: write out the entire input dask array to a temporary
Expand All @@ -237,6 +245,9 @@ def reproject_adaptive(
fits into memory (as this will then be faster than ``'memmap'``),
and when the data contains more dimensions than the input WCS and
the block_size is chosen to iterate over the extra dimensions.
zarr_path : str, optional
Path to use for the output zarr array when ``return_type='zarr'``. This
must be a path that does not already exist.

Returns
-------
Expand Down Expand Up @@ -284,4 +295,5 @@ def reproject_adaptive(
bad_fill_value=bad_fill_value,
),
return_type=return_type,
zarr_path=zarr_path,
)
22 changes: 17 additions & 5 deletions reproject/interpolation/_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def reproject_interp(
parallel=False,
return_type=None,
dask_method=None,
zarr_path=None,
):
"""
Reproject data to a new projection using interpolation (this is typically
Expand Down Expand Up @@ -109,25 +110,35 @@ def reproject_interp(
even when the input and output WCS have the same number of dimensions as
the data. The dimensions must be the leading ones, given as a tuple of
sequential integers starting from zero (e.g. ``(0,)`` or ``(0, 1)``).
This currently requires passing a ``block_size`` whose entries along
the reprojected dimensions match ``shape_out`` (optionally combined
with ``parallel`` to compute the blocks concurrently).
This currently requires passing an explicit ``block_size``; its entries
along the reprojected dimensions may either match ``shape_out`` or be
smaller, in which case each plane is reprojected in sub-tiles of that
size to keep the coordinate-transform memory bounded (optionally
combined with ``parallel`` to compute the blocks concurrently).
parallel : bool or int or str, optional
If `True`, the reprojection is carried out in parallel, and if a
positive integer, this specifies the number of threads to use.
The reprojection will be parallelized over output array blocks specified
by ``block_size`` (if the block size is not set, it will be determined
automatically). To use the currently active dask scheduler (e.g.
dask.distributed), set this to ``'current-scheduler'``.
return_type : {'numpy', 'dask'}, optional
Whether to return numpy or dask arrays.
return_type : {'numpy', 'dask', 'zarr'}, optional
Whether to return numpy or dask arrays, or to write the output to a zarr
array on disk. If ``'zarr'``, ``zarr_path`` must also be given; the
output is then computed in blocks (using dask, on the synchronous
scheduler when ``parallel`` is `False`), ``block_size`` defaults to
``'auto'`` when not specified, and dask arrays backed by the zarr array
are returned.
dask_method : {'memmap', 'none'}, optional
Method to use when input array is a dask array. The methods are:
* ``'memmap'``: write out the entire input dask array to a temporary
memory-mapped array. This requires enough disk space to store
the entire input array.
* ``'none'`` (default): use native dask interpolation, which avoids
having to write the array to disk.
zarr_path : str, optional
Path to use for the output zarr array when ``return_type='zarr'``. This
must be a path that does not already exist.

Returns
-------
Expand Down Expand Up @@ -170,4 +181,5 @@ def reproject_interp(
),
return_type=return_type,
dask_method=dask_method,
zarr_path=zarr_path,
)
2 changes: 1 addition & 1 deletion reproject/interpolation/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def test_auto_block_size(dask_method):
wcs_out = WCS(naxis=2)

# When block size and parallel aren't specified, can't return as dask arrays
with pytest.raises(ValueError, match="Output cannot be returned as dask arrays"):
with pytest.raises(ValueError, match="Output cannot be returned as dask or zarr arrays"):
reproject_interp(
(array_in, wcs_in),
wcs_out,
Expand Down
Loading
Loading