1#!/usr/bin/env python3
2"""
3Abstract Specs: A[n]
4Concrete Specs: C[n]
5Task: T[n]
6
7 Expansion: ∀x ∈ C[n].depends_on => A[x] -> C[x]
8 Head: C[1].depends_on[A[n].$head$] => A[n] -> C[n], A[n].head -> C[n].head, connect
9
10"""
11# Imports:
12from __future__ import annotations
13
14# ##-- stdlib imports
15import datetime
16import enum
17import functools as ftz
18import itertools as itz
19import logging as logmod
20import pathlib as pl
21import re
22import time
23import types
24import weakref
25from collections import defaultdict
26from itertools import chain, cycle
27from uuid import UUID, uuid1
28
29# ##-- end stdlib imports
30
31# ##-- 3rd party imports
32from jgdv import Proto
33
34# ##-- end 3rd party imports
35
36# ##-- 1st party imports
37import doot
38import doot.errors
39from doot.workflow.factory import SubTaskFactory, TaskFactory
40from doot.workflow import (ActionSpec, DootTask, InjectSpec, RelationSpec,
41 TaskArtifact, TaskName, TaskSpec)
42from doot.workflow._interface import (CLI_K, MUST_INJECT_K, Artifact_i,
43 ArtifactStatus_e, InjectSpec_i,
44 RelationSpec_i, Task_i, Task_p,
45 TaskName_p, TaskSpec_i, TaskStatus_e,
46 DelayedSpec)
47
48# ##-- end 1st party imports
49
50# ##-| Local
51from . import _interface as API # noqa: N812
52from ._base import Tracker_abs
53from ._interface import WorkflowTracker_p
54from .network import TrackNetwork
55from .queue import TrackQueue
56from .registry import TrackRegistry
57
58# # End of Imports.
59
60# ##-- types
61# isort: off
62import abc
63import collections.abc
64from typing import TYPE_CHECKING, cast, assert_type, assert_never
65from typing import Generic, NewType
66# Protocols:
67from typing import Protocol, runtime_checkable
68# Typing Decorators:
69from typing import no_type_check, final, override, overload
70
71if TYPE_CHECKING:
72 from jgdv import Maybe
73 from typing import Final
74 from typing import ClassVar, Any, LiteralString
75 from typing import Never, Self, Literal
76 from typing import TypeGuard
77 from collections.abc import Iterable, Iterator, Callable, Generator
78 from collections.abc import Sequence, Mapping, MutableMapping, Hashable
79 from networkx import DiGraph
80
81 from doot.workflow._interface import TaskFactory_p, SubTaskFactory_p
82 type Abstract[T] = T
83 type Concrete[T] = T
84
85##--|
86# isort: on
87# ##-- end types
88
89##-- logging
90logging = logmod.getLogger(__name__)
91##-- end logging
92
93assert(isinstance(TrackRegistry, API.Registry_p))
94##--|
95
[docs]
96@Proto(WorkflowTracker_p)
97class NaiveTracker(Tracker_abs):
98 """ Specific implementations for the default naive tracker """
99 _registry : TrackRegistry
100
[docs]
101 def next_for(self, target:Maybe[str|TaskName_p]=None) -> Maybe[Task_p|Artifact_i]:
102 """ ask for the next task that can be performed
103
104 Returns a Task or Artifact that needs to be executed or created
105 Returns None if it loops too many times trying to find a target,
106 or if theres nothing left in the queue
107
108 """
109 count : int
110 focus : TaskName_p|Artifact_i
111 idx : int
112 result : Maybe[Task_p|Artifact_i]
113 status : TaskStatus_e|ArtifactStatus_e
114 x : Any
115
116 logging.info("[Next.For] (Active: %s)", len(self.active))
117 if not self.is_valid:
118 raise doot.errors.TrackingError("Network is in an invalid internal_state")
119
120 if target and target not in self.active:
121 self.queue(target, silent=True)
122
123 idx, count = 0, API.MAX_LOOP
124 result = None
125 while (result is None) and bool(self._queue) and 0 < (count:=count-1) and (idx:=idx+1):
126 focus = self._deque()
127 status, _ = self.get_status(target=focus)
128 logging.debug("[Next.For.%-3s]: %s : %s", idx, status, focus)
129
130 match focus:
131 case x if x not in self.active:
132 continue
133 case TaskName_p():
134 result = self._next_for_task(focus)
135 case Artifact_i():
136 result = self._next_for_artifact(focus)
137 case x: # Error otherwise
138 raise doot.errors.TrackingError("Unknown task focus", x)
139
140 else:
141 logging.info("[Next.For] <- %s", result)
142 return result
143
[docs]
144 def _next_for_task(self, focus:TaskName_p) -> Maybe[Task_p]: # noqa: PLR0912, PLR0915
145 """ logic for handling a dequed task """
146 x : Any
147 status, _ = self.get_status(target=focus) # type: ignore[attr-defined]
148 match status:
149 case TaskStatus_e.DEAD:
150 # Clear internal_state
151 self.specs[focus].task = TaskStatus_e.DEAD
152 self.active.remove(focus)
153 assert(focus not in self.active)
154 case TaskStatus_e.DISABLED:
155 self.active.remove(focus)
156 case TaskStatus_e.TEARDOWN:
157 # Queue cleanup tasks
158 for succ, _ in self._successor_states_of(focus):
159 match self.queue(succ):
160 case TaskName() as x if x.is_cleanup():
161 # make the cleanup task early, to apply shared internal_state
162 assert(isinstance(focus, TaskName))
163 self._instantiate(x, parent=focus, task=True)
164 case _:
165 pass
166 else:
167 # TODO for cleanup succ, move focus.internal_state -> succ.internal_state
168 self.set_status(focus, TaskStatus_e.DEAD) # type: ignore[attr-defined]
169 case TaskStatus_e.SUCCESS:
170 self.queue(focus, status=TaskStatus_e.TEARDOWN)
171 case TaskStatus_e.FAILED: # propagate failure
172 self.queue(focus, status=TaskStatus_e.TEARDOWN)
173 case TaskStatus_e.HALTED: # remove and propagate halted status
174 self.queue(focus, status=TaskStatus_e.TEARDOWN)
175 case TaskStatus_e.SKIPPED:
176 self.queue(focus, status=TaskStatus_e.DEAD)
177 case TaskStatus_e.RUNNING:
178 self.queue(focus)
179 case TaskStatus_e.READY: # return the task if its ready
180 self.queue(focus, status=TaskStatus_e.RUNNING)
181 return cast("Task_p", self.specs[focus].task)
182 case TaskStatus_e.WAIT: # Add dependencies of a task to the stack
183 waiting : bool = False
184 deps_of_focus = self._dependency_states_of(focus)
185 for dep, dep_status in deps_of_focus:
186 if dep_status in API.SUCCESS_STATUSES:
187 continue
188 self.queue(dep)
189 waiting = True
190 else:
191 match waiting:
192 case False:
193 self.queue(focus, status=TaskStatus_e.READY)
194 case True:
195 logging.debug("[Next.For] Task Blocked: %s on : %s", focus, deps_of_focus)
196 self.queue(focus)
197 case TaskStatus_e.INIT:
198 self.queue(focus, status=TaskStatus_e.WAIT)
199 case TaskStatus_e.DEFINED:
200 self._instantiate(focus, task=True)
201 self.queue(focus)
202 case TaskStatus_e.DECLARED:
203 self.queue(focus, status=TaskStatus_e.DEFINED)
204 case TaskStatus_e.NAMED:
205 logging.warning("A Name only was queued, it has no backing in the tracker: %s", focus)
206 case x: # Error otherwise
207 raise doot.errors.TrackingError("Unknown task internal_state", x)
208 ##--|
209 return None
210
[docs]
211 def _next_for_artifact(self, focus:Artifact_i) -> Maybe[Artifact_i]: # noqa: PLR0912
212 """ logic for handling a dequed artifact """
213 status, _ = self.get_status(target=focus) # type: ignore[attr-defined]
214 match status:
215 case ArtifactStatus_e.EXISTS:
216 # TODO artifact Exists, queue its dependents and *don't* add the artifact back in
217 pass
218 case ArtifactStatus_e.STALE:
219 for pred, _ in self._dependency_states_of(focus):
220 self.queue(pred)
221 case ArtifactStatus_e.DECLARED if bool(focus):
222 self.queue(focus)
223 case ArtifactStatus_e.DECLARED: # Add dependencies of an artifact to the stack
224 deps : list[tuple] = self._dependency_states_of(focus)
225 match deps:
226 case [] if not focus.is_concrete():
227 self.queue(focus)
228 case []:
229 assert(not bool(focus))
230 path = doot.locs[focus]
231 # Returns the artifact, the runner can try to create
232 # it, then override the halt
233 return focus
234 case [*xs]:
235 logging.info("[Next.For] Artifact Blocked, queuing producer tasks, count: %s", len(xs))
236 case x:
237 raise TypeError(type(x))
238
239 for dep, dep_state in deps:
240 if dep_state in API.SUCCESS_STATUSES:
241 continue
242 self.queue(dep)
243 else:
244 # No need to requeue, as tasks will check for the artifacts themselves
245 pass
246
247 case x: # Error otherwise
248 raise doot.errors.TrackingError("Unknown task internal_state", x)
249
250 ##--|
251 return None
252
[docs]
253 @override
254 def _instantiate(self, target:TaskName_p|RelationSpec_i, *args:Any, task:bool=False, **kwargs:Any) -> Maybe[TaskName_p]:
255 """ extends base instantiation to add late injection for tasks """
256 parent : TaskName_p
257 result : Maybe[TaskName_p]
258 ##--|
259 parent = kwargs.pop("parent", None)
260 match super()._instantiate(target, *args, task=task, **kwargs):
261 case TaskName_p() as result if task:
262 self._apply_injections(result, parent=parent)
263 case result:
264 pass
265
266 return result
267
268 ##--| internal
269
[docs]
270 def _dependency_states_of(self, focus:TaskName_p|Artifact_i) -> list[tuple]:
271 return [(x, self.get_status(target=x)[0]) for x in self._network.pred[focus]]
272
[docs]
273 def _successor_states_of(self, focus:TaskName_p|Artifact_i) -> list[tuple]:
274 return [(x, self.get_status(target=x)[0]) for x in self._network.succ[focus]]
275
[docs]
276 def _deque(self) -> TaskName_p|Artifact_i:
277 focus = self._queue.deque_entry()
278 match self.specs.get(focus, focus): # type: ignore[arg-type]
279 case None | API.SpecMeta_d(task=None) | TaskName_p():
280 pass
281 case API.SpecMeta_d(task=Task_p() as task) if task.priority < self._min_priority:
282 logging.warning("[Deque] Halting (Min Priority) : %s", focus[:])
283 self.set_status(focus, TaskStatus_e.HALTED)
284 case API.SpecMeta_d(task=Task_p() as task):
285 prior = task.priority
286 task.priority = 1
287 logging.debug("[Deque] %s -> %s : %s", prior, task.priority, focus[:])
288 case TaskArtifact() as focus: # type: ignore[misc]
289 focus.priority -= 1 # type: ignore[union-attr]
290
291 return focus
292
293 ##--| utils
294
[docs]
295 def get_status(self, *, target:Maybe[Concrete[TaskName_p]|Artifact_i]=None) -> tuple[TaskStatus_e|ArtifactStatus_e, int]:
296 return self._registry.get_status(target)
297
[docs]
298 def set_status(self, task:Concrete[TaskName_p]|Artifact_i|Task_p, internal_state:TaskStatus_e) -> bool:
299 return self._registry.set_status(task, internal_state) # type: ignore[attr-defined]
300
[docs]
301 def _apply_injections(self, name:TaskName_p, *, parent:Maybe[TaskName_p]=None) -> None:
302 """ After a task is created, values can be injected into it.
303 these include, in order:
304 - parent internal_state,
305 - cli params
306 - instantiator internal_state injection
307 """
308 x : Any
309 meta : API.SpecMeta_d
310 task : Task_p
311 idx : int = 0
312 ##--|
313 match self.specs[name]:
314 case API.SpecMeta_d(task=Task_p() as task) as meta:
315 pass
316 case x:
317 raise TypeError(type(x))
318 ##--| Get parent data (for cleanup tasks
319 match self._get_parent_data(parent):
320 case None:
321 pass
322 case dict() as pdata:
323 task.internal_state.update(pdata)
324 ##--| apply CLI params
325 match self._get_cli_data(name):
326 case None:
327 pass
328 case dict() as cdata:
329 # Apply CLI passed params, but only as the default
330 # So if override values have been injected, they are preferred
331 for x,y in cdata.items():
332 task.internal_state.setdefault(x, y)
333
334 ##--| apply late injections
335 match self._get_inject_data(name):
336 case None:
337 pass
338 case dict() as idata:
339 task.internal_state.update(idata)
340
341 ##--| validate
342 if CLI_K in task.internal_state: # type: ignore[attr-defined]
343 del task.internal_state[CLI_K] # type: ignore[attr-defined]
344 match task.spec.extra.on_fail([])[MUST_INJECT_K](): # type: ignore[attr-defined]
345 case []:
346 pass
347 case [*xs] if bool(missing:=[x for x in xs if x not in task.internal_state]): # type: ignore[attr-defined]
348 raise doot.errors.TrackingError("Task did not receive required injections", task.name, xs, task.internal_state.keys()) # type: ignore[attr-defined]
349
350 ##--| prep actions
351 task.prepare_actions()
352
[docs]
353 def _get_parent_data(self, parent:Maybe[TaskName_p]=None) -> Maybe[dict]:
354 match self.specs.get(parent, None): # type: ignore[arg-type]
355 case None:
356 return None
357 case API.SpecMeta_d(task=Task_p() as p_task):
358 return dict(p_task.internal_state)
359
[docs]
360 def _get_cli_data(self, name:TaskName_p) -> Maybe[dict]:
361 idx = 0
362 target = name.pop()[:,:]
363 return doot.args.on_fail({}).subs[target][idx]['args']()
364
[docs]
365 def _get_inject_data(self, name:TaskName_p) -> Maybe[dict]:
366 inj_control : TaskName_p
367 inj : InjectSpec_i
368 meta = self.specs[name]
369
370 match meta.injection_source:
371 case None:
372 inj_control = None
373 case TaskName_p() as inj_control, _ if inj_control not in self.specs:
374 raise ValueError("Late Injection source is not a task", inj_control)
375 case TaskName_p() as inj_control, InjectSpec_i() as inj:
376 pass
377
378 match self.specs.get(inj_control, None): # type: ignore[arg-type]
379 case None:
380 return None
381 case API.SpecMeta_d(task=Task_p() as control):
382 return inj.apply_from_state(control)
383 case x:
384 raise TypeError(type(x))