22
33
44# standard library
5- from typing import Any , Callable , overload
6-
5+ from dataclasses import replace
6+ from typing import Any , ForwardRef , Literal , Optional , overload
77
88# dependencies
9- from xarray import DataArray , Dataset
10- from .typing import DataClass , DataClassOf , PAny , TDataArray , TDataset , TXarray
9+ from dataspecs import ID , ROOT , Spec , Specs
10+ from numpy import asarray , array
11+ from typing_extensions import get_args , get_origin
12+ from xarray import DataArray , Dataset , Variable
13+ from .typing import (
14+ DataClass ,
15+ DataClassOf ,
16+ Factory ,
17+ HashDict ,
18+ PAny ,
19+ TAny ,
20+ TDataArray ,
21+ TDataset ,
22+ TXarray ,
23+ Tag ,
24+ )
25+
26+
27+ # type hints
28+ Attrs = HashDict [Any ]
29+ Vars = HashDict [Variable ]
1130
1231
1332@overload
@@ -24,7 +43,7 @@ def asdataarray(
2443 obj : DataClass [PAny ],
2544 / ,
2645 * ,
27- factory : Callable [..., TDataArray ],
46+ factory : Factory [ TDataArray ],
2847) -> TDataArray : ...
2948
3049
@@ -56,7 +75,7 @@ def asdataset(
5675 obj : DataClass [PAny ],
5776 / ,
5877 * ,
59- factory : Callable [..., TDataset ],
78+ factory : Factory [ TDataset ],
6079) -> TDataset : ...
6180
6281
@@ -88,10 +107,84 @@ def asxarray(
88107 obj : DataClass [PAny ],
89108 / ,
90109 * ,
91- factory : Callable [..., TXarray ],
110+ factory : Factory [ TXarray ],
92111) -> TXarray : ...
93112
94113
95114def asxarray (obj : Any , / , * , factory : Any = None ) -> Any :
96115 """Create a DataArray/set object from a dataclass object."""
97116 ...
117+
118+
119+ def get_attrs (specs : Specs [Spec [Any ]], / , * , at : ID = ROOT ) -> Attrs :
120+ """Create attributes from data specs."""
121+ attrs : Attrs = {}
122+
123+ for spec in specs [at .children ][Tag .ATTR ]:
124+ options = specs [spec .id .children ]
125+ factory = maybe (options [Tag .FACTORY ].unique ).data or identity
126+ name = maybe (options [Tag .NAME ].unique ).data or spec .id .name
127+
128+ if Tag .MULTIPLE not in spec .tags :
129+ spec = replace (spec , data = {name : spec .data })
130+
131+ for name , data in spec [HashDict [Any ]].data .items ():
132+ attrs [name ] = factory (data )
133+
134+ return attrs
135+
136+
137+ def get_vars (specs : Specs [Spec [Any ]], of : Tag , / , * , at : ID = ROOT ) -> Vars :
138+ """Create variables of given tag from data specs."""
139+ vars : Vars = {}
140+
141+ for spec in specs [at .children ][of ]:
142+ options = specs [spec .id .children ]
143+ attrs = get_attrs (specs , at = spec .id )
144+ factory = maybe (options [Tag .FACTORY ].unique ).data or Variable
145+ name = maybe (options [Tag .NAME ].unique ).data or spec .id .name
146+
147+ if (type_ := maybe (options [Tag .DIMS ].unique ).type ) is None :
148+ raise RuntimeError ("Could not find any data spec for dims." )
149+ elif get_origin (type_ ) is tuple :
150+ dims = tuple (str (unwrap (arg )) for arg in get_args (type_ ))
151+ else :
152+ dims = (str (unwrap (type_ )),)
153+
154+ if (type_ := maybe (options [Tag .DTYPE ].unique ).type ) is None :
155+ raise RuntimeError ("Could not find any data spec for dims." )
156+ elif type_ is type (None ) or type_ is Any :
157+ dtype = None
158+ else :
159+ dtype = unwrap (type_ )
160+
161+ if Tag .MULTIPLE not in spec .tags :
162+ spec = replace (spec , data = {name : spec .data })
163+
164+ for name , data in spec [HashDict [Any ]].data .items ():
165+ if not (data := asarray (data , dtype )).ndim :
166+ data = array (data , ndmin = len (dims ))
167+
168+ vars [name ] = factory (attrs = attrs , data = data , dims = dims )
169+
170+ return vars
171+
172+
173+ def identity (obj : TAny , / ) -> TAny :
174+ """Identity function used for the default factory."""
175+ return obj
176+
177+
178+ def maybe (obj : Optional [Spec [Any ]], / ) -> Spec [Any ]:
179+ """Return a dummy (``None``-filled) data spec if an object is not one."""
180+ return Spec (ROOT , (), None , None ) if obj is None else obj
181+
182+
183+ def unwrap (obj : Any , / ) -> Any :
184+ """Unwrap if an object is a literal or a forward reference."""
185+ if get_origin (obj ) is Literal :
186+ return args [0 ] if len (args := get_args (obj )) == 1 else obj
187+ elif isinstance (obj , ForwardRef ):
188+ return obj .__forward_arg__
189+ else :
190+ return obj
0 commit comments