1#!/usr/bin/env python3
2"""
3The network of task relations.
4
5Uses an nx.Digraph internally.
6Is build 'backwards', as this preserves the meaning
7of graph.pred[x] = [y] as y.depends_on[x]
8and graph.succ[x] = [y] as y.required_for[x]
9
10"""
11# ruff: noqa: ERA001
12# Imports:
13from __future__ import annotations
14
15# ##-- stdlib imports
16import datetime
17import enum
18import functools as ftz
19import itertools as itz
20import logging as logmod
21import pathlib as pl
22import re
23import time
24import types
25from collections import defaultdict
26from itertools import chain, cycle
27from uuid import UUID, uuid1
28
29# ##-- end stdlib imports
30
31# ##-- 3rd party imports
32from jgdv import Proto, Mixin
33import networkx as nx
34from jgdv.structs.chainguard import ChainGuard
35from jgdv.structs.dkey import DKey
36# ##-- end 3rd party imports
37
38# ##-- 1st party imports
39import doot
40import doot.errors
41from ._interface import EdgeType_e
42from doot.workflow import ActionSpec, TaskName, TaskSpec, DootTask, RelationSpec, TaskArtifact
43# ##-- end 1st party imports
44
45import matplotlib.pyplot as plt
46from . import _interface as API # noqa: N812
47from doot.workflow._interface import TaskName_p, Artifact_i, RelationSpec_i
48
49# ##-- types
50# isort: off
51import abc
52import collections.abc
53from typing import TYPE_CHECKING, cast, assert_type, assert_never
54from typing import Generic, NewType
55# Protocols:
56from typing import Protocol, runtime_checkable
57# Typing Decorators:
58from typing import no_type_check, final, override, overload
59
60if TYPE_CHECKING:
61 import weakref
62 from doot.workflow._interface import TaskStatus_e, ArtifactStatus_e
63 from jgdv import Maybe
64 from typing import Final
65 from typing import ClassVar, Any, LiteralString
66 from typing import Never, Self, Literal
67 from typing import TypeGuard
68 from collections.abc import Iterable, Iterator, Callable, Generator
69 from collections.abc import Sequence, Mapping, MutableMapping, Hashable
70 from .track_registry import TrackRegistry
71 from doot.worfklow._interface import TaskSpec_i
72
73 type Abstract[T] = T
74 type Concrete[T] = T
75
76 type ActionElem = ActionSpec|RelationSpec
77 type ActionGroup = list[ActionElem]
78##--|
79
80# isort: on
81# ##-- end types
82
83##-- logging
84logging = logmod.getLogger(__name__)
85logging.disabled = False
86##-- end logging
87
88show_graph : Final[bool] = doot.config.on_fail(False, bool).settings.commands.run.show() # type: ignore[attr-defined] # noqa: FBT003
89
90DRAW_OPTIONS : Final[dict] = dict(
91 with_labels=True,
92 # arrowstyle="->",
93 node_color="green",
94 verticalalignment="baseline",
95 bbox={"edgecolor": "k", "facecolor": "white", "alpha": 0.5 },
96)
97##--|
98
[docs]
99class _Expansion_m:
100
101 _tracker : API.WorkflowTracker_i
102 pred : Mapping
103 succ : Mapping
104 nodes : Mapping
105 edges : Mapping
106 _graph : Any
107 non_expanded : set
108
[docs]
109 def build_network(self, *, sources:Maybe[Literal[True]|list[Concrete[TaskName_p]|Artifact_i]]=None) -> None:
110 """
111 for each task queued (ie: connected to the root node)
112 expand its dependencies and add into the _graph, until no more nodes to expand.
113 then connect concrete _tracker._registry.artifacts to abstract _tracker._registry.artifacts.
114
115 passing sources=True forces build of any non_expanded nodes that have an edge
116
117 # TODO _graph could be built in total, or on demand
118 """
119 x : Any
120 processed : set
121 queue : list
122 additions : set
123 ##--|
124 logging.info("[Network.Build] -> Start")
125 match sources:
126 case None:
127 queue = list(self.pred[self._tracker._root_node].keys())
128 case True:
129 queue_set = set(self.nodes.keys())
130 queue_set.update([x for x in self.non_expanded if self.pred[x] or self.succ[x]])
131 queue = list(queue_set)
132 case [*xs]:
133 queue = list(sources)
134 case x:
135 raise TypeError(type(x))
136 processed = { self._tracker._root_node }
137 logging.info("[Build.Initial] Network Queue: %s", queue)
138 while bool(queue): # expand tasks
139 match (current:=queue.pop()):
140 case x if x in processed or self.nodes[x].get(API.EXPANDED, False):
141 logging.debug("[Build.Processed] %s", current)
142 processed.add(x)
143 case TaskName_p() as x if x in self.nodes:
144 additions = self._expand_task_node(x)
145 queue += additions
146 processed.add(x)
147 case Artifact_i() as x if x in self.nodes:
148 additions = self._expand_artifact(x)
149 logging.debug("[Build.Artifact] Expansion produced: %s", additions)
150 queue += additions
151 processed.add(x)
152 case _:
153 raise doot.errors.TrackingError("Unknown value in _graph")
154
155 else:
156 logging.debug("[Network.Build] <- Nodes: %s Edges: %s", len(self.nodes), len(self.edges))
157 self.report_tree() # type: ignore[attr-defined]
158
[docs]
159 def connect(self, left:Concrete[TaskName_p]|Artifact_i, right:Maybe[Literal[False]|Concrete[TaskName_p]|Artifact_i]=None, **kwargs) -> None: # noqa: ANN003
160 """
161 Connect a task node to another. left -> right
162 If given left, None, connect left -> API.ROOT
163 if given left, False, just add the node
164
165 (This preserves graph.pred[x] as the nodes x is dependent on)
166 """
167 assert("type" not in kwargs)
168 self._add_node(left)
169 match right:
170 case False:
171 return
172 case None:
173 right = self._tracker._root_node
174 self._add_node(right)
175 case x:
176 self._add_node(right)
177
178 if left in self.succ and right in self.succ[left]:
179 # nothing to do
180 return
181
182 # Add the edge, with metadata
183 match left, right:
184 case TaskName_p(), TaskName_p():
185 logging.debug("[Connect] %s -> %s", left, right)
186 self._graph.add_edge(left, right, type=EdgeType_e.TASK, **kwargs)
187 case TaskName_p(), Artifact_i():
188 logging.debug("[Connect] %s -> %s", left[:], right)
189 self._graph.add_edge(left, right, type=EdgeType_e.TASK_CROSS, **kwargs)
190 case Artifact_i(), TaskName_p():
191 logging.debug("[Connect] %s -> %s", left, right[:])
192 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_CROSS, **kwargs)
193 case Artifact_i(), Artifact_i() if left.is_concrete() and right.is_concrete():
194 raise doot.errors.TrackingError("Tried to connect two concrete _tracker._registry.artifacts", left, right)
195 case Artifact_i(), Artifact_i() if right.is_concrete():
196 logging.debug("[Connect] %s -> %s", left, right)
197 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_UP, **kwargs)
198 case Artifact_i(), Artifact_i() if not right.is_concrete():
199 logging.debug("[Connect] %s -> %s", left, right)
200 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_DOWN, **kwargs)
201
202 ##--| internal
203
[docs]
204 def _add_node(self, name:Concrete[TaskName_p]|Artifact_i) -> None:
205 """idempotent"""
206 match name:
207 case x if x in self.nodes:
208 return
209 case x if x is self._tracker._root_node:
210 if x in self._graph:
211 return
212 self._graph.add_node(name)
213 self.nodes[name][API.EXPANDED] = True
214 self.nodes[name][API.REACTIVE_ADD] = False
215 case TaskName_p() as x if not x.uuid():
216 raise doot.errors.TrackingError("Nodes should only be instantiated spec names", x)
217 case TaskName_p() as x if x not in self._tracker.specs:
218 raise doot.errors.TrackingError("Can't connect a non-existent task", x)
219 case Artifact_i() as x if x not in self._tracker.artifacts:
220 raise doot.errors.TrackingError("Can't connect a non-existent artifact", x)
221 case Artifact_i(): # Add node with metadata
222 logging.debug("[Network.Artifact.+] %s", name)
223 self._graph.add_node(name)
224 self.nodes[name][API.EXPANDED] = False
225 self.nodes[name][API.REACTIVE_ADD] = False
226 self.non_expanded.add(name)
227 case TaskName_p(): # Add node with metadata
228 logging.debug("[Network.Task.+] %s", name)
229 self._graph.add_node(name)
230 self.nodes[name][API.EXPANDED] = False
231 self.nodes[name][API.REACTIVE_ADD] = False
232 self.non_expanded.add(name)
233 case x:
234 raise TypeError(type(x))
235
[docs]
236 def _expand_task_node(self, name:Concrete[TaskName_p]) -> set[Concrete[TaskName_p]|Artifact_i]:
237 """ expand a task node, instantiating and connecting to its dependencies and dependents,
238 *without* expanding those new nodes.
239 returns a list of the new nodes tasknames
240 """
241 to_expand : set[TaskName_p|Artifact_i]
242 spec : TaskSpec_i
243 assert(name.uuid())
244 assert(not self.nodes[name].get(API.EXPANDED, False))
245 spec = self._tracker.specs[name].spec
246 spec_pred, spec_succ = self.pred[name], self.succ[name]
247 to_expand = set()
248
249 logging.info("[Build.Expand.Task] -> %s : Pre(%s), Post(%s)", name, len(spec.depends_on), len(spec.required_for))
250
251 # Connect Relations
252 for rel in self._tracker._factory.action_group_elements(spec):
253 if not isinstance(rel, RelationSpec_i):
254 # Ignore Actions
255 continue
256 relevant_edges = spec_succ if rel.forward_dir_p() else spec_pred
257 match rel:
258 case RelationSpec_i(target=Artifact_i() as target):
259 # Connect the artifact mentioned
260 assert(target in self._tracker.artifacts)
261 self.connect(*rel.to_ordered_pair(name)) # type: ignore[arg-type]
262 to_expand.add(target)
263 case RelationSpec_i(target=TaskName_p() as target):
264 # Get specs and instances with matching target
265 instance = self._tracker._instantiate(rel, control=name)
266 self.connect(*rel.to_ordered_pair(name, target=instance)) # type: ignore[arg-type]
267 to_expand.add(instance)
268 else:
269 assert(name in self.nodes)
270 self.nodes[name][API.EXPANDED] = True
271 self.non_expanded.remove(name)
272
273 to_expand.update(self._generate_blockers(name))
274 to_expand.update(self._generate_successor_nodes(spec))
275 logging.debug("[Build.Expand.Task] <- %s : %s", name, to_expand)
276 return to_expand
277
[docs]
278 def _generate_successor_nodes(self, spec:Concrete[TaskSpec]) -> list[Concrete[TaskName_p]]:
279 """
280 instantiate and connect a job's head task
281
282 for a spec S, find the tasks T that have registered a relation
283 of T < S.
284 (S would not know about these blockers).
285
286 For these T, link instantiated nodes that match constraints and link them to S,
287 or if no nodes exist, create and link them.
288 """
289 result = []
290 logging.debug("[Build.Task.Successor] : %s", spec.name)
291 names = self._tracker._subfactory.generate_names(spec)
292 assert(len(names) <= 1)
293 for x in names:
294 assert(x.uuid() == spec.name.uuid())
295 assert(x in self._tracker.specs)
296 self.connect(spec.name, x)
297 result.append(x)
298 else:
299 logging.debug("[Successors] : %s", result)
300 return result
301
[docs]
302 def _generate_blockers(self, name:TaskName_p) -> set[Concrete[TaskName_p]|Artifact_i]:
303 x : Any
304 target : TaskName_p
305 ##--|
306 results = []
307 blockers = set()
308 logging.info("[Blockers] : %s", name)
309 if name in self._tracker.specs:
310 blockers.update(self._tracker.specs[name].blocked_by)
311 if (x:=name.de_uniq()) in self._tracker.specs:
312 blockers.update(self._tracker.specs[x].blocked_by)
313 for blocker in blockers:
314 if blocker in self.pred[name]:
315 continue
316 instance = self._tracker._instantiate(blocker)
317 if instance in self.pred[name]:
318 continue
319 self.connect(instance, name)
320 results.append(instance)
321 logging.info("[Blocker.Connect] %s -> %s", instance, name)
322 else:
323 logging.debug("[Build.Generate.Blockers] %s -> %s", name, results)
324 return results
325
[docs]
326 def _expand_artifact(self, artifact:Artifact_i) -> set[Concrete[TaskName_p]|Artifact_i]:
327 """ expand _tracker._registry.artifacts, instantiating related tasks,
328 and connecting the task to its abstract/concrete related _tracker._registry.artifacts
329 """
330 to_expand : set[TaskName_p|Artifact_i]
331 meta : API.ArtifactMeta_d
332 abstract : TaskName_p | Artifact_i
333 ##--|
334 assert(artifact in self._tracker.artifacts)
335 assert(artifact in self.nodes)
336 assert(not self.nodes[artifact].get(API.EXPANDED, False))
337 logging.info("[Build.Expand.Artifact] --> %s", artifact)
338 to_expand = set()
339
340 meta = self._tracker.artifacts[artifact]
341 relevant = list(meta.builders)
342 logging.debug("-- Instantiating Artifact relevant tasks: %s", len(relevant))
343 for name in relevant:
344 instance = self._tracker._instantiate(name)
345 assert(instance is not None)
346 self.connect(instance, False) # noqa: FBT003
347 to_expand.add(instance)
348
349 match artifact.is_concrete():
350 case True:
351 logging.debug("-- Connecting concrete artifact to parent abstracts")
352 art_path = DKey[pl.Path](artifact[1,:])(relative=True) # type: ignore[operator]
353 for abstract in self._tracker.abstract:
354 match abstract:
355 case TaskName_p():
356 continue
357 case Artifact_i() as x if art_path not in x and artifact not in x:
358 continue
359 case _:
360 self.connect(artifact, abstract)
361 to_expand.add(abstract)
362 case False:
363 logging.debug("-- Connecting abstract task to child concrete _tracker._registry.artifacts")
364 for conc in self._tracker.concrete:
365 match conc:
366 case TaskName_p():
367 continue
368 case Artifact_i():
369 assert(conc.is_concrete())
370 conc_path = DKey[pl.Path](conc[1,:])(relative=True) # type: ignore[operator]
371 if conc_path not in artifact:
372 continue
373 self.connect(conc, artifact)
374 to_expand.add(conc)
375
376 logging.info("[Build.Expand.Artifact] <-- %s -> %s", artifact, to_expand)
377 self.nodes[artifact][API.EXPANDED] = True
378 self.non_expanded.remove(artifact)
379 return to_expand
380
[docs]
381class _Validation_m:
382
383 _tracker : API.WorkflowTracker_i
384 _graph : Any
385 nodes : Mapping
386 edges : Mapping
387 pred : Mapping
388 succ : Mapping
389
[docs]
390 def validate_network(self, *, strict:bool=True) -> bool: # noqa: PLR0912
391 """ Finalise and ensure consistence of the task _graph.
392 run tests to check the dependency graph is acceptable
393 """
394 logging.info("Validating Task Network")
395 if not nx.is_directed_acyclic_graph(self._graph):
396 raise doot.errors.TrackingError("Network isn't a DAG")
397
398 failures = []
399 for node, data in self.nodes.items():
400 match node:
401 case TaskName_p() as x if x == self._tracker._root_node: # Ignore the root
402 pass
403 case TaskName_p():
404 if not data.get(API.EXPANDED, False): # every node is expanded
405 failures.append(f"{node} is not expanded")
406 if not node.uuid(): # every node is uniq
407 failures.append(f"{node} is not unique")
408 if node not in self._tracker.specs: # every node has a spec
409 failures.append(f"{node} has no backing spec")
410 case Artifact_i():
411 if not data.get(API.EXPANDED, False): # Every node is expanded
412 failures.append(f"{node} is not expanded")
413 if (TaskArtifact.Wild.glob in node
414 and not bool(self._graph.pred[node])):
415 failures.append(f"{node} has no concrete predecessors")
416 else:
417 if not self._tracker.is_valid:
418 raise doot.errors.TrackingError("Network is not marked as valid")
419
420 if not bool(failures):
421 return True
422
423 if strict:
424 raise doot.errors.TrackingError("Errors in network", failures)
425 else:
426 logging.warning("Failures in network: %s", failures)
427 return False
428
[docs]
429 def concrete_edges(self, name:Concrete[TaskName_p|TaskArtifact]) -> ChainGuard:
430 """ get the concrete edges of a task.
431 ie: the ones in the task _graph, not the abstract ones in the spec.
432 """
433 assert(name in self.nodes)
434 preds = self.pred[name]
435 succ = self.succ[name]
436 return ChainGuard({
437 "pred" : {"tasks": [x for x in preds if isinstance(x, TaskName)],
438 "_tracker._registry.artifacts": {"abstract": [x for x in preds if isinstance(x, TaskArtifact) and not x.is_concrete()],
439 "concrete": [x for x in preds if isinstance(x, TaskArtifact) and x.is_concrete()]}},
440 "succ" : {"tasks": [x for x in succ if isinstance(x, TaskName) and x is not self._tracker._root_node],
441 "_tracker._registry.artifacts": {"abstract": [x for x in succ if isinstance(x, TaskArtifact) and not x.is_concrete()],
442 "concrete": [x for x in succ if isinstance(x, TaskArtifact) and x.is_concrete()]}},
443 "root" : self._tracker._root_node in succ,
444 })
445
[docs]
446 def report_tree(self) -> None:
447 """ Use networkx + plt's graph drawing to inspect the constructed graph """
448 mapping : dict[TaskName_p|Artifact_i, str]
449 if not show_graph:
450 return
451
452 mapping = {}
453 count = 0
454 for x in self._graph.nodes:
455 match x:
456 case Artifact_i():
457 mapping[x] = str(x)
458 case TaskName_p():
459 mapping[x] = f"{x[:]}.{count}"
460 count += 1
461
462 mapping[self._tracker._root_node] = cast("str", self._tracker._root_node)
463 undir = nx.Graph(self._graph)
464 undir = nx.relabel_nodes(undir, mapping)
465
466 sub = undir.subgraph(nx.node_connected_component(undir, self._tracker._root_node))
467 nx.draw(sub, pos=nx.bfs_layout(sub, self._tracker._root_node), **DRAW_OPTIONS)
468 plt.show()
469
470##--|
471
[docs]
472@Mixin(_Expansion_m, _Validation_m)
473class TrackNetwork:
474 """ The _graph of concrete tasks and their dependencies """
475 # TODO use this instaed of _tracker._registry
476 _tracker : API.WorkflowTracker_p
477 _graph : nx.DiGraph[Concrete[TaskName_p]|TaskArtifact]
478
479 non_expanded : set[TaskName_p|Artifact_i]
480
481 def __init__(self, *, tracker:API.WorkflowTracker_p) -> None:
482 match tracker:
483 case API.WorkflowTracker_p():
484 self._tracker = tracker
485 case x:
486 raise TypeError(type(x))
487 self._graph = nx.DiGraph()
488 self.non_expanded = set()
489 self._add_node(self._tracker._root_node) # type: ignore[attr-defined]
490
491 ##--| properties
492
[docs]
493 @property
494 def nodes(self) -> dict:
495 return self._graph.nodes
496
[docs]
497 @property
498 def edges(self) -> dict:
499 return self._graph.edges
500
[docs]
501 @property
502 def pred(self) -> dict:
503 return self._graph.pred
504
[docs]
505 @property
506 def adj(self) -> dict:
507 return self._graph.adj
508
[docs]
509 @property
510 def succ(self) -> dict:
511 return self._graph.succ
512
513 ##--| dunders
514
515 def __len__(self) -> int:
516 return len(self._graph.nodes)
517
518 def __contains__(self, other:Concrete[TaskName_p]|TaskArtifact) -> bool:
519 return other in self._graph