Source code for doot.control.loaders.task

  1#!/usr/bin/env python3
  2"""
  3
  4"""
  5# ruff: noqa: FBT003
  6# Imports:
  7from __future__ import annotations
  8
  9# ##-- stdlib imports
 10import datetime
 11import enum
 12import functools as ftz
 13import importlib
 14import itertools as itz
 15import logging as logmod
 16import re
 17import time
 18import types
 19import typing
 20from collections import ChainMap, defaultdict
 21from uuid import UUID, uuid1
 22
 23# ##-- end stdlib imports
 24
 25# ##-- 3rd party imports
 26from jgdv import Proto
 27from jgdv.debugging.timing import TimeCtx
 28from jgdv.structs.chainguard import ChainGuard
 29from jgdv.structs.locator.errors import LocationError, StrangError
 30from jgdv.structs.strang import CodeReference
 31from pydantic import ValidationError
 32
 33# ##-- end 3rd party imports
 34
 35# ##-- 1st party imports
 36import doot
 37import doot.errors
 38from doot.workflow import TaskName
 39
 40# ##-- end 1st party imports
 41
 42# ##-| Local
 43from . import _interface as API#  noqa: N812
 44from doot.workflow.factory import TaskFactory
 45from ._interface import TaskLoader_p
 46
 47# # End of Imports.
 48
 49# ##-- types
 50# isort: off
 51import abc
 52import collections.abc
 53from typing import TYPE_CHECKING, cast, assert_type, assert_never
 54from typing import Generic, NewType
 55# Protocols:
 56from typing import Protocol, runtime_checkable
 57# Typing Decorators:
 58from typing import no_type_check, final, override, overload
 59
 60if TYPE_CHECKING:
 61    from doot.workflow._interface import TaskName_p, TaskSpec_i, TaskFactory_p
 62    import pathlib as pl
 63    from jgdv import Maybe
 64    from typing import Final
 65    from typing import ClassVar, Any, LiteralString
 66    from typing import Never, Self, Literal
 67    from typing import TypeGuard
 68    from collections.abc import Iterable, Iterator, Callable, Generator
 69    from collections.abc import Sequence, Mapping, MutableMapping, Hashable
 70
 71# isort: on
 72# ##-- end types
 73
 74##-- logging
 75logging = logmod.getLogger(__name__)
 76##-- end logging
 77
 78##--| vars
 79TOML_SUFFIX            : Final[str]   = ".toml"
 80exit_on_load_failures  : Final[bool]  = doot.config.on_fail(False).shutdown.exit_on_load_failures()
 81allow_overloads        : Final[bool]  = doot.config.on_fail(False, bool).allow_overloads()
 82
 83##--| util
[docs] 84def apply_group_and_source(group, source, x): # noqa: ANN201, ANN001 85 """ insert the group and source into a task definition dict 86 87 a task is: 88 [[tasks.GROUP]]: 89 name = TASKNAME 90 ... 91 92 So the group isn't actually part of the dict. 93 This fn adds it in, plus where the dict came from 94 95 """ 96 match x: 97 case ChainGuard(): 98 x = dict(x.items()) 99 x['group'] = x.get('group', group) 100 if 'sources' not in x: 101 x['sources'] = [] 102 x['sources'].append(str(source)) 103 case dict(): 104 x['group'] = x.get('group', group) 105 if 'sources' not in x: 106 x['sources'] = [] 107 x['sources'].append(str(source)) 108 return x
109 110##--|
[docs] 111@Proto(TaskLoader_p) 112class TaskLoader: 113 """ 114 load toml defined tasks, and create doot.structs.TaskSpecs of them 115 """ 116 tasks : dict[str|TaskName_p, TaskSpec_i] 117 failures : dict[str|pl.Path, list] 118 cmd_names : set[str] 119 task_builders : dict[str,Any] 120 extra : Maybe[ChainGuard|dict] 121 exit_on_load_failures : bool 122 factory : TaskFactory_p 123 124 def __init__(self): 125 self.tasks = {} 126 self.failures = defaultdict(list) 127 self.cmd_names = set() 128 self.task_builders = dict() 129 self.extra = None 130 self.exit_on_load_failures = exit_on_load_failures 131 self.factory = TaskFactory() 132
[docs] 133 def setup(self, plugins:ChainGuard, extra:Maybe[ChainGuard]=None) -> Self: 134 logging.debug("---- Registering Task Builders") 135 match plugins.get("command", []): 136 case [*xs]: 137 self.cmd_names = {x.name for x in xs} 138 case x: 139 raise TypeError(type(x)) 140 self.tasks = {} 141 self.plugins = plugins 142 self.task_builders = {} 143 self.failures = defaultdict(list) 144 for plugin in ChainGuard(plugins).on_fail([]).task(): # type: ignore[attr-defined] 145 if plugin.name in self.task_builders: 146 logging.warning("Conflicting Task Builder Type Name: %s: %s / %s", 147 plugin.name, 148 self.task_builders[plugin.name], 149 plugin) 150 continue 151 152 try: 153 self.task_builders[plugin.name] = plugin.load() 154 logging.info("Registered Task Builder short name: %s", plugin.name) 155 except ModuleNotFoundError as err: 156 logging.warning("Bad Task Builder Plugin Specified: %s", plugin) 157 else: 158 logging.debug("Registered Task Builders: %s", self.task_builders.keys()) 159 160 match extra: # { group : [task_dicts] } 161 case None: 162 self.extra = {} 163 case list(): 164 self.extra = {"_": extra} 165 case dict() | ChainGuard(): 166 self.extra = ChainGuard(extra).on_fail({}).tasks() # type: ignore[attr-defined] 167 logging.debug("Task Loader Setup with %s extra tasks", len(self.extra)) 168 return self
169
[docs] 170 def load(self) -> ChainGuard: 171 assert(hasattr(doot.report, "gen")) 172 173 def loc_wrapper(xs:str) -> list[pl.Path]: 174 return [doot.locs[x] for x in xs] 175 176 177 logging.info("---- Loading Tasks from Config Files") 178 with TimeCtx(logger=logging) as timer: 179 logging.debug("Loading Tasks from Config files") 180 for source in doot.configs_loaded_from: # type: ignore[attr-defined] 181 try: 182 source_data : ChainGuard = ChainGuard.load(source) # type: ignore[attr-defined] 183 task_specs = source_data.on_fail({}).tasks() # type: ignore[attr-defined] 184 except OSError as err: 185 logging.exception("Failed to Load Config File: %s : %s", source, err.args) 186 continue 187 else: 188 raw = self._get_raw_specs_from_data(task_specs, source) 189 self._build_task_specs(raw) 190 191 if self.extra: 192 logging.debug("Loading Tasks from extra values") 193 raw = self._get_raw_specs_from_data(self.extra, "(extra)") 194 self._build_task_specs(raw) 195 196 task_sources = doot.config.on_fail([doot.locs[".tasks"]], list).startup.sources.tasks.sources(wrapper=loc_wrapper) # type: ignore[index, union-attr] 197 logging.debug("Loading tasks from sources: %s", [str(x) for x in task_sources]) 198 for path in task_sources: 199 self._load_specs_from_path(path) 200 201 logging.info("---- Loading Tasks took: %s", timer.total_s) 202 203 204 match self.failures: 205 case dict() if bool(self.failures) and self.exit_on_load_failures: 206 # After everything is loaded, raise a total failure if necessary 207 raise doot.errors.StructLoadError("Loading Tasks Failed", self.failures) 208 case dict() if bool(self.failures): 209 doot.report.gen.user("!!!! Loading Tasks Failed: %s", len(self.failures)) 210 doot.report.gen.user("") 211 for x,msgs in self.failures.items(): 212 doot.report.gen.user("- %s:", x) 213 for y in msgs: 214 doot.report.gen.user("-- %s", y) 215 else: 216 doot.report.gen.user("") 217 else: 218 doot.report.gen.user("Continuing...") 219 220 221 222 logging.debug("Task List Size: %s", len(self.tasks)) 223 logging.debug("Task List Names: %s", list(self.tasks.keys())) 224 return ChainGuard(self.tasks) # type: ignore[arg-type]
225
[docs] 226 def _get_raw_specs_from_data(self, data:dict, source:pl.Path|Literal['(extra)']) -> list[dict]: 227 """ extract raw task descriptions from a toplevel tasks dict, with no format checking. 228 expects the dict to be { group_key : [ task_dict ] } 229 """ 230 raw_specs : list = [] 231 # Load from doot.toml task specs 232 for group, d in data.items(): 233 if not isinstance(d, list): 234 logging.warning("Unexpected task specification format: %s : %s", group, d) 235 else: 236 raw_specs += map(ftz.partial(apply_group_and_source, group, source), d) 237 238 logging.info("Loaded Tasks from: %s", source) 239 return raw_specs
240
[docs] 241 def _load_specs_from_path(self, path:pl.Path) -> None: 242 """ load a config file defined task_sources of tasks """ 243 data : ChainGuard 244 assert(hasattr(doot, "verify_config_version")) 245 targets = [] 246 if path.is_dir(): 247 targets += [x for x in path.iterdir() if x.suffix == TOML_SUFFIX] 248 elif path.is_file(): 249 targets.append(path) 250 else: 251 assert(not path.exists()) 252 253 for task_file in targets: 254 logging.info("Loading Tasks from: %s", task_file) 255 try: 256 data = ChainGuard.load(task_file) # type: ignore[attr-defined] 257 doot.verify_config_version(data.on_fail(None).doot_version(), source=task_file) # type: ignore[attr-defined] 258 except OSError as err: 259 self.failures[task_file].append(str(err)) 260 except doot.errors.VersionMismatchError as err: 261 if "startup" not in data: 262 # startup designates a config file, which is handled in main 263 self.failures[task_file].append("Version mismatch") 264 else: 265 doot.update_global_task_state(data, source=task_file) # type: ignore[attr-defined] 266 267 raw_specs : list = [] 268 for group, val in data.on_fail({}).tasks().items(): # type: ignore[attr-defined] 269 # sets 'group' for each task if it hasn't been set already 270 raw_specs += map(ftz.partial(apply_group_and_source, group, task_file), val) 271 272 self._build_task_specs(raw_specs, source=task_file) 273 self._load_location_updates(data.on_fail([]).locations(), task_file) # type: ignore[attr-defined]
274
[docs] 275 def _build_task_specs(self, specs:list[dict], source:Maybe[str|pl.Path]=None) -> None: # noqa: PLR0912 276 """ 277 convert raw dicts into TaskSpec objects 278 279 """ 280 logging.info("---- Building Task Specs (%s Current, %s Potential) ", len(self.tasks), len(specs)) 281 source = source or "<Sourceless>" 282 283 def _allow_registration(task_name:TaskName_p|str) -> bool: 284 """ precondition to check for overrides/name conflicts """ 285 logging.info("Checking: %s", task_name) 286 if allow_overloads: 287 return True 288 return task_name not in self.tasks 289 290 for spec in specs: 291 logging.info("Processing: %s", spec['name']) 292 task_alias = "task" 293 task_spec = None 294 try: 295 match spec: 296 case {"name": task_name, "ctor": CodeReference() as ctor}: 297 task_spec = self.factory.build(spec) 298 case {"name": task_name, "ctor": str() as task_alias} if task_alias in self.task_builders: 299 spec['ctor'] = CodeReference(self.task_builders[task_alias]) 300 task_spec = self.factory.build(spec) 301 case {"name": task_name}: 302 task_spec = self.factory.build(spec) 303 case _: # Else complain 304 raise doot.errors.StructLoadError("Task Spec missing, at least, needs at least a name and ctor", spec, spec['sources'][0] ) 305 except ValidationError as err: 306 for suberr in err.errors(): 307 locs = ", ".join(suberr['loc']) 308 self.failures[source].append(f"({locs}) : '{suberr['input']}' :- {suberr['msg']}") 309 except StrangError as err: 310 self.failures[source].append(err) 311 except LocationError as err: 312 self.failures[source].append(err) 313 except ModuleNotFoundError as err: 314 self.failures[source].append(err) 315 except AttributeError as err: 316 self.failures[source].append(err) 317 except ValueError as err: 318 self.failures[source].append(err) 319 except TypeError as err: 320 self.failures[source].append(err) 321 except ImportError as err: 322 self.failures[source].append(err) 323 else: 324 assert(task_spec is not None) 325 if _allow_registration(task_spec.name): # complain on overload 326 logging.info("Registering Task: %s", task_spec.name) 327 self.tasks[task_spec.name] = task_spec 328 else: 329 logging.warning("Current Tasks: %s", self.tasks) 330 _err = doot.errors.StructLoadError("Task Name Overloaded", task_name) 331 self.failures[source].append(_err)
332
[docs] 333 def _load_location_updates(self, data:list[ChainGuard], source:str|pl.Path) -> None: 334 logging.debug("Loading Location Updates: %s", source) 335 for group in data: 336 try: 337 doot.locs.Current.update(group, strict=False) 338 except KeyError as err: 339 doot.report.gen.warn("Locations Already Defined: %s : %s", err.args, source) 340 except TypeError as err: 341 doot.report.gen.warn("Location failed to validate: %s : %s", err.args, source) 342 except LocationError as err: 343 doot.report.gen.warn("%s : %s", str(err), source)