diff --git a/openeo_pg_parser_networkx/graph.py b/openeo_pg_parser_networkx/graph.py index c22e56e..8110606 100644 --- a/openeo_pg_parser_networkx/graph.py +++ b/openeo_pg_parser_networkx/graph.py @@ -24,6 +24,18 @@ ProcessGraphUnflattener, parse_nested_parameter, ) +import xarray as xr +import dask.array as da +from datetime import datetime +import os +from functools import wraps +## For yprov4wfs +import json +from yprov4wfs.datamodel.workflow import Workflow +from yprov4wfs.datamodel.task import Task +from yprov4wfs.datamodel.data import Data +import uuid + logger = logging.getLogger(__name__) @@ -66,6 +78,10 @@ def __repr__(self): class OpenEOProcessGraph: def __init__(self, pg_data: dict): + # Make a workflow object + self.workflow = Workflow('wfs1', 'Workflow 1') + self.workflow._engineWMS = "Openeo-LocalProcessing" + self.workflow._level= "0" self.G = nx.DiGraph() nested_raw_graph = self._unflatten_raw_process_graph(pg_data) @@ -295,7 +311,7 @@ def to_callable( return self._map_node_to_callable( self.result_node, process_registry, results_cache, parameters ) - + def _map_node_to_callable( self, node: str, @@ -352,8 +368,27 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs): for func in parent_callables: func(*args, named_parameters=named_parameters, **kwargs) + cache_users = {} try: # If this node has already been computed once, just grab that result from the results_cache instead of recomputing it. + # This cannot be done for aggregated data as the wrapped function has to be called multiple times with different values. + # This also means the results_cache will be useless for these functions. + # TODO: track how often functions need to be called and check if they have been called that many times, if yes, we can + # use the cache for aggregate functions, but this is probably not super necessary + parent_node_id = [edge[0] for edge in self.edges if edge[1] == node] + + if parent_node_id: + parent_node_process_id = [ + n[1]["process_id"] + for n in self.nodes + if n[0] == parent_node_id[0] + ] + + if parent_node_process_id and parent_node_process_id[0] in [ + "aggregate_temporal_period", + "aggregate_spatial", + ]: + raise KeyError() return results_cache.__getitem__(node) except KeyError: for _, source_node, data in self.G.out_edges(node, data=True): @@ -366,13 +401,102 @@ def node_callable(*args, parent_callables, named_parameters=None, **kwargs): kwargs[arg_sub.arg_name] = self.G.nodes(data=True)[node][ "resolved_kwargs" ].__getitem__(arg_sub.arg_name) - - result = prebaked_process_impl( - *args, named_parameters=named_parameters, **kwargs - ) - + # Make a dictionary from the nodes that uses the outputs of the other nodes + if source_node not in cache_users: + cache_users[source_node] = [] + cache_users[source_node].append(node) + # Make the tasks + task = Task(node, node_with_data['process_id']) + result, execution_data= self.profile_function(prebaked_process_impl)( + *args, named_parameters=named_parameters, **kwargs + ) + + if isinstance(result, xr.DataArray): + processed_result = { + "entity_type": "xarray.DataArray", + "info": { + "shape": result.shape, + "dimensions": list(result.dims), + # "attributes": result.attrs, + "dtype": str(result.dtype) + } + } + + elif isinstance(result, da.Array): + processed_result = { + "entity_type": "dask.Array", + "info": { + "shape": result.shape, + "dtype": str(result.dtype), + "chunk_size": result.chunksize, + "chunk_type": type(result._meta).__name__ + } + } + else: + processed_result = {} + processed_result['info'] = result + processed_result['entity_type'] = type(result).__name__ + if result is not None: + results_cache_node = Data(str((uuid.uuid4())), processed_result['entity_type']) + results_cache_node._info = processed_result['info'] + task.add_output(results_cache_node) + self.workflow.add_data(results_cache_node) results_cache[node] = result + # Loading data info + process_id = node_with_data.get("process_id") + resolved_kwargs = node_with_data.get("resolved_kwargs", {}) + + if process_id in ("load_stac", "load_collection"): + key = "url" if process_id == "load_stac" else "id" + raw_source = resolved_kwargs.get(key, "") + data_source = raw_source.split("\\")[-1] + + data_src = Data(str(uuid.uuid4()), data_source) + # Extract extra information + if process_id == "load_stac": + data_src._info = resolved_kwargs + + + task._start_time = execution_data['start_time'] + task._end_time = execution_data['end_time'] + task._status = execution_data['task_status'] + task._level = "1" + + # This is just for load stac ( for the temporary usage) + if node_with_data['process_id'] in ["load_stac", "load_collection"]: + task.add_input(data_src) + + self.workflow.add_task(task) + + if cache_users: + for source_node, target_node in cache_users.items(): + output_data_from_source = self.workflow.get_task_by_id(source_node)._outputs[0]._id + for target in target_node: + self.workflow.get_task_by_id(target) .add_input( + self.workflow.get_data_by_id(output_data_from_source) + ) + + edges = [ + {"source": source, "target": target, "type": data["reference_type"]} + for source, target, data in self.G.edges(node, data=True)] + + for edge in edges: + self.workflow.get_task_by_id(edge['source']).set_next(self.workflow.get_task_by_id(edge['target'])) + + if node == self.result_node: + self.workflow._status= "Ok" + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + save_path = os.path.join(os.getcwd(), f"run_{timestamp}") + print(f"Provenance file saved to: {save_path}") + + # Create the new directory + os.makedirs(save_path, exist_ok=True) + + + self.workflow.prov_to_json(directory_path=save_path) + return result return partial(node_callable, parent_callables=parent_callables) @@ -471,3 +595,36 @@ def plot(self, reverse=False): if reverse: self.G = self.G.reverse() + + @staticmethod + def profile_function(func): + """ Decorator to track execution performance and return both result and profiling data. + In the case in the future there will be some more metrics of intrest (like cpu and memory + usage) to extract.""" + + @wraps(func) + def wrapper(*args, named_parameters, **kwargs): + start_dt = datetime.now() + start_timestamp = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + + try: + result = func(*args, named_parameters, **kwargs) + status = "Ok" + except Exception as e: + result = str(e) + status = f"Error: {result[:70]}" + + end_dt = datetime.now() + end_timestamp = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + execution_time = (end_dt - start_dt).total_seconds() + execution_data = { + # "function": func.__name__, + "task_status": status, + "start_time": start_timestamp, + "end_time": end_timestamp, + "execution_time_sec": round(execution_time, 4), + } + # Return both the result and profiling data + return result, execution_data + + return wrapper