Source code for geowombat.core.api

# https://github.com/pydata/xarray/issues/2560
try:
    import netCDF4
except ImportError:
    pass

try:
    import h5netcdf
except ImportError:
    pass

import concurrent.futures
import logging
import threading
import typing as T
import warnings
from contextlib import contextmanager
from pathlib import Path

import dask
import dask.array as da
import numpy as np
import rasterio as rio
import xarray as xr
from rasterio.coords import BoundingBox
from rasterio.windows import Window, from_bounds
from tqdm.auto import tqdm

from ..backends import concat as gw_concat
from ..backends import mosaic as gw_mosaic
from ..backends import warp_open
from ..backends.rasterio_ import check_src_crs
from ..config import _set_defaults, config
from ..handler import add_handler
from . import geoxarray
from .series import BaseSeries, SeriesStats, TransferLib
from .util import Chunks, get_file_extension, parse_wildcard

logger = logging.getLogger(__name__)
logger = add_handler(logger)
warnings.filterwarnings("ignore")

ch = Chunks()

IO_DICT = dict(
    rasterio=[
        ".tif",
        ".tiff",
        ".TIF",
        ".TIFF",
        ".img",
        ".IMG",
        ".kea",
        ".vrt",
        ".VRT",
        ".jp2",
        ".JP2",
        ".hgt",
        ".HGT",
        ".hdf",
        ".HDF",
        ".h5",
        ".H5",
    ],
    xarray=[".nc"],
)


@contextmanager
def _tqdm(*args, **kwargs):
    yield None


def _get_attrs(src, **kwargs):
    cellxh = src.res[0] / 2.0
    cellyh = src.res[1] / 2.0

    left_ = src.bounds.left + (kwargs["window"].col_off * src.res[0]) + cellxh
    top_ = src.bounds.top - (kwargs["window"].row_off * src.res[1]) - cellyh

    xcoords = np.arange(
        left_, left_ + kwargs["window"].width * src.res[0], src.res[0]
    )
    ycoords = np.arange(
        top_, top_ - kwargs["window"].height * src.res[1], -src.res[1]
    )

    attrs = {}
    attrs["transform"] = (
        src.gw.transform if hasattr(src, "gw") else src.transform
    )

    if hasattr(src, "crs"):
        src_crs = check_src_crs(src)
        try:
            attrs["crs"] = src_crs.to_proj4()
        except Exception:
            attrs["crs"] = src_crs.to_string()

    if hasattr(src, "res"):
        attrs["res"] = src.res

    if hasattr(src, "is_tiled"):
        attrs["is_tiled"] = np.uint8(src.is_tiled)

    if hasattr(src, "nodatavals"):
        attrs["nodatavals"] = tuple(
            np.nan if nodataval is None else nodataval
            for nodataval in src.nodatavals
        )

    if hasattr(src, "offsets"):
        attrs["offsets"] = src.scales

    if hasattr(src, "offsets"):
        attrs["offsets"] = src.offsets

    if hasattr(src, "descriptions") and any(src.descriptions):
        attrs["descriptions"] = src.descriptions

    if hasattr(src, "units") and any(src.units):
        attrs["units"] = src.units

    return ycoords, xcoords, attrs


@dask.delayed
def read_delayed(fname, chunks, **kwargs):
    with rio.open(fname) as src:
        data_slice = src.read(**kwargs)
        single_band = True if len(data_slice.shape) == 2 else False

        if isinstance(chunks, int):
            chunks_ = (1, chunks, chunks)
        elif isinstance(chunks, tuple):
            chunks_ = (1,) + chunks if len(chunks) < 3 else chunks

        if single_band:
            # Expand to 1 band
            data_slice = da.from_array(data_slice[np.newaxis], chunks=chunks_)
        else:
            data_slice = da.from_array(data_slice, chunks=chunks)

        return data_slice


[docs]def read_list(file_list, chunks, **kwargs): return [read_delayed(fn, chunks, **kwargs) for fn in file_list]
[docs]def read( filename, band_names=None, time_names=None, bounds=None, chunks=256, num_workers=1, **kwargs, ): """Reads a window slice in-memory. Args: filename (str or list): A file name or list of file names to open read. band_names (Optional[list]): A list of names to give the output band dimension. time_names (Optional[list]): A list of names to give the time dimension. bounds (Optional[1d array-like]): A bounding box to subset to, given as [minx, miny, maxx, maxy] or [left, bottom, right, top]. chunks (Optional[tuple]): The data chunk size. num_workers (Optional[int]): The number of parallel ``dask`` workers. kwargs (Optional[dict]): Keyword arguments to pass to ``rasterio.write``. Returns: ``xarray.DataArray`` """ # Cannot pass 'chunks' to rasterio if "chunks" in kwargs: del kwargs["chunks"] if isinstance(filename, str): with rio.open(filename) as src: src_transform = ( src.gw.transform if hasattr(src, "gw") else src.transform ) if bounds and ("window" not in kwargs): kwargs["window"] = from_bounds( *bounds, transform=src_transform ) ycoords, xcoords, attrs = _get_attrs(src, **kwargs) data = dask.compute( read_delayed(filename, chunks, **kwargs), num_workers=num_workers )[0] if not band_names: band_names = np.arange(1, data.shape[0] + 1) if len(band_names) != data.shape[0]: logger.exception( " The band names do not match the output dimensions." ) raise ValueError data = xr.DataArray( data, dims=("band", "y", "x"), coords={ "band": band_names, "y": ycoords[: data.shape[-2]], "x": xcoords[: data.shape[-1]], }, attrs=attrs, ) else: with rio.open(filename[0]) as src: src_transform = ( src.gw.transform if hasattr(src, "gw") else src.transform ) if bounds and ("window" not in kwargs): kwargs["window"] = from_bounds( *bounds, transform=src_transform ) ycoords, xcoords, attrs = _get_attrs(src, **kwargs) data = da.concatenate( dask.compute( read_list(filename, chunks, **kwargs), num_workers=num_workers ), axis=0, ) if not band_names: band_names = np.arange(1, data.shape[-3] + 1) if len(band_names) != data.shape[-3]: logger.exception( " The band names do not match the output dimensions." ) raise ValueError if not time_names: time_names = np.arange(1, len(filename) + 1) if len(time_names) != data.shape[-4]: logger.exception( " The time names do not match the output dimensions." ) raise ValueError data = xr.DataArray( data, dims=("time", "band", "y", "x"), coords={ "time": time_names, "band": band_names, "y": ycoords[: data.shape[-2]], "x": xcoords[: data.shape[-1]], }, attrs=attrs, ) return data
data_ = None
[docs]class open(object): """Opens one or more raster files. Args: filename (str or list): The file name, search string, or a list of files to open. band_names (Optional[1d array-like]): A list of band names if ``bounds`` is given or ``window`` is given. Default is None. time_names (Optional[1d array-like]): A list of names to give the time dimension if ``bounds`` is given. Default is None. stack_dim (Optional[str]): The stack dimension. Choices are ['time', 'band']. bounds (Optional[1d array-like]): A bounding box to subset to, given as [minx, maxy, miny, maxx]. Default is None. bounds_by (Optional[str]): How to concatenate the output extent if ``filename`` is a ``list`` and ``mosaic`` = ``False``. Choices are ['intersection', 'union', 'reference']. * reference: Use the bounds of the reference image. If a ``ref_image`` is not given, the first image in the ``filename`` list is used. * intersection: Use the intersection (i.e., minimum extent) of all the image bounds * union: Use the union (i.e., maximum extent) of all the image bounds resampling (Optional[str]): The resampling method if ``filename`` is a ``list``. Choices are ['average', 'bilinear', 'cubic', 'cubic_spline', 'gauss', 'lanczos', 'max', 'med', 'min', 'mode', 'nearest']. persist_filenames (Optional[bool]): Whether to persist the filenames list with the ``xarray.DataArray`` attributes. By default, ``persist_filenames=False`` to avoid storing large file lists. netcdf_vars (Optional[list]): NetCDF variables to open as a band stack. mosaic (Optional[bool]): If ``filename`` is a ``list``, whether to mosaic the arrays instead of stacking. overlap (Optional[str]): The keyword that determines how to handle overlapping data if ``filenames`` is a ``list``. Choices are ['min', 'max', 'mean']. nodata (Optional[float | int]): A 'no data' value to set. Default is ``None``. If ``nodata`` is ``None``, the 'no data' value is set from the file metadata. If ``nodata`` is given, then the file 'no data' value is overridden. See docstring examples for use of ``nodata`` in ``geowombat.config.update``. .. note:: The ``geowombat.config.update`` overrides this argument. Thus, preference is always given in the following order: 1. ``geowombat.config.update(nodata not None)`` 2. ``open(nodata not None)`` 3. file 'no data' value from metadata '_FillValue' or 'nodatavals' scale_factor (Optional[float | int]): A scale value to apply to the opened data. The same rules used in ``nodata`` apply. I.e., .. note:: The ``geowombat.config.update`` overrides this argument. Thus, preference is always given in the following order: 1. ``geowombat.config.update(scale_factor not None)`` 2. ``open(scale_factor not None)`` 3. file scale value from metadata 'scales' offset (Optional[float | int]): An offset value to apply to the opened data. The same rules used in ``nodata`` apply. I.e., .. note:: The ``geowombat.config.update`` overrides this argument. Thus, preference is always given in the following order: 1. ``geowombat.config.update(offset not None)`` 2. ``open(offset not None)`` 3. file offset value from metadata 'offsets' dtype (Optional[str]): A data type to force the output to. If not given, the data type is extracted from the file. scale_data (Optional[bool]): Whether to apply scaling to the opened data. Default is ``False``. Scaled data are returned as: scaled = data * gain + offset See the arguments ``nodata``, ``scale_factor``, and ``offset`` for rules regarding how scaling is applied. num_workers (Optional[int]): The number of parallel workers for Dask if ``bounds`` is given or ``window`` is given. Default is 1. kwargs (Optional[dict]): Keyword arguments passed to the file opener. Returns: ``xarray.DataArray`` or ``xarray.Dataset`` Examples: >>> import geowombat as gw >>> >>> # Open an image >>> with gw.open('image.tif') as ds: >>> print(ds) >>> >>> # Open a list of images, stacking along the 'time' dimension >>> with gw.open(['image1.tif', 'image2.tif']) as ds: >>> print(ds) >>> >>> # Open all GeoTiffs in a directory, stack along the 'time' dimension >>> with gw.open('*.tif') as ds: >>> print(ds) >>> >>> # Use a context manager to handle images of difference sizes and projections >>> with gw.config.update(ref_image='image1.tif'): >>> # Use 'time' names to stack and mosaic non-aligned images with identical dates >>> with gw.open(['image1.tif', 'image2.tif', 'image3.tif'], >>> >>> # The first two images were acquired on the same date >>> # and will be merged into a single time layer >>> time_names=['date1', 'date1', 'date2']) as ds: >>> >>> print(ds) >>> >>> # Mosaic images across space using a reference >>> # image for the CRS and cell resolution >>> with gw.config.update(ref_image='image1.tif'): >>> with gw.open(['image1.tif', 'image2.tif'], mosaic=True) as ds: >>> print(ds) >>> >>> # Mix configuration keywords >>> with gw.config.update(ref_crs='image1.tif', ref_res='image1.tif', ref_bounds='image2.tif'): >>> # The ``bounds_by`` keyword overrides the extent bounds >>> with gw.open(['image1.tif', 'image2.tif'], bounds_by='union') as ds: >>> print(ds) >>> >>> # Resample an image to 10m x 10m cell size >>> with gw.config.update(ref_crs=(10, 10)): >>> with gw.open('image.tif', resampling='cubic') as ds: >>> print(ds) >>> >>> # Open a list of images at a window slice >>> from rasterio.windows import Window >>> # Stack two images, opening band 3 >>> with gw.open( >>> ['image1.tif', 'image2.tif'], >>> band_names=['date1', 'date2'], >>> num_workers=8, >>> indexes=3, >>> window=Window(row_off=0, col_off=0, height=100, width=100), >>> dtype='float32' >>> ) as ds: >>> print(ds) >>> >>> # Scale data upon opening, using the image metadata to get scales and offsets >>> with gw.open('image.tif', scale_data=True) as ds: >>> print(ds) >>> >>> # Scale data upon opening, specifying scales and overriding metadata >>> with gw.open('image.tif', scale_data=True, scale_factor=1e-4) as ds: >>> print(ds) >>> >>> # Scale data upon opening, specifying scales and overriding metadata >>> with gw.config.update(scale_factor=1e-4): >>> with gw.open('image.tif', scale_data=True) as ds: >>> print(ds) >>> >>> # Open a NetCDF variable, specifying a NetCDF prefix and variable to open >>> with gw.open('netcdf:image.nc:blue') as src: >>> print(src) >>> >>> # Open a NetCDF image without access to transforms by providing full file path >>> # NOTE: This will be faster than the above method >>> # as it uses ``xarray.open_dataset`` and bypasses CRS checks. >>> # NOTE: The chunks must be provided by the user. >>> # NOTE: Providing band names will ensure the correct order when reading from a NetCDF dataset. >>> with gw.open( >>> 'image.nc', >>> chunks={'band': -1, 'y': 256, 'x': 256}, >>> band_names=['blue', 'green', 'red', 'nir', 'swir1', 'swir2'], >>> engine='h5netcdf' >>> ) as src: >>> print(src) >>> >>> # Open multiple NetCDF variables as an array stack >>> with gw.open('netcdf:image.nc', netcdf_vars=['blue', 'green', 'red']) as src: >>> print(src) """ def __init__( self, filename: T.Union[str, Path, T.Sequence[T.Union[str, Path]]], band_names: T.Optional[T.Union[T.Sequence, int, str]] = None, time_names: T.Optional[T.Sequence] = None, stack_dim: T.Optional[str] = "time", bounds: T.Optional[T.Union[BoundingBox, T.Sequence[float]]] = None, bounds_by: T.Optional[str] = "reference", resampling: T.Optional[str] = "nearest", persist_filenames: T.Optional[bool] = False, netcdf_vars: T.Optional[T.Union[T.Sequence, int, str]] = None, mosaic: T.Optional[bool] = False, overlap: T.Optional[str] = "max", nodata: T.Optional[T.Union[float, int]] = None, scale_factor: T.Optional[T.Union[float, int]] = None, offset: T.Optional[T.Union[float, int]] = None, dtype: T.Optional[T.Union[str, np.dtype]] = None, scale_data: T.Optional[bool] = False, num_workers: T.Optional[int] = 1, **kwargs, ): if stack_dim not in ["band", "time"]: logger.exception( f" The 'stack_dim' keyword argument must be either 'band' or 'time', but not {stack_dim}" ) raise NameError if isinstance(filename, Path): filename = str(filename) elif isinstance(filename, list) and len(filename) == 1: filename = str(filename[0]) self.data = data_ self.__is_context_manager = False self.__data_are_separate = False self.__data_are_stacked = False self.__filenames = [] band_chunks = -1 if "chunks" in kwargs: if kwargs["chunks"] is not None: kwargs["chunks"] = ch.check_chunktype( kwargs["chunks"], output="3d" ) if bounds or ( "window" in kwargs and isinstance(kwargs["window"], Window) ): if "chunks" not in kwargs: if isinstance(filename, list): with rio.open(filename[0]) as src_: w = src_.block_window(1, 0, 0) chunks = (band_chunks, w.height, w.width) else: with rio.open(filename) as src_: w = src_.block_window(1, 0, 0) chunks = (band_chunks, w.height, w.width) else: chunks = kwargs["chunks"] del kwargs["chunks"] self.data = read( filename, band_names=band_names, time_names=time_names, bounds=bounds, chunks=chunks, num_workers=num_workers, **kwargs, ) self.__filenames = [str(filename)] else: if (isinstance(filename, str) and "*" in filename) or isinstance( filename, list ): # Build the filename list if isinstance(filename, str): filename = parse_wildcard(filename) if "chunks" not in kwargs: with rio.open(filename[0]) as src: w = src.block_window(1, 0, 0) kwargs["chunks"] = (band_chunks, w.height, w.width) if mosaic: # Mosaic images over space self.data = gw_mosaic( filename, overlap=overlap, bounds_by=bounds_by, resampling=resampling, band_names=band_names, nodata=nodata, dtype=dtype, **kwargs, ) else: # Stack images along the 'time' axis self.data = gw_concat( filename, stack_dim=stack_dim, bounds_by=bounds_by, resampling=resampling, time_names=time_names, band_names=band_names, nodata=nodata, overlap=overlap, dtype=dtype, netcdf_vars=netcdf_vars, **kwargs, ) self.__data_are_stacked = True self.__data_are_separate = True self.__filenames = [str(fn) for fn in filename] else: self.__filenames = [str(filename)] file_names = get_file_extension(filename) if ( file_names.f_ext.lower() not in IO_DICT["rasterio"] + IO_DICT["xarray"] ) and not filename.lower().startswith("netcdf:"): logger.exception(" The file format is not recognized.") raise OSError if ( file_names.f_ext.lower() in IO_DICT["rasterio"] ) or filename.lower().startswith("netcdf:"): if "chunks" not in kwargs: with rio.open(filename) as src: w = src.block_window(1, 0, 0) kwargs["chunks"] = (band_chunks, w.height, w.width) self.data = warp_open( filename, band_names=band_names, resampling=resampling, dtype=dtype, netcdf_vars=netcdf_vars, nodata=nodata, **kwargs, ) else: if "chunks" in kwargs and not isinstance( kwargs["chunks"], dict ): logger.exception( " The chunks should be a dictionary." ) raise TypeError with xr.open_dataset(filename, **kwargs) as src: self.data = src.to_array(dim="band") # Ensure the filename attribute gets updated as the NetCDF file self.data = self.data.assign_attrs( **{"filename": str(filename)} ) self.__filenames = [str(filename)] # Order bands from the NetCDF dataset if band_names is not None: if len(band_names) != self.data["band"].shape[0]: raise ValueError( "The length of band_names must match the length of the band coordinate." ) band_names_new = [] band_names_old = [] for bname_new, bname_old in zip( band_names, self.data["band"].values ): band_names_new.append(bname_new) if bname_new in self.data["band"].values: band_names_old.append(bname_new) else: band_names_old.append(bname_old) self.data = self.data.sel(band=band_names_old) self.data = self.data.assign_coords( **{"band": band_names_new} ) self.data = self.data.assign_attrs( { "_data_are_separate": int(self.__data_are_separate), "_data_are_stacked": int(self.__data_are_stacked), } ) if persist_filenames: self.data = self.data.assign_attrs( **{"_filenames": self.__filenames} ) if scale_data: self.data = self.data.gw.set_nodata( src_nodata=nodata, dst_nodata=np.nan, out_range=None, dtype="float64", scale_factor=scale_factor, offset=offset, ) else: # No scaling is applied, but the user assigned a scale factor to update the attributes if scale_factor is not None: self.data = self.data.assign_attrs( **{"scales": (scale_factor,) * self.data.gw.nbands} ) # No scaling is applied, but the user assigned an offset to update the attributes if offset is not None: self.data = self.data.assign_attrs( **{"offsets": (offset,) * self.data.gw.nbands} ) def __enter__(self): self.__is_context_manager = True return self.data def __exit__(self, *args, **kwargs): if not self.data.gw.config["with_config"]: _set_defaults(config) self.close() d = self.data self._reset(d) @staticmethod def _reset(d): d = None @contextmanager def _optional_lock(self, needs_lock): """Context manager for optionally acquiring a lock.""" if needs_lock: with threading.Lock(): yield else: yield
[docs] def close(self): if hasattr(self, "data"): if hasattr(self.data, "gw"): if hasattr(self.data.gw, "_obj"): self.data.gw._obj = None if hasattr(self.data, "close"): self.data.close() if "gw" in self.data._cache: with self._optional_lock(True): file = self.data._cache.pop("gw", None) self.data = None
[docs]def load( image_list, time_names, band_names, chunks=512, nodata=65535, in_range=None, out_range=None, data_slice=None, num_workers=1, src=None, scheduler="ray", ): """Loads data into memory using :func:`xarray.open_mfdataset` and ``ray``. This function does not check data alignments and CRSs. It assumes each image in ``image_list`` has the same y and x dimensions and that the coordinates align. The ``load`` function cannot be used if ``dataclasses`` was pip installed. Args: image_list (list): The list of image file paths. time_names (list): The list of image ``datetime`` objects. band_names (list): The list of bands to open. chunks (Optional[int]): The dask chunk size. nodata (Optional[float | int]): The 'no data' value. in_range (Optional[tuple]): The input (min, max) range. If not given, defaults to (0, 10000). out_range (Optional[tuple]): The output (min, max) range. If not given, defaults to (0, 1). data_slice (Optional[tuple]): The slice object to read, given as (time, bands, rows, columns). num_workers (Optional[int]): The number of threads. scheduler (Optional[str]): The distributed scheduler. Currently not implemented. Returns: ``list``, ``numpy.ndarray``: Datetime list, array of (time x bands x rows x columns) Example: >>> import datetime >>> import geowombat as gw >>> >>> image_names = ['LT05_L1TP_227082_19990311_20161220_01_T1.nc', >>> 'LT05_L1TP_227081_19990311_20161220_01_T1.nc', >>> 'LT05_L1TP_227082_19990327_20161220_01_T1.nc'] >>> >>> image_dates = [datetime.datetime(1999, 3, 11, 0, 0), >>> datetime.datetime(1999, 3, 11, 0, 0), >>> datetime.datetime(1999, 3, 27, 0, 0)] >>> >>> data_slice = (slice(0, None), slice(0, None), slice(0, 64), slice(0, 64)) >>> >>> # Load data into memory >>> dates, y = gw.load(image_names, >>> image_dates, >>> ['red', 'nir'], >>> chunks=512, >>> nodata=65535, >>> data_slice=data_slice, >>> num_workers=4) """ import dask import ray from dask.diagnostics import ProgressBar from ray.util.dask import ray_dask_get netcdf_prepend = [ True for fn in image_list if str(fn).startswith("netcdf:") ] if any(netcdf_prepend): raise NameError( "The NetCDF names cannot be prepended with netcdf: when using `geowombat.load()`." ) if not in_range: in_range = (0, 10000) if not out_range: out_range = (0, 1) scale_factor = float(out_range[1]) / float(in_range[1]) if src is None: with open( image_list[0], time_names=time_names[0], band_names=band_names if not str(image_list[0]).endswith(".nc") else None, netcdf_vars=band_names if str(image_list[0]).endswith(".nc") else None, chunks=chunks, ) as src: pass attrs = src.attrs.copy() nrows = src.gw.nrows ncols = src.gw.ncols ycoords = src.y xcoords = src.x if data_slice is None: data_slice = ( slice(0, None), slice(0, None), slice(0, None), slice(0, None), ) def expand_time(dataset): """``open_mfdataset`` preprocess function.""" # Convert the Dataset into a DataArray, # rename the band coordinate, # select the required VI bands, # assign y/x coordinates from a reference, # add the time coordiante, and # get the sub-array slice darray = ( dataset.to_array() .rename({"variable": "band"})[:, :nrows, :ncols] .sel(band=band_names) .assign_coords(y=ycoords, x=xcoords) .expand_dims(dim="time") .clip(0, max(in_range[1], nodata))[data_slice] ) # Scale from [0-10000] -> [0,1] darray = xr.where(darray == nodata, 0, darray * scale_factor).astype( "float64" ) return ( darray.where(np.isfinite(darray)) .fillna(0) .clip(min=out_range[0], max=out_range[1]) ) ray.shutdown() ray.init(num_cpus=num_workers) with dask.config.set(scheduler=ray_dask_get): # Open all arrays ds = ( xr.open_mfdataset( image_list, concat_dim="time", chunks=chunks, combine="nested", engine="h5netcdf", preprocess=expand_time, parallel=True, ) .assign_coords(time=time_names) .groupby("time.date") .max() .rename({"date": "time"}) .assign_attrs(**attrs) ) # Get the time series dates after grouping real_proc_times = ds.gw.pydatetime # Convert the DataArray into a NumPy array # ds.data.visualize(filename='graph.svg') # with performance_report(filename='dask-report.html'): with ProgressBar(): y = ds.data.compute() ds.close() ray.shutdown() return real_proc_times, y
class _ImportGPU(object): try: import jax import jax.numpy as jnp JAX_INSTALLED = True except ImportError: JAX_INSTALLED = False try: import torch PYTORCH_INSTALLED = True except ImportError: PYTORCH_INSTALLED = False try: import tensorflow as tf TENSORFLOW_INSTALLED = True except ImportError: TENSORFLOW_INSTALLED = False try: from tensorflow import keras KERAS_INSTALLED = True except ImportError: KERAS_INSTALLED = False
[docs]class series(BaseSeries): """A class for time series concurrent processing on a GPU. Args: filenames (list): The list of filenames to open. band_names (Optional[list]): The band associated names. transfer_lib (Optional[str]): The library to transfer data to. Choices are ['jax', 'keras', 'numpy', 'pytorch', 'tensorflow']. crs (Optional[str]): The coordinate reference system. res (Optional[list | tuple]): The cell resolution. bounds (Optional[object]): The coordinate bounds. resampling (Optional[str]): The resampling method. nodata (Optional[float | int]): The 'no data' value. warp_mem_limit (Optional[int]): The ``rasterio`` warping memory limit (in MB). num_threads (Optional[int]): The number of ``rasterio`` warping threads. window_size (Optional[int | list | tuple]): The concurrent processing window size (height, width) or -1 (i.e., entire array). padding (Optional[list | tuple]): Padding for each window. ``padding`` should be given as a tuple of (left pad, bottom pad, right pad, top pad). If ``padding`` is given, the returned list will contain a tuple of ``rasterio.windows.Window`` objects as (w1, w2), where w1 contains the normal window offsets and w2 contains the padded window offsets. Requirement: > # CUDA 11.1 > pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html """ def __init__( self, filenames: list, time_names: list = None, band_names: list = None, transfer_lib: str = "jax", crs: str = None, res: T.Union[list, tuple] = None, bounds: T.Union[BoundingBox, list, tuple] = None, resampling: str = "nearest", nodata: T.Union[float, int] = 0, warp_mem_limit: int = 256, num_threads: int = 1, window_size: T.Union[int, list, tuple] = None, padding: T.Union[list, tuple] = None, ): imports_ = _ImportGPU() if not imports_.JAX_INSTALLED and (transfer_lib == "jax"): logger.exception("JAX must be installed.") raise ImportError("JAX must be installed.") if not imports_.PYTORCH_INSTALLED and (transfer_lib == "pytorch"): logger.exception("PyTorch must be installed.") raise ImportError("PyTorch must be installed.") if not imports_.TENSORFLOW_INSTALLED and ( transfer_lib == "tensorflow" ): logger.exception("Tensorflow must be installed.") raise ImportError("Tensorflow must be installed.") if not imports_.KERAS_INSTALLED and (transfer_lib == "keras"): logger.exception("Keras must be installed.") raise ImportError("Keras must be installed.") self.filenames = filenames self.time_names = time_names self.band_names = band_names self.padding = padding self.srcs_ = None self.vrts_ = None self.windows_ = None if transfer_lib == "jax": self.out_array_type = imports_.jax.Array elif transfer_lib == "numpy": self.out_array_type = np.ndarray elif transfer_lib == "pytorch": self.out_array_type = imports_.torch.Tensor elif transfer_lib in ["keras", "tensorflow"]: self.out_array_type = imports_.tf.Tensor self.put = TransferLib(transfer_lib) self.open(filenames) self.warp( dst_crs=crs, dst_res=res, dst_bounds=bounds, resampling=resampling, nodata=nodata, warp_mem_limit=warp_mem_limit, num_threads=num_threads, window_size=window_size, padding=self.padding, )
[docs] def read( self, bands: T.Union[int, list], window: Window = None, gain: float = 1.0, offset: T.Union[float, int] = 0.0, pool: T.Any = None, num_workers: int = None, tqdm_obj: T.Any = None, ) -> T.Any: """Reads a window.""" if isinstance(bands, int): if bands == -1: band_list = list(range(1, self.count + 1)) else: band_list = [bands] else: band_list = bands def _read(vrt_ptr, bd): array = vrt_ptr.read(bd, window=window) mask = vrt_ptr.read_masks(bd, window=window) array = array * gain + offset array[mask == 0] = np.nan return array if pool is not None: def _read_bands(vrt_): return np.stack([_read(vrt_, band) for band in band_list]) with pool(num_workers) as executor: data_gen = (vrt for vrt in self.vrts_) results = [] for res in tqdm_obj( executor.map(_read_bands, data_gen), total=len(self.vrts_) ): results.append(res) return self.put(np.array(results)) else: return self.put( np.array( [ np.stack([_read(vrt, band) for band in band_list]) for vrt in self.vrts_ ] ) )
@staticmethod def _create_file(filename, **profile): if Path(filename).is_file(): Path(filename).unlink() with rio.open(filename, mode="w", **profile) as dst: pass
[docs] def apply( self, func: T.Union[T.Callable, str, list, tuple], bands: T.Union[list, int], gain: float = 1.0, offset: T.Union[float, int] = 0.0, processes: bool = False, num_workers: int = 1, monitor_progress: bool = True, outfile: T.Union[Path, str] = None, bigtiff: str = "NO", kwargs: dict = {}, ): """Applies a function concurrently over windows. Args: func (object | str | list | tuple): The function to apply. If ``func`` is a string, choices are ['cv', 'max', 'mean', 'min']. bands (list | int): The bands to read. gain (Optional[float]): A gain factor to apply. offset (Optional[float | int]): An offset factor to apply. processes (Optional[bool]): Whether to use process workers, otherwise use threads. num_workers (Optional[int]): The number of concurrent workers. monitor_progress (Optional[bool]): Whether to monitor progress with a ``tqdm`` bar. outfile (Optional[Path | str]): The output file. bigtiff (Optional[str]): Whether to create a BigTIFF file. Choices are ['YES', 'NO',"IF_NEEDED", "IF_SAFER"]. Default is 'NO'. kwargs (Optional[dict]): Keyword arguments passed to rasterio open profile. Returns: If outfile is None: Window, array, [datetime, ...] If outfile is not None: None, writes to ``outfile`` Example: >>> import itertools >>> import geowombat as gw >>> import rasterio as rio >>> >>> # Import an image with 3 bands >>> from geowombat.data import l8_224078_20200518 >>> >>> # Create a custom class >>> class TemporalMean(gw.TimeModule): >>> >>> def __init__(self): >>> super(TemporalMean, self).__init__() >>> >>> # The main function >>> def calculate(self, array): >>> >>> sl1 = (slice(0, None), slice(self.band_dict['red'], self.band_dict['red']+1), slice(0, None), slice(0, None)) >>> sl2 = (slice(0, None), slice(self.band_dict['green'], self.band_dict['green']+1), slice(0, None), slice(0, None)) >>> >>> vi = (array[sl1] - array[sl2]) / ((array[sl1] + array[sl2]) + 1e-9) >>> >>> return vi.mean(axis=0).squeeze() >>> >>> with rio.open(l8_224078_20200518) as src: >>> res = src.res >>> bounds = src.bounds >>> nodata = 0 >>> >>> # Open many files, each with 3 bands >>> with gw.series([l8_224078_20200518]*100, >>> band_names=['blue', 'green', 'red'], >>> crs='epsg:32621', >>> res=res, >>> bounds=bounds, >>> nodata=nodata, >>> num_threads=4, >>> window_size=(1024, 1024)) as src: >>> >>> src.apply(TemporalMean(), >>> bands=-1, # open all bands >>> gain=0.0001, # scale from [0,10000] -> [0,1] >>> processes=False, # use threads >>> num_workers=4, # use 4 concurrent threads, one per window >>> outfile='vi_mean.tif') >>> >>> # Import a single-band image >>> from geowombat.data import l8_224078_20200518_B4 >>> >>> # Open many files, each with 1 band >>> with gw.series([l8_224078_20200518_B4]*100, >>> band_names=['red'], >>> crs='epsg:32621', >>> res=res, >>> bounds=bounds, >>> nodata=nodata, >>> num_threads=4, >>> window_size=(1024, 1024)) as src: >>> >>> src.apply('mean', # built-in function over single-band images >>> bands=1, # open all bands >>> gain=0.0001, # scale from [0,10000] -> [0,1] >>> num_workers=4, # use 4 concurrent threads, one per window >>> outfile='red_mean.tif') >>> >>> with gw.series([l8_224078_20200518_B4]*100, >>> band_names=['red'], >>> crs='epsg:32621', >>> res=res, >>> bounds=bounds, >>> nodata=nodata, >>> num_threads=4, >>> window_size=(1024, 1024)) as src: >>> >>> src.apply(['mean', 'max', 'cv'], # calculate multiple statistics >>> bands=1, # open all bands >>> gain=0.0001, # scale from [0,10000] -> [0,1] >>> num_workers=4, # use 4 concurrent threads, one per window >>> outfile='stack_mean.tif') """ pool = ( concurrent.futures.ProcessPoolExecutor if processes else concurrent.futures.ThreadPoolExecutor ) tqdm_obj = tqdm if monitor_progress else _tqdm if ( isinstance(func, str) or isinstance(func, list) or isinstance(func, tuple) ): if isinstance(bands, list) or isinstance(bands, tuple): logger.exception( "Only single-band images can be used with built-in functions." ) raise ValueError( "Only single-band images can be used with built-in functions." ) apply_func_ = SeriesStats(func) else: apply_func_ = func if outfile is not None: profile = { "count": apply_func_.count, "width": self.width, "height": self.height, "crs": self.crs, "transform": self.transform, "driver": "GTiff", "dtype": apply_func_.dtype, "compress": apply_func_.compress, "sharing": False, "tiled": True, "nodata": self.nodata, "blockxsize": self.blockxsize, "blockysize": self.blockysize, **kwargs, } # check for bigtiff config if config["with_config"]: bigtiff = config["bigtiff"].upper() if bigtiff not in ( "YES", "NO", "IF_NEEDED", "IF_SAFER", ): raise NameError( "The GDAL BIGTIFF must be one of 'YES', 'NO', 'IF_NEEDED', or 'IF_SAFER'. See https://gdal.org/drivers/raster/gtiff.html#creation-issues for more information." ) profile["BIGTIFF"] = bigtiff # Create the file self._create_file(outfile, **profile) if outfile is not None: with rio.open(outfile, mode="r+", sharing=False) as dst: with pool(num_workers) as executor: data_gen = ( ( w, self.read( bands, window=w[1], gain=gain, offset=offset ), self.band_dict, ) if self.padding else ( w, self.read( bands, window=w, gain=gain, offset=offset ), self.band_dict, ) for w in self.windows_ ) for w, res in tqdm_obj( executor.map(lambda f: apply_func_(*f), data_gen), total=self.nchunks, ): with threading.Lock(): self._write_window(dst, res, apply_func_.count, w) else: if self.padding: w, res = apply_func_( self.windows_[0], self.read( bands, window=self.windows_[0][1], gain=gain, offset=offset, pool=pool, num_workers=num_workers, tqdm_obj=tqdm_obj, ), self.band_dict, ) else: w, res = apply_func_( self.windows_[0], self.read( bands, window=self.windows_[0], gain=gain, offset=offset, pool=pool, num_workers=num_workers, tqdm_obj=tqdm_obj, ), self.band_dict, ) # Group duplicate dates res, image_dates = self.group_dates( res, self.time_names, self.band_names ) return w, res, image_dates
def _write_window(self, dst_, out_data_, count, w): if self.padding: window_ = w[0] padded_window_ = w[1] # Get the non-padded array slice row_diff = abs(window_.row_off - padded_window_.row_off) col_diff = abs(window_.col_off - padded_window_.col_off) out_data_ = out_data_[ :, row_diff : row_diff + window_.height, col_diff : col_diff + window_.width, ] dst_.write( out_data_, indexes=1 if count == 1 else range(1, count + 1), window=w[0] if self.padding else w, ) def __enter__(self): return self def __exit__(self, *args, **kwargs): for src in self.srcs_: src.close() for vrt in self.vrts_: vrt.close()