import contextlib
import logging
import os
import typing as T
import warnings
from pathlib import Path
import dask.array as da
import numpy as np
import xarray as xr
from dask.delayed import Delayed
from rasterio import open as rio_open
from rasterio.coords import BoundingBox
from rasterio.windows import Window
from ..config import config
from ..core.util import parse_filename_dates
from ..core.windows import get_window_offsets
from rasterio.enums import Resampling as RioResampling
from rasterio.warp import reproject
from ..handler import add_handler
from .rasterio_ import get_file_bounds, get_ref_image_meta
from .rasterio_ import transform_crs as rio_transform_crs
from .rasterio_ import (
unpack_bounding_box,
unpack_window,
warp,
warp_images,
window_to_bounds,
)
from .xarray_rasterio_ import open_rasterio
logger = logging.getLogger(__name__)
logger = add_handler(logger)
def _update_kwarg(ref_obj, ref_kwargs, key):
"""Updates keyword arguments for global config parameters.
Args:
ref_obj (str or object)
ref_kwargs (dict)
key (str)
Returns:
``dict``
"""
if isinstance(ref_obj, str) and Path(ref_obj).is_file():
# Get the metadata from the reference image
ref_meta = get_ref_image_meta(ref_obj)
ref_kwargs[key] = getattr(ref_meta, key)
else:
if ref_obj is not None:
ref_kwargs[key] = ref_obj
return ref_kwargs
def _get_raster_coords(filename: T.Union[str, Path]) -> tuple:
with open_rasterio(filename) as src:
x = src.x.values - src.res[0] / 2.0
y = src.y.values + src.res[1] / 2.0
return x, y
def _check_config_globals(
filenames: T.Union[str, Path, T.Sequence[T.Union[str, Path]]],
bounds_by: str,
ref_kwargs: dict,
) -> dict:
"""Checks global configuration parameters.
Args:
filenames (str or str list)
bounds_by (str)
ref_kwargs (dict)
"""
assert bounds_by.lower() in (
"intersection",
"reference",
"union",
), "The bounds_by argument must be 'intersection', 'reference', or 'union'."
if config['nodata'] is not None:
ref_kwargs = _update_kwarg(config['nodata'], ref_kwargs, 'nodata')
# Check if there is a reference image
if config['ref_image']:
if (
isinstance(config['ref_image'], (Path, str))
and Path(config['ref_image']).is_file()
):
# Get the metadata from the reference image
ref_meta = get_ref_image_meta(config['ref_image'])
ref_kwargs['bounds'] = ref_meta.bounds
ref_kwargs['crs'] = ref_meta.crs
ref_kwargs['res'] = ref_meta.res
else:
if not config['ignore_warnings']:
logger.warning(' The reference image does not exist')
else:
if config['ref_bounds']:
if isinstance(config['ref_bounds'], str) and config[
'ref_bounds'
].startswith('Window'):
ref_bounds = window_to_bounds(
filenames, unpack_window(config['ref_bounds'])
)
elif isinstance(config['ref_bounds'], str) and config[
'ref_bounds'
].startswith('BoundingBox'):
ref_bounds = unpack_bounding_box(config['ref_bounds'])
elif isinstance(config['ref_bounds'], Window):
ref_bounds = window_to_bounds(filenames, config['ref_bounds'])
elif isinstance(config['ref_bounds'], BoundingBox):
ref_bounds = config['ref_bounds']
else:
ref_bounds = config['ref_bounds']
ref_kwargs = _update_kwarg(tuple(ref_bounds), ref_kwargs, 'bounds')
else:
if isinstance(filenames, (Path, str)):
# Use the bounds of the input image
ref_kwargs['bounds'] = get_file_bounds(
[filenames],
bounds_by='reference',
crs=ref_kwargs['crs'],
res=ref_kwargs['res'],
return_bounds=True,
)
else:
# Get the union bounds of all images
ref_kwargs['bounds'] = get_file_bounds(
filenames,
bounds_by=bounds_by.lower(),
crs=ref_kwargs['crs'],
res=ref_kwargs['res'],
return_bounds=True,
)
config['ref_bounds'] = ref_kwargs['bounds']
if config['ref_crs'] is not None:
ref_kwargs = _update_kwarg(config['ref_crs'], ref_kwargs, 'crs')
if config['ref_res'] is not None:
ref_kwargs = _update_kwarg(config['ref_res'], ref_kwargs, 'res')
if config['ref_tar'] is not None:
if isinstance(config['ref_tar'], str):
if Path(config['ref_tar']).is_file():
ref_kwargs = _update_kwarg(
_get_raster_coords(config['ref_tar']),
ref_kwargs,
'tac',
)
else:
if not config['ignore_warnings']:
logger.warning(
' The target aligned raster does not exist.'
)
else:
if not config['ignore_warnings']:
logger.warning(
' The target aligned raster must be an image.'
)
return ref_kwargs
def delayed_to_xarray(
delayed_data: Delayed,
shape: tuple,
dtype: T.Union[str, np.dtype],
chunks: tuple,
coords: dict,
attrs: dict,
) -> xr.DataArray:
"""Converts a dask.Delayed array to a Xarray DataArray."""
return xr.DataArray(
da.from_delayed(delayed_data, shape=shape, dtype=dtype).rechunk(
chunks
),
dims=('band', 'y', 'x'),
coords=coords,
attrs=attrs,
)
def _attach_nodata_mask(filename, src, nodata_value):
"""Create a nodata mask from the original file and warp it.
Reads band 1 of *filename* at native resolution, builds a binary
mask (1 = nodata), warps it with nearest-neighbour to the grid of
*src*, and attaches it as the ``_nodata_mask`` coordinate (bool,
``True`` where nodata).
"""
try:
with rio_open(filename) as raw:
native_data = raw.read(1)
src_transform = raw.transform
src_crs = raw.crs
# Build mask: True where ANY band has the nodata value
if isinstance(nodata_value, float) and np.isnan(nodata_value):
nd_native = np.isnan(native_data).astype(np.uint8)
else:
nd_native = (native_data == nodata_value).astype(np.uint8)
if nd_native.sum() == 0:
return src
# Destination grid from the DataArray
dst_height, dst_width = src.shape[-2], src.shape[-1]
t = src.attrs.get('transform')
if t is None:
return src
from rasterio.transform import Affine
dst_transform = Affine(*t[:6]) if not isinstance(t, Affine) else t
dst_crs = src_crs # same CRS assumed after warp
dst_mask = np.zeros((dst_height, dst_width), dtype=np.uint8)
reproject(
nd_native,
dst_mask,
src_transform=src_transform,
src_crs=src_crs,
dst_transform=dst_transform,
dst_crs=dst_crs,
resampling=RioResampling.nearest,
)
mask_da = xr.DataArray(
dst_mask.astype(bool),
dims=('y', 'x'),
coords={'y': src.y, 'x': src.x},
)
src = src.assign_coords(_nodata_mask=mask_da)
except Exception:
pass
return src
def warp_open(
filename: T.Union[str, Path],
band_names: T.Optional[T.Sequence[T.Union[int, str]]] = None,
resampling: str = 'nearest',
dtype: T.Optional[str] = None,
netcdf_vars: T.Optional[T.Sequence[T.Union[int, str]]] = None,
nodata: T.Optional[T.Union[int, float]] = None,
return_windows: bool = False,
warp_mem_limit: int = 512,
num_threads: int = 1,
tap: bool = False,
**kwargs,
):
"""Warps and opens a file.
Args:
filename (str): The file to open.
band_names (Optional[int, str, or list]): The band names.
resampling (Optional[str]): The resampling method.
dtype (Optional[str]): A data type to force the output to. If not given, the data type is extracted
from the file.
netcdf_vars (Optional[list]): NetCDF variables to open as a band stack.
nodata (Optional[float | int]): A 'no data' value to set. Default is ``None``.
return_windows (Optional[bool]): Whether to return block windows.
warp_mem_limit (Optional[int]): The memory limit (in MB) for the ``rasterio.vrt.WarpedVRT`` function.
num_threads (Optional[int]): The number of warp worker threads.
tap (Optional[bool]): Whether to target align pixels.
kwargs (Optional[dict]): Keyword arguments passed to ``open_rasterio``.
Returns:
``xarray.DataArray``
"""
ref_kwargs = {
'bounds': None,
'crs': None,
'res': None,
'nodata': nodata,
'warp_mem_limit': warp_mem_limit,
'num_threads': num_threads,
'tap': tap,
'tac': None,
}
ref_kwargs_netcdf_stack = ref_kwargs.copy()
ref_kwargs_netcdf_stack['bounds_by'] = 'union'
del ref_kwargs_netcdf_stack['tap']
ref_kwargs = _check_config_globals(filename, 'reference', ref_kwargs)
filenames = None
# Create a list of variables to open
if filename.lower().startswith('netcdf:') and netcdf_vars:
filenames = (f'{filename}:' + f',{filename}:'.join(netcdf_vars)).split(
','
)
if filenames:
ref_kwargs_netcdf_stack = _check_config_globals(
filenames[0], 'reference', ref_kwargs_netcdf_stack
)
with rio_open(filenames[0]) as src:
tags = src.tags()
else:
ref_kwargs_netcdf_stack = _check_config_globals(
filename, 'reference', ref_kwargs_netcdf_stack
)
with rio_open(filename) as src:
tags = src.tags()
@contextlib.contextmanager
def warp_netcdf_vars():
# Warp all images to the same grid.
warped_objects = warp_images(
filenames, resampling=resampling, **ref_kwargs_netcdf_stack
)
yield xr.concat(
(
open_rasterio(
wobj,
nodata=ref_kwargs['nodata'],
num_threads=num_threads,
**kwargs,
).assign_coords(
band=[band_names[wi]] if band_names else [netcdf_vars[wi]]
)
for wi, wobj in enumerate(warped_objects)
),
dim='band',
)
with open_rasterio(
warp(filename, resampling=resampling, **ref_kwargs),
nodata=ref_kwargs['nodata'],
num_threads=num_threads,
**kwargs,
) if not filenames else warp_netcdf_vars() as src:
if band_names:
if len(band_names) > src.gw.nbands:
src.coords['band'] = band_names[: src.gw.nbands]
else:
src.coords['band'] = band_names
else:
if src.gw.sensor:
if src.gw.sensor not in src.gw.avail_sensors:
if not src.gw.config['ignore_warnings']:
logger.warning(
' The {} sensor is not currently supported.\nChoose from [{}].'.format(
src.gw.sensor, ', '.join(src.gw.avail_sensors)
)
)
else:
new_band_names = list(
src.gw.wavelengths[src.gw.sensor]._fields
)
# Avoid nested opens within a `config` context
if len(new_band_names) != len(src.band.values.tolist()):
if not src.gw.config['ignore_warnings']:
logger.warning(
' The new bands, {}, do not match the sensor bands, {}.'.format(
new_band_names, src.band.values.tolist()
)
)
else:
src = src.assign_coords(**{'band': new_band_names})
src = src.assign_attrs(
**{'sensor': src.gw.sensor_names[src.gw.sensor]}
)
if return_windows:
if isinstance(kwargs['chunks'], tuple):
chunksize = kwargs['chunks'][-1]
else:
chunksize = kwargs['chunks']
src.attrs['block_windows'] = get_window_offsets(
src.shape[-2],
src.shape[-1],
chunksize,
chunksize,
return_as='list',
)
src = src.assign_attrs(
**{'filename': filename, 'resampling': resampling}
)
if tags:
attrs = src.attrs.copy()
attrs.update(tags)
src = src.assign_attrs(**attrs)
# Build a nodata mask from the original (pre-warp) file.
# After GDAL warping, nodata pixels may get interpolated
# to non-nodata values, making them undetectable via
# simple value comparison. This warps a binary mask with
# nearest-neighbor to preserve the nodata footprint.
nd = ref_kwargs.get('nodata')
if nd is not None and not filenames:
src = _attach_nodata_mask(filename, src, nd)
if dtype:
attrs = src.attrs.copy()
return src.astype(dtype).assign_attrs(**attrs)
else:
return src
def mosaic(
filenames: T.Sequence[T.Union[str, Path]],
overlap: str = 'max',
bounds_by: str = 'reference',
resampling: str = 'nearest',
band_names: T.Optional[T.Sequence[T.Union[int, str]]] = None,
nodata: T.Optional[T.Union[float, int]] = None,
dtype: T.Optional[str] = None,
warp_mem_limit: int = 512,
num_threads: int = 1,
**kwargs,
) -> xr.DataArray:
"""Mosaics a list of images.
Args:
filenames (list): A list of file names to mosaic.
overlap (Optional[str]): The keyword that determines how to handle overlapping data.
Choices are ['min', 'max', 'mean'].
bounds_by (Optional[str]): How to concatenate the output extent. Choices are ['intersection', 'union', 'reference'].
* reference: Use the bounds of the reference image
* intersection: Use the intersection (i.e., minimum extent) of all the image bounds
* union: Use the union (i.e., maximum extent) of all the image bounds
resampling (Optional[str]): The resampling method.
band_names (Optional[1d array-like]): A list of names to give the band dimension.
nodata (Optional[float | int]): A 'no data' value to set. Default is ``None``.
dtype (Optional[str]): A data type to force the output to. If not given, the data type is extracted
from the file.
warp_mem_limit (Optional[int]): The memory limit (in MB) for the ``rasterio.vrt.WarpedVRT`` function.
num_threads (Optional[int]): The number of warp worker threads.
kwargs (Optional[dict]): Keyword arguments passed to ``open_rasterio``.
Returns:
``xarray.DataArray``
"""
if overlap not in (
'min',
'max',
'mean',
):
logger.exception(
" The overlap argument must be one of 'min', 'max', or 'mean'."
)
ref_kwargs = {
'bounds': None,
'crs': None,
'res': None,
'nodata': nodata,
'warp_mem_limit': warp_mem_limit,
'num_threads': num_threads,
'tac': None,
}
ref_kwargs = _check_config_globals(filenames, bounds_by, ref_kwargs)
# Warp all images to the same grid.
warped_objects = warp_images(
filenames, bounds_by=bounds_by, resampling=resampling, **ref_kwargs
)
with rio_open(filenames[0]) as src_:
tags = src_.tags()
# Get the original bounds, unsampled
with open_rasterio(
filenames[0],
nodata=ref_kwargs['nodata'],
num_threads=num_threads,
**kwargs,
) as src_:
attrs = src_.attrs.copy()
geometries = []
for fn in filenames:
with open_rasterio(
fn,
nodata=ref_kwargs['nodata'],
num_threads=num_threads,
**kwargs,
) as src_:
geometries.append(src_.gw.geometry)
if overlap == 'min':
tmp_nodata = 1e9
elif overlap in ('max', 'mean'):
tmp_nodata = -1e9
# Open all the data pointers
data_arrays = [
open_rasterio(
wo,
nodata=ref_kwargs['nodata'],
num_threads=num_threads,
**kwargs,
)
.gw.set_nodata(
src_nodata=ref_kwargs['nodata'],
dst_nodata=tmp_nodata,
dtype='float64',
)
.gw.mask_nodata()
for wo in warped_objects
]
# Stack all arrays and reduce in one operation (O(1) graph
# depth instead of O(N) from pairwise functools.reduce)
stacked = da.stack([d.data for d in data_arrays], axis=0)
if overlap == 'min':
reduced = da.nanmin(stacked, axis=0)
elif overlap == 'max':
reduced = da.nanmax(stacked, axis=0)
elif overlap == 'mean':
reduced = da.nanmean(stacked, axis=0)
darray = xr.DataArray(
reduced,
dims=data_arrays[0].dims,
coords=data_arrays[0].coords,
)
# Reset the 'no data' values
darray = darray.gw.set_nodata(
src_nodata=tmp_nodata,
dst_nodata=ref_kwargs['nodata'],
).assign_attrs(**attrs)
if band_names:
darray.coords['band'] = band_names
else:
if darray.gw.sensor:
if darray.gw.sensor not in darray.gw.avail_sensors:
if not darray.gw.config['ignore_warnings']:
logger.warning(
' The {} sensor is not currently supported.\nChoose from [{}].'.format(
darray.gw.sensor,
', '.join(darray.gw.avail_sensors),
)
)
else:
new_band_names = list(
darray.gw.wavelengths[darray.gw.sensor]._fields
)
if len(new_band_names) != len(darray.band.values.tolist()):
if not darray.gw.config['ignore_warnings']:
logger.warning(
' The band list length does not match the sensor bands.'
)
else:
darray = darray.assign_coords(**{'band': new_band_names})
darray = darray.assign_attrs(
**{'sensor': darray.gw.sensor_names[darray.gw.sensor]}
)
darray = darray.assign_attrs(
**{'resampling': resampling, 'geometries': geometries}
)
if tags:
attrs = darray.attrs.copy()
attrs.update(tags)
darray = darray.assign_attrs(**attrs)
if dtype is not None:
attrs = darray.attrs.copy()
return darray.astype(dtype).assign_attrs(**attrs)
else:
return darray
def check_alignment(concat_list: T.Sequence[xr.DataArray]) -> None:
try:
for fidx in range(0, len(concat_list) - 1):
xr.align(concat_list[fidx], concat_list[fidx + 1], join='exact')
except ValueError:
warning_message = (
'The stacked dimensions are not aligned. If this was not intentional, '
'use gw.config.update to align coordinates. To suppress this message, use '
'with gw.config.update(ignore_warnings=True):.'
)
warnings.warn(warning_message, UserWarning)
logger.warning(warning_message)
def concat(
filenames: T.Sequence[T.Union[str, Path]],
stack_dim: str = 'time',
bounds_by: str = 'reference',
resampling: str = 'nearest',
time_names: T.Optional[T.Sequence[T.Any]] = None,
band_names: T.Optional[T.Sequence[T.Any]] = None,
nodata: T.Optional[T.Union[float, int]] = None,
dtype: T.Optional[str] = None,
netcdf_vars: T.Optional[T.Sequence[T.Any]] = None,
overlap: str = 'max',
warp_mem_limit: int = 512,
num_threads: int = 1,
tap: bool = False,
**kwargs,
):
"""Concatenates a list of images.
Args:
filenames (list): A list of file names to concatenate.
stack_dim (Optional[str]): The stack dimension. Choices are ['time', 'band'].
bounds_by (Optional[str]): How to concatenate the output extent. Choices are ['intersection', 'union', 'reference'].
* reference: Use the bounds of the reference image
* intersection: Use the intersection (i.e., minimum extent) of all the image bounds
* union: Use the union (i.e., maximum extent) of all the image bounds
resampling (Optional[str]): The resampling method.
time_names (Optional[1d array-like]): A list of names to give the time dimension.
band_names (Optional[1d array-like]): A list of names to give the band dimension.
nodata (Optional[float | int]): A 'no data' value to set. Default is ``None``.
dtype (Optional[str]): A data type to force the output to. If not given, the data type is extracted
from the file.
netcdf_vars (Optional[list]): NetCDF variables to open as a band stack.
overlap (Optional[str]): The keyword that determines how to handle overlapping data.
Choices are ['min', 'max', 'mean']. Only used when mosaicking arrays from the same timeframe.
warp_mem_limit (Optional[int]): The memory limit (in MB) for the ``rasterio.vrt.WarpedVRT`` function.
num_threads (Optional[int]): The number of warp worker threads.
tap (Optional[bool]): Whether to target align pixels.
kwargs (Optional[dict]): Keyword arguments passed to ``open_rasterio``.
Returns:
``xarray.DataArray``
"""
if stack_dim.lower() not in (
'band',
'time',
):
logger.exception(" The stack dimension should be 'band' or 'time'.")
with rio_open(filenames[0]) as src_:
tags = src_.tags()
src_ = warp_open(
f'{filenames[0]}:{netcdf_vars[0]}' if netcdf_vars else filenames[0],
resampling=resampling,
band_names=[netcdf_vars[0]] if netcdf_vars else band_names,
nodata=nodata,
warp_mem_limit=warp_mem_limit,
num_threads=num_threads,
**kwargs,
)
attrs = src_.attrs.copy()
src_.close()
src_ = None
if time_names and not (str(filenames[0]).lower().startswith('netcdf:')):
concat_list = []
new_time_names = []
# Check the time names for duplicates
for tidx in range(0, len(time_names)):
if list(time_names).count(time_names[tidx]) > 1:
if time_names[tidx] not in new_time_names:
# Get the file names to mosaic
filenames_mosaic = [
filenames[i]
for i in range(0, len(time_names))
if time_names[i] == time_names[tidx]
]
# Mosaic the images into a single-date array
concat_list.append(
mosaic(
filenames_mosaic,
overlap=overlap,
bounds_by=bounds_by,
resampling=resampling,
band_names=band_names,
nodata=nodata,
warp_mem_limit=warp_mem_limit,
num_threads=num_threads,
**kwargs,
)
)
new_time_names.append(time_names[tidx])
else:
new_time_names.append(time_names[tidx])
# Warp the date
concat_list.append(
warp_open(
filenames[tidx],
resampling=resampling,
band_names=band_names,
nodata=nodata,
warp_mem_limit=warp_mem_limit,
num_threads=num_threads,
**kwargs,
)
)
if not concat_list[0].gw.config['ignore_warnings']:
check_alignment(concat_list)
# Warp all images and concatenate along the 'time' axis into a DataArray
src = xr.concat(concat_list, dim=stack_dim.lower()).assign_coords(
time=new_time_names
)
else:
warp_list = [
warp_open(
fn,
resampling=resampling,
band_names=band_names,
netcdf_vars=netcdf_vars,
nodata=nodata,
warp_mem_limit=warp_mem_limit,
num_threads=num_threads,
**kwargs,
)
for fn in filenames
]
if not warp_list[0].gw.config['ignore_warnings']:
check_alignment(warp_list)
src = xr.concat(warp_list, dim=stack_dim.lower())
src = src.assign_attrs(**{'filename': [Path(fn).name for fn in filenames]})
if tags:
attrs = src.attrs.copy()
attrs.update(tags)
src = src.assign_attrs(**attrs)
if stack_dim == 'time':
if str(filenames[0]).lower().startswith('netcdf:'):
if time_names:
src.coords['time'] = time_names
else:
src.coords['time'] = parse_filename_dates(filenames)
try:
src = src.groupby('time').max().assign_attrs(**attrs)
except ValueError:
pass
else:
if not time_names:
src.coords['time'] = parse_filename_dates(filenames)
if band_names:
src = src.assign_coords(band=band_names)
else:
if src.gw.sensor:
if src.gw.sensor not in src.gw.avail_sensors:
if not src.gw.config['ignore_warnings']:
logger.warning(
' The {} sensor is not currently supported.\nChoose from [{}].'.format(
src.gw.sensor, ', '.join(src.gw.avail_sensors)
)
)
else:
new_band_names = list(
src.gw.wavelengths[src.gw.sensor]._fields
)
if len(new_band_names) != len(src.band.values.tolist()):
if not src.gw.config['ignore_warnings']:
logger.warning(
' The new bands, {}, do not match the sensor bands, {}.'.format(
new_band_names, src.band.values.tolist()
)
)
else:
src = src.assign_coords(**{'band': new_band_names})
src = src.assign_attrs(
**{'sensor': src.gw.sensor_names[src.gw.sensor]}
)
else:
src = src.assign_coords(band=range(1, src.gw.nbands + 1))
if dtype:
attrs = src.attrs.copy()
return src.astype(dtype).assign_attrs(**attrs)
else:
return src