Skip to content

Commit 8b591b3

Browse files
committed
#237 Add helper dataclass converters
1 parent 30b82e2 commit 8b591b3

File tree

1 file changed

+95
-4
lines changed

1 file changed

+95
-4
lines changed

xarray_dataclasses/api.py

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,29 @@
22

33

44
# standard library
5-
from typing import Any, Callable, overload
6-
5+
from dataclasses import replace
6+
from typing import Any, Callable, ForwardRef, Hashable, Literal, Optional, overload
77

88
# dependencies
9-
from xarray import DataArray, Dataset
10-
from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray
9+
from dataspecs import ID, ROOT, Spec, Specs
10+
from numpy import asarray, array
11+
from typing_extensions import get_args, get_origin
12+
from xarray import DataArray, Dataset, Variable
13+
from .typing import (
14+
DataClass,
15+
DataClassOf,
16+
PAny,
17+
TAny,
18+
TDataArray,
19+
TDataset,
20+
TXarray,
21+
Tag,
22+
)
23+
24+
25+
# type hints
26+
Attrs = dict[Hashable, Any]
27+
Vars = dict[Hashable, Variable]
1128

1229

1330
@overload
@@ -95,3 +112,77 @@ def asxarray(
95112
def asxarray(obj: Any, /, *, factory: Any = None) -> Any:
96113
"""Create a DataArray/set object from a dataclass object."""
97114
...
115+
116+
117+
def get_attrs(specs: Specs[Spec[Any]], /, *, at: ID = ROOT) -> Attrs:
118+
"""Create attributes from data specs."""
119+
attrs: Attrs = {}
120+
121+
for spec in specs[at.children][Tag.ATTR]:
122+
options = specs[spec.id.children]
123+
factory = maybe(options[Tag.FACTORY].unique).data or identity
124+
name = maybe(options[Tag.NAME].unique).data or spec.id.name
125+
126+
if Tag.MULTIPLE not in spec.tags:
127+
spec = replace(spec, data={name: spec.data})
128+
129+
for name, data in spec[dict[Hashable, Any]].data.items():
130+
attrs[name] = factory(data)
131+
132+
return attrs
133+
134+
135+
def get_vars(specs: Specs[Spec[Any]], of: Tag, /, *, at: ID = ROOT) -> Vars:
136+
"""Create variables of given tag from data specs."""
137+
vars: Vars = {}
138+
139+
for spec in specs[at.children][of]:
140+
options = specs[spec.id.children]
141+
attrs = get_attrs(specs, at=spec.id)
142+
factory = maybe(options[Tag.FACTORY].unique).data or Variable
143+
name = maybe(options[Tag.NAME].unique).data or spec.id.name
144+
145+
if (type_ := maybe(options[Tag.DIMS].unique).type) is None:
146+
raise RuntimeError("Could not find any data spec for dims.")
147+
elif get_origin(type_) is tuple:
148+
dims = tuple(str(unwrap(arg)) for arg in get_args(type_))
149+
else:
150+
dims = (str(unwrap(type_)),)
151+
152+
if (type_ := maybe(options[Tag.DTYPE].unique).type) is None:
153+
raise RuntimeError("Could not find any data spec for dims.")
154+
elif type_ is type(None) or type_ is Any:
155+
dtype = None
156+
else:
157+
dtype = unwrap(type_)
158+
159+
if Tag.MULTIPLE not in spec.tags:
160+
spec = replace(spec, data={name: spec.data})
161+
162+
for name, data in spec[dict[Hashable, Any]].data.items():
163+
if not (data := asarray(data, dtype)).ndim:
164+
data = array(data, ndmin=len(dims))
165+
166+
vars[name] = factory(attrs=attrs, data=data, dims=dims)
167+
168+
return vars
169+
170+
171+
def identity(obj: TAny, /) -> TAny:
172+
"""Identity function used for the default factory."""
173+
return obj
174+
175+
176+
def maybe(obj: Optional[Spec[Any]], /) -> Spec[Any]:
177+
"""Return a dummy (``None``-filled) data spec if an object is not one."""
178+
return Spec(ROOT, (), None, None) if obj is None else obj
179+
180+
181+
def unwrap(obj: Any, /) -> Any:
182+
"""Unwrap if an object is a literal or a forward reference."""
183+
if get_origin(obj) is Literal:
184+
return args[0] if len(args := get_args(obj)) == 1 else obj
185+
elif isinstance(obj, ForwardRef):
186+
return obj.__forward_arg__
187+
else:
188+
return obj

0 commit comments

Comments
 (0)