import logging
import os
import random
import string
from abc import ABC, abstractmethod
from contextlib import ExitStack
from copy import copy
from datetime import datetime
from pathlib import Path
import graphviz
import xarray as xr
from .. import config as gw_config
from .. import open as gw_open
from ..handler import add_handler
logger = logging.getLogger(__name__)
logger = add_handler(logger)
PROC_NODE_ATTRS = {
"shape": "oval",
"color": "#3454b4",
"fontcolor": "#131f43",
"style": "filled",
"fillcolor": "#c6d2f6",
}
PROC_EDGE_ATTRS = {"color": "#3454b4", "style": "bold"}
CONFIG_NODE_ATTRS = {
"shape": "diamond",
"color": "black",
"fontcolor": "#131f43",
"style": "rounded,filled",
"fillcolor": "none",
}
CONFIG_EDGE_ATTRS = {"color": "grey", "style": "dashed"}
OUT_NODE_ATTRS = {
"shape": "pentagon",
"color": "black",
"fontcolor": "#131f43",
"style": "rounded,filled",
"fillcolor": "none",
}
OUT_EDGE_ATTRS = {"color": "#edcec6", "style": "dashed"}
INPUT_NODE_ATTRS = {
"shape": "box",
"color": "#b49434",
"fontcolor": "#2d250d",
"style": "filled",
"fillcolor": "#f3e3b3",
}
INPUT_EDGE_ATTRS = {"color": "#b49434"}
VAR_NODE_ATTRS = {
"shape": "box",
"color": "#555555",
"fontcolor": "#555555",
"style": "dashed",
}
VAR_EDGE_ATTRS = {"color": "#555555"}
[docs]class BaseGeoTask(ABC):
@abstractmethod
def __init__(
self,
inputs,
outputs,
tasks,
clean=None,
config_args=None,
open_args=None,
func_args=None,
out_args=None,
log_file=None,
):
self.inputs = inputs
self.outputs = outputs
self.tasks = tasks
self.clean = clean if clean else {}
self.config_args = config_args if inputs else {}
self.open_args = open_args if inputs else {}
self.func_args = func_args if inputs else {}
self.out_args = out_args if inputs else {}
self.log_file = log_file
_log_home = os.path.abspath(os.path.dirname(__file__))
if not os.access(_log_home, os.W_OK):
_log_home = os.path.expanduser('~')
if not self.log_file:
self.log_file = os.path.join(_log_home, 'task.log')
[docs] def copy(self):
return copy(self)
def __add__(self, other):
"""Add another pipeline."""
self_inputs_copy = self.inputs.copy()
self_outputs_copy = self.outputs.copy()
self_func_args_copy = self.func_args.copy()
self_config_args_copy = self.config_args.copy()
self_open_args_copy = self.open_args.copy()
self_out_args_copy = self.out_args.copy()
self_clean_copy = self.clean.copy()
tasks = list()
for task_id, task in self.tasks:
tasks.append((task_id, task))
for task_id, task in other.tasks:
tasks.append((task_id, task))
self_inputs_copy.update(other.inputs)
self_outputs_copy.update(other.outputs)
self_func_args_copy.update(other.func_args)
self_config_args_copy.update(other.config_args)
self_open_args_copy.update(other.open_args)
self_out_args_copy.update(other.out_args)
self_clean_copy.update(other.clean)
return GeoTask(
self_inputs_copy,
self_outputs_copy,
tuple(tasks),
clean=self_clean_copy,
config_args=self_config_args_copy,
open_args=self_open_args_copy,
func_args=self_func_args_copy,
out_args=self_out_args_copy,
)
[docs] @abstractmethod
def execute(self, task_id, task, src, task_results, attrs, **kwargs):
"""Execute a task."""
pass
[docs] @abstractmethod
def submit(self):
"""Submit a task pipeline."""
raise NotImplementedError
def _cleanup(self, level, task_id):
clean_level = self.clean[task_id]
if clean_level == level:
fn = Path(self.outputs[task_id])
if fn.is_file():
try:
fn.unlink()
except Exception:
logger.warning(
' Could not remove task {task_id} output.'.format(
task_id=task_id
)
)
def _check_task(self, task_id):
return True if Path(self.outputs[task_id]).is_file() else False
def _set_log(self, task_id):
letters_digits = string.ascii_letters + string.digits
random_id = ''.join(random.choice(letters_digits) for i in range(0, 9))
task_output = self.outputs[task_id]
when = datetime.now().strftime('%A, %d-%m-%Y at %H:%M:%S')
task_log = f"{when} | {task_output} | task_id-{random_id}"
return (
'{task_log} ok\n'.format(task_log=task_log)
if self._check_task(task_id)
else '{task_log} failed\n'.format(task_log=task_log)
)
def _log_task(self, task_id):
if Path(self.log_file).is_file():
with open(self.log_file, mode='r') as f:
lines = f.readlines()
else:
lines = []
with open(self.log_file, mode='w') as f:
f.writelines(lines)
with open(self.log_file, mode='r+') as f:
lines.append(self._set_log(task_id))
f.writelines(lines)
def __len__(self):
return len(self.processes)
[docs]class GraphBuilder(object):
"""
Reference:
https://github.com/benbovy/xarray-simlab/blob/master/xsimlab/dot.py
"""
[docs] def visualize(self, **kwargs):
if not kwargs:
kwargs = {'rankdir': 'LR'}
self.seen = set()
self.inputs_seen = set()
self.outputs_seen = set()
counter = 0
self.g = graphviz.Digraph()
self.g.subgraph(graph_attr=kwargs)
for task_id, task in self.tasks:
if task_id not in self.seen:
self.seen.add(task_id)
self.g.node(
task_id,
label='Task {task_id}: {task_name}'.format(
task_id=task_id, task_name=task.__name__
),
**PROC_NODE_ATTRS,
)
if task_id != list(self.tasks)[0][0]:
if isinstance(self.inputs[task_id], str):
self.g.edge(
list(self.tasks)[counter - 1][0],
task_id,
**PROC_EDGE_ATTRS,
)
else:
task_list_ = list(list(zip(*self.tasks))[0])
for ctask in self.inputs[task_id]:
if ctask in task_list_:
cidx = task_list_.index(ctask)
else:
cidx = counter - 1
self.g.edge(
list(self.tasks)[cidx][0],
task_id,
**PROC_EDGE_ATTRS,
)
for config_key, config_setting in self.config_args.items():
with self.g.subgraph(name='cluster_0') as c:
c.attr(style='filled', color='lightgrey')
c.node(
config_key,
label='{config_key}: {config_setting}'.format(
config_key=config_key,
config_setting=config_setting,
),
**CONFIG_NODE_ATTRS,
)
c.attr(label='geowombat.config.update() args')
self.g.edge(
config_key,
list(self.tasks)[counter - 1][0],
**CONFIG_EDGE_ATTRS,
)
counter += 1
counter = 0
for task_id, output_ in self.outputs.items():
if output_ not in self.outputs_seen:
self.outputs_seen.add(output_)
node_attrs = INPUT_NODE_ATTRS.copy()
edge_attrs = INPUT_EDGE_ATTRS.copy()
if task_id in self.clean:
if self.clean[task_id] == 'task':
node_attrs['color'] = 'red'
elif self.clean[task_id] == 'pipeline':
node_attrs['color'] = 'purple'
if task_id == list(self.outputs.keys())[-1]:
node_attrs['color'] = 'blue'
node_attrs['style'] = 'dashed'
edge_attrs['style'] = 'dashed'
self.g.node(
'{task_id} {task_output}'.format(
task_id=task_id, task_output=self.outputs[task_id]
),
label=self.outputs[task_id],
**node_attrs,
)
self.g.edge(
task_id,
'{task_id} {task_output}'.format(
task_id=task_id, task_output=self.outputs[task_id]
),
weight='200',
**edge_attrs,
)
for out_key, out_setting in self.out_args.items():
if not output_.startswith('mem|'):
with self.g.subgraph(name='cluster_1') as c:
c.attr(style='filled', color='#a6d5ab')
c.node(
out_key,
label='{out_key}: {out_setting}'.format(
out_key=out_key, out_setting=out_setting
),
**OUT_NODE_ATTRS,
)
c.attr(label='geowombat.to_raster() args')
self.g.edge(
out_key,
'{task_id} {task_output}'.format(
task_id=task_id, task_output=self.outputs[task_id]
),
**OUT_EDGE_ATTRS,
)
if counter > 0:
task_id_ = list(self.outputs.keys())[counter - 1]
if not self.outputs[task_id_].startswith('mem|'):
self.g.edge(
'{task_id_} {task_output}'.format(
task_id_=task_id_,
task_output=self.outputs[task_id_],
),
task_id,
weight='200',
**edge_attrs,
)
counter += 1
counter = 0
for task_id, inputs_ in self.inputs.items():
if isinstance(inputs_, str):
self._add_inputs(counter, task_id, [inputs_])
else:
self._add_inputs(counter, task_id, inputs_)
counter += 1
for task_id, params in self.func_args.items():
for k, param in params.items():
self.g.node(
'{task_id} {k}'.format(task_id=task_id, k=k),
label='{k}: {param}'.format(k=k, param=param),
**VAR_NODE_ATTRS,
)
self.g.edge(
'{task_id} {k}'.format(task_id=task_id, k=k),
task_id,
weight='200',
**VAR_EDGE_ATTRS,
)
return self.g
def _add_inputs(self, counter, task_id, input_list):
for input_ in input_list:
if isinstance(input_, tuple) or isinstance(input_, list):
gen = input_
else:
gen = (input_,)
for gen_item in gen:
if gen_item in self.outputs_seen:
task_id_b = None
else:
if gen_item not in self.inputs_seen:
self.inputs_seen.add(gen_item)
gen_label = (
Path(gen_item).name
if Path(gen_item).is_file()
else gen_item
)
if gen_item not in self.outputs:
self.g.node(
'{task_id} {gen_item}'.format(
task_id=task_id, gen_item=gen_item
),
label=gen_label,
**INPUT_NODE_ATTRS,
)
task_id_b = task_id
else:
task_id_b = list(self.tasks)[counter - 1][0]
if task_id_b:
if gen_item not in self.outputs:
for itask, iinputs in self.inputs.items():
if gen_item in iinputs:
task_id_b = itask
break
self.g.edge(
'{task_id_b} {gen_item}'.format(
task_id_b=task_id_b, gen_item=gen_item
),
task_id,
weight='200',
**INPUT_EDGE_ATTRS,
)
[docs]class GeoTask(BaseGeoTask, GraphBuilder):
"""A Geo-task scheduler.
Args:
inputs (dict): The input steps.
outputs (dict): The outputs.
tasks (tuple): The tasks to execute.
clean (Optional[dict]): Currently not implemented.
config_args (Optional[dict]): The arguments for `geowombat.config.update`.
open_args (Optional[dict]): The arguments for `geowombat.open`.
func_args (Optional[dict]): The arguments to pass to each function in `tasks`.
out_args (Optional[dict]): The arguments for `geowombat.to_raster`.
log_file (Optional[str]): A file to write the log to.
Examples:
>>> import geowombat as gw
>>> from geowombat.data import l8_224078_20200518_B3, l8_224078_20200518_B4, l8_224078_20200518
>>> from geowombat.tasks import GeoTask
>>>
>>> # Task a and b take 1 input file
>>> # Task c takes 2 input files
>>> # Task d takes the output of task c
>>> # Task e takes the outputs of a, b, and d
>>> inputs = {'a': l8_224078_20200518, 'b': l8_224078_20200518, 'c': (l8_224078_20200518_B3, l8_224078_20200518_B4), 'd': 'c', 'e': ('a', 'b', 'd')}
>>>
>>> # The output task names
>>> # All tasks are in-memory DataArrays
>>> outputs = {'a': 'mem|r1', 'b': 'mem|r2', 'c': 'mem|r3', 'd': 'mem|mean', 'e': 'mem|stack'}
>>>
>>> # Task a and b compute the `norm_diff`
>>> # Task c concatenates two images
>>> # Task d takes the mean of c
>>> # Task e concatenates a, b, and d
>>> tasks = (('a', gw.norm_diff), ('b', gw.norm_diff), ('c', xr.concat), ('d', xr.DataArray.mean), ('e', xr.concat))
>>>
>>> # Task a and b take band name arguments
>>> # Tasks c, d, and e take the coordinate dimension name as an argument
>>> func_args = {'a': {'b1': 'green', 'b2': 'red'}, 'b': {'b1': 'blue', 'b2': 'green'}, 'c': {'dim': 'band'}, 'd': {'dim': 'band'}, 'e': {'dim': 'band'}}
>>> open_args = {'chunks': 512}
>>> config_args = {'sensor': 'bgr', 'nodata': 0, 'scale_factor': 0.0001}
>>>
>>> # Setup a task
>>> task_mean = GeoTask(inputs, outputs, tasks, config_args=config_args, open_args=open_args, func_args=func_args)
>>>
>>> # Visualize the task
>>> task_mean.visualize()
>>>
>>> # Create a task that takes the output of task e and writes the mean to file
>>> task_write = GeoTask({'f': 'e'}, {'f': 'mean.tif'}, (('f', xr.DataArray.mean),),
>>> config_args=config_args,
>>> func_args={'f': {'dim': 'band'}},
>>> open_args=open_args,
>>> out_args={'compress': 'lzw', 'overwrite': True})
>>>
>>> # Add the new task
>>> new_task = task_mean + task_write
>>>
>>> new_task.visualize()
>>>
>>> # Write the task pipeline to file
>>> new_task.submit()
"""
def __init__(
self,
inputs,
outputs,
tasks,
clean=None,
config_args=None,
open_args=None,
func_args=None,
out_args=None,
log_file=None,
):
super().__init__(
inputs,
outputs,
tasks,
clean=clean,
config_args=config_args,
open_args=open_args,
func_args=func_args,
out_args=out_args,
log_file=log_file,
)
[docs] def execute(self, task_id, task, src, task_results, attrs, **kwargs):
"""Executes an individual task.
Args:
task_id (str)
task (func)
src (DataArray | list)
task_results (dict)
attrs (dict)
kwargs (Optional[dict])
"""
# Execute the task
if isinstance(src, tuple):
res = task((task_results[i] for i in src), **kwargs)
else:
res = task(src, **kwargs)
if not hasattr(res, 'band'):
res = res.expand_dims(dim='band').assign_coords({'band': ['res']})
# Write to file
if task_id in self.outputs:
if self.outputs[task_id].lower().endswith('.tif'):
if not hasattr(res, 'crs'):
res.attrs = attrs
res.gw.to_raster(self.outputs[task_id], **self.out_args)
return res
[docs] def submit(self):
"""Submits a pipeline task."""
task_results = {}
attrs = None
res = None
with gw_config.update(**self.config_args):
counter = 0
for task_id, task in self.tasks:
# Check task keywords
kwargs = (
self.func_args[task_id]
if task_id in self.func_args
else {}
)
# Check task input(s)
if isinstance(self.inputs[task_id], tuple) or isinstance(
self.inputs[task_id], list
):
with ExitStack() as stack:
# Open input files for the task
src = (
stack.enter_context(gw_open(fn, **self.open_args))
if Path(fn).is_file()
else task_results[fn]
for fn in self.inputs[task_id]
)
res = self.execute(
task_id, task, src, task_results, attrs, **kwargs
)
# res = self.execute(task_id, task, self.inputs[task_id], task_results, attrs, **kwargs)
elif (
isinstance(self.inputs[task_id], str)
and not Path(self.inputs[task_id]).is_file()
):
res = self.execute(
task_id,
task,
task_results[self.inputs[task_id]],
task_results,
attrs,
**kwargs,
)
elif (
isinstance(self.inputs[task_id], str)
and Path(self.inputs[task_id]).is_file()
):
with gw_open(
self.inputs[task_id], **self.open_args
) as src:
attrs = src.attrs.copy()
res = self.execute(
task_id, task, src, task_results, attrs, **kwargs
)
task_results[task_id] = res
self._log_task(task_id)
# if counter > 0:
# self._cleanup('task', self.tasks[counter-1][[0]])
counter += 1
# for task_id, __ in self.tasks:
# self._cleanup('pipeline', task_id)
return res