import concurrent.futures
import itertools
import logging
import multiprocessing as multi
import os
import random
import shutil
import string
import threading
import typing as T
import warnings
from pathlib import Path
import dask
import numpy as np
import pyproj
import rasterio as rio
import xarray as xr
from affine import Affine
from dask import is_dask_collection
from dask.distributed import Client, progress
from osgeo import gdal
from rasterio import shutil as rio_shutil
from rasterio.drivers import driver_from_extension
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from rasterio.windows import Window
from threadpoolctl import threadpool_limits
from tqdm import tqdm
from tqdm.dask import TqdmCallback
try:
import zarr
from ..backends.zarr_ import to_zarr
ZARR_INSTALLED = True
except ImportError:
ZARR_INSTALLED = False
from ..backends.rasterio_ import RasterioStore, to_gtiff
from ..config import config
from ..handler import add_handler
from .windows import get_window_offsets
logger = logging.getLogger(__name__)
logger = add_handler(logger)
[docs]def get_norm_indices(n_bands, window_slice, indexes_multi):
# Prepend the band position index to the window slice
if n_bands == 1:
window_slice = tuple([slice(0, 1)] + list(window_slice))
indexes = 1
else:
window_slice = tuple([slice(0, n_bands)] + list(window_slice))
indexes = indexes_multi
return window_slice, indexes
def _window_worker(w):
"""Helper to return window slice."""
return slice(w.row_off, w.row_off + w.height), slice(
w.col_off, w.col_off + w.width
)
def _window_worker_time(w, n_bands, tidx, n_time):
"""Helper to return window slice."""
window_slice = (
slice(w.row_off, w.row_off + w.height),
slice(w.col_off, w.col_off + w.width),
)
# Prepend the band position index to the window slice
if n_bands == 1:
window_slice = tuple(
[slice(tidx, n_time)] + [slice(0, 1)] + list(window_slice)
)
else:
window_slice = tuple(
[slice(tidx, n_time)] + [slice(0, n_bands)] + list(window_slice)
)
return window_slice
def _block_read_func(fn_, g_, t_):
"""Function for block writing with ``concurrent.futures``"""
if t_ == "zarr":
group_node = zarr.open(fn_, mode="r")[g_]
w_ = Window(
row_off=group_node.attrs["row_off"],
col_off=group_node.attrs["col_off"],
height=group_node.attrs["height"],
width=group_node.attrs["width"],
)
out_data_ = np.squeeze(group_node["data"][:])
else:
w_ = Window(
row_off=int(
os.path.splitext(os.path.basename(fn_))[0].split("_")[-4][1:]
),
col_off=int(
os.path.splitext(os.path.basename(fn_))[0].split("_")[-3][1:]
),
height=int(
os.path.splitext(os.path.basename(fn_))[0].split("_")[-2][1:]
),
width=int(
os.path.splitext(os.path.basename(fn_))[0].split("_")[-1][1:]
),
)
out_data_ = np.squeeze(rio.open(fn_).read(window=w_))
out_indexes_ = (
1
if len(out_data_.shape) == 2
else list(range(1, out_data_.shape[0] + 1))
)
return w_, out_indexes_, out_data_
def _check_offsets(
block, out_data_, window_, oleft, otop, ocols, orows, left_, top_
):
# Check if the data were read at larger
# extents than the write bounds.
obottom = otop - (orows * abs(block.gw.celly))
oright = oleft + (ocols * abs(block.gw.cellx))
bottom_ = top_ - (window_.height * abs(block.gw.celly))
right_ = left_ - (window_.width * abs(block.gw.cellx))
left_diff = 0
right_diff = 0
top_diff = 0
bottom_diff = 0
if left_ < oleft:
left_diff = int(abs(oleft - left_) / abs(block.gw.cellx))
right_diff = out_data_.shape[-1]
elif right_ > oright:
left_diff = 0
right_diff = int(abs(oright - right_) / abs(block.gw.cellx))
if bottom_ < obottom:
bottom_diff = int(abs(obottom - bottom_) / abs(block.gw.celly))
top_diff = 0
elif top_ > otop:
bottom_diff = out_data_.shape[-2]
top_diff = int(abs(otop - top_) / abs(block.gw.celly))
if (
(left_diff != 0)
or (top_diff != 0)
or (bottom_diff != 0)
or (right_diff != 0)
):
dshape = out_data_.shape
if len(dshape) == 2:
out_data_ = out_data_[top_diff:bottom_diff, left_diff:right_diff]
elif len(dshape) == 3:
out_data_ = out_data_[
:, top_diff:bottom_diff, left_diff:right_diff
]
elif len(dshape) == 4:
out_data_ = out_data_[
:, :, top_diff:bottom_diff, left_diff:right_diff
]
window_ = Window(
col_off=window_.col_off,
row_off=window_.row_off,
width=out_data_.shape[-1],
height=out_data_.shape[-2],
)
return out_data_, window_
def _compute_block(
block, wid, window_, padded_window_, n_workers, num_workers
):
"""Computes a DataArray window block of data.
Args:
block (DataArray): The ``xarray.DataArray`` to compute.
wid (int): The window id.
window_ (namedtuple): The window ``rasterio.windows.Window`` object.
padded_window_ (namedtuple): A padded window ``rasterio.windows.Window`` object.
n_workers (int): The number of parallel workers for chunks.
num_workers (int): The number of parallel workers for ``dask.compute``.
oleft (float): The output image left coordinate.
otop (float): The output image top coordinate.
ocols (int): The output image columns.
orows (int): The output image rows.
Returns:
``numpy.ndarray`` | ``rasterio.windows.Window`` | ``int`` | ``list``
"""
out_data_ = None
if "apply" in block.attrs:
attrs = block.attrs.copy()
# Update the block transform
attrs["transform"] = Affine(*block.gw.transform)
attrs["window_id"] = wid
block = block.assign_attrs(**attrs)
if ("apply" in block.attrs) and hasattr(
block.attrs["apply"], "wombat_func_"
):
if padded_window_:
logger.warning(" Padding is not supported with lazy functions.")
if block.attrs["apply"].wombat_func_:
# Add the data to the keyword arguments
block.attrs["apply_kwargs"]["data"] = block
out_data_ = block.attrs["apply"](**block.attrs["apply_kwargs"])
if n_workers == 1:
out_data_ = out_data_.data.compute(
scheduler="threads", num_workers=num_workers
)
else:
with threading.Lock():
out_data_ = out_data_.data.compute(
scheduler="threads", num_workers=num_workers
)
else:
logger.exception(" The lazy wombat function is turned off.")
else:
# Get the data as a NumPy array
if n_workers == 1:
out_data_ = block.data.compute(
scheduler="threads", num_workers=num_workers
)
else:
with threading.Lock():
out_data_ = block.data.compute(
scheduler="threads", num_workers=num_workers
)
if ("apply" in block.attrs) and not hasattr(
block.attrs["apply"], "wombat_func_"
):
if padded_window_:
# Add extra padding on the image borders
rspad = (
padded_window_.height - window_.height
if window_.row_off == 0
else 0
)
cspad = (
padded_window_.width - window_.width
if window_.col_off == 0
else 0
)
repad = (
padded_window_.height - window_.height
if (window_.row_off != 0)
and (window_.height < block.gw.row_chunks)
else 0
)
cepad = (
padded_window_.width - window_.width
if (window_.col_off != 0)
and (window_.width < block.gw.col_chunks)
else 0
)
dshape = out_data_.shape
if (rspad > 0) or (cspad > 0) or (repad > 0) or (cepad > 0):
if len(dshape) == 2:
out_data_ = np.pad(
out_data_,
((rspad, repad), (cspad, cepad)),
mode="reflect",
)
elif len(dshape) == 3:
out_data_ = np.pad(
out_data_,
((0, 0), (rspad, repad), (cspad, cepad)),
mode="reflect",
)
elif len(dshape) == 4:
out_data_ = np.pad(
out_data_,
((0, 0), (0, 0), (rspad, repad), (cspad, cepad)),
mode="reflect",
)
# Apply the user function
if ("apply_args" in block.attrs) and (
"apply_kwargs" in block.attrs
):
out_data_ = block.attrs["apply"](
out_data_,
*block.attrs["apply_args"],
**block.attrs["apply_kwargs"],
)
elif ("apply_args" in block.attrs) and (
"apply_kwargs" not in block.attrs
):
out_data_ = block.attrs["apply"](
out_data_, *block.attrs["apply_args"]
)
elif ("apply_args" not in block.attrs) and (
"apply_kwargs" in block.attrs
):
out_data_ = block.attrs["apply"](
out_data_, **block.attrs["apply_kwargs"]
)
else:
out_data_ = block.attrs["apply"](out_data_)
if padded_window_:
# Remove the extra padding
dshape = out_data_.shape
if len(dshape) == 2:
out_data_ = out_data_[
rspad : rspad + padded_window_.height,
cspad : cspad + padded_window_.width,
]
elif len(dshape) == 3:
out_data_ = out_data_[
:,
rspad : rspad + padded_window_.height,
cspad : cspad + padded_window_.width,
]
elif len(dshape) == 4:
out_data_ = out_data_[
:,
:,
rspad : rspad + padded_window_.height,
cspad : cspad + padded_window_.width,
]
dshape = out_data_.shape
# Remove the padding
# Get the non-padded array slice
row_diff = abs(window_.row_off - padded_window_.row_off)
col_diff = abs(window_.col_off - padded_window_.col_off)
if len(dshape) == 2:
out_data_ = out_data_[
row_diff : row_diff + window_.height,
col_diff : col_diff + window_.width,
]
elif len(dshape) == 3:
out_data_ = out_data_[
:,
row_diff : row_diff + window_.height,
col_diff : col_diff + window_.width,
]
elif len(dshape) == 4:
out_data_ = out_data_[
:,
:,
row_diff : row_diff + window_.height,
col_diff : col_diff + window_.width,
]
else:
if padded_window_:
logger.warning(
" Padding is only supported with user functions."
)
if not isinstance(out_data_, np.ndarray):
logger.exception(
" The data were not computed properly for block {:,d}".format(wid)
)
dshape = out_data_.shape
if len(dshape) > 2:
out_data_ = out_data_.squeeze()
if len(dshape) == 2:
indexes_ = 1
else:
indexes_ = 1 if dshape[0] == 1 else list(range(1, dshape[0] + 1))
return out_data_, indexes_, window_
@threadpool_limits.wrap(limits=1, user_api="blas")
def _write_xarray(*args):
"""Writes a DataArray to file.
Args:
args (iterable): A tuple from the window generator.
Reference:
https://github.com/dask/dask/issues/3600
Returns:
``str`` | ``None``
"""
zarr_file = None
(
block,
filename,
wid,
block_window,
padded_window,
n_workers,
n_threads,
separate,
chunks,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
) = list(itertools.chain(*args))
output, out_indexes, block_window = _compute_block(
block, wid, block_window, padded_window, n_workers, n_threads
)
if separate and (out_block_type.lower() == "zarr"):
zarr_file = to_zarr(filename, output, block_window, chunks, root=root)
else:
to_gtiff(
filename,
output,
block_window,
out_indexes,
block.gw.transform,
n_workers,
separate,
tags,
kwargs,
)
return zarr_file
[docs]def to_vrt(
data,
filename,
overwrite=False,
resampling=None,
nodata=None,
init_dest_nodata=True,
warp_mem_limit=128,
):
"""Writes a file to a VRT file.
Args:
data (DataArray): The ``xarray.DataArray`` to write.
filename (str): The output file name to write to.
overwrite (Optional[bool]): Whether to overwrite an existing VRT file.
resampling (Optional[object]): The resampling algorithm for ``rasterio.vrt.WarpedVRT``. Default is 'nearest'.
nodata (Optional[float or int]): The 'no data' value for ``rasterio.vrt.WarpedVRT``.
init_dest_nodata (Optional[bool]): Whether or not to initialize output to ``nodata`` for ``rasterio.vrt.WarpedVRT``.
warp_mem_limit (Optional[int]): The GDAL memory limit for ``rasterio.vrt.WarpedVRT``.
Returns:
``None``, writes to ``filename``
Examples:
>>> import geowombat as gw
>>> from rasterio.enums import Resampling
>>>
>>> # Transform a CRS and save to VRT
>>> with gw.config.update(ref_crs=102033):
>>> with gw.open('image.tif') as src:
>>> gw.to_vrt(
>>> src,
>>> 'output.vrt',
>>> resampling=Resampling.cubic,
>>> warp_mem_limit=256
>>> )
>>>
>>> # Load multiple files set to a common geographic extent
>>> bounds = (left, bottom, right, top)
>>> with gw.config.update(ref_bounds=bounds):
>>> with gw.open(
>>> ['image1.tif', 'image2.tif'], mosaic=True
>>> ) as src:
>>> gw.to_vrt(src, 'output.vrt')
"""
if Path(filename).is_file():
if overwrite:
Path(filename).unlink()
else:
logger.warning(f" The VRT file {filename} already exists.")
return
if not resampling:
resampling = Resampling.nearest
if isinstance(data.attrs["filename"], str) or isinstance(
data.attrs["filename"], Path
):
# Open the input file on disk
with rio.open(data.attrs["filename"]) as src:
with WarpedVRT(
src,
src_crs=src.crs, # the original CRS
crs=data.crs, # the transformed CRS
src_transform=src.gw.transform, # the original transform
transform=data.gw.transform, # the new transform
dtype=data.dtype,
resampling=resampling,
nodata=nodata,
init_dest_nodata=init_dest_nodata,
warp_mem_limit=warp_mem_limit,
) as vrt:
rio_shutil.copy(vrt, filename, driver="VRT")
else:
if not data.gw.filenames:
logger.exception(
" The data filenames attribute is empty. Use gw.open(..., persist_filenames=True)."
)
raise KeyError
separate = (
True
if data.gw.data_are_separate and data.gw.data_are_stacked
else False
)
vrt_options = gdal.BuildVRTOptions(
outputBounds=data.gw.bounds,
xRes=data.gw.cellx,
yRes=data.gw.celly,
separate=separate,
outputSRS=data.gw.crs_to_pyproj.to_wkt(),
)
ds = gdal.BuildVRT(
str(filename), data.gw.filenames, options=vrt_options
)
ds = None
[docs]def to_netcdf(
data: xr.DataArray,
filename: T.Union[str, Path],
overwrite: T.Optional[bool] = False,
compute: T.Optional[bool] = True,
*args,
**kwargs,
):
"""Writes an Xarray DataArray to a NetCDF file.
Args:
data (DataArray): The ``xarray.DataArray`` to write.
filename (str): The output file name to write to.
overwrite (Optional[bool]): Whether to overwrite an existing file. Default is ``False``.
compute (Optinoal[bool]): Whether to compute and write to ``filename``. Otherwise, return
the ``dask`` task graph. Default is ``True``.
args (DataArray): Additional ``DataArrays`` to stack.
kwargs (dict): Encoding arguments.
Return:
``None``, writes to ``filename``
Examples:
>>> import geowombat as gw
>>> import xarray as xr
>>>
>>> # Write a single DataArray to a .nc file
>>> with gw.config.update(sensor='l7'):
>>> with gw.open('LC08_L1TP_225078_20200219_20200225_01_T1.tif') as src:
>>> gw.to_netcdf(src, 'filename.nc', zlib=True, complevel=5)
>>>
>>> # Add extra layers
>>> with gw.config.update(sensor='l7'):
>>> with gw.open(
>>> 'LC08_L1TP_225078_20200219_20200225_01_T1.tif'
>>> ) as src, gw.open(
>>> 'LC08_L1TP_225078_20200219_20200225_01_T1_angles.tif',
>>> band_names=['zenith', 'azimuth']
>>> ) as ang:
>>> src = (
>>> xr.where(
>>> src == 0, -32768, src
>>> )
>>> .astype('int16')
>>> .assign_attrs(**src.attrs)
>>> )
>>>
>>> gw.to_netcdf(
>>> src,
>>> 'filename.nc',
>>> ang.astype('int16'),
>>> zlib=True,
>>> complevel=5,
>>> _FillValue=-32768
>>> )
>>>
>>> # Open the data and convert to a DataArray
>>> with xr.open_dataset(
>>> 'filename.nc', engine='h5netcdf', chunks=256
>>> ) as ds:
>>> src = ds.to_array(dim='band')
"""
if Path(filename).is_file():
if overwrite:
Path(filename).unlink()
else:
logger.warning(f"The file {str(filename)} already exists.")
return
encodings = {}
chunksize = min(data.gw.row_chunks, data.gw.col_chunks)
for band_name in data.band.values.tolist():
encode_dict = {
"chunksizes": (chunksize, chunksize),
"dtype": (
data.dtype.name
if isinstance(data.dtype, np.dtype)
else data.dtype
),
}
encode_dict.update(**kwargs)
encodings[band_name] = encode_dict
for other_data in args:
chunksize = min(other_data.gw.row_chunks, other_data.gw.col_chunks)
for band_name in other_data.band.values.tolist():
encode_dict = {
"chunksizes": (chunksize, chunksize),
"dtype": (
other_data.dtype.name
if isinstance(other_data.dtype, np.dtype)
else other_data.dtype
),
}
encode_dict.update(**kwargs)
encodings[band_name] = encode_dict
data = xr.concat((data, other_data), dim="band")
attrs = data.attrs.copy()
attrs["crs"] = f"epsg:{data.gw.crs_to_pyproj.to_epsg()}"
ds = data.to_dataset(dim="band").assign_attrs(**attrs)
if compute:
(
ds.to_netcdf(
path=filename,
mode="w",
format="NETCDF4",
engine="h5netcdf",
encoding=encodings,
compute=True,
)
)
else:
return ds.to_netcdf(
path=filename,
mode="w",
format="NETCDF4",
engine="h5netcdf",
encoding=encodings,
compute=False,
)
[docs]def save(
data: xr.DataArray,
filename: T.Union[str, Path],
mode: T.Optional[str] = "w",
nodata: T.Optional[T.Union[float, int]] = None,
overwrite: T.Optional[bool] = False,
client: T.Optional[Client] = None,
compute: T.Optional[bool] = True,
tags: T.Optional[dict] = None,
compress: T.Optional[str] = "none",
compression: T.Optional[str] = None,
num_workers: T.Optional[int] = 1,
log_progress: T.Optional[bool] = True,
tqdm_kwargs: T.Optional[dict] = None,
bigtiff: T.Optional[str] = None,
):
"""Saves a DataArray to raster using rasterio/dask.
Args:
filename (str | Path): The output file name to write to.
overwrite (Optional[bool]): Whether to overwrite an existing file. Default is False.
mode (Optional[str]): The file storage mode. Choices are ['w', 'r+'].
nodata (Optional[float | int]): The 'no data' value. If ``None`` (default), the 'no data'
value is taken from the ``DataArray`` metadata.
client (Optional[Client object]): A ``dask.distributed.Client`` client object to persist data.
Default is None.
compute (Optinoal[bool]): Whether to compute and write to ``filename``. Otherwise, return
the ``dask`` task graph. If ``True``, compute and write to ``filename``. If ``False``,
return the ``dask`` task graph. Default is ``True``.
tags (Optional[dict]): Metadata tags to write to file. Default is None.
compress (Optional[str]): The file compression type. Default is 'none', or no compression.
compression (Optional[str]): The file compression type. Default is 'none', or no compression.
.. deprecated:: 2.1.4
Use 'compress' -- 'compression' will be removed in >=2.2.0.
num_workers (Optional[int]): The number of dask workers (i.e., chunks) to write concurrently.
Default is 1.
log_progress (Optional[bool]): Whether to log the progress bar during writing. Default is True.
tqdm_kwargs (Optional[dict]): Keyword arguments to pass to ``tqdm``.
bigtiff (Optional[str]): A GDAL BIGTIFF flag. Choices are ["YES", "NO", "IF_NEEDED", "IF_SAFER"].
Returns:
``None``, writes to ``filename``
Example:
>>> import geowombat as gw
>>>
>>> with gw.open('file.tif') as src:
>>> result = ...
>>> gw.save(result, 'output.tif', compress='lzw', num_workers=8)
>>>
>>> # Create delayed write tasks and compute later
>>> tasks = [gw.save(array, 'output.tif', compute=False) for array in array_list]
>>> # Write and close files
>>> dask.compute(tasks, num_workers=8)
"""
if compression is not None:
warnings.warn(
"The argument 'compression' will be deprecated in >=2.2.0. Use 'compress'.",
DeprecationWarning,
stacklevel=2,
)
compress = compression
if mode not in ["w", "r+"]:
raise AttributeError("The mode must be either 'w' or 'r+'.")
if Path(filename).is_file():
if overwrite:
Path(filename).unlink()
else:
logger.warning(f"The file {str(filename)} already exists.")
return
if nodata is None:
if hasattr(data, "_FillValue"):
nodata = data.attrs["_FillValue"]
else:
if hasattr(data, "nodatavals"):
nodata = data.attrs["nodatavals"][0]
else:
raise AttributeError(
"The DataArray does not have any 'no data' attributes."
)
dtype = data.dtype.name if isinstance(data.dtype, np.dtype) else data.dtype
if isinstance(nodata, float):
if dtype != "float32":
dtype = "float64"
blockxsize = (
data.gw.check_chunksize(512, data.gw.ncols)
if not data.gw.array_is_dask
else data.gw.col_chunks
)
blockysize = (
data.gw.check_chunksize(512, data.gw.nrows)
if not data.gw.array_is_dask
else data.gw.row_chunks
)
tiled = True
if config["with_config"]:
if config["bigtiff"] is not None:
if isinstance(config["bigtiff"], bool):
bigtiff = "YES" if config["bigtiff"] else "NO"
else:
bigtiff = config["bigtiff"].upper()
if bigtiff not in (
"YES",
"NO",
"IF_NEEDED",
"IF_SAFER",
):
raise NameError(
"The GDAL BIGTIFF must be one of 'YES', 'NO', 'IF_NEEDED', or 'IF_SAFER'. See https://gdal.org/drivers/raster/gtiff.html#creation-issues for more information."
)
if config["compress"] is not None:
compress = config["compress"]
if config["tiled"] is not None:
tiled = config["tiled"]
kwargs = dict(
driver=driver_from_extension(filename),
width=data.gw.ncols,
height=data.gw.nrows,
count=data.gw.nbands,
dtype=dtype,
nodata=nodata,
blockxsize=blockxsize,
blockysize=blockysize,
crs=data.gw.crs_to_pyproj,
transform=data.gw.transform,
compress=compress,
tiled=tiled if max(blockxsize, blockysize) >= 16 else False,
sharing=False,
BIGTIFF=bigtiff,
)
if tqdm_kwargs is None:
tqdm_kwargs = {}
if not compute:
return (
RasterioStore(filename, mode=mode, tags=tags, **kwargs)
.open()
.write_delayed(data)
)
else:
with RasterioStore(
filename, mode=mode, tags=tags, **kwargs
) as rio_store:
# Store the data and return a lazy evaluator
res = rio_store.write(data)
if client is not None:
results = client.persist(res)
if log_progress:
progress(results)
dask.compute(results)
else:
if log_progress:
with TqdmCallback(**tqdm_kwargs):
dask.compute(res, num_workers=num_workers)
else:
dask.compute(res, num_workers=num_workers)
return None
[docs]def to_raster(
data,
filename,
readxsize=None,
readysize=None,
separate=False,
out_block_type="gtiff",
keep_blocks=False,
verbose=0,
overwrite=False,
gdal_cache=512,
scheduler="mpool",
n_jobs=1,
n_workers=None,
n_threads=None,
n_chunks=None,
padding=None,
tags=None,
tqdm_kwargs=None,
**kwargs,
):
"""Writes a ``dask`` array to a raster file.
.. note::
We advise using :func:`save` in place of this method.
Args:
data (DataArray): The ``xarray.DataArray`` to write.
filename (str): The output file name to write to.
readxsize (Optional[int]): The size of column chunks to read. If not given, ``readxsize`` defaults to Dask
chunk size.
readysize (Optional[int]): The size of row chunks to read. If not given, ``readysize`` defaults to Dask
chunk size.
separate (Optional[bool]): Whether to write blocks as separate files. Otherwise, write to a single file.
out_block_type (Optional[str]): The output block type. Choices are ['gtiff', 'zarr'].
Only used if ``separate`` = ``True``.
keep_blocks (Optional[bool]): Whether to keep the blocks stored on disk. Only used if ``separate`` = ``True``.
verbose (Optional[int]): The verbosity level.
overwrite (Optional[bool]): Whether to overwrite an existing file.
gdal_cache (Optional[int]): The ``GDAL`` cache size (in MB).
scheduler (Optional[str]): The parallel task scheduler to use. Choices are ['processes', 'threads', 'mpool'].
mpool: process pool of workers using ``multiprocessing.Pool``
processes: process pool of workers using ``concurrent.futures``
threads: thread pool of workers using ``concurrent.futures``
n_jobs (Optional[int]): The total number of parallel jobs.
n_workers (Optional[int]): The number of process workers.
n_threads (Optional[int]): The number of thread workers.
n_chunks (Optional[int]): The chunk size of windows. If not given, equal to ``n_workers`` x 50.
overviews (Optional[bool or list]): Whether to build overview layers.
resampling (Optional[str]): The resampling method for overviews when ``overviews`` is ``True`` or a ``list``.
Choices are ['average', 'bilinear', 'cubic', 'cubic_spline', 'gauss', 'lanczos', 'max', 'med', 'min', 'mode', 'nearest'].
padding (Optional[tuple]): Padding for each window. ``padding`` should be given as a tuple
of (left pad, bottom pad, right pad, top pad). If ``padding`` is given, the returned list will contain
a tuple of ``rasterio.windows.Window`` objects as (w1, w2), where w1 contains the normal window offsets
and w2 contains the padded window offsets.
tags (Optional[dict]): Image tags to write to file.
tqdm_kwargs (Optional[dict]): Additional keyword arguments to pass to ``tqdm``.
kwargs (Optional[dict]): Additional keyword arguments to pass to ``rasterio.write``.
Returns:
``None``, writes to ``filename``
Examples:
>>> import geowombat as gw
>>>
>>> # Use 8 parallel workers
>>> with gw.open('input.tif') as ds:
>>> gw.to_raster(ds, 'output.tif', n_jobs=8)
>>>
>>> # Use 4 process workers and 2 thread workers
>>> with gw.open('input.tif') as ds:
>>> gw.to_raster(ds, 'output.tif', n_workers=4, n_threads=2)
>>>
>>> # Control the window chunks passed to concurrent.futures
>>> with gw.open('input.tif') as ds:
>>> gw.to_raster(ds, 'output.tif', n_workers=4, n_threads=2, n_chunks=16)
>>>
>>> # Compress the output and build overviews
>>> with gw.open('input.tif') as ds:
>>> gw.to_raster(ds, 'output.tif', n_jobs=8, overviews=True, compress='lzw')
"""
if separate and not ZARR_INSTALLED and (out_block_type.lower() == "zarr"):
logger.exception(" zarr must be installed to write separate blocks.")
raise ImportError
pfile = Path(filename)
if scheduler.lower() == "mpool":
pool_executor = multi.Pool
else:
pool_executor = (
concurrent.futures.ProcessPoolExecutor
if scheduler.lower() == "processes"
else concurrent.futures.ThreadPoolExecutor
)
if overwrite:
if pfile.is_file():
pfile.unlink()
if pfile.is_file():
logger.warning(" The output file already exists.")
return
if not is_dask_collection(data.data):
logger.exception(" The data should be a dask array.")
if isinstance(n_workers, int) and isinstance(n_threads, int):
n_jobs = n_workers * n_threads
else:
n_workers = n_jobs
n_threads = 1
if not isinstance(n_chunks, int):
n_chunks = n_workers * 50
if not isinstance(readxsize, int):
readxsize = data.gw.col_chunks
if not isinstance(readysize, int):
readysize = data.gw.row_chunks
chunksize = (data.gw.row_chunks, data.gw.col_chunks)
if tqdm_kwargs is None:
tqdm_kwargs = {}
# Force tiled outputs with no file sharing
kwargs["sharing"] = False
if data.gw.tiled:
kwargs["tiled"] = True
if "compress" in kwargs:
# boolean True or '<>'
if kwargs["compress"]:
if (
isinstance(kwargs["compress"], str)
and kwargs["compress"].lower() == "none"
):
compress = False
else:
if "num_threads" in kwargs:
compress = False
else:
compress = True
if compress:
# Store the compression type because
# it is removed in concurrent writing
compress_type = kwargs["compress"]
del kwargs["compress"]
else:
compress = False
elif isinstance(data.gw.compress, str) and (
data.gw.compress.lower() in ["lzw", "deflate"]
):
compress = True
compress_type = data.gw.compress
else:
compress = False
if kwargs.get("nodata") is None:
if isinstance(data.gw.nodataval, (float, int)):
kwargs["nodata"] = data.gw.nodataval
if "blockxsize" not in kwargs:
kwargs["blockxsize"] = data.gw.col_chunks
if "blockysize" not in kwargs:
kwargs["blockysize"] = data.gw.row_chunks
if "bigtiff" not in kwargs:
kwargs["bigtiff"] = data.gw.bigtiff
if "driver" not in kwargs:
kwargs["driver"] = data.gw.driver
if "count" not in kwargs:
kwargs["count"] = data.gw.nbands
if "width" not in kwargs:
kwargs["width"] = data.gw.ncols
if "height" not in kwargs:
kwargs["height"] = data.gw.nrows
if "transform" not in kwargs:
kwargs["transform"] = data.gw.transform
if "num_threads" in kwargs:
if isinstance(kwargs["num_threads"], str):
kwargs["num_threads"] = "all_cpus"
if "crs" in kwargs:
crs = kwargs["crs"]
else:
crs = data.crs
if str(crs).lower().startswith("epsg:"):
kwargs["crs"] = pyproj.CRS.from_user_input(crs)
else:
try:
kwargs["crs"] = pyproj.CRS.from_epsg(int(crs))
except ValueError:
kwargs["crs"] = pyproj.CRS.from_user_input(crs)
kwargs["crs"] = kwargs["crs"].to_wkt()
root = None
if separate and (out_block_type.lower() == "zarr"):
d_name = pfile.parent
sub_dir = d_name.joinpath("sub_tmp_")
sub_dir.mkdir(parents=True, exist_ok=True)
zarr_file = str(sub_dir.joinpath("data.zarr"))
root = zarr.open(zarr_file, mode="w")
else:
if not separate:
if verbose > 0:
logger.info(" Creating the file ...\n")
with rio.open(filename, mode="w", **kwargs) as rio_dst:
if tags:
rio_dst.update_tags(**tags)
if verbose > 0:
logger.info(" Writing data to file ...\n")
with rio.Env(GDAL_CACHEMAX=gdal_cache):
windows = get_window_offsets(
data.gw.nrows,
data.gw.ncols,
readysize,
readxsize,
return_as="list",
padding=padding,
)
n_windows = len(windows)
oleft, otop = kwargs["transform"][2], kwargs["transform"][5]
ocols, orows = kwargs["width"], kwargs["height"]
# Iterate over the windows in chunks
for wchunk in range(0, n_windows, n_chunks):
window_slice = windows[wchunk : wchunk + n_chunks]
n_windows_slice = len(window_slice)
if verbose > 0:
logger.info(
" Windows {:,d}--{:,d} of {:,d} ...".format(
wchunk + 1, wchunk + n_windows_slice, n_windows
)
)
if padding:
# Read the padded window
if len(data.shape) == 2:
data_gen = (
(
data[
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
filename,
widx + wchunk,
w[0],
w[1],
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
elif len(data.shape) == 3:
data_gen = (
(
data[
:,
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
filename,
widx + wchunk,
w[0],
w[1],
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
else:
data_gen = (
(
data[
:,
:,
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
filename,
widx + wchunk,
w[0],
w[1],
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
else:
if len(data.shape) == 2:
data_gen = (
(
data[
w.row_off : w.row_off + w.height,
w.col_off : w.col_off + w.width,
],
filename,
widx + wchunk,
w,
None,
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
elif len(data.shape) == 3:
data_gen = (
(
data[
:,
w.row_off : w.row_off + w.height,
w.col_off : w.col_off + w.width,
],
filename,
widx + wchunk,
w,
None,
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
else:
data_gen = (
(
data[
:,
:,
w.row_off : w.row_off + w.height,
w.col_off : w.col_off + w.width,
],
filename,
widx + wchunk,
w,
None,
n_workers,
n_threads,
separate,
chunksize,
root,
out_block_type,
tags,
oleft,
otop,
ocols,
orows,
kwargs,
)
for widx, w in enumerate(window_slice)
)
if n_workers == 1:
for __ in tqdm(
map(_write_xarray, data_gen),
total=n_windows_slice,
**tqdm_kwargs,
):
pass
else:
with pool_executor(n_workers) as executor:
if scheduler == "mpool":
for __ in tqdm(
executor.imap_unordered(_write_xarray, data_gen),
total=n_windows_slice,
**tqdm_kwargs,
):
pass
else:
for __ in tqdm(
executor.map(_write_xarray, data_gen),
total=n_windows_slice,
**tqdm_kwargs,
):
pass
if compress:
if separate:
if out_block_type.lower() == "zarr":
group_keys = list(root.group_keys())
n_groups = len(group_keys)
if out_block_type.lower() == "zarr":
open_file = zarr_file
kwargs["compress"] = compress_type
n_windows = len(group_keys)
# Compress into one file
with rio.open(filename, mode="w", **kwargs) as dst_:
if tags:
dst_.update_tags(**tags)
# Iterate over the windows in chunks
for wchunk in range(0, n_groups, n_chunks):
group_keys_slice = group_keys[
wchunk : wchunk + n_chunks
]
n_windows_slice = len(group_keys_slice)
if verbose > 0:
logger.info(
" Windows {:,d}--{:,d} of {:,d} ...".format(
wchunk + 1,
wchunk + n_windows_slice,
n_windows,
)
)
################################################
data_gen = (
(open_file, group, "zarr")
for group in group_keys_slice
)
with concurrent.futures.ProcessPoolExecutor(
max_workers=n_workers
) as executor:
# Submit all the tasks as futures
futures = [
executor.submit(_block_read_func, f, g, t)
for f, g, t in data_gen
]
for f in tqdm(
concurrent.futures.as_completed(futures),
total=n_windows_slice,
**tqdm_kwargs,
):
(
out_window,
out_indexes,
out_block,
) = f.result()
dst_.write(
out_block,
window=out_window,
indexes=out_indexes,
)
futures = None
if not keep_blocks:
shutil.rmtree(sub_dir)
else:
if verbose > 0:
logger.info(" Compressing output file ...")
p = Path(filename)
d_name = p.parent
f_base, f_ext = os.path.splitext(p.name)
ld = string.ascii_letters + string.digits
rstr = "".join(random.choice(ld) for i in range(0, 9))
temp_file = d_name.joinpath(
"{f_base}_temp_{rstr}{f_ext}".format(
f_base=f_base, rstr=rstr, f_ext=f_ext
)
)
compress_raster(
filename,
str(temp_file),
n_jobs=n_jobs,
gdal_cache=gdal_cache,
compress=compress_type,
tags=tags,
)
temp_file.replace(filename)
if verbose > 0:
logger.info(" Finished compressing")
if verbose > 0:
logger.info("\nFinished writing the data.")
def _arg_gen(arg_, iter_):
for i_ in iter_:
yield arg_
[docs]def apply(
infile,
outfile,
block_func,
args=None,
count=1,
scheduler="processes",
gdal_cache=512,
n_jobs=4,
overwrite=False,
tags=None,
**kwargs,
):
"""Applies a function and writes results to file.
Args:
infile (str): The input file to process.
outfile (str): The output file.
block_func (func): The user function to apply to each block. The function should always return the window,
the data, and at least one argument. The block data inside the function will be a 2d array if the
input image has 1 band, otherwise a 3d array.
args (Optional[tuple]): Additional arguments to pass to ``block_func``.
count (Optional[int]): The band count for the output file.
scheduler (Optional[str]): The ``concurrent.futures`` scheduler to use. Choices are ['threads', 'processes'].
gdal_cache (Optional[int]): The ``GDAL`` cache size (in MB).
n_jobs (Optional[int]): The number of blocks to process in parallel.
overwrite (Optional[bool]): Whether to overwrite an existing output file.
tags (Optional[dict]): Image tags to write to file.
kwargs (Optional[dict]): Additional keyword arguments to pass to ``rasterio.open``.
Returns:
``None``, writes to ``outfile``
Examples:
>>> import geowombat as gw
>>>
>>> # Here is a function with no arguments
>>> def my_func0(w, block, arg):
>>> return w, block
>>>
>>> gw.apply('input.tif',
>>> 'output.tif',
>>> my_func0,
>>> n_jobs=8)
>>>
>>> # Here is a function with 1 argument
>>> def my_func1(w, block, arg):
>>> return w, block * arg
>>>
>>> gw.apply('input.tif',
>>> 'output.tif',
>>> my_func1,
>>> args=(10.0,),
>>> n_jobs=8)
"""
if not args:
args = (None,)
if overwrite:
if os.path.isfile(outfile):
os.remove(outfile)
kwargs["sharing"] = False
kwargs["tiled"] = True
io_mode = "r+" if os.path.isfile(outfile) else "w"
out_indexes = 1 if count == 1 else list(range(1, count + 1))
futures_executor = (
concurrent.futures.ThreadPoolExecutor
if scheduler == "threads"
else concurrent.futures.ProcessPoolExecutor
)
with rio.Env(GDAL_CACHEMAX=gdal_cache):
with rio.open(infile) as src:
profile = src.profile.copy()
if "dtype" not in kwargs:
kwargs["dtype"] = profile["dtype"]
if "nodata" not in kwargs:
kwargs["nodata"] = profile["nodata"]
if "blockxsize" not in kwargs:
kwargs["blockxsize"] = profile["blockxsize"]
if "blockxsize" not in kwargs:
kwargs["blockysize"] = profile["blockysize"]
# Create a destination dataset based on source params. The
# destination will be tiled, and we'll process the tiles
# concurrently.
profile.update(count=count, **kwargs)
with rio.open(outfile, io_mode, **profile) as dst:
if tags:
dst.update_tags(**tags)
# This generator comprehension gives us raster data
# arrays for each window. Later we will zip a mapping
# of it with the windows list to get (window, result)
# pairs.
data_gen = (
src.read(window=w, out_dtype=profile["dtype"])
for ij, w in src.block_windows(1)
)
if args:
args = [
_arg_gen(arg, src.block_windows(1)) for arg in args
]
with futures_executor(max_workers=n_jobs) as executor:
# Submit all of the tasks as futures
futures = [
executor.submit(block_func, iter_[0][1], *iter_[1:])
for iter_ in zip(
list(src.block_windows(1)), data_gen, *args
)
]
for f in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
):
out_window, out_block = f.result()
dst.write(
np.squeeze(out_block),
window=out_window,
indexes=out_indexes,
)
def _compress_dummy(w, block, dummy):
"""Dummy function to pass to concurrent writing."""
return w, block
[docs]def compress_raster(
infile, outfile, n_jobs=1, gdal_cache=512, compress="lzw", tags=None
):
"""Compresses a raster file.
Args:
infile (str): The file to compress.
outfile (str): The output file.
n_jobs (Optional[int]): The number of concurrent blocks to write.
gdal_cache (Optional[int]): The ``GDAL`` cache size (in MB).
compress (Optional[str]): The compression method.
tags (Optional[dict]): Image tags to write to file.
Returns:
None
"""
with rio.open(infile) as src:
profile = src.profile.copy()
profile.update(compress=compress)
apply(
infile,
outfile,
_compress_dummy,
scheduler="processes",
args=(None,),
gdal_cache=gdal_cache,
n_jobs=n_jobs,
tags=tags,
count=src.count,
dtype=src.profile["dtype"],
nodata=src.profile["nodata"],
tiled=src.profile["tiled"],
blockxsize=src.profile["blockxsize"],
blockysize=src.profile["blockysize"],
compress=compress,
)