import concurrent.futures
import typing as T
from abc import abstractmethod
from datetime import datetime
import numpy as np
import pandas as pd
import rasterio as rio
import xarray as xr
from affine import Affine
from rasterio.coords import BoundingBox
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from rasterio.windows import Window
from .windows import get_window_offsets
try:
import torch
PYTORCH_INSTALLED = True
except ImportError:
PYTORCH_INSTALLED = False
try:
import tensorflow as tf
TENSORFLOW_INSTALLED = True
except ImportError:
TENSORFLOW_INSTALLED = False
try:
import jax.numpy as jnp
JAX_INSTALLED = True
except ImportError:
JAX_INSTALLED = False
[docs]class TransferLib(object):
"""Device transfers.
Args:
transfer_lib (str): The device library to transfer to.
Choices are ['jax', 'keras', 'numpy', 'pytorch', 'tensorflow'].
'jax' -> GPU
'keras' -> GPU
'numpy' -> CPU
'pytorch' -> GPU
'tensorflow' -> GPU
"""
def __init__(self, transfer_lib: str):
self.transfer_lib = transfer_lib
[docs] @staticmethod
def jax(array):
return jnp.asarray(array, dtype='float32')
[docs] @staticmethod
def keras(array):
raise NotImplementedError
[docs] @staticmethod
def numpy(array):
return np.asarray(array, dtype='float64')
[docs] @staticmethod
def pytorch(array):
return torch.from_numpy(array).float().to('cuda:0')
[docs] @staticmethod
def tensorflow(array):
return tf.convert_to_tensor(array, tf.float64)
def __call__(self, array):
return getattr(self, self.transfer_lib)(array)
class _Warp(object):
def warp(
self,
dst_crs=None,
dst_res=None,
dst_bounds=None,
resampling='nearest',
nodata=None,
warp_mem_limit=None,
num_threads=None,
window_size=None,
padding=None,
):
if dst_crs is None:
dst_crs = self.srcs_[0].crs
if dst_res is None:
dst_res = self.srcs_[0].res
if dst_bounds is None:
dst_bounds = self.srcs_[0].bounds
else:
if isinstance(dst_bounds, list) or isinstance(dst_bounds, tuple):
dst_bounds = BoundingBox(
left=dst_bounds[0],
bottom=dst_bounds[1],
right=dst_bounds[2],
top=dst_bounds[3],
)
if nodata is None:
nodata = self.srcs_[0].nodata
if warp_mem_limit is None:
warp_mem_limit = 256
if num_threads is None:
num_threads = 1
# The destination transform
dst_transform = Affine(
dst_res[0], 0.0, dst_bounds.left, 0.0, -dst_res[1], dst_bounds.top
)
# The destination size
dst_width = int((dst_bounds.right - dst_bounds.left) / dst_res[0])
dst_height = int((dst_bounds.top - dst_bounds.bottom) / dst_res[1])
# The write parameters
vrt_options = {
'resampling': getattr(Resampling, resampling),
'crs': dst_crs,
'transform': dst_transform,
'height': dst_height,
'width': dst_width,
'nodata': nodata,
'warp_mem_limit': warp_mem_limit,
}
def _warp_window(src_):
return WarpedVRT(
src_,
src_crs=src_.crs,
src_transform=src_.transform,
**vrt_options,
)
# Warp all inputs into virtual in-memory objects
with concurrent.futures.ThreadPoolExecutor(num_threads) as executor:
data_gen = (src for src in self.srcs_)
self.vrts_ = []
for res in executor.map(_warp_window, data_gen):
self.vrts_.append(res)
# self.vrts_ = [
# WarpedVRT(
# src,
# src_crs=src.crs,
# src_transform=src.transform,
# **vrt_options)
# for src in self.srcs_
# ]
if window_size == -1:
self.windows_ = [
Window(
row_off=0, col_off=0, height=dst_height, width=dst_width
)
]
elif window_size:
# Get a list of Window objects
self.windows_ = get_window_offsets(
dst_height,
dst_width,
window_size[0],
window_size[1],
return_as='list',
padding=padding,
)
else:
self.windows_ = [
[w[1] for w in src.block_windows(1)] for src in self.vrts_
][0]
class _SeriesProps(object):
@property
def crs(self):
return self.vrts_[0].crs
@property
def transform(self):
return self.vrts_[0].transform
@property
def count(self):
return self.vrts_[0].count
@property
def width(self):
return self.vrts_[0].width
@property
def height(self):
return self.vrts_[0].height
@property
def blockxsize(self):
return self.windows_[0].width
@property
def blockysize(self):
return self.windows_[0].height
@property
def nchunks(self):
return len(self.windows_)
@property
def nodata(self):
return self.vrts_[0].nodata
@property
def band_dict(self):
return (
dict(zip(self.band_names, range(0, self.count)))
if self.band_names
else None
)
[docs]class BaseSeries(_SeriesProps, _Warp):
[docs] def open(self, filenames):
self.srcs_ = [rio.open(fn) for fn in filenames]
[docs] @staticmethod
def ndarray_to_darray(
data: np.ndarray,
image_dates: T.List[datetime],
band_names: T.List[str],
y: np.ndarray,
x: np.ndarray,
attrs: T.Optional[T.Dict] = None,
) -> xr.DataArray:
return xr.DataArray(
data,
dims=('time', 'band', 'y', 'x'),
coords={'time': image_dates, 'band': band_names, 'y': y, 'x': x},
attrs=attrs,
)
[docs] def group_dates(
self,
data: np.ndarray,
image_dates: T.List[datetime],
band_names: T.List[str],
) -> T.Tuple[np.ndarray, T.List[datetime]]:
"""Groups data by dates."""
time_df = pd.DataFrame(data=image_dates, columns=['date'])
dupe_dates = time_df.duplicated(keep='first')
if not dupe_dates.any():
return data, image_dates
# Convert the NumPy array to a DataArray
da = self.ndarray_to_darray(
data,
image_dates=image_dates,
band_names=band_names,
y=np.arange(data.shape[2]),
x=np.arange(data.shape[3]),
)
# Group duplicated dates
da = (
da.where(lambda x: x != 0)
.groupby('time')
.mean('time', skipna=True)
)
return da.values, da.gw.pydatetime.tolist()
[docs]class TimeModule(object):
def __init__(self):
self.dtype = 'float64'
self.count = 1
self.compress = 'lzw'
self.bigtiff = 'NO'
self.band_dict = None
[docs] def __call__(self, w, array, band_dict):
self.band_dict = band_dict
return w, self.calculate(array)
def __repr__(self):
return (
f"{self.__class__.__name__}():\n "
f"self.dtype='{self.dtype}'\n "
f"self.count={self.count}\n "
f"self.compress='{self.compress}'\n "
f"self.bigtiff='{self.bigtiff}'"
)
def __str__(self):
return (
f"{self.__class__.__name__}():\n "
f"self.dtype='{self.dtype}'\n "
f"self.count={self.count}\n "
f"self.compress='{self.compress}'\n "
f"self.bigtiff='{self.bigtiff}'\n "
f"-> Array(numpy.ndarray | jax.Array | torch.Tensor | tensorflow.Tensor)[bands x height x width]"
)
def __add__(self, other):
if isinstance(other, TimeModulePipeline):
return TimeModulePipeline([self] + other.modules)
else:
return TimeModulePipeline([self, other])
[docs] @abstractmethod
def calculate(self, data: T.Any) -> T.Any:
"""Calculates the user function.
Args:
data (``numpy.ndarray`` |
``jax.Array`` |
``torch.Tensor`` |
``tensorflow.Tensor``): The input array, shaped [time x bands x rows x columns].
Returns:
``numpy.ndarray`` |
``jax.Array`` |
``torch.Tensor`` |
``tensorflow.Tensor``:
Shaped (time|bands x rows x columns)
"""
raise NotImplementedError
[docs]class TimeModulePipeline(object):
def __init__(self, module_list: T.List[TimeModule]):
self.modules = module_list
self.count = 0
for module in self.modules:
self.count += module.count
self.dtype = self.modules[-1].dtype
self.compress = self.modules[-1].compress
self.bigtiff = self.modules[-1].bigtiff
def __add__(self, other):
if isinstance(other, TimeModulePipeline):
return TimeModulePipeline(self.modules + other.modules)
else:
return TimeModulePipeline(self.modules + [other])
[docs] def __call__(self, w, array, band_dict):
results = []
for module in self.modules:
res = module(w, array, band_dict)[1]
if len(res.shape) == 2:
res = res[np.newaxis]
results.append(res)
return w, jnp.vstack(results).squeeze()
[docs]class SeriesStats(TimeModule):
def __init__(self, time_stats):
super(SeriesStats, self).__init__()
self.time_stats = time_stats
if isinstance(self.time_stats, str):
self.count = 1
else:
self.count = len(list(self.time_stats))
[docs] def calculate(self, array):
if isinstance(self.time_stats, str):
return np.asarray(getattr(self, self.time_stats)(array))
else:
return np.asarray(self._stack(array, self.time_stats))
@staticmethod
def _scale_min_max(xv, mni, mxi, mno, mxo):
return ((((mxo - mno) * (xv - mni)) / (mxi - mni)) + mno).clip(
mno, mxo
)
@staticmethod
def _lstsq(data):
ndims, nbands, nrows, ncols = data.shape
M = data.squeeze().transpose(1, 2, 0).reshape(nrows * ncols, ndims).T
x = jnp.arange(0, M.shape[0])
# Fit a least squares solution to each sample
return jnp.linalg.lstsq(jnp.c_[x, jnp.ones_like(x)], M, rcond=None)[0]
[docs] def abs_slope_q1(self, data):
"""Calculates the absolute slope of the first quarter."""
b1 = self._lstsq(data[: int(0.25 * data.shape[0])])[0]
b1[np.isnan(b1) | np.isinf(b1)] = 0
return self._scale_min_max(jnp.fabs(b1), 0.0, 0.05, 0.0, 1.0)
[docs] def abs_slope_q2(self, data):
"""Calculates the absolute slope of the second quarter."""
b1 = self._lstsq(
data[int(0.25 * data.shape[0]) : int(0.5 * data.shape[0])]
)[0]
b1[np.isnan(b1) | np.isinf(b1)] = 0
return self._scale_min_max(jnp.fabs(b1), 0.0, 0.05, 0.0, 1.0)
[docs] def abs_slope_q3(self, data):
"""Calculates the absolute slope of the third quarter."""
b1 = self._lstsq(
data[int(0.5 * data.shape[0]) : int(0.75 * data.shape[0])]
)[0]
b1[np.isnan(b1) | np.isinf(b1)] = 0
return self._scale_min_max(jnp.fabs(b1), 0.0, 0.05, 0.0, 1.0)
[docs] def abs_slope_q4(self, data):
"""Calculates the absolute slope of the fourth quarter."""
b1 = self._lstsq(data[int(0.75 * data.shape[0]) :])[0]
b1[np.isnan(b1) | np.isinf(b1)] = 0
return self._scale_min_max(jnp.fabs(b1), 0.0, 0.05, 0.0, 1.0)
[docs] @staticmethod
def amp(array):
"""Calculates the amplitude."""
return (
jnp.nanmax(array, axis=0).squeeze()
- jnp.nanmin(array, axis=0).squeeze()
)
[docs] @staticmethod
def cv(array):
"""Calculates the coefficient of variation."""
return jnp.nanstd(array, axis=0).squeeze() / (
jnp.nanmean(array, axis=0).squeeze() + 1e-9
)
[docs] @staticmethod
def max(array):
"""Calculates the max."""
return jnp.nanmax(array, axis=0).squeeze()
[docs] @staticmethod
def mean(array):
"""Calculates the mean."""
return jnp.nanmean(array, axis=0).squeeze()
[docs] def mean_abs_diff(self, array):
"""Calculates the mean absolute difference."""
d = jnp.nanmean(
jnp.fabs(jnp.diff(array, n=1, axis=0)), axis=0
).squeeze()
return self._scale_min_max(d, 0.0, 0.05, 0.0, 1.0)
[docs] @staticmethod
def min(array):
"""Calculates the min."""
return jnp.nanmin(array, axis=0).squeeze()
[docs] @staticmethod
def norm_abs_energy(array):
"""Calculates the normalized absolute energy."""
return (
jnp.nansum(array**2, axis=0).squeeze()
/ (jnp.nanmax(array, axis=0) ** 2 * array.shape[0]).squeeze()
)
[docs] @staticmethod
def percentile(array, p):
"""Calculates the nth percentile."""
return jnp.nanpercentile(array, p, axis=0).squeeze()
def _stack(self, array, stats):
"""Calculates a stack of statistics."""
return jnp.vstack(
[
getattr(self, 'percentile')(array, int(stat[10:]))[np.newaxis]
if stat.startswith('percentile')
else getattr(self, stat)(array)[np.newaxis]
for stat in stats
]
).squeeze()