Source code for doot.control.tracker.network

  1#!/usr/bin/env python3
  2"""
  3The network of task relations.
  4
  5Uses an nx.Digraph internally.
  6Is build 'backwards', as this preserves the meaning
  7of graph.pred[x]  = [y] as y.depends_on[x]
  8and graph.succ[x] = [y] as y.required_for[x]
  9
 10"""
 11# ruff: noqa: ERA001
 12# Imports:
 13from __future__ import annotations
 14
 15# ##-- stdlib imports
 16import datetime
 17import enum
 18import functools as ftz
 19import itertools as itz
 20import logging as logmod
 21import pathlib as pl
 22import re
 23import time
 24import types
 25from collections import defaultdict
 26from itertools import chain, cycle
 27from uuid import UUID, uuid1
 28
 29# ##-- end stdlib imports
 30
 31# ##-- 3rd party imports
 32from jgdv import Proto, Mixin
 33import networkx as nx
 34from jgdv.structs.chainguard import ChainGuard
 35from jgdv.structs.dkey import DKey
 36# ##-- end 3rd party imports
 37
 38# ##-- 1st party imports
 39import doot
 40import doot.errors
 41from ._interface import EdgeType_e
 42from doot.workflow import ActionSpec, TaskName, TaskSpec, DootTask, RelationSpec, TaskArtifact
 43# ##-- end 1st party imports
 44
 45import matplotlib.pyplot as plt
 46from . import _interface as API # noqa: N812
 47from doot.workflow._interface import TaskName_p, Artifact_i, RelationSpec_i
 48
 49# ##-- types
 50# isort: off
 51import abc
 52import collections.abc
 53from typing import TYPE_CHECKING, cast, assert_type, assert_never
 54from typing import Generic, NewType
 55# Protocols:
 56from typing import Protocol, runtime_checkable
 57# Typing Decorators:
 58from typing import no_type_check, final, override, overload
 59
 60if TYPE_CHECKING:
 61    import weakref
 62    from doot.workflow._interface import TaskStatus_e, ArtifactStatus_e
 63    from jgdv import Maybe
 64    from typing import Final
 65    from typing import ClassVar, Any, LiteralString
 66    from typing import Never, Self, Literal
 67    from typing import TypeGuard
 68    from collections.abc import Iterable, Iterator, Callable, Generator
 69    from collections.abc import Sequence, Mapping, MutableMapping, Hashable
 70    from .track_registry import TrackRegistry
 71    from doot.worfklow._interface import TaskSpec_i
 72
 73    type Abstract[T] = T
 74    type Concrete[T] = T
 75
 76    type ActionElem  = ActionSpec|RelationSpec
 77    type ActionGroup = list[ActionElem]
 78##--|
 79
 80# isort: on
 81# ##-- end types
 82
 83##-- logging
 84logging          = logmod.getLogger(__name__)
 85logging.disabled = False
 86##-- end logging
 87
 88show_graph : Final[bool] = doot.config.on_fail(False, bool).settings.commands.run.show() # type: ignore[attr-defined]  # noqa: FBT003
 89
 90DRAW_OPTIONS : Final[dict]  = dict(
 91    with_labels=True,
 92    # arrowstyle="->",
 93    node_color="green",
 94    verticalalignment="baseline",
 95    bbox={"edgecolor": "k", "facecolor": "white", "alpha": 0.5 },
 96)
 97##--|
 98
[docs] 99class _Expansion_m: 100 101 _tracker : API.WorkflowTracker_i 102 pred : Mapping 103 succ : Mapping 104 nodes : Mapping 105 edges : Mapping 106 _graph : Any 107 non_expanded : set 108
[docs] 109 def build_network(self, *, sources:Maybe[Literal[True]|list[Concrete[TaskName_p]|Artifact_i]]=None) -> None: 110 """ 111 for each task queued (ie: connected to the root node) 112 expand its dependencies and add into the _graph, until no more nodes to expand. 113 then connect concrete _tracker._registry.artifacts to abstract _tracker._registry.artifacts. 114 115 passing sources=True forces build of any non_expanded nodes that have an edge 116 117 # TODO _graph could be built in total, or on demand 118 """ 119 x : Any 120 processed : set 121 queue : list 122 additions : set 123 ##--| 124 logging.info("[Network.Build] -> Start") 125 match sources: 126 case None: 127 queue = list(self.pred[self._tracker._root_node].keys()) 128 case True: 129 queue_set = set(self.nodes.keys()) 130 queue_set.update([x for x in self.non_expanded if self.pred[x] or self.succ[x]]) 131 queue = list(queue_set) 132 case [*xs]: 133 queue = list(sources) 134 case x: 135 raise TypeError(type(x)) 136 processed = { self._tracker._root_node } 137 logging.info("[Build.Initial] Network Queue: %s", queue) 138 while bool(queue): # expand tasks 139 match (current:=queue.pop()): 140 case x if x in processed or self.nodes[x].get(API.EXPANDED, False): 141 logging.debug("[Build.Processed] %s", current) 142 processed.add(x) 143 case TaskName_p() as x if x in self.nodes: 144 additions = self._expand_task_node(x) 145 queue += additions 146 processed.add(x) 147 case Artifact_i() as x if x in self.nodes: 148 additions = self._expand_artifact(x) 149 logging.debug("[Build.Artifact] Expansion produced: %s", additions) 150 queue += additions 151 processed.add(x) 152 case _: 153 raise doot.errors.TrackingError("Unknown value in _graph") 154 155 else: 156 logging.debug("[Network.Build] <- Nodes: %s Edges: %s", len(self.nodes), len(self.edges)) 157 self.report_tree() # type: ignore[attr-defined]
158
[docs] 159 def connect(self, left:Concrete[TaskName_p]|Artifact_i, right:Maybe[Literal[False]|Concrete[TaskName_p]|Artifact_i]=None, **kwargs) -> None: # noqa: ANN003 160 """ 161 Connect a task node to another. left -> right 162 If given left, None, connect left -> API.ROOT 163 if given left, False, just add the node 164 165 (This preserves graph.pred[x] as the nodes x is dependent on) 166 """ 167 assert("type" not in kwargs) 168 self._add_node(left) 169 match right: 170 case False: 171 return 172 case None: 173 right = self._tracker._root_node 174 self._add_node(right) 175 case x: 176 self._add_node(right) 177 178 if left in self.succ and right in self.succ[left]: 179 # nothing to do 180 return 181 182 # Add the edge, with metadata 183 match left, right: 184 case TaskName_p(), TaskName_p(): 185 logging.debug("[Connect] %s -> %s", left, right) 186 self._graph.add_edge(left, right, type=EdgeType_e.TASK, **kwargs) 187 case TaskName_p(), Artifact_i(): 188 logging.debug("[Connect] %s -> %s", left[:], right) 189 self._graph.add_edge(left, right, type=EdgeType_e.TASK_CROSS, **kwargs) 190 case Artifact_i(), TaskName_p(): 191 logging.debug("[Connect] %s -> %s", left, right[:]) 192 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_CROSS, **kwargs) 193 case Artifact_i(), Artifact_i() if left.is_concrete() and right.is_concrete(): 194 raise doot.errors.TrackingError("Tried to connect two concrete _tracker._registry.artifacts", left, right) 195 case Artifact_i(), Artifact_i() if right.is_concrete(): 196 logging.debug("[Connect] %s -> %s", left, right) 197 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_UP, **kwargs) 198 case Artifact_i(), Artifact_i() if not right.is_concrete(): 199 logging.debug("[Connect] %s -> %s", left, right) 200 self._graph.add_edge(left, right, type=EdgeType_e.ARTIFACT_DOWN, **kwargs)
201 202 ##--| internal 203
[docs] 204 def _add_node(self, name:Concrete[TaskName_p]|Artifact_i) -> None: 205 """idempotent""" 206 match name: 207 case x if x in self.nodes: 208 return 209 case x if x is self._tracker._root_node: 210 if x in self._graph: 211 return 212 self._graph.add_node(name) 213 self.nodes[name][API.EXPANDED] = True 214 self.nodes[name][API.REACTIVE_ADD] = False 215 case TaskName_p() as x if not x.uuid(): 216 raise doot.errors.TrackingError("Nodes should only be instantiated spec names", x) 217 case TaskName_p() as x if x not in self._tracker.specs: 218 raise doot.errors.TrackingError("Can't connect a non-existent task", x) 219 case Artifact_i() as x if x not in self._tracker.artifacts: 220 raise doot.errors.TrackingError("Can't connect a non-existent artifact", x) 221 case Artifact_i(): # Add node with metadata 222 logging.debug("[Network.Artifact.+] %s", name) 223 self._graph.add_node(name) 224 self.nodes[name][API.EXPANDED] = False 225 self.nodes[name][API.REACTIVE_ADD] = False 226 self.non_expanded.add(name) 227 case TaskName_p(): # Add node with metadata 228 logging.debug("[Network.Task.+] %s", name) 229 self._graph.add_node(name) 230 self.nodes[name][API.EXPANDED] = False 231 self.nodes[name][API.REACTIVE_ADD] = False 232 self.non_expanded.add(name) 233 case x: 234 raise TypeError(type(x))
235
[docs] 236 def _expand_task_node(self, name:Concrete[TaskName_p]) -> set[Concrete[TaskName_p]|Artifact_i]: 237 """ expand a task node, instantiating and connecting to its dependencies and dependents, 238 *without* expanding those new nodes. 239 returns a list of the new nodes tasknames 240 """ 241 to_expand : set[TaskName_p|Artifact_i] 242 spec : TaskSpec_i 243 assert(name.uuid()) 244 assert(not self.nodes[name].get(API.EXPANDED, False)) 245 spec = self._tracker.specs[name].spec 246 spec_pred, spec_succ = self.pred[name], self.succ[name] 247 to_expand = set() 248 249 logging.info("[Build.Expand.Task] -> %s : Pre(%s), Post(%s)", name, len(spec.depends_on), len(spec.required_for)) 250 251 # Connect Relations 252 for rel in self._tracker._factory.action_group_elements(spec): 253 if not isinstance(rel, RelationSpec_i): 254 # Ignore Actions 255 continue 256 relevant_edges = spec_succ if rel.forward_dir_p() else spec_pred 257 match rel: 258 case RelationSpec_i(target=Artifact_i() as target): 259 # Connect the artifact mentioned 260 assert(target in self._tracker.artifacts) 261 self.connect(*rel.to_ordered_pair(name)) # type: ignore[arg-type] 262 to_expand.add(target) 263 case RelationSpec_i(target=TaskName_p() as target): 264 # Get specs and instances with matching target 265 instance = self._tracker._instantiate(rel, control=name) 266 self.connect(*rel.to_ordered_pair(name, target=instance)) # type: ignore[arg-type] 267 to_expand.add(instance) 268 else: 269 assert(name in self.nodes) 270 self.nodes[name][API.EXPANDED] = True 271 self.non_expanded.remove(name) 272 273 to_expand.update(self._generate_blockers(name)) 274 to_expand.update(self._generate_successor_nodes(spec)) 275 logging.debug("[Build.Expand.Task] <- %s : %s", name, to_expand) 276 return to_expand
277
[docs] 278 def _generate_successor_nodes(self, spec:Concrete[TaskSpec]) -> list[Concrete[TaskName_p]]: 279 """ 280 instantiate and connect a job's head task 281 282 for a spec S, find the tasks T that have registered a relation 283 of T < S. 284 (S would not know about these blockers). 285 286 For these T, link instantiated nodes that match constraints and link them to S, 287 or if no nodes exist, create and link them. 288 """ 289 result = [] 290 logging.debug("[Build.Task.Successor] : %s", spec.name) 291 names = self._tracker._subfactory.generate_names(spec) 292 assert(len(names) <= 1) 293 for x in names: 294 assert(x.uuid() == spec.name.uuid()) 295 assert(x in self._tracker.specs) 296 self.connect(spec.name, x) 297 result.append(x) 298 else: 299 logging.debug("[Successors] : %s", result) 300 return result
301
[docs] 302 def _generate_blockers(self, name:TaskName_p) -> set[Concrete[TaskName_p]|Artifact_i]: 303 x : Any 304 target : TaskName_p 305 ##--| 306 results = [] 307 blockers = set() 308 logging.info("[Blockers] : %s", name) 309 if name in self._tracker.specs: 310 blockers.update(self._tracker.specs[name].blocked_by) 311 if (x:=name.de_uniq()) in self._tracker.specs: 312 blockers.update(self._tracker.specs[x].blocked_by) 313 for blocker in blockers: 314 if blocker in self.pred[name]: 315 continue 316 instance = self._tracker._instantiate(blocker) 317 if instance in self.pred[name]: 318 continue 319 self.connect(instance, name) 320 results.append(instance) 321 logging.info("[Blocker.Connect] %s -> %s", instance, name) 322 else: 323 logging.debug("[Build.Generate.Blockers] %s -> %s", name, results) 324 return results
325
[docs] 326 def _expand_artifact(self, artifact:Artifact_i) -> set[Concrete[TaskName_p]|Artifact_i]: 327 """ expand _tracker._registry.artifacts, instantiating related tasks, 328 and connecting the task to its abstract/concrete related _tracker._registry.artifacts 329 """ 330 to_expand : set[TaskName_p|Artifact_i] 331 meta : API.ArtifactMeta_d 332 abstract : TaskName_p | Artifact_i 333 ##--| 334 assert(artifact in self._tracker.artifacts) 335 assert(artifact in self.nodes) 336 assert(not self.nodes[artifact].get(API.EXPANDED, False)) 337 logging.info("[Build.Expand.Artifact] --> %s", artifact) 338 to_expand = set() 339 340 meta = self._tracker.artifacts[artifact] 341 relevant = list(meta.builders) 342 logging.debug("-- Instantiating Artifact relevant tasks: %s", len(relevant)) 343 for name in relevant: 344 instance = self._tracker._instantiate(name) 345 assert(instance is not None) 346 self.connect(instance, False) # noqa: FBT003 347 to_expand.add(instance) 348 349 match artifact.is_concrete(): 350 case True: 351 logging.debug("-- Connecting concrete artifact to parent abstracts") 352 art_path = DKey[pl.Path](artifact[1,:])(relative=True) # type: ignore[operator] 353 for abstract in self._tracker.abstract: 354 match abstract: 355 case TaskName_p(): 356 continue 357 case Artifact_i() as x if art_path not in x and artifact not in x: 358 continue 359 case _: 360 self.connect(artifact, abstract) 361 to_expand.add(abstract) 362 case False: 363 logging.debug("-- Connecting abstract task to child concrete _tracker._registry.artifacts") 364 for conc in self._tracker.concrete: 365 match conc: 366 case TaskName_p(): 367 continue 368 case Artifact_i(): 369 assert(conc.is_concrete()) 370 conc_path = DKey[pl.Path](conc[1,:])(relative=True) # type: ignore[operator] 371 if conc_path not in artifact: 372 continue 373 self.connect(conc, artifact) 374 to_expand.add(conc) 375 376 logging.info("[Build.Expand.Artifact] <-- %s -> %s", artifact, to_expand) 377 self.nodes[artifact][API.EXPANDED] = True 378 self.non_expanded.remove(artifact) 379 return to_expand
380
[docs] 381class _Validation_m: 382 383 _tracker : API.WorkflowTracker_i 384 _graph : Any 385 nodes : Mapping 386 edges : Mapping 387 pred : Mapping 388 succ : Mapping 389
[docs] 390 def validate_network(self, *, strict:bool=True) -> bool: # noqa: PLR0912 391 """ Finalise and ensure consistence of the task _graph. 392 run tests to check the dependency graph is acceptable 393 """ 394 logging.info("Validating Task Network") 395 if not nx.is_directed_acyclic_graph(self._graph): 396 raise doot.errors.TrackingError("Network isn't a DAG") 397 398 failures = [] 399 for node, data in self.nodes.items(): 400 match node: 401 case TaskName_p() as x if x == self._tracker._root_node: # Ignore the root 402 pass 403 case TaskName_p(): 404 if not data.get(API.EXPANDED, False): # every node is expanded 405 failures.append(f"{node} is not expanded") 406 if not node.uuid(): # every node is uniq 407 failures.append(f"{node} is not unique") 408 if node not in self._tracker.specs: # every node has a spec 409 failures.append(f"{node} has no backing spec") 410 case Artifact_i(): 411 if not data.get(API.EXPANDED, False): # Every node is expanded 412 failures.append(f"{node} is not expanded") 413 if (TaskArtifact.Wild.glob in node 414 and not bool(self._graph.pred[node])): 415 failures.append(f"{node} has no concrete predecessors") 416 else: 417 if not self._tracker.is_valid: 418 raise doot.errors.TrackingError("Network is not marked as valid") 419 420 if not bool(failures): 421 return True 422 423 if strict: 424 raise doot.errors.TrackingError("Errors in network", failures) 425 else: 426 logging.warning("Failures in network: %s", failures) 427 return False
428
[docs] 429 def concrete_edges(self, name:Concrete[TaskName_p|TaskArtifact]) -> ChainGuard: 430 """ get the concrete edges of a task. 431 ie: the ones in the task _graph, not the abstract ones in the spec. 432 """ 433 assert(name in self.nodes) 434 preds = self.pred[name] 435 succ = self.succ[name] 436 return ChainGuard({ 437 "pred" : {"tasks": [x for x in preds if isinstance(x, TaskName)], 438 "_tracker._registry.artifacts": {"abstract": [x for x in preds if isinstance(x, TaskArtifact) and not x.is_concrete()], 439 "concrete": [x for x in preds if isinstance(x, TaskArtifact) and x.is_concrete()]}}, 440 "succ" : {"tasks": [x for x in succ if isinstance(x, TaskName) and x is not self._tracker._root_node], 441 "_tracker._registry.artifacts": {"abstract": [x for x in succ if isinstance(x, TaskArtifact) and not x.is_concrete()], 442 "concrete": [x for x in succ if isinstance(x, TaskArtifact) and x.is_concrete()]}}, 443 "root" : self._tracker._root_node in succ, 444 })
445
[docs] 446 def report_tree(self) -> None: 447 """ Use networkx + plt's graph drawing to inspect the constructed graph """ 448 mapping : dict[TaskName_p|Artifact_i, str] 449 if not show_graph: 450 return 451 452 mapping = {} 453 count = 0 454 for x in self._graph.nodes: 455 match x: 456 case Artifact_i(): 457 mapping[x] = str(x) 458 case TaskName_p(): 459 mapping[x] = f"{x[:]}.{count}" 460 count += 1 461 462 mapping[self._tracker._root_node] = cast("str", self._tracker._root_node) 463 undir = nx.Graph(self._graph) 464 undir = nx.relabel_nodes(undir, mapping) 465 466 sub = undir.subgraph(nx.node_connected_component(undir, self._tracker._root_node)) 467 nx.draw(sub, pos=nx.bfs_layout(sub, self._tracker._root_node), **DRAW_OPTIONS) 468 plt.show()
469 470##--| 471
[docs] 472@Mixin(_Expansion_m, _Validation_m) 473class TrackNetwork: 474 """ The _graph of concrete tasks and their dependencies """ 475 # TODO use this instaed of _tracker._registry 476 _tracker : API.WorkflowTracker_p 477 _graph : nx.DiGraph[Concrete[TaskName_p]|TaskArtifact] 478 479 non_expanded : set[TaskName_p|Artifact_i] 480 481 def __init__(self, *, tracker:API.WorkflowTracker_p) -> None: 482 match tracker: 483 case API.WorkflowTracker_p(): 484 self._tracker = tracker 485 case x: 486 raise TypeError(type(x)) 487 self._graph = nx.DiGraph() 488 self.non_expanded = set() 489 self._add_node(self._tracker._root_node) # type: ignore[attr-defined] 490 491 ##--| properties 492
[docs] 493 @property 494 def nodes(self) -> dict: 495 return self._graph.nodes
496
[docs] 497 @property 498 def edges(self) -> dict: 499 return self._graph.edges
500
[docs] 501 @property 502 def pred(self) -> dict: 503 return self._graph.pred
504
[docs] 505 @property 506 def adj(self) -> dict: 507 return self._graph.adj
508
[docs] 509 @property 510 def succ(self) -> dict: 511 return self._graph.succ
512 513 ##--| dunders 514 515 def __len__(self) -> int: 516 return len(self._graph.nodes) 517 518 def __contains__(self, other:Concrete[TaskName_p]|TaskArtifact) -> bool: 519 return other in self._graph