1+ from __future__ import annotations
2+
13import warnings
24
35from collections .abc import Iterator
46from copy import deepcopy
57from dataclasses import dataclass , fields
6- from typing import Generic , Self , TypeVar
8+ from typing import TYPE_CHECKING , Generic , Self , TypeVar
79
810from pymc_extras .statespace .core import PyMCStateSpace
911from pymc_extras .statespace .utils .constants import (
1517 SHOCK_DIM ,
1618)
1719
20+ if TYPE_CHECKING :
21+ from pymc_extras .statespace .models .structural .core import Component
22+
1823
1924@dataclass (frozen = True )
2025class Property :
@@ -62,7 +67,7 @@ def __getitem__(self, key: str) -> T:
6267 def __contains__ (self , key : object ) -> bool :
6368 return key in self ._index
6469
65- def __iter__ (self ) -> Iterator [str ]:
70+ def __iter__ (self ) -> Iterator [T ]:
6671 return iter (self .items )
6772
6873 def __len__ (self ) -> int :
@@ -71,11 +76,24 @@ def __len__(self) -> int:
7176 def __str__ (self ) -> str :
7277 return f"{ self .key_field } s: { list (self ._index .keys ())} "
7378
79+ def add (self , new_item : T ):
80+ return type (self )([* self .items , new_item ])
81+
82+ def merge (self , other : Self , allow_duplicates : bool = False ) -> Self :
83+ if not isinstance (other , type (self )):
84+ raise TypeError (f"Cannot merge { type (other ).__name__ } with { type (self ).__name__ } " )
85+
86+ overlapping = set (self .names ) & set (other .names )
87+ if overlapping and not allow_duplicates :
88+ raise ValueError (f"Duplicate names found: { overlapping } " )
89+
90+ return type (self )(list (self .items ) + list (other .items ))
91+
7492 @property
7593 def names (self ) -> tuple [str , ...]:
7694 return tuple (self ._index .keys ())
7795
78- def copy (self ) -> " Info[T]" :
96+ def copy (self ) -> Info [T ]:
7997 return deepcopy (self )
8098
8199
@@ -92,21 +110,6 @@ class ParameterInfo(Info[Parameter]):
92110 def __init__ (self , parameters : list [Parameter ]):
93111 super ().__init__ (items = tuple (parameters ), key_field = "name" )
94112
95- def add (self , parameter : Parameter ) -> "ParameterInfo" :
96- # return a new ParameterInfo with parameter appended
97- return ParameterInfo (parameters = [* list (self .items ), parameter ])
98-
99- def merge (self , other : "ParameterInfo" , allow_duplicates : bool = False ) -> "ParameterInfo" :
100- """Combine parameters from two ParameterInfo objects."""
101- if not isinstance (other , ParameterInfo ):
102- raise TypeError (f"Cannot merge { type (other ).__name__ } with ParameterInfo" )
103-
104- overlapping = set (self .names ) & set (other .names )
105- if overlapping and not allow_duplicates :
106- raise ValueError (f"Duplicate parameter names found: { overlapping } " )
107-
108- return ParameterInfo (parameters = list (self .items ) + list (other .items ))
109-
110113
111114@dataclass (frozen = True )
112115class Data (Property ):
@@ -132,21 +135,6 @@ def exogenous_names(self) -> tuple[str, ...]:
132135 def __str__ (self ) -> str :
133136 return f"data: { [d .name for d in self .items ]} \n needs exogenous data: { self .needs_exogenous_data } "
134137
135- def add (self , data : Data ) -> "DataInfo" :
136- # return a new DataInfo with data appended
137- return DataInfo (data = [* list (self .items ), data ])
138-
139- def merge (self , other : "DataInfo" , allow_duplicates : bool = False ) -> "DataInfo" :
140- """Combine data from two DataInfo objects."""
141- if not isinstance (other , DataInfo ):
142- raise TypeError (f"Cannot merge { type (other ).__name__ } with DataInfo" )
143-
144- overlapping = set (self .names ) & set (other .names )
145- if overlapping and not allow_duplicates :
146- raise ValueError (f"Duplicate data names found: { overlapping } " )
147-
148- return DataInfo (data = list (self .items ) + list (other .items ))
149-
150138
151139@dataclass (frozen = True )
152140class Coord (Property ):
@@ -169,7 +157,7 @@ def __str__(self) -> str:
169157
170158 @classmethod
171159 def default_coords_from_model (
172- cls , model : PyMCStateSpace
160+ cls , model : PyMCStateSpace | Component
173161 ) -> (
174162 Self
175163 ): # TODO: Need to figure out how to include Component type was causing circular import issues
@@ -192,21 +180,6 @@ def default_coords_from_model(
192180 def to_dict (self ):
193181 return {coord .dimension : coord .labels for coord in self .items if len (coord .labels ) > 0 }
194182
195- def add (self , coord : Coord ) -> "CoordInfo" :
196- # return a new CoordInfo with data appended
197- return CoordInfo (coords = [* list (self .items ), coord ])
198-
199- def merge (self , other : "CoordInfo" , allow_duplicates : bool = False ) -> "CoordInfo" :
200- """Combine data from two CoordInfo objects."""
201- if not isinstance (other , CoordInfo ):
202- raise TypeError (f"Cannot merge { type (other ).__name__ } with CoordInfo" )
203-
204- overlapping = set (self .names ) & set (other .names )
205- if overlapping and not allow_duplicates :
206- raise ValueError (f"Duplicate coord names found: { overlapping } " )
207-
208- return CoordInfo (coords = list (self .items ) + list (other .items ))
209-
210183
211184@dataclass (frozen = True )
212185class State (Property ):
@@ -237,11 +210,7 @@ def observed_state_names(self) -> tuple[State, ...]:
237210 def unobserved_state_names (self ) -> tuple [State , ...]:
238211 return tuple (s .name for s in self .items if not s .observed )
239212
240- def add (self , state : State ) -> "StateInfo" :
241- # return a new StateInfo with state appended
242- return StateInfo (states = [* list (self .items ), state ])
243-
244- def merge (self , other : "StateInfo" , allow_duplicates : bool = False ) -> "StateInfo" :
213+ def merge (self , other : StateInfo , allow_duplicates : bool = False ) -> StateInfo :
245214 """Combine states from two StateInfo objects."""
246215 if not isinstance (other , StateInfo ):
247216 raise TypeError (f"Cannot merge { type (other ).__name__ } with StateInfo" )
@@ -270,18 +239,3 @@ class Shock(Property):
270239class ShockInfo (Info [Shock ]):
271240 def __init__ (self , shocks : list [Shock ]):
272241 super ().__init__ (items = tuple (shocks ), key_field = "name" )
273-
274- def add (self , shock : Shock ) -> "ShockInfo" :
275- # return a new ShockInfo with shock appended
276- return ShockInfo (shocks = [* list (self .items ), shock ])
277-
278- def merge (self , other : "ShockInfo" , allow_duplicates : bool = False ) -> "ShockInfo" :
279- """Combine shocks from two ShockInfo objects."""
280- if not isinstance (other , ShockInfo ):
281- raise TypeError (f"Cannot merge { type (other ).__name__ } with ShockInfo" )
282-
283- overlapping = set (self .names ) & set (other .names )
284- if overlapping and not allow_duplicates :
285- raise ValueError (f"Duplicate shock names found: { overlapping } " )
286-
287- return ShockInfo (shocks = list (self .items ) + list (other .items ))
0 commit comments