1#!/usr/bin/env python3
2"""
3
4"""
5# ruff: noqa: FBT003
6# Imports:
7from __future__ import annotations
8
9# ##-- stdlib imports
10import datetime
11import enum
12import functools as ftz
13import importlib
14import itertools as itz
15import logging as logmod
16import re
17import time
18import types
19import typing
20from collections import ChainMap, defaultdict
21from uuid import UUID, uuid1
22
23# ##-- end stdlib imports
24
25# ##-- 3rd party imports
26from jgdv import Proto
27from jgdv.debugging.timing import TimeCtx
28from jgdv.structs.chainguard import ChainGuard
29from jgdv.structs.locator.errors import LocationError, StrangError
30from jgdv.structs.strang import CodeReference
31from pydantic import ValidationError
32
33# ##-- end 3rd party imports
34
35# ##-- 1st party imports
36import doot
37import doot.errors
38from doot.workflow import TaskName
39
40# ##-- end 1st party imports
41
42# ##-| Local
43from . import _interface as API# noqa: N812
44from doot.workflow.factory import TaskFactory
45from ._interface import TaskLoader_p
46
47# # End of Imports.
48
49# ##-- types
50# isort: off
51import abc
52import collections.abc
53from typing import TYPE_CHECKING, cast, assert_type, assert_never
54from typing import Generic, NewType
55# Protocols:
56from typing import Protocol, runtime_checkable
57# Typing Decorators:
58from typing import no_type_check, final, override, overload
59
60if TYPE_CHECKING:
61 from doot.workflow._interface import TaskName_p, TaskSpec_i, TaskFactory_p
62 import pathlib as pl
63 from jgdv import Maybe
64 from typing import Final
65 from typing import ClassVar, Any, LiteralString
66 from typing import Never, Self, Literal
67 from typing import TypeGuard
68 from collections.abc import Iterable, Iterator, Callable, Generator
69 from collections.abc import Sequence, Mapping, MutableMapping, Hashable
70
71# isort: on
72# ##-- end types
73
74##-- logging
75logging = logmod.getLogger(__name__)
76##-- end logging
77
78##--| vars
79TOML_SUFFIX : Final[str] = ".toml"
80exit_on_load_failures : Final[bool] = doot.config.on_fail(False).shutdown.exit_on_load_failures()
81allow_overloads : Final[bool] = doot.config.on_fail(False, bool).allow_overloads()
82
83##--| util
[docs]
84def apply_group_and_source(group, source, x): # noqa: ANN201, ANN001
85 """ insert the group and source into a task definition dict
86
87 a task is:
88 [[tasks.GROUP]]:
89 name = TASKNAME
90 ...
91
92 So the group isn't actually part of the dict.
93 This fn adds it in, plus where the dict came from
94
95 """
96 match x:
97 case ChainGuard():
98 x = dict(x.items())
99 x['group'] = x.get('group', group)
100 if 'sources' not in x:
101 x['sources'] = []
102 x['sources'].append(str(source))
103 case dict():
104 x['group'] = x.get('group', group)
105 if 'sources' not in x:
106 x['sources'] = []
107 x['sources'].append(str(source))
108 return x
109
110##--|
[docs]
111@Proto(TaskLoader_p)
112class TaskLoader:
113 """
114 load toml defined tasks, and create doot.structs.TaskSpecs of them
115 """
116 tasks : dict[str|TaskName_p, TaskSpec_i]
117 failures : dict[str|pl.Path, list]
118 cmd_names : set[str]
119 task_builders : dict[str,Any]
120 extra : Maybe[ChainGuard|dict]
121 exit_on_load_failures : bool
122 factory : TaskFactory_p
123
124 def __init__(self):
125 self.tasks = {}
126 self.failures = defaultdict(list)
127 self.cmd_names = set()
128 self.task_builders = dict()
129 self.extra = None
130 self.exit_on_load_failures = exit_on_load_failures
131 self.factory = TaskFactory()
132
[docs]
133 def setup(self, plugins:ChainGuard, extra:Maybe[ChainGuard]=None) -> Self:
134 logging.debug("---- Registering Task Builders")
135 match plugins.get("command", []):
136 case [*xs]:
137 self.cmd_names = {x.name for x in xs}
138 case x:
139 raise TypeError(type(x))
140 self.tasks = {}
141 self.plugins = plugins
142 self.task_builders = {}
143 self.failures = defaultdict(list)
144 for plugin in ChainGuard(plugins).on_fail([]).task(): # type: ignore[attr-defined]
145 if plugin.name in self.task_builders:
146 logging.warning("Conflicting Task Builder Type Name: %s: %s / %s",
147 plugin.name,
148 self.task_builders[plugin.name],
149 plugin)
150 continue
151
152 try:
153 self.task_builders[plugin.name] = plugin.load()
154 logging.info("Registered Task Builder short name: %s", plugin.name)
155 except ModuleNotFoundError as err:
156 logging.warning("Bad Task Builder Plugin Specified: %s", plugin)
157 else:
158 logging.debug("Registered Task Builders: %s", self.task_builders.keys())
159
160 match extra: # { group : [task_dicts] }
161 case None:
162 self.extra = {}
163 case list():
164 self.extra = {"_": extra}
165 case dict() | ChainGuard():
166 self.extra = ChainGuard(extra).on_fail({}).tasks() # type: ignore[attr-defined]
167 logging.debug("Task Loader Setup with %s extra tasks", len(self.extra))
168 return self
169
[docs]
170 def load(self) -> ChainGuard:
171 assert(hasattr(doot.report, "gen"))
172
173 def loc_wrapper(xs:str) -> list[pl.Path]:
174 return [doot.locs[x] for x in xs]
175
176
177 logging.info("---- Loading Tasks from Config Files")
178 with TimeCtx(logger=logging) as timer:
179 logging.debug("Loading Tasks from Config files")
180 for source in doot.configs_loaded_from: # type: ignore[attr-defined]
181 try:
182 source_data : ChainGuard = ChainGuard.load(source) # type: ignore[attr-defined]
183 task_specs = source_data.on_fail({}).tasks() # type: ignore[attr-defined]
184 except OSError as err:
185 logging.exception("Failed to Load Config File: %s : %s", source, err.args)
186 continue
187 else:
188 raw = self._get_raw_specs_from_data(task_specs, source)
189 self._build_task_specs(raw)
190
191 if self.extra:
192 logging.debug("Loading Tasks from extra values")
193 raw = self._get_raw_specs_from_data(self.extra, "(extra)")
194 self._build_task_specs(raw)
195
196 task_sources = doot.config.on_fail([doot.locs[".tasks"]], list).startup.sources.tasks.sources(wrapper=loc_wrapper) # type: ignore[index, union-attr]
197 logging.debug("Loading tasks from sources: %s", [str(x) for x in task_sources])
198 for path in task_sources:
199 self._load_specs_from_path(path)
200
201 logging.info("---- Loading Tasks took: %s", timer.total_s)
202
203
204 match self.failures:
205 case dict() if bool(self.failures) and self.exit_on_load_failures:
206 # After everything is loaded, raise a total failure if necessary
207 raise doot.errors.StructLoadError("Loading Tasks Failed", self.failures)
208 case dict() if bool(self.failures):
209 doot.report.gen.user("!!!! Loading Tasks Failed: %s", len(self.failures))
210 doot.report.gen.user("")
211 for x,msgs in self.failures.items():
212 doot.report.gen.user("- %s:", x)
213 for y in msgs:
214 doot.report.gen.user("-- %s", y)
215 else:
216 doot.report.gen.user("")
217 else:
218 doot.report.gen.user("Continuing...")
219
220
221
222 logging.debug("Task List Size: %s", len(self.tasks))
223 logging.debug("Task List Names: %s", list(self.tasks.keys()))
224 return ChainGuard(self.tasks) # type: ignore[arg-type]
225
[docs]
226 def _get_raw_specs_from_data(self, data:dict, source:pl.Path|Literal['(extra)']) -> list[dict]:
227 """ extract raw task descriptions from a toplevel tasks dict, with no format checking.
228 expects the dict to be { group_key : [ task_dict ] }
229 """
230 raw_specs : list = []
231 # Load from doot.toml task specs
232 for group, d in data.items():
233 if not isinstance(d, list):
234 logging.warning("Unexpected task specification format: %s : %s", group, d)
235 else:
236 raw_specs += map(ftz.partial(apply_group_and_source, group, source), d)
237
238 logging.info("Loaded Tasks from: %s", source)
239 return raw_specs
240
[docs]
241 def _load_specs_from_path(self, path:pl.Path) -> None:
242 """ load a config file defined task_sources of tasks """
243 data : ChainGuard
244 assert(hasattr(doot, "verify_config_version"))
245 targets = []
246 if path.is_dir():
247 targets += [x for x in path.iterdir() if x.suffix == TOML_SUFFIX]
248 elif path.is_file():
249 targets.append(path)
250 else:
251 assert(not path.exists())
252
253 for task_file in targets:
254 logging.info("Loading Tasks from: %s", task_file)
255 try:
256 data = ChainGuard.load(task_file) # type: ignore[attr-defined]
257 doot.verify_config_version(data.on_fail(None).doot_version(), source=task_file) # type: ignore[attr-defined]
258 except OSError as err:
259 self.failures[task_file].append(str(err))
260 except doot.errors.VersionMismatchError as err:
261 if "startup" not in data:
262 # startup designates a config file, which is handled in main
263 self.failures[task_file].append("Version mismatch")
264 else:
265 doot.update_global_task_state(data, source=task_file) # type: ignore[attr-defined]
266
267 raw_specs : list = []
268 for group, val in data.on_fail({}).tasks().items(): # type: ignore[attr-defined]
269 # sets 'group' for each task if it hasn't been set already
270 raw_specs += map(ftz.partial(apply_group_and_source, group, task_file), val)
271
272 self._build_task_specs(raw_specs, source=task_file)
273 self._load_location_updates(data.on_fail([]).locations(), task_file) # type: ignore[attr-defined]
274
[docs]
275 def _build_task_specs(self, specs:list[dict], source:Maybe[str|pl.Path]=None) -> None: # noqa: PLR0912
276 """
277 convert raw dicts into TaskSpec objects
278
279 """
280 logging.info("---- Building Task Specs (%s Current, %s Potential) ", len(self.tasks), len(specs))
281 source = source or "<Sourceless>"
282
283 def _allow_registration(task_name:TaskName_p|str) -> bool:
284 """ precondition to check for overrides/name conflicts """
285 logging.info("Checking: %s", task_name)
286 if allow_overloads:
287 return True
288 return task_name not in self.tasks
289
290 for spec in specs:
291 logging.info("Processing: %s", spec['name'])
292 task_alias = "task"
293 task_spec = None
294 try:
295 match spec:
296 case {"name": task_name, "ctor": CodeReference() as ctor}:
297 task_spec = self.factory.build(spec)
298 case {"name": task_name, "ctor": str() as task_alias} if task_alias in self.task_builders:
299 spec['ctor'] = CodeReference(self.task_builders[task_alias])
300 task_spec = self.factory.build(spec)
301 case {"name": task_name}:
302 task_spec = self.factory.build(spec)
303 case _: # Else complain
304 raise doot.errors.StructLoadError("Task Spec missing, at least, needs at least a name and ctor", spec, spec['sources'][0] )
305 except ValidationError as err:
306 for suberr in err.errors():
307 locs = ", ".join(suberr['loc'])
308 self.failures[source].append(f"({locs}) : '{suberr['input']}' :- {suberr['msg']}")
309 except StrangError as err:
310 self.failures[source].append(err)
311 except LocationError as err:
312 self.failures[source].append(err)
313 except ModuleNotFoundError as err:
314 self.failures[source].append(err)
315 except AttributeError as err:
316 self.failures[source].append(err)
317 except ValueError as err:
318 self.failures[source].append(err)
319 except TypeError as err:
320 self.failures[source].append(err)
321 except ImportError as err:
322 self.failures[source].append(err)
323 else:
324 assert(task_spec is not None)
325 if _allow_registration(task_spec.name): # complain on overload
326 logging.info("Registering Task: %s", task_spec.name)
327 self.tasks[task_spec.name] = task_spec
328 else:
329 logging.warning("Current Tasks: %s", self.tasks)
330 _err = doot.errors.StructLoadError("Task Name Overloaded", task_name)
331 self.failures[source].append(_err)
332
[docs]
333 def _load_location_updates(self, data:list[ChainGuard], source:str|pl.Path) -> None:
334 logging.debug("Loading Location Updates: %s", source)
335 for group in data:
336 try:
337 doot.locs.Current.update(group, strict=False)
338 except KeyError as err:
339 doot.report.gen.warn("Locations Already Defined: %s : %s", err.args, source)
340 except TypeError as err:
341 doot.report.gen.warn("Location failed to validate: %s : %s", err.args, source)
342 except LocationError as err:
343 doot.report.gen.warn("%s : %s", str(err), source)