import itertools
import logging
import multiprocessing as multi
import os
import tempfile
import typing as T
import warnings
from collections import defaultdict
from datetime import datetime
from pathlib import Path
import dask
import dask.array as da
import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from affine import Affine
from dask.distributed import Client, LocalCluster
from pyproj.enums import WktVersion
from pyproj.exceptions import CRSError
from rasterio import features
from rasterio.coords import BoundingBox
from scipy.spatial import cKDTree
from scipy.stats import mode as sci_mode
from shapely.geometry import Polygon
try:
import arosics
from geoarray import GeoArray
AROSICS_INSTALLED = True
except ImportError:
AROSICS_INSTALLED = False
try:
import pymorph
PYMORPH_INSTALLED = True
except ImportError:
PYMORPH_INSTALLED = False
from ..backends.rasterio_ import array_bounds, check_crs, get_dims_from_bounds
from ..handler import add_handler
from .base import PropertyMixin as _PropertyMixin
from .base import _client_dummy, _cluster_dummy
from .conversion import Converters
from .parallel import ParallelTask
from .util import lazy_wombat
logger = logging.getLogger(__name__)
logger = add_handler(logger)
def _remove_near_points(
dataframe: gpd.GeoDataFrame, r: T.Union[float, int]
) -> gpd.GeoDataFrame:
"""Removes points less than a specified distance to another point.
Args:
dataframe (GeoDataFrame): The ``GeoDataFrame`` with point geometry.
r (float or int): The minimum distance (radius) in the CRS units of ``dataframe``.
Returns:
``geopandas.GeoDataFrame``
"""
# Setup a KD tree
tree = cKDTree(np.c_[dataframe.geometry.x, dataframe.geometry.y])
# Query all pairs within ``min_dist`` of each other
near_pairs = tree.query_pairs(r=r, output_type='ndarray')
if near_pairs.shape[0] > 0:
# Get a list of pairs to remove
rm_idx = list(sorted(set(near_pairs[:, 0].tolist())))
return dataframe.query("index != {}".format(rm_idx))
return dataframe
def _transform_and_shift(
affine_transform: Affine,
col_indices: T.Sequence[int],
row_indices: T.Sequence[int],
cellxh: float,
cellyh: float,
) -> T.Tuple[np.ndarray, np.ndarray]:
"""Transforms indices to coordinates and applies a half pixel shift.
Args:
affine_transform (object): The affine transform.
col_indices (1d array): The column indices.
row_indices (1d array): The row indices.
cellxh (float): The half cell width in the x direction.
cellyh (float): The half cell width in the y direction.
Returns:
``numpy.ndarray``, ``numpy.ndarray``
"""
x_coords, y_coords = affine_transform * (col_indices, row_indices)
x_coords += abs(cellxh)
y_coords -= abs(cellyh)
return x_coords, y_coords
[docs]class SpatialOperations(_PropertyMixin):
[docs] @staticmethod
def calc_area(
data: xr.DataArray,
values: T.Sequence[T.Union[float, int]],
op: str = 'eq',
units: str = 'km2',
row_chunks: int = None,
col_chunks: int = None,
n_workers: int = 1,
n_threads: int = 1,
scheduler: str = 'threads',
n_chunks: int = 100,
) -> pd.DataFrame:
"""Calculates the area of data values.
Args:
data (DataArray): The ``xarray.DataArray`` to calculate area.
values (list): A list of values.
op (Optional[str]): The value sign. Choices are ['gt', 'ge', 'lt', 'le', 'eq'].
units (Optional[str]): The units to return. Choices are ['km2', 'ha'].
row_chunks (Optional[int]): The row chunk size to process in parallel.
col_chunks (Optional[int]): The column chunk size to process in parallel.
n_workers (Optional[int]): The number of parallel workers for ``scheduler``.
n_threads (Optional[int]): The number of parallel threads for ``dask.compute()``.
scheduler (Optional[str]): The parallel task scheduler to use. Choices are ['processes', 'threads', 'mpool'].
mpool: process pool of workers using ``multiprocessing.Pool``
processes: process pool of workers using ``concurrent.futures``
threads: thread pool of workers using ``concurrent.futures``
n_chunks (Optional[int]): The chunk size of windows. If not given, equal to ``n_workers`` x 50.
Returns:
``pandas.DataFrame``
Example:
>>> import geowombat as gw
>>>
>>> # Read a land cover image with 512x512 chunks
>>> with gw.open('land_cover.tif', chunks=512) as src:
>>>
>>> df = gw.calc_area(
>>> src,
>>> [1, 2, 5], # calculate the area of classes 1, 2, and 5
>>> units='km2', # return area in kilometers squared
>>> n_workers=4,
>>> row_chunks=1024, # iterate over larger chunks to use 512 chunks in parallel
>>> col_chunks=1024
>>> )
"""
def area_func(*args):
data_chunk, window_id, uvalues, area_units, n_threads = list(
itertools.chain(*args)
)
sqm = abs(data_chunk.gw.celly) * abs(data_chunk.gw.cellx)
area_conversion = 1e-6 if area_units == 'km2' else 0.0001
data_totals_ = defaultdict(float)
for value in uvalues:
chunk_value_total = (
data_chunk.gw.compare(op, value, return_binary=True)
.sum(skipna=True)
.data.compute(scheduler='threads', num_workers=n_threads)
)
data_totals_[value] += (
chunk_value_total * sqm
) * area_conversion
return dict(data_totals_)
pt = ParallelTask(
data,
row_chunks=row_chunks,
col_chunks=col_chunks,
scheduler=scheduler,
n_workers=n_workers,
n_chunks=n_chunks,
)
futures = pt.map(area_func, values, units, n_threads)
# Combine the results
data_totals = defaultdict(float)
for future in futures:
for k, v in future.items():
data_totals[k] += v
data_totals = dict(data_totals)
data_totals = dict(sorted(data_totals.items()))
df = pd.DataFrame.from_dict(
data_totals, orient='index', columns=[units]
)
df['area_value'] = df.index
return df
[docs] def sample(
self,
data: xr.DataArray,
method: str = 'random',
band: T.Union[int, str] = None,
n: int = None,
strata: T.Optional[T.Dict[str, T.Union[float, int]]] = None,
spacing: T.Optional[float] = None,
min_dist: T.Optional[T.Union[float, int]] = None,
max_attempts: T.Optional[int] = 10,
num_workers: T.Optional[int] = 1,
verbose: T.Optional[int] = 1,
**kwargs,
) -> gpd.GeoDataFrame:
"""Generates samples from a raster.
Args:
data (DataArray): The ``xarray.DataArray`` to extract data from.
method (Optional[str]): The sampling method. Choices are ['random', 'systematic'].
band (Optional[int or str]): The band name to extract from. Only required if ``method`` = 'random' and ``strata`` is given.
n (Optional[int]): The total number of samples. Only required if ``method`` = 'random'.
strata (Optional[dict]): The strata to sample within. The dictionary key-->value pairs should be {'conditional,value': sample size}.
E.g.,
strata = {'==,1': 0.5, '>=,2': 0.5}
... would sample 50% of total samples within class 1 and 50% of total samples in class >= 2.
strata = {'==,1': 10, '>=,2': 20}
... would sample 10 samples within class 1 and 20 samples in class >= 2.
spacing (Optional[float]): The spacing (in map projection units) when ``method`` = 'systematic'.
min_dist (Optional[float or int]): A minimum distance allowed between samples. Only applies when ``method`` = 'random'.
max_attempts (Optional[int]): The maximum numer of attempts to sample points > ``min_dist`` from each other.
num_workers (Optional[int]): The number of parallel workers for :func:`dask.compute`.
verbose (Optional[int]): The verbosity level.
kwargs (Optional[dict]): Keyword arguments passed to ``geowombat.extract``.
Returns:
``geopandas.GeoDataFrame``
Examples:
>>> import geowombat as gw
>>>
>>> # Sample 100 points randomly across the image
>>> with gw.open('image.tif') as src:
>>> df = gw.sample(src, n=100)
>>>
>>> # Sample points systematically (with 10km spacing) across the image
>>> with gw.open('image.tif') as src:
>>> df = gw.sample(src, method='systematic', spacing=10000.0)
>>>
>>> # Sample 50% of 100 in class 1 and 50% in classes >= 2
>>> strata = {'==,1': 0.5, '>=,2': 0.5}
>>> with gw.open('image.tif') as src:
>>> df = gw.sample(src, band=1, n=100, strata=strata)
>>>
>>> # Specify a per-stratum minimum allowed point distance of 1,000 meters
>>> with gw.open('image.tif') as src:
>>> df = gw.sample(src, band=1, n=100, min_dist=1000, strata=strata)
"""
if method.strip().lower() not in ['random', 'systematic']:
raise NameError("The method must be 'random' or 'systematic'.")
if method.strip().lower() == 'systematic':
if not isinstance(spacing, float):
if not isinstance(spacing, int):
logger.exception(
" If the method is 'systematic', the spacing should be provided as a float or integer."
)
raise TypeError
if strata and not band and (method.strip().lower() == 'random'):
logger.exception(
' The band name must be provided with random stratified sampling.'
)
raise NameError
df = None
if not strata:
if method == 'systematic':
x_samples = []
y_samples = []
for i in range(0, data.gw.nrows, int(spacing / data.gw.celly)):
for j in range(
0, data.gw.ncols, int(spacing / data.gw.cellx)
):
x_samples.append(j)
y_samples.append(i)
x_samples = np.array(x_samples, dtype='int64')
y_samples = np.array(y_samples, dtype='int64')
# Convert the map indices to map coordinates
x_coords, y_coords = _transform_and_shift(
data.gw.meta.affine,
x_samples,
y_samples,
data.gw.cellxh,
data.gw.cellyh,
)
df = gpd.GeoDataFrame(
data=range(0, x_coords.shape[0]),
geometry=gpd.points_from_xy(x_coords, y_coords),
crs=data.crs,
columns=['point'],
)
else:
dfs = None
sample_size = n
attempts = 0
while True:
if attempts >= max_attempts:
if verbose > 0:
logger.warning(
' Max attempts reached. Try relaxing the distance threshold.'
)
break
# Sample directly from the coordinates
y_coords = np.random.choice(
data.y.values,
size=sample_size
if sample_size < data.y.values.shape[0]
else data.y.values.shape[0] - 1,
replace=False,
)
x_coords = np.random.choice(
data.x.values,
size=sample_size
if sample_size < data.x.values.shape[0]
else data.x.values.shape[0] - 1,
replace=False,
)
if isinstance(dfs, gpd.GeoDataFrame):
dfs = pd.concat(
(
dfs,
gpd.GeoDataFrame(
data=range(0, x_coords.shape[0]),
geometry=gpd.points_from_xy(
x_coords, y_coords
),
crs=data.crs,
columns=['point'],
),
),
axis=0,
)
else:
dfs = gpd.GeoDataFrame(
data=range(0, x_coords.shape[0]),
geometry=gpd.points_from_xy(x_coords, y_coords),
crs=data.crs,
columns=['point'],
)
if isinstance(min_dist, float) or isinstance(
min_dist, int
):
# Remove samples within a minimum distance
dfn = _remove_near_points(dfs, min_dist)
df_diff = dfs.shape[0] - dfn.shape[0]
if df_diff > 0:
dfs = dfn.copy()
sample_size = df_diff
attempts += 1
continue
break
df = dfs.copy()
else:
counter = 0
dfs = None
for cond, stratum_size in strata.items():
conditionals = cond.split(';')
if len(conditionals) > 1:
sign, value = conditionals[0].split(',')
sign2, value2 = conditionals[1].split(',')
value = float(value)
value2 = float(value2)
else:
sign, value = cond.split(',')
sign = sign.strip()
value = float(value)
if isinstance(stratum_size, int):
sample_size = stratum_size
else:
sample_size = int(np.ceil(n * stratum_size))
if sample_size == 0:
continue
attempts = 0
while True:
if attempts >= max_attempts:
if verbose > 0:
logger.warning(
" Max attempts reached for value {:f}. Try relaxing the distance threshold.".format(
value
)
)
if not isinstance(df, gpd.GeoDataFrame):
df = dfs.copy()
else:
df = pd.concat((df, dfs), axis=0)
break
if len(conditionals) > 1:
if (sign == '>') and (sign2 == '<'):
valid_samples = da.where(
(data.sel(band=band).data > value)
& data.sel(band=band).data
< value2
)
elif (sign == '>') and (sign2 == '>'):
valid_samples = da.where(
(data.sel(band=band).data > value)
& data.sel(band=band).data
> value2
)
elif (sign == '>') and (sign2 == '<='):
valid_samples = da.where(
(data.sel(band=band).data > value)
& data.sel(band=band).data
<= value2
)
elif (sign == '>') and (sign2 == '>='):
valid_samples = da.where(
(data.sel(band=band).data > value)
& data.sel(band=band).data
>= value2
)
elif (sign == '<') and (sign2 == '<'):
valid_samples = da.where(
(data.sel(band=band).data < value)
& data.sel(band=band).data
< value2
)
elif (sign == '<') and (sign2 == '>'):
valid_samples = da.where(
(data.sel(band=band).data < value)
& data.sel(band=band).data
> value2
)
elif (sign == '<') and (sign2 == '<='):
valid_samples = da.where(
(data.sel(band=band).data < value)
& data.sel(band=band).data
<= value2
)
elif (sign == '<') and (sign2 == '>='):
valid_samples = da.where(
(data.sel(band=band).data < value)
& data.sel(band=band).data
>= value2
)
elif (sign == '>=') and (sign2 == '<'):
valid_samples = da.where(
(data.sel(band=band).data >= value)
& data.sel(band=band).data
< value2
)
elif (sign == '>=') and (sign2 == '>'):
valid_samples = da.where(
(data.sel(band=band).data >= value)
& data.sel(band=band).data
> value2
)
elif (sign == '>=') and (sign2 == '<='):
valid_samples = da.where(
(data.sel(band=band).data >= value)
& data.sel(band=band).data
<= value2
)
elif (sign == '>=') and (sign2 == '>='):
valid_samples = da.where(
(data.sel(band=band).data >= value)
& data.sel(band=band).data
>= value2
)
elif (sign == '<=') and (sign2 == '<'):
valid_samples = da.where(
(data.sel(band=band).data <= value)
& data.sel(band=band).data
< value2
)
elif (sign == '<=') and (sign2 == '>'):
valid_samples = da.where(
(data.sel(band=band).data <= value)
& data.sel(band=band).data
> value2
)
elif (sign == '<=') and (sign2 == '<='):
valid_samples = da.where(
(data.sel(band=band).data <= value)
& data.sel(band=band).data
<= value2
)
elif (sign == '<=') and (sign2 == '>='):
valid_samples = da.where(
(data.sel(band=band).data <= value)
& data.sel(band=band).data
>= value2
)
else:
logger.exception(
' The conditional sign was not recognized.'
)
raise NameError
else:
if sign == '>':
valid_samples = da.where(
data.sel(band=band).data > value
)
elif sign == '>=':
valid_samples = da.where(
data.sel(band=band).data >= value
)
elif sign == '<':
valid_samples = da.where(
data.sel(band=band).data < value
)
elif sign == '<=':
valid_samples = da.where(
data.sel(band=band).data <= value
)
elif sign == '==':
valid_samples = da.where(
data.sel(band=band).data == value
)
else:
logger.exception(
" The conditional sign was not recognized. Use one of '>', '>=', '<', '<=', or '=='."
)
raise NameError
valid_samples = dask.compute(
valid_samples,
num_workers=num_workers,
scheduler='threads',
)[0]
y_samples = valid_samples[0]
x_samples = valid_samples[1]
if y_samples.shape[0] > 0:
ssize = (
sample_size
if sample_size < y_samples.shape[0]
else y_samples.shape[0] - 1
)
if ssize == 0:
break
# Get indices within the stratum
idx = np.random.choice(
range(0, y_samples.shape[0]),
size=ssize,
replace=False,
)
y_samples = y_samples[idx]
x_samples = x_samples[idx]
# Convert the map indices to map coordinates
x_coords, y_coords = _transform_and_shift(
data.gw.meta.affine,
x_samples,
y_samples,
data.gw.cellxh,
data.gw.cellyh,
)
if isinstance(dfs, gpd.GeoDataFrame):
dfs = pd.concat(
(
dfs,
gpd.GeoDataFrame(
data=range(0, x_coords.shape[0]),
geometry=gpd.points_from_xy(
x_coords, y_coords
),
crs=data.crs,
columns=['point'],
),
),
axis=0,
)
else:
dfs = gpd.GeoDataFrame(
data=range(0, x_coords.shape[0]),
geometry=gpd.points_from_xy(
x_coords, y_coords
),
crs=data.crs,
columns=['point'],
)
if isinstance(min_dist, float) or isinstance(
min_dist, int
):
# Remove samples within a minimum distance
dfn = _remove_near_points(dfs, min_dist)
df_diff = dfs.shape[0] - dfn.shape[0]
if df_diff > 0:
dfs = dfn.copy()
sample_size = df_diff
attempts += 1
continue
if not isinstance(df, gpd.GeoDataFrame):
df = dfs.copy()
else:
df = pd.concat((df, dfs), axis=0)
dfs = None
break
counter += 1
if isinstance(df, gpd.GeoDataFrame):
return self.extract(data, df, **kwargs)
else:
return None
[docs] def clip_by_polygon(
self,
data: xr.DataArray,
df: T.Union[str, Path, gpd.GeoDataFrame],
query: T.Optional[str] = None,
mask_data: T.Optional[bool] = False,
expand_by: T.Optional[int] = 0,
) -> xr.DataArray:
"""Clips a DataArray by vector polygon geometry.
Args:
data (DataArray): The ``xarray.DataArray`` to subset.
df (GeoDataFrame or str): The ``geopandas.GeoDataFrame`` or filename to clip to.
query (Optional[str]): A query to apply to ``df``.
mask_data (Optional[bool]): Whether to mask values outside of the ``df`` geometry envelope.
expand_by (Optional[int]): Expand the clip array bounds by ``expand_by`` pixels on each side.
Returns:
``xarray.DataArray``
Examples:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif') as ds:
>>> ds = gw.clip_by_polygon(ds, df, query="Id == 1")
"""
if isinstance(df, (Path, str)):
if not Path(df).is_file():
raise FileExistsError(f'{str(df)} does not exist.')
df = gpd.read_file(df)
if query is not None:
df = df.query(query)
data_crs_ = check_crs(data.crs)
df_crs_ = check_crs(df.crs)
# Re-project the DataFrame to match the image CRS
if data_crs_ != df_crs_:
df = df.to_crs(data_crs_)
# Get the target bounds
left, bottom, right, top = df.total_bounds
# Get the target dimensions
dst_bounds = BoundingBox(
left=left,
bottom=bottom,
right=right,
top=top,
)
align_height, align_width = get_dims_from_bounds(
dst_bounds, (data.gw.cellx, data.gw.celly)
)
align_transform = Affine(
data.gw.cellx, 0.0, left, 0.0, -data.gw.celly, top
)
# Get the new bounds
new_left, new_bottom, new_right, new_top = array_bounds(
align_height, align_width, align_transform
)
if expand_by > 0:
new_left -= data.gw.cellx * expand_by
new_bottom -= data.gw.celly * expand_by
new_right += data.gw.cellx * expand_by
new_top += data.gw.celly * expand_by
# Subset the array
data = self.subset(
data,
left=new_left,
bottom=new_bottom,
right=new_right,
top=new_top,
)
if mask_data:
converters = Converters()
mask = converters.polygon_to_array(df, data=data)
data = (
(data * mask.where(lambda x: x == 1))
.fillna(data.gw.nodataval)
.assign_attrs(**data.attrs)
)
return data
[docs] def clip(
self,
data: xr.DataArray,
df: T.Union[str, gpd.GeoDataFrame],
query: T.Optional[str] = None,
mask_data: T.Optional[bool] = False,
expand_by: T.Optional[int] = 0,
) -> xr.DataArray:
"""Clips a DataArray by vector polygon geometry.
.. deprecated:: 2.1.7
Use :func:`geowombat.clip_by_polygon`.
Args:
data (DataArray): The ``xarray.DataArray`` to subset.
df (GeoDataFrame or str): The ``geopandas.GeoDataFrame`` or filename to clip to.
query (Optional[str]): A query to apply to ``df``.
mask_data (Optional[bool]): Whether to mask values outside of the ``df`` geometry envelope.
expand_by (Optional[int]): Expand the clip array bounds by ``expand_by`` pixels on each side.
Returns:
``xarray.DataArray``
Examples:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif') as ds:
>>> ds = gw.clip(ds, df, query="Id == 1")
>>>
>>> # or
>>>
>>> with gw.open('image.tif') as ds:
>>> ds = ds.gw.clip(df, query="Id == 1")
"""
warnings.warn(
'The method clip() will be deprecated in >=2.2.0. Use clip_by_polygon() instead.',
DeprecationWarning,
stacklevel=2,
)
if isinstance(df, (Path, str)):
if not Path(df).is_file():
raise FileExistsError(f'{str(df)} does not exist.')
df = gpd.read_file(df)
if query is not None:
df = df.query(query)
data_crs_ = check_crs(data.crs)
df_crs_ = check_crs(df.crs)
# Re-project the DataFrame to match the image CRS
if data_crs_ != df_crs_:
df = df.to_crs(data_crs_)
# Get the target bounds
left, bottom, right, top = df.total_bounds
# Get the target dimensions
dst_bounds = BoundingBox(
left=left,
bottom=bottom,
right=right,
top=top,
)
align_height, align_width = get_dims_from_bounds(
dst_bounds, (data.gw.cellx, data.gw.celly)
)
align_transform = Affine(
data.gw.cellx, 0.0, left, 0.0, -data.gw.celly, top
)
# Get the new bounds
new_left, new_bottom, new_right, new_top = array_bounds(
align_height, align_width, align_transform
)
if expand_by > 0:
new_left -= data.gw.cellx * expand_by
new_bottom -= data.gw.celly * expand_by
new_right += data.gw.cellx * expand_by
new_top += data.gw.celly * expand_by
# Subset the array
data = self.subset(
data,
left=new_left,
bottom=new_bottom,
right=new_right,
top=new_top,
)
if mask_data:
converters = Converters()
mask = converters.polygon_to_array(df, data=data)
data = (
(data * mask.where(lambda x: x == 1))
.fillna(data.gw.nodataval)
.assign_attrs(**data.attrs)
)
return data
[docs] @staticmethod
@lazy_wombat
def mask(
data: xr.DataArray,
dataframe: T.Union[str, Path, gpd.GeoDataFrame],
query: T.Optional[str] = None,
keep: T.Optional[str] = 'in',
) -> xr.DataArray:
"""Masks a DataArray by vector polygon geometry.
Args:
data (DataArray): The ``xarray.DataArray`` to mask.
dataframe (GeoDataFrame or str): The ``geopandas.GeoDataFrame`` or filename to use for masking.
query (Optional[str]): A query to apply to ``dataframe``.
keep (Optional[str]): If ``keep`` = 'in', mask values outside of the geometry (keep inside).
Otherwise, if ``keep`` = 'out', mask values inside (keep outside).
Returns:
``xarray.DataArray``
Examples:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif') as ds:
>>> ds = ds.gw.mask(df)
"""
if isinstance(dataframe, str) and os.path.isfile(dataframe):
dataframe = gpd.read_file(dataframe)
if query:
dataframe = dataframe.query(query)
data_crs_ = check_crs(data.crs)
df_crs_ = check_crs(dataframe.crs)
if data_crs_ != df_crs_:
dataframe = dataframe.to_crs(data_crs_)
# Rasterize the geometry and store as a DataArray
mask = xr.DataArray(
data=da.from_array(
features.rasterize(
list(dataframe.geometry.values),
out_shape=(data.gw.nrows, data.gw.ncols),
transform=data.gw.transform,
fill=0,
out=None,
all_touched=True,
default_value=1,
dtype='int32',
),
chunks=(data.gw.row_chunks, data.gw.col_chunks),
),
dims=['y', 'x'],
coords={'y': data.y.values, 'x': data.x.values},
)
# Return the masked array
if keep == 'out':
return data.where(mask != 1)
else:
return data.where(mask == 1)
[docs] @staticmethod
def replace(
data: xr.DataArray, to_replace: T.Dict[int, T.Union[int, str]]
) -> xr.DataArray:
"""Replace values given in to_replace with value.
Args:
data (DataArray): The ``xarray.DataArray`` to recode.
to_replace (dict): How to find the values to replace. Dictionary mappings should be given
as {from: to} pairs. If ``to_replace`` is an integer/string mapping, the to string should be 'mode'.
{1: 5}:
recode values of 1 to 5
{1: 'mode'}:
recode values of 1 to the polygon mode
Returns:
``xarray.DataArray``
Example:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif', chunks=512) as ds:
>>> # Replace 1 with 5
>>> res = gw.replace(ds, {1: 5})
"""
attrs = data.attrs.copy()
dtype = data.dtype.name
if not isinstance(to_replace, dict):
raise TypeError(
"The replace values must be a dictionary of {from: to} mappings."
)
data = data.astype('int64')
for k, v in to_replace.items():
data = xr.where(data == k, v + 100000, data)
for v in to_replace.values():
data = xr.where(data == v + 100000, data - 100000, data)
return data.assign_attrs(**attrs).astype(dtype)
[docs] @lazy_wombat
def recode(
self,
data: xr.DataArray,
polygon: T.Union[str, Path, gpd.GeoDataFrame],
to_replace: dict,
num_workers: T.Optional[int] = 1,
) -> xr.DataArray:
"""Recodes a DataArray with polygon mappings.
Args:
data (DataArray): The ``xarray.DataArray`` to recode.
polygon (GeoDataFrame | str): The ``geopandas.DataFrame`` or file with polygon geometry.
to_replace (dict): How to find the values to replace. Dictionary mappings should be given
as {from: to} pairs. If ``to_replace`` is an integer/string mapping, the to string should be 'mode'.
{1: 5}:
recode values of 1 to 5
{1: 'mode'}:
recode values of 1 to the polygon mode
num_workers (Optional[int]): The number of parallel Dask workers (only used if ``to_replace``
has a 'mode' mapping).
Returns:
``xarray.DataArray``
Example:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif', chunks=512) as ds:
>>> # Recode 1 with 5 within a polygon
>>> res = gw.recode(ds, 'poly.gpkg', {1: 5})
"""
dtype = data.dtype.name
attrs = data.attrs.copy()
converters = Converters()
poly_array = converters.polygon_to_array(polygon, data=data)
for k, v in to_replace.items():
if isinstance(v, str) and (v.lower() == 'mode'):
data_array_np = (
data.squeeze()
.where(poly_array.squeeze() == 1)
.data.compute(num_workers=num_workers)
)
to_replace[k] = int(
sci_mode(
data_array_np, axis=None, nan_policy='omit'
).mode.flatten()
)
return (
xr.where(poly_array == 1, self.replace(data, to_replace), data)
.assign_attrs(**attrs)
.astype(dtype)
)
[docs] @staticmethod
def subset(
data: xr.DataArray,
left: T.Optional[float] = None,
top: T.Optional[float] = None,
right: T.Optional[float] = None,
bottom: T.Optional[float] = None,
rows: T.Optional[int] = None,
cols: T.Optional[int] = None,
center: T.Optional[bool] = False,
mask_corners: T.Optional[bool] = False,
) -> xr.DataArray:
"""Subsets a DataArray.
Args:
data (DataArray): The ``xarray.DataArray`` to subset.
left (Optional[float]): The left coordinate.
top (Optional[float]): The top coordinate.
right (Optional[float]): The right coordinate.
bottom (Optional[float]): The bottom coordinate.
rows (Optional[int]): The number of output rows.
cols (Optional[int]): The number of output rows.
center (Optional[bool]): Whether to center the subset on ``left`` and ``top``.
mask_corners (Optional[bool]): Whether to mask corners (*requires ``pymorph``).
Returns:
``xarray.DataArray``
Example:
>>> import geowombat as gw
>>>
>>> with gw.open('image.tif', chunks=512) as ds:
>>> ds_sub = gw.subset(
>>> ds,
>>> left=-263529.884,
>>> top=953985.314,
>>> rows=2048,
>>> cols=2048
>>> )
"""
if isinstance(right, (float, int)):
cols = int((right - left) / data.gw.celly)
if not isinstance(cols, int):
logger.exception(
' The right coordinate or columns must be specified.'
)
raise NameError
if isinstance(bottom, (float, int)):
rows = int((top - bottom) / data.gw.celly)
if not isinstance(rows, int):
logger.exception(
' The bottom coordinate or rows must be specified.'
)
raise NameError
converters = Converters()
x_idx, y_idx = converters.coords_to_indices(
left + abs(data.gw.cellxh),
top - abs(data.gw.cellyh),
data.gw.affine,
)
if center:
y_idx += (rows / 2.0) * abs(data.gw.celly)
x_idx -= (cols / 2.0) * abs(data.gw.cellx)
ds_sub = data.isel(
indexers={
'band': slice(0, None),
'y': slice(y_idx, y_idx + rows),
'x': slice(x_idx, x_idx + cols),
}
)
ds_sub.attrs['transform'] = ds_sub.gw.transform
if mask_corners:
if PYMORPH_INSTALLED:
try:
disk = da.from_array(
pymorph.sedisk(r=int(rows / 2.0))[:rows, :cols],
chunks=ds_sub.data.chunksize,
).astype('uint8')
ds_sub = ds_sub.where(disk == 1)
except ValueError:
logger.warning(
' Cannot mask corners without a square subset.'
)
else:
logger.warning(' Cannot mask corners without Pymorph.')
return ds_sub
[docs] @staticmethod
def coregister(
target: T.Union[str, Path, xr.DataArray],
reference: T.Union[str, Path, xr.DataArray],
band_names_reference: T.Optional[T.Sequence[str]] = None,
band_names_target: T.Optional[T.Sequence[str]] = None,
wkt_version: T.Optional[str] = None,
**kwargs,
) -> xr.DataArray:
"""Co-registers an image, or images, using AROSICS.
While the required inputs are DataArrays, the intermediate results are stored as NumPy arrays.
Therefore, memory usage is constrained to the size of the input data. Dask is not used for any of the
computation in this function.
Args:
target (DataArray or str): The target ``xarray.DataArray`` or file name to co-register to ``reference``.
reference (DataArray or str): The reference ``xarray.DataArray`` or file name used to co-register ``target``.
band_names_reference (Optional[list | tuple]): Band names to open for the reference data.
band_names_target (Optional[list | tuple]): Band names to open for the target data.
wkt_version (Optional[str]): The WKT version to use with ``to_wkt()``.
kwargs (Optional[dict]): Keyword arguments passed to ``arosics``.
Reference:
https://pypi.org/project/arosics
Returns:
``xarray.DataArray``
Example:
>>> import geowombat as gw
>>>
>>> # Co-register a single image to a reference image
>>> with gw.open('target.tif') as tar, gw.open('reference.tif') as ref:
>>> results = gw.coregister(
>>> tar, ref, q=True, ws=(512, 512), max_shift=3, CPUs=4
>>> )
>>>
>>> # or
>>>
>>> results = gw.coregister(
>>> 'target.tif',
>>> 'reference.tif',
>>> q=True,
>>> ws=(512, 512),
>>> max_shift=3,
>>> CPUs=4
>>> )
"""
import geowombat as gw_
if not AROSICS_INSTALLED:
logger.exception(
'\nAROSICS must be installed to co-register data.\nSee https://pypi.org/project/arosics for details'
)
raise NameError
if isinstance(reference, str):
if not os.path.isfile(reference):
logger.exception(' The reference file does not exist.')
raise OSError
with gw_.open(reference) as reference:
pass
if isinstance(target, str):
if not os.path.isfile(target):
logger.exception(' The target file does not exist.')
raise OSError
with gw_.open(target) as target:
pass
if wkt_version is None:
wkt_version = WktVersion.WKT2_2019.value
elif isinstance(wkt_version, str):
try:
wkt_version = WktVersion[wkt_version].value
except KeyError as e:
logger.exception(e)
raise KeyError
with tempfile.TemporaryDirectory() as tmp:
ref_path = Path(tmp) / '_reference.tif'
tar_path = Path(tmp) / '_target.tif'
ref_kwargs = (
{'engine': 'h5netcdf'}
if str(reference.filename).lower().endswith('.nc')
else {}
)
with gw_.open(
reference.filename,
band_names=band_names_reference,
chunks={
'band': reference.gw.band_chunks,
'y': reference.gw.row_chunks,
'x': reference.gw.col_chunks,
},
**ref_kwargs,
) as ref_src:
ref_src = ref_src.assign_attrs(
{
'crs': ref_src.gw.crs_to_pyproj.to_wkt(
version=wkt_version
)
}
)
if 'nodata' in kwargs:
ref_src = ref_src.fillna(
kwargs['nodata'][0]
).gw.assign_nodata_attrs(kwargs['nodata'][0])
ref_src.gw.save(ref_path, overwrite=True, log_progress=False)
tar_kwargs = (
{'engine': 'h5netcdf'}
if str(target.filename).lower().endswith('.nc')
else {}
)
with gw_.open(
target.filename,
band_names=band_names_target,
chunks={
'band': target.gw.band_chunks,
'y': target.gw.row_chunks,
'x': target.gw.col_chunks,
},
**tar_kwargs,
) as tar_src:
tar_src = tar_src.assign_attrs(
{
'crs': tar_src.gw.crs_to_pyproj.to_wkt(
version=wkt_version
)
}
)
if 'nodata' in kwargs:
tar_src = tar_src.fillna(
kwargs['nodata'][1]
).gw.assign_nodata_attrs(kwargs['nodata'][1])
tar_src.gw.save(tar_path, overwrite=True, log_progress=False)
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
try:
cr = arosics.COREG(
GeoArray(
GeoArray(str(ref_path)), projection=ref_src.crs
),
GeoArray(
GeoArray(str(tar_path)), projection=tar_src.crs
),
**kwargs,
)
except CRSError as e:
logger.exception(e)
raise AttributeError(
"Try using an integer EPSG format (e.g., data.attrs['crs'] = 'epsg:32621')"
)
try:
cr.calculate_spatial_shifts()
except RuntimeError as e:
logger.warning(
f" Could not co-register the data because -> {e}"
)
return target
# Apply spatial shifts
shift_info = cr.correct_shifts()
left = shift_info['updated geotransform'][0]
top = shift_info['updated geotransform'][3]
# Update the transform
transform = (target.gw.cellx, 0.0, left, 0.0, -target.gw.celly, top)
target.attrs['transform'] = transform
# Check if multi-band
if len(shift_info['arr_shifted'].shape) > 2:
data = shift_info['arr_shifted'].transpose(2, 0, 1)
else:
data = shift_info['arr_shifted'][np.newaxis]
yidx = 1
xidx = 2
# Get the updated coordinates
ycoords = np.linspace(
top - target.gw.cellyh,
top - target.gw.cellyh - (data.shape[yidx] * target.gw.celly),
data.shape[yidx],
)
xcoords = np.linspace(
left + target.gw.cellxh,
left + target.gw.cellxh + (data.shape[xidx] * target.gw.cellx),
data.shape[xidx],
)
target_attrs = target.attrs.copy()
if hasattr(cr, 'x_shift_px'):
target_attrs['x_shift_px'] = cr.x_shift_px
if hasattr(cr, 'y_shift_px'):
target_attrs['y_shift_px'] = cr.y_shift_px
return xr.DataArray(
data=da.from_array(
data,
chunks=(
target.gw.band_chunks,
target.gw.row_chunks,
target.gw.col_chunks,
),
),
dims=('band', 'y', 'x'),
coords={
'band': target.band.values.tolist(),
'y': ycoords,
'x': xcoords,
},
attrs=target_attrs,
)