1#!/usr/bin/env python3
2"""
3The core implementation of the ChainGuard object,
4which is then extended with mixins.
5"""
6
7# Imports:
8from __future__ import annotations
9
10# ##-- stdlib imports
11import datetime
12import enum
13import functools as ftz
14import itertools as itz
15import logging as logmod
16import pathlib as pl
17import re
18import time
19import types
20from uuid import UUID, uuid1
21
22# ##-- end stdlib imports
23
24from .errors import GuardedAccessError
25from .mixins.access_m import super_get, super_set
26from . import _interface as API # noqa: N812
27from jgdv import Proto
28
29# ##-- types
30# isort: off
31import abc
32import collections.abc
33from typing import TYPE_CHECKING, cast, assert_type, assert_never
34from typing import Generic, NewType
35# Protocols:
36from typing import Protocol, runtime_checkable
37# Typing Decorators:
38from typing import no_type_check, final, override, overload
39
40if TYPE_CHECKING:
41 from collections.abc import ItemsView, KeysView, ValuesView
42 from ._interface import TomlTypes
43 from jgdv import Maybe
44 from typing import Final
45 from typing import ClassVar, Any, LiteralString
46 from typing import Never, Self, Literal
47 from typing import TypeGuard
48 from collections.abc import Iterable, Iterator, Callable, Generator
49 from collections.abc import Sequence, Mapping, MutableMapping, Hashable
50
51##--|
52
53# isort: on
54# ##-- end types
55
56##-- logging
57logging = logmod.getLogger(__name__)
58##-- end logging
59
60type InputData = dict
61TABLE_K : Final[str] = "__table"
62INDEX_K : Final[str] = "__index"
63MUTABLE_K : Final[str] = "__mutable"
64ROOT_STR : Final[str] = "<root>"
65USCORE : Final[str] = "_"
66DASH : Final[str] = "-"
67MUTABLE : Final[str] = "__mutable"
68##--|
69
[docs]
70class GuardBase(dict):
71 """
72 Provides access to toml data (ChainGuard.load(apath))
73 but as attributes (data.a.path.in.the.data)
74 instead of key access (data['a']['path']['in']['the']['data'])
75
76 while also providing typed, guarded access:
77 data.on_fail("test", str | int).a.path.that.may.exist()
78
79 while it can then report missing paths:
80 data.report_defaulted() -> ['a.path.that.may.exist.<str|int>']
81 """
82
83 def __init__(self, data:Maybe[InputData]=None, *, index:Maybe[Iterable[int|str]]=None, mutable:bool=False) -> None:
84 super().__init__()
85 super_set(self, TABLE_K, data or {})
86 super_set(self, INDEX_K, tuple(index or [ROOT_STR]))
87 super_set(self, MUTABLE_K, mutable)
88
89 @override
90 def __repr__(self) -> str:
91 match self._table():
92 case dict() as d:
93 return f"<{self.__class__.__name__}:{list(d.keys())}>"
94 case d:
95 return f"<{self.__class__.__name__}:{d}>"
96
97 @override
98 def __eq__(self, other:object) -> bool:
99 match other:
100 case GuardBase() as base:
101 return self._table() == base._table()
102 case dict() as adict:
103 return self._table() == adict
104 case _:
105 return False
106
107 @override
108 def __hash__(self) -> int: # type: ignore[override]
109 return hash(self._table())
110
111 @override
112 def __len__(self) -> int:
113 return len(self._table())
114
115 def __call__(self) -> None:
116 msg = "Don't call a ChainGuard, call a GuardProxy using methods like .on_fail"
117 raise GuardedAccessError(msg)
118
119 @override
120 def __iter__(self) -> Iterator:
121 return iter(getattr(self, TABLE_K).keys())
122
123 @override
124 def __contains__(self, _key: object) -> bool:
125 match _key:
126 case str():
127 return _key in self.keys() or _key.replace("_","-") in self.keys()
128 case x:
129 return x in self.keys()
130
131
132 @override
133 def __setattr__(self, attr:str, value:Any) -> None:
134 if not getattr(self, MUTABLE):
135 raise TypeError()
136 super_set(self, attr, value)
137
138 def __getattr__(self, attr:str) -> Any: # noqa: ANN401
139 return self.__getitem__(attr)
140
141 @override
142 def __getitem__(self, keys:int|str|list[str]|tuple[int|str, ...]) -> Any:
143 table : dict
144 curr : dict
145 ##--|
146 match keys:
147 case tuple():
148 pass
149 case int()|str():
150 keys = (keys, )
151 case x:
152 raise TypeError(type(x))
153
154
155 table = self._table()
156 curr = table
157 for k in keys:
158 match k:
159 case str() if k in curr:
160 pass
161 case str() if (k:=k.replace(USCORE, DASH)) in curr:
162 pass
163 case str():
164 index_s = ".".join(map(str, self._index(k)))
165 available = " ".join(table.keys())
166 msg = f"{index_s} not found, available: [{available}]"
167 raise GuardedAccessError(msg)
168 case int() if k < len(curr):
169 pass
170 case int():
171 raise GuardedAcccessError("tried to access a list of wrong length")
172
173 match curr.get(k, None):
174 case dict() as result:
175 curr = result
176 case result:
177 curr = result
178 else:
179 match curr:
180 case dict():
181 return type(self)(curr, index=self._index(keys))
182 case [*xs] if all(isinstance(x, dict) for x in xs):
183 index = self._index(keys)
184 return [type(self)(x, index=index) for x in xs]
185 case [*xs]:
186 return xs
187 case x:
188 return x
189
[docs]
190 @override
191 def get(self, key:str, default:Maybe=None) -> Maybe:
192 if key in self:
193 return self.__getitem__(key)
194
195 return default
196 ##--|
[docs]
197 def _index(self, sub:Maybe[int|str|tuple[int|str, ...]]=None) -> tuple[int|str, ...]:
198 match sub:
199 case None:
200 return super_get(self, INDEX_K)[:]
201 case int()|str() as x:
202 return (*super_get(self, INDEX_K), x)
203 case [*xs]:
204 return (*super_get(self, INDEX_K), *xs)
205 case x:
206 raise TypeError(type(x))
207
[docs]
208 def _table(self) -> dict:
209 return super_get(self, TABLE_K)
210
[docs]
211 @override
212 def keys(self) -> KeysView[str]: # type: ignore[override]
213 table = super_get(self, TABLE_K)
214 return table.keys()
215
[docs]
216 @override
217 def items(self) -> ItemsView: # type: ignore[override]
218 match super_get(self, TABLE_K):
219 case dict() as val:
220 return val.items()
221 case list() as val:
222 return {self._index()[-1]: val}.items()
223 case GuardBase() as val:
224 return val.items()
225 case x:
226 msg = "Unknown table type"
227 raise TypeError(msg, x)
228
[docs]
229 @override
230 def values(self) -> list|ValuesView: # type: ignore[override]
231 match super_get(self, TABLE_K):
232 case dict() as val:
233 return val.values()
234 case list() as val:
235 return val
236 case _:
237 raise TypeError()
238
[docs]
239 @override
240 def update(self, *args) -> Never: # type: ignore[override] # noqa: ANN002
241 msg = "ChainGuards are immutable"
242 raise NotImplementedError(msg)