import concurrent.futures
import multiprocessing as multi
import rasterio as rio
import xarray as xr
from tqdm import tqdm, trange
from .base import _executor_dummy
from .windows import get_window_offsets
_EXEC_DICT = {
'mpool': multi.Pool,
'ray': None,
'processes': concurrent.futures.ProcessPoolExecutor,
'threads': concurrent.futures.ThreadPoolExecutor,
}
[docs]class ParallelTask(object):
"""A class for parallel tasks over a ``xarray.DataArray`` with returned
results for each chunk.
Args:
data (DataArray): The ``xarray.DataArray`` to process.
row_chunks (Optional[int]): The row chunk size to process in parallel.
col_chunks (Optional[int]): The column chunk size to process in parallel.
padding (Optional[tuple]): Padding for each window. ``padding`` should be given as a tuple
of (left pad, bottom pad, right pad, top pad). If ``padding`` is given, the returned list will contain
a tuple of ``rasterio.windows.Window`` objects as (w1, w2), where w1 contains the normal window offsets
and w2 contains the padded window offsets.
scheduler (Optional[str]): The parallel task scheduler to use. Choices are ['processes', 'threads', 'mpool'].
mpool: process pool of workers using ``multiprocessing.Pool``
ray: process pool of workers using ``ray.remote``.
processes: process pool of workers using ``concurrent.futures``
threads: thread pool of workers using ``concurrent.futures``
get_ray (Optional[bool]): Whether to get results from ``ray`` futures.
n_workers (Optional[int]): The number of parallel workers for ``scheduler``.
n_chunks (Optional[int]): The chunk size of windows. If not given, equal to ``n_workers`` x 50.
Examples:
>>> import geowombat as gw
>>> from geowombat.core.parallel import ParallelTask
>>>
>>> ########################
>>> # Use concurrent threads
>>> ########################
>>>
>>> def user_func_threads(*args):
>>> data, window_id, num_workers = list(itertools.chain(*args))
>>> return data.data.sum().compute(scheduler='threads', num_workers=num_workers)
>>>
>>> # Process 8 windows in parallel using threads
>>> # Process 4 dask chunks in parallel using threads
>>> # 32 total workers are needed
>>> with gw.open('image.tif') as src:
>>> pt = ParallelTask(src, scheduler='threads', n_workers=8)
>>> res = pt.map(user_func_threads, 4)
>>>
>>> #########
>>> # Use Ray
>>> #########
>>>
>>> import ray
>>>
>>> @ray.remote
>>> def user_func_ray(data_block_id, data_slice, window_id, num_workers):
>>> return (
>>> data_block_id[data_slice].data.sum()
>>> .compute(scheduler='threads', num_workers=num_workers)
>>> )
>>>
>>> ray.init(num_cpus=8)
>>>
>>> with gw.open('image.tif', chunks=512) as src:
>>> pt = ParallelTask(
>>> src, row_chunks=1024, col_chunks=1024, scheduler='ray', n_workers=8
>>> )
>>> res = ray.get(pt.map(user_func_ray, 4))
>>>
>>> ray.shutdown()
>>>
>>> #####################################
>>> # Use with a dask.distributed cluster
>>> #####################################
>>>
>>> from dask.distributed import LocalCluster
>>>
>>> with LocalCluster(
>>> n_workers=4,
>>> threads_per_worker=2,
>>> scheduler_port=0,
>>> processes=False,
>>> memory_limit='4GB'
>>> ) as cluster:
>>>
>>> with gw.open('image.tif') as src:
>>> pt = ParallelTask(src, scheduler='threads', n_workers=4, n_chunks=50)
>>> res = pt.map(user_func_threads, 2)
>>>
>>> # Map over multiple rasters
>>> for pt in ParallelTask(
>>> ['image1.tif', 'image2.tif'], scheduler='threads', n_workers=4, n_chunks=500
>>> ):
>>> res = pt.map(user_func_threads, 2)
"""
def __init__(
self,
data=None,
chunks=None,
row_chunks=None,
col_chunks=None,
padding=None,
scheduler='threads',
get_ray=False,
n_workers=1,
n_chunks=None,
):
self.chunks = 512 if not isinstance(chunks, int) else chunks
if isinstance(data, list):
self.data_list = data
self.data = xr.open_rasterio(self.data_list[0])
else:
self.data = data
self.row_chunks = row_chunks
self.col_chunks = col_chunks
self.padding = padding
self.n = None
self.scheduler = scheduler
self.get_ray = get_ray
self.executor = _EXEC_DICT[scheduler]
self.n_workers = n_workers
self.n_chunks = n_chunks
self._in_session = False
self.windows = None
self.slices = None
self.n_windows = None
if not isinstance(self.n_chunks, int):
self.n_chunks = self.n_workers * 50
if not isinstance(data, list):
self._setup()
def __iter__(self):
self.i_ = 0
return self
def __next__(self):
self.data.close()
if self.i_ >= len(self.data_list):
raise StopIteration
self.data = xr.open_rasterio(self.data_list[self.i_])
self._setup()
self.i_ += 1
return self
def __enter__(self):
self._in_session = True
return self
def __exit__(self, *args, **kwargs):
self._in_session = False
def _setup(self):
default_rchunks = (
self.data.block_window(1, 0, 0).height
if isinstance(self.data, rio.io.DatasetReader)
else self.data.gw.row_chunks
)
default_cchunks = (
self.data.block_window(1, 0, 0).width
if isinstance(self.data, rio.io.DatasetReader)
else self.data.gw.col_chunks
)
rchunksize = (
self.row_chunks
if isinstance(self.row_chunks, int)
else default_rchunks
)
cchunksize = (
self.col_chunks
if isinstance(self.col_chunks, int)
else default_cchunks
)
self.windows = get_window_offsets(
self.data.height
if isinstance(self.data, rio.io.DatasetReader)
else self.data.gw.nrows,
self.data.width
if isinstance(self.data, rio.io.DatasetReader)
else self.data.gw.ncols,
rchunksize,
cchunksize,
return_as='list',
padding=self.padding,
)
# Convert windows into slices
if len(self.data.shape) == 2:
self.slices = [
(
slice(w.row_off, w.row_off + w.height),
slice(w.col_off, w.col_off + w.width),
)
for w in self.windows
]
else:
self.slices = [
tuple([slice(0, None)] * (len(self.data.shape) - 2))
+ (
slice(w.row_off, w.row_off + w.height),
slice(w.col_off, w.col_off + w.width),
)
for w in self.windows
]
self.n_windows = len(self.windows)
[docs] def map(self, func, *args, **kwargs):
"""Maps a function over a DataArray.
Args:
func (func): The function to apply to the ``data`` chunks.
When using any scheduler other than 'ray' (i.e., 'mpool', 'threads', 'processes'), the function
should always be defined with ``*args``. With these schedulers, the function will always return
the ``DataArray`` window and window id as the first two arguments. If no user arguments are
passed to ``map`` , the function will look like:
def my_func(*args):
data, window_id = list(itertools.chain(*args))
# do something
return results
If user arguments are passed, e.g., ``map(my_func, arg1, arg2)``, the function will look like:
def my_func(*args):
data, window_id, arg1, arg2 = list(itertools.chain(*args))
# do something
return results
When ``scheduler`` = 'ray', the user function requires an additional slice argument that looks like:
@ray.remote
def my_ray_func(data_block_id, data_slice, window_id):
# do something
return results
Note the addition of the ``@ray.remote`` decorator, as well as the explicit arguments in the function
call. Extra user arguments would look like:
@ray.remote
def my_ray_func(data_block_id, data_slice, window_id, arg1, arg2):
# do something
return results
Other ``ray`` classes can also be used in place of a function.
args (items): Function arguments.
kwargs (Optional[dict]): Keyword arguments passed to ``multiprocessing.Pool().imap``.
Returns:
``list``: Results for each data chunk.
Examples:
>>> import geowombat as gw
>>> from geowombat.core.parallel import ParallelTask
>>> from geowombat.data import l8_224078_20200518_points, l8_224078_20200518
>>> import geopandas as gpd
>>> import rasterio as rio
>>> import ray
>>> from ray.util import ActorPool
>>>
>>> @ray.remote
>>> class Actor(object):
>>>
>>> def __init__(self, aoi_id=None, id_column=None, band_names=None):
>>>
>>> self.aoi_id = aoi_id
>>> self.id_column = id_column
>>> self.band_names = band_names
>>>
>>> # While the names can differ, these three arguments are required.
>>> # For ``ParallelTask``, the callable function within an ``Actor`` must be named exec_task.
>>> def exec_task(self, data_block_id, data_slice, window_id):
>>>
>>> data_block = data_block_id[data_slice]
>>> left, bottom, right, top = data_block.gw.bounds
>>> aoi_sub = self.aoi_id.cx[left:right, bottom:top]
>>>
>>> if aoi_sub.empty:
>>> return aoi_sub
>>>
>>> # Return a GeoDataFrame for each actor
>>> return gw.extract(data_block,
>>> aoi_sub,
>>> id_column=self.id_column,
>>> band_names=self.band_names)
>>>
>>> ray.init(num_cpus=8)
>>>
>>> band_names = [1, 2, 3]
>>> df_id = ray.put(gpd.read_file(l8_224078_20200518_points))
>>>
>>> with rio.Env(GDAL_CACHEMAX=256*1e6) as env:
>>>
>>> # Since we are iterating over the image block by block, we do not need to load
>>> # a lazy dask array (i.e., chunked).
>>> with gw.open(l8_224078_20200518, band_names=band_names, chunks=None) as src:
>>>
>>> # Setup the pool of actors, one for each resource available to ``ray``.
>>> actor_pool = ActorPool([Actor.remote(aoi_id=df_id, id_column='id', band_names=band_names)
>>> for n in range(0, int(ray.cluster_resources()['CPU']))])
>>>
>>> # Setup the task object
>>> pt = ParallelTask(src, row_chunks=4096, col_chunks=4096, scheduler='ray', n_chunks=1000)
>>> results = pt.map(actor_pool)
>>>
>>> del df_id, actor_pool
>>>
>>> ray.shutdown()
"""
if (self.n_workers == 1) or (self.scheduler == 'ray'):
executor_pool = _executor_dummy
ranger = range
else:
executor_pool = self.executor
ranger = trange
if self.scheduler == 'ray':
if self.padding:
raise SyntaxError('Ray cannot be used with array padding.')
import ray
if isinstance(self.data, rio.io.DatasetReader):
data_id = self.data.name
else:
data_id = ray.put(self.data)
results = []
with executor_pool(self.n_workers) as executor:
# Iterate over the windows in chunks
for wchunk in ranger(0, self.n_windows, self.n_chunks):
if self.padding:
window_slice = self.windows[
wchunk : wchunk + self.n_chunks
]
# Read the padded window
if len(self.data.shape) == 2:
data_gen = (
(
self.data[
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
widx + wchunk,
*args,
)
for widx, w in enumerate(window_slice)
)
elif len(self.data.shape) == 3:
data_gen = (
(
self.data[
:,
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
widx + wchunk,
*args,
)
for widx, w in enumerate(window_slice)
)
else:
data_gen = (
(
self.data[
:,
:,
w[1].row_off : w[1].row_off + w[1].height,
w[1].col_off : w[1].col_off + w[1].width,
],
widx + wchunk,
*args,
)
for widx, w in enumerate(window_slice)
)
else:
window_slice = self.slices[wchunk : wchunk + self.n_chunks]
if self.scheduler == 'ray':
data_gen = (
(data_id, slice_, widx + wchunk, *args)
for widx, slice_ in enumerate(window_slice)
)
else:
data_gen = (
(self.data[slice_], widx + wchunk, *args)
for widx, slice_ in enumerate(window_slice)
)
if (self.n_workers == 1) and (self.scheduler != 'ray'):
for result in map(func, data_gen):
results.append(result)
else:
if self.scheduler == 'mpool':
for result in executor.imap(func, data_gen, **kwargs):
results.append(result)
elif self.scheduler == 'ray':
if isinstance(func, ray.util.actor_pool.ActorPool):
for result in tqdm(
func.map(
lambda a, v: a.exec_task.remote(*v),
data_gen,
),
total=len(window_slice),
):
results.append(result)
else:
if isinstance(func, ray.actor.ActorHandle):
futures = [
func.exec_task.remote(*dargs)
for dargs in data_gen
]
else:
futures = [
func.remote(*dargs) for dargs in data_gen
]
if self.get_ray:
with tqdm(total=len(futures)) as pbar:
results_ = []
while len(futures):
done_id, futures = ray.wait(futures)
results_.append(ray.get(done_id[0]))
pbar.update(1)
results += results_
else:
results += futures
else:
for result in executor.map(func, data_gen):
results.append(result)
if self.scheduler == 'ray':
del data_id
return results