|
2 | 2 |
|
3 | 3 |
|
4 | 4 | # standard library |
5 | | -from typing import Any, Callable, overload |
6 | | - |
| 5 | +from dataclasses import replace |
| 6 | +from typing import Any, Callable, ForwardRef, Hashable, Literal, Optional, overload |
7 | 7 |
|
8 | 8 | # dependencies |
9 | | -from xarray import DataArray, Dataset |
10 | | -from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray |
| 9 | +from dataspecs import ID, ROOT, Spec, Specs |
| 10 | +from numpy import asarray, array |
| 11 | +from typing_extensions import get_args, get_origin |
| 12 | +from xarray import DataArray, Dataset, Variable |
| 13 | +from .typing import ( |
| 14 | + DataClass, |
| 15 | + DataClassOf, |
| 16 | + PAny, |
| 17 | + TAny, |
| 18 | + TDataArray, |
| 19 | + TDataset, |
| 20 | + TXarray, |
| 21 | + Tag, |
| 22 | +) |
| 23 | + |
| 24 | + |
| 25 | +# type hints |
| 26 | +Attrs = dict[Hashable, Any] |
| 27 | +Vars = dict[Hashable, Variable] |
11 | 28 |
|
12 | 29 |
|
13 | 30 | @overload |
@@ -95,3 +112,77 @@ def asxarray( |
95 | 112 | def asxarray(obj: Any, /, *, factory: Any = None) -> Any: |
96 | 113 | """Create a DataArray/set object from a dataclass object.""" |
97 | 114 | ... |
| 115 | + |
| 116 | + |
| 117 | +def get_attrs(specs: Specs[Spec[Any]], /, *, at: ID = ROOT) -> Attrs: |
| 118 | + """Create attributes from data specs.""" |
| 119 | + attrs: Attrs = {} |
| 120 | + |
| 121 | + for spec in specs[at.children][Tag.ATTR]: |
| 122 | + options = specs[spec.id.children] |
| 123 | + factory = maybe(options[Tag.FACTORY].unique).data or identity |
| 124 | + name = maybe(options[Tag.NAME].unique).data or spec.id.name |
| 125 | + |
| 126 | + if Tag.MULTIPLE not in spec.tags: |
| 127 | + spec = replace(spec, data={name: spec.data}) |
| 128 | + |
| 129 | + for name, data in spec[dict[Hashable, Any]].data.items(): |
| 130 | + attrs[name] = factory(data) |
| 131 | + |
| 132 | + return attrs |
| 133 | + |
| 134 | + |
| 135 | +def get_vars(specs: Specs[Spec[Any]], of: Tag, /, *, at: ID = ROOT) -> Vars: |
| 136 | + """Create variables of given tag from data specs.""" |
| 137 | + vars: Vars = {} |
| 138 | + |
| 139 | + for spec in specs[at.children][of]: |
| 140 | + options = specs[spec.id.children] |
| 141 | + attrs = get_attrs(specs, at=spec.id) |
| 142 | + factory = maybe(options[Tag.FACTORY].unique).data or Variable |
| 143 | + name = maybe(options[Tag.NAME].unique).data or spec.id.name |
| 144 | + |
| 145 | + if (type_ := maybe(options[Tag.DIMS].unique).type) is None: |
| 146 | + raise RuntimeError("Could not find any data spec for dims.") |
| 147 | + elif get_origin(type_) is tuple: |
| 148 | + dims = tuple(str(unwrap(arg)) for arg in get_args(type_)) |
| 149 | + else: |
| 150 | + dims = (str(unwrap(type_)),) |
| 151 | + |
| 152 | + if (type_ := maybe(options[Tag.DTYPE].unique).type) is None: |
| 153 | + raise RuntimeError("Could not find any data spec for dims.") |
| 154 | + elif type_ is type(None) or type_ is Any: |
| 155 | + dtype = None |
| 156 | + else: |
| 157 | + dtype = unwrap(type_) |
| 158 | + |
| 159 | + if Tag.MULTIPLE not in spec.tags: |
| 160 | + spec = replace(spec, data={name: spec.data}) |
| 161 | + |
| 162 | + for name, data in spec[dict[Hashable, Any]].data.items(): |
| 163 | + if not (data := asarray(data, dtype)).ndim: |
| 164 | + data = array(data, ndmin=len(dims)) |
| 165 | + |
| 166 | + vars[name] = factory(attrs=attrs, data=data, dims=dims) |
| 167 | + |
| 168 | + return vars |
| 169 | + |
| 170 | + |
| 171 | +def identity(obj: TAny, /) -> TAny: |
| 172 | + """Identity function used for the default factory.""" |
| 173 | + return obj |
| 174 | + |
| 175 | + |
| 176 | +def maybe(obj: Optional[Spec[Any]], /) -> Spec[Any]: |
| 177 | + """Return a dummy (``None``-filled) data spec if an object is not one.""" |
| 178 | + return Spec(ROOT, (), None, None) if obj is None else obj |
| 179 | + |
| 180 | + |
| 181 | +def unwrap(obj: Any, /) -> Any: |
| 182 | + """Unwrap if an object is a literal or a forward reference.""" |
| 183 | + if get_origin(obj) is Literal: |
| 184 | + return args[0] if len(args := get_args(obj)) == 1 else obj |
| 185 | + elif isinstance(obj, ForwardRef): |
| 186 | + return obj.__forward_arg__ |
| 187 | + else: |
| 188 | + return obj |
0 commit comments