Source code for doot.control.tracker.naive_tracker

  1#!/usr/bin/env python3
  2"""
  3Abstract Specs: A[n]
  4Concrete Specs: C[n]
  5Task:           T[n]
  6
  7  Expansion: ∀x ∈ C[n].depends_on => A[x] -> C[x]
  8  Head: C[1].depends_on[A[n].$head$] => A[n] -> C[n], A[n].head -> C[n].head, connect
  9
 10"""
 11# Imports:
 12from __future__ import annotations
 13
 14# ##-- stdlib imports
 15import datetime
 16import enum
 17import functools as ftz
 18import itertools as itz
 19import logging as logmod
 20import pathlib as pl
 21import re
 22import time
 23import types
 24import weakref
 25from collections import defaultdict
 26from itertools import chain, cycle
 27from uuid import UUID, uuid1
 28
 29# ##-- end stdlib imports
 30
 31# ##-- 3rd party imports
 32from jgdv import Proto
 33
 34# ##-- end 3rd party imports
 35
 36# ##-- 1st party imports
 37import doot
 38import doot.errors
 39from doot.workflow.factory import SubTaskFactory, TaskFactory
 40from doot.workflow import (ActionSpec, DootTask, InjectSpec, RelationSpec,
 41                           TaskArtifact, TaskName, TaskSpec)
 42from doot.workflow._interface import (CLI_K, MUST_INJECT_K, Artifact_i,
 43                                      ArtifactStatus_e, InjectSpec_i,
 44                                      RelationSpec_i, Task_i, Task_p,
 45                                      TaskName_p, TaskSpec_i, TaskStatus_e,
 46                                      DelayedSpec)
 47
 48# ##-- end 1st party imports
 49
 50# ##-| Local
 51from . import _interface as API # noqa: N812
 52from ._base import Tracker_abs
 53from ._interface import WorkflowTracker_p
 54from .network import TrackNetwork
 55from .queue import TrackQueue
 56from .registry import TrackRegistry
 57
 58# # End of Imports.
 59
 60# ##-- types
 61# isort: off
 62import abc
 63import collections.abc
 64from typing import TYPE_CHECKING, cast, assert_type, assert_never
 65from typing import Generic, NewType
 66# Protocols:
 67from typing import Protocol, runtime_checkable
 68# Typing Decorators:
 69from typing import no_type_check, final, override, overload
 70
 71if TYPE_CHECKING:
 72    from jgdv import Maybe
 73    from typing import Final
 74    from typing import ClassVar, Any, LiteralString
 75    from typing import Never, Self, Literal
 76    from typing import TypeGuard
 77    from collections.abc import Iterable, Iterator, Callable, Generator
 78    from collections.abc import Sequence, Mapping, MutableMapping, Hashable
 79    from networkx import DiGraph
 80
 81    from doot.workflow._interface import TaskFactory_p, SubTaskFactory_p
 82    type Abstract[T] = T
 83    type Concrete[T] = T
 84
 85##--|
 86# isort: on
 87# ##-- end types
 88
 89##-- logging
 90logging    = logmod.getLogger(__name__)
 91##-- end logging
 92
 93assert(isinstance(TrackRegistry, API.Registry_p))
 94##--|
 95
[docs] 96@Proto(WorkflowTracker_p) 97class NaiveTracker(Tracker_abs): 98 """ Specific implementations for the default naive tracker """ 99 _registry : TrackRegistry 100
[docs] 101 def next_for(self, target:Maybe[str|TaskName_p]=None) -> Maybe[Task_p|Artifact_i]: 102 """ ask for the next task that can be performed 103 104 Returns a Task or Artifact that needs to be executed or created 105 Returns None if it loops too many times trying to find a target, 106 or if theres nothing left in the queue 107 108 """ 109 count : int 110 focus : TaskName_p|Artifact_i 111 idx : int 112 result : Maybe[Task_p|Artifact_i] 113 status : TaskStatus_e|ArtifactStatus_e 114 x : Any 115 116 logging.info("[Next.For] (Active: %s)", len(self.active)) 117 if not self.is_valid: 118 raise doot.errors.TrackingError("Network is in an invalid internal_state") 119 120 if target and target not in self.active: 121 self.queue(target, silent=True) 122 123 idx, count = 0, API.MAX_LOOP 124 result = None 125 while (result is None) and bool(self._queue) and 0 < (count:=count-1) and (idx:=idx+1): 126 focus = self._deque() 127 status, _ = self.get_status(target=focus) 128 logging.debug("[Next.For.%-3s]: %s : %s", idx, status, focus) 129 130 match focus: 131 case x if x not in self.active: 132 continue 133 case TaskName_p(): 134 result = self._next_for_task(focus) 135 case Artifact_i(): 136 result = self._next_for_artifact(focus) 137 case x: # Error otherwise 138 raise doot.errors.TrackingError("Unknown task focus", x) 139 140 else: 141 logging.info("[Next.For] <- %s", result) 142 return result
143
[docs] 144 def _next_for_task(self, focus:TaskName_p) -> Maybe[Task_p]: # noqa: PLR0912, PLR0915 145 """ logic for handling a dequed task """ 146 x : Any 147 status, _ = self.get_status(target=focus) # type: ignore[attr-defined] 148 match status: 149 case TaskStatus_e.DEAD: 150 # Clear internal_state 151 self.specs[focus].task = TaskStatus_e.DEAD 152 self.active.remove(focus) 153 assert(focus not in self.active) 154 case TaskStatus_e.DISABLED: 155 self.active.remove(focus) 156 case TaskStatus_e.TEARDOWN: 157 # Queue cleanup tasks 158 for succ, _ in self._successor_states_of(focus): 159 match self.queue(succ): 160 case TaskName() as x if x.is_cleanup(): 161 # make the cleanup task early, to apply shared internal_state 162 assert(isinstance(focus, TaskName)) 163 self._instantiate(x, parent=focus, task=True) 164 case _: 165 pass 166 else: 167 # TODO for cleanup succ, move focus.internal_state -> succ.internal_state 168 self.set_status(focus, TaskStatus_e.DEAD) # type: ignore[attr-defined] 169 case TaskStatus_e.SUCCESS: 170 self.queue(focus, status=TaskStatus_e.TEARDOWN) 171 case TaskStatus_e.FAILED: # propagate failure 172 self.queue(focus, status=TaskStatus_e.TEARDOWN) 173 case TaskStatus_e.HALTED: # remove and propagate halted status 174 self.queue(focus, status=TaskStatus_e.TEARDOWN) 175 case TaskStatus_e.SKIPPED: 176 self.queue(focus, status=TaskStatus_e.DEAD) 177 case TaskStatus_e.RUNNING: 178 self.queue(focus) 179 case TaskStatus_e.READY: # return the task if its ready 180 self.queue(focus, status=TaskStatus_e.RUNNING) 181 return cast("Task_p", self.specs[focus].task) 182 case TaskStatus_e.WAIT: # Add dependencies of a task to the stack 183 waiting : bool = False 184 deps_of_focus = self._dependency_states_of(focus) 185 for dep, dep_status in deps_of_focus: 186 if dep_status in API.SUCCESS_STATUSES: 187 continue 188 self.queue(dep) 189 waiting = True 190 else: 191 match waiting: 192 case False: 193 self.queue(focus, status=TaskStatus_e.READY) 194 case True: 195 logging.debug("[Next.For] Task Blocked: %s on : %s", focus, deps_of_focus) 196 self.queue(focus) 197 case TaskStatus_e.INIT: 198 self.queue(focus, status=TaskStatus_e.WAIT) 199 case TaskStatus_e.DEFINED: 200 self._instantiate(focus, task=True) 201 self.queue(focus) 202 case TaskStatus_e.DECLARED: 203 self.queue(focus, status=TaskStatus_e.DEFINED) 204 case TaskStatus_e.NAMED: 205 logging.warning("A Name only was queued, it has no backing in the tracker: %s", focus) 206 case x: # Error otherwise 207 raise doot.errors.TrackingError("Unknown task internal_state", x) 208 ##--| 209 return None
210
[docs] 211 def _next_for_artifact(self, focus:Artifact_i) -> Maybe[Artifact_i]: # noqa: PLR0912 212 """ logic for handling a dequed artifact """ 213 status, _ = self.get_status(target=focus) # type: ignore[attr-defined] 214 match status: 215 case ArtifactStatus_e.EXISTS: 216 # TODO artifact Exists, queue its dependents and *don't* add the artifact back in 217 pass 218 case ArtifactStatus_e.STALE: 219 for pred, _ in self._dependency_states_of(focus): 220 self.queue(pred) 221 case ArtifactStatus_e.DECLARED if bool(focus): 222 self.queue(focus) 223 case ArtifactStatus_e.DECLARED: # Add dependencies of an artifact to the stack 224 deps : list[tuple] = self._dependency_states_of(focus) 225 match deps: 226 case [] if not focus.is_concrete(): 227 self.queue(focus) 228 case []: 229 assert(not bool(focus)) 230 path = doot.locs[focus] 231 # Returns the artifact, the runner can try to create 232 # it, then override the halt 233 return focus 234 case [*xs]: 235 logging.info("[Next.For] Artifact Blocked, queuing producer tasks, count: %s", len(xs)) 236 case x: 237 raise TypeError(type(x)) 238 239 for dep, dep_state in deps: 240 if dep_state in API.SUCCESS_STATUSES: 241 continue 242 self.queue(dep) 243 else: 244 # No need to requeue, as tasks will check for the artifacts themselves 245 pass 246 247 case x: # Error otherwise 248 raise doot.errors.TrackingError("Unknown task internal_state", x) 249 250 ##--| 251 return None
252
[docs] 253 @override 254 def _instantiate(self, target:TaskName_p|RelationSpec_i, *args:Any, task:bool=False, **kwargs:Any) -> Maybe[TaskName_p]: 255 """ extends base instantiation to add late injection for tasks """ 256 parent : TaskName_p 257 result : Maybe[TaskName_p] 258 ##--| 259 parent = kwargs.pop("parent", None) 260 match super()._instantiate(target, *args, task=task, **kwargs): 261 case TaskName_p() as result if task: 262 self._apply_injections(result, parent=parent) 263 case result: 264 pass 265 266 return result
267 268 ##--| internal 269
[docs] 270 def _dependency_states_of(self, focus:TaskName_p|Artifact_i) -> list[tuple]: 271 return [(x, self.get_status(target=x)[0]) for x in self._network.pred[focus]]
272
[docs] 273 def _successor_states_of(self, focus:TaskName_p|Artifact_i) -> list[tuple]: 274 return [(x, self.get_status(target=x)[0]) for x in self._network.succ[focus]]
275
[docs] 276 def _deque(self) -> TaskName_p|Artifact_i: 277 focus = self._queue.deque_entry() 278 match self.specs.get(focus, focus): # type: ignore[arg-type] 279 case None | API.SpecMeta_d(task=None) | TaskName_p(): 280 pass 281 case API.SpecMeta_d(task=Task_p() as task) if task.priority < self._min_priority: 282 logging.warning("[Deque] Halting (Min Priority) : %s", focus[:]) 283 self.set_status(focus, TaskStatus_e.HALTED) 284 case API.SpecMeta_d(task=Task_p() as task): 285 prior = task.priority 286 task.priority = 1 287 logging.debug("[Deque] %s -> %s : %s", prior, task.priority, focus[:]) 288 case TaskArtifact() as focus: # type: ignore[misc] 289 focus.priority -= 1 # type: ignore[union-attr] 290 291 return focus
292 293 ##--| utils 294
[docs] 295 def get_status(self, *, target:Maybe[Concrete[TaskName_p]|Artifact_i]=None) -> tuple[TaskStatus_e|ArtifactStatus_e, int]: 296 return self._registry.get_status(target)
297
[docs] 298 def set_status(self, task:Concrete[TaskName_p]|Artifact_i|Task_p, internal_state:TaskStatus_e) -> bool: 299 return self._registry.set_status(task, internal_state) # type: ignore[attr-defined]
300
[docs] 301 def _apply_injections(self, name:TaskName_p, *, parent:Maybe[TaskName_p]=None) -> None: 302 """ After a task is created, values can be injected into it. 303 these include, in order: 304 - parent internal_state, 305 - cli params 306 - instantiator internal_state injection 307 """ 308 x : Any 309 meta : API.SpecMeta_d 310 task : Task_p 311 idx : int = 0 312 ##--| 313 match self.specs[name]: 314 case API.SpecMeta_d(task=Task_p() as task) as meta: 315 pass 316 case x: 317 raise TypeError(type(x)) 318 ##--| Get parent data (for cleanup tasks 319 match self._get_parent_data(parent): 320 case None: 321 pass 322 case dict() as pdata: 323 task.internal_state.update(pdata) 324 ##--| apply CLI params 325 match self._get_cli_data(name): 326 case None: 327 pass 328 case dict() as cdata: 329 # Apply CLI passed params, but only as the default 330 # So if override values have been injected, they are preferred 331 for x,y in cdata.items(): 332 task.internal_state.setdefault(x, y) 333 334 ##--| apply late injections 335 match self._get_inject_data(name): 336 case None: 337 pass 338 case dict() as idata: 339 task.internal_state.update(idata) 340 341 ##--| validate 342 if CLI_K in task.internal_state: # type: ignore[attr-defined] 343 del task.internal_state[CLI_K] # type: ignore[attr-defined] 344 match task.spec.extra.on_fail([])[MUST_INJECT_K](): # type: ignore[attr-defined] 345 case []: 346 pass 347 case [*xs] if bool(missing:=[x for x in xs if x not in task.internal_state]): # type: ignore[attr-defined] 348 raise doot.errors.TrackingError("Task did not receive required injections", task.name, xs, task.internal_state.keys()) # type: ignore[attr-defined] 349 350 ##--| prep actions 351 task.prepare_actions()
352
[docs] 353 def _get_parent_data(self, parent:Maybe[TaskName_p]=None) -> Maybe[dict]: 354 match self.specs.get(parent, None): # type: ignore[arg-type] 355 case None: 356 return None 357 case API.SpecMeta_d(task=Task_p() as p_task): 358 return dict(p_task.internal_state)
359
[docs] 360 def _get_cli_data(self, name:TaskName_p) -> Maybe[dict]: 361 idx = 0 362 target = name.pop()[:,:] 363 return doot.args.on_fail({}).subs[target][idx]['args']()
364
[docs] 365 def _get_inject_data(self, name:TaskName_p) -> Maybe[dict]: 366 inj_control : TaskName_p 367 inj : InjectSpec_i 368 meta = self.specs[name] 369 370 match meta.injection_source: 371 case None: 372 inj_control = None 373 case TaskName_p() as inj_control, _ if inj_control not in self.specs: 374 raise ValueError("Late Injection source is not a task", inj_control) 375 case TaskName_p() as inj_control, InjectSpec_i() as inj: 376 pass 377 378 match self.specs.get(inj_control, None): # type: ignore[arg-type] 379 case None: 380 return None 381 case API.SpecMeta_d(task=Task_p() as control): 382 return inj.apply_from_state(control) 383 case x: 384 raise TypeError(type(x))