Skip to content

Commit 228acff

Browse files
committed
1. added add and merge methods to base class
2. created tests for add and merge methods 3. added utility to convert from snake to pascal and integrated it in error messaging
1 parent a183c71 commit 228acff

File tree

4 files changed

+56
-71
lines changed

4 files changed

+56
-71
lines changed

pymc_extras/statespace/core/properties.py

Lines changed: 23 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import warnings
24

35
from collections.abc import Iterator
46
from copy import deepcopy
57
from dataclasses import dataclass, fields
6-
from typing import Generic, Self, TypeVar
8+
from typing import TYPE_CHECKING, Generic, Self, TypeVar
79

810
from pymc_extras.statespace.core import PyMCStateSpace
911
from pymc_extras.statespace.utils.constants import (
@@ -15,6 +17,9 @@
1517
SHOCK_DIM,
1618
)
1719

20+
if TYPE_CHECKING:
21+
from pymc_extras.statespace.models.structural.core import Component
22+
1823

1924
@dataclass(frozen=True)
2025
class Property:
@@ -62,7 +67,7 @@ def __getitem__(self, key: str) -> T:
6267
def __contains__(self, key: object) -> bool:
6368
return key in self._index
6469

65-
def __iter__(self) -> Iterator[str]:
70+
def __iter__(self) -> Iterator[T]:
6671
return iter(self.items)
6772

6873
def __len__(self) -> int:
@@ -71,11 +76,24 @@ def __len__(self) -> int:
7176
def __str__(self) -> str:
7277
return f"{self.key_field}s: {list(self._index.keys())}"
7378

79+
def add(self, new_item: T):
80+
return type(self)([*self.items, new_item])
81+
82+
def merge(self, other: Self, allow_duplicates: bool = False) -> Self:
83+
if not isinstance(other, type(self)):
84+
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")
85+
86+
overlapping = set(self.names) & set(other.names)
87+
if overlapping and not allow_duplicates:
88+
raise ValueError(f"Duplicate names found: {overlapping}")
89+
90+
return type(self)(list(self.items) + list(other.items))
91+
7492
@property
7593
def names(self) -> tuple[str, ...]:
7694
return tuple(self._index.keys())
7795

78-
def copy(self) -> "Info[T]":
96+
def copy(self) -> Info[T]:
7997
return deepcopy(self)
8098

8199

@@ -92,21 +110,6 @@ class ParameterInfo(Info[Parameter]):
92110
def __init__(self, parameters: list[Parameter]):
93111
super().__init__(items=tuple(parameters), key_field="name")
94112

95-
def add(self, parameter: Parameter) -> "ParameterInfo":
96-
# return a new ParameterInfo with parameter appended
97-
return ParameterInfo(parameters=[*list(self.items), parameter])
98-
99-
def merge(self, other: "ParameterInfo", allow_duplicates: bool = False) -> "ParameterInfo":
100-
"""Combine parameters from two ParameterInfo objects."""
101-
if not isinstance(other, ParameterInfo):
102-
raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo")
103-
104-
overlapping = set(self.names) & set(other.names)
105-
if overlapping and not allow_duplicates:
106-
raise ValueError(f"Duplicate parameter names found: {overlapping}")
107-
108-
return ParameterInfo(parameters=list(self.items) + list(other.items))
109-
110113

111114
@dataclass(frozen=True)
112115
class Data(Property):
@@ -132,21 +135,6 @@ def exogenous_names(self) -> tuple[str, ...]:
132135
def __str__(self) -> str:
133136
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
134137

135-
def add(self, data: Data) -> "DataInfo":
136-
# return a new DataInfo with data appended
137-
return DataInfo(data=[*list(self.items), data])
138-
139-
def merge(self, other: "DataInfo", allow_duplicates: bool = False) -> "DataInfo":
140-
"""Combine data from two DataInfo objects."""
141-
if not isinstance(other, DataInfo):
142-
raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo")
143-
144-
overlapping = set(self.names) & set(other.names)
145-
if overlapping and not allow_duplicates:
146-
raise ValueError(f"Duplicate data names found: {overlapping}")
147-
148-
return DataInfo(data=list(self.items) + list(other.items))
149-
150138

151139
@dataclass(frozen=True)
152140
class Coord(Property):
@@ -169,7 +157,7 @@ def __str__(self) -> str:
169157

170158
@classmethod
171159
def default_coords_from_model(
172-
cls, model: PyMCStateSpace
160+
cls, model: PyMCStateSpace | Component
173161
) -> (
174162
Self
175163
): # TODO: Need to figure out how to include Component type was causing circular import issues
@@ -192,21 +180,6 @@ def default_coords_from_model(
192180
def to_dict(self):
193181
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}
194182

195-
def add(self, coord: Coord) -> "CoordInfo":
196-
# return a new CoordInfo with data appended
197-
return CoordInfo(coords=[*list(self.items), coord])
198-
199-
def merge(self, other: "CoordInfo", allow_duplicates: bool = False) -> "CoordInfo":
200-
"""Combine data from two CoordInfo objects."""
201-
if not isinstance(other, CoordInfo):
202-
raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo")
203-
204-
overlapping = set(self.names) & set(other.names)
205-
if overlapping and not allow_duplicates:
206-
raise ValueError(f"Duplicate coord names found: {overlapping}")
207-
208-
return CoordInfo(coords=list(self.items) + list(other.items))
209-
210183

211184
@dataclass(frozen=True)
212185
class State(Property):
@@ -237,11 +210,7 @@ def observed_state_names(self) -> tuple[State, ...]:
237210
def unobserved_state_names(self) -> tuple[State, ...]:
238211
return tuple(s.name for s in self.items if not s.observed)
239212

240-
def add(self, state: State) -> "StateInfo":
241-
# return a new StateInfo with state appended
242-
return StateInfo(states=[*list(self.items), state])
243-
244-
def merge(self, other: "StateInfo", allow_duplicates: bool = False) -> "StateInfo":
213+
def merge(self, other: StateInfo, allow_duplicates: bool = False) -> StateInfo:
245214
"""Combine states from two StateInfo objects."""
246215
if not isinstance(other, StateInfo):
247216
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")
@@ -270,18 +239,3 @@ class Shock(Property):
270239
class ShockInfo(Info[Shock]):
271240
def __init__(self, shocks: list[Shock]):
272241
super().__init__(items=tuple(shocks), key_field="name")
273-
274-
def add(self, shock: Shock) -> "ShockInfo":
275-
# return a new ShockInfo with shock appended
276-
return ShockInfo(shocks=[*list(self.items), shock])
277-
278-
def merge(self, other: "ShockInfo", allow_duplicates: bool = False) -> "ShockInfo":
279-
"""Combine shocks from two ShockInfo objects."""
280-
if not isinstance(other, ShockInfo):
281-
raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo")
282-
283-
overlapping = set(self.names) & set(other.names)
284-
if overlapping and not allow_duplicates:
285-
raise ValueError(f"Duplicate shock names found: {overlapping}")
286-
287-
return ShockInfo(shocks=list(self.items) + list(other.items))

pymc_extras/statespace/models/structural/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ALL_STATE_DIM,
3535
LONG_MATRIX_NAMES,
3636
)
37+
from pymc_extras.statespace.utils.message_tools import snake_to_pascal
3738

3839
_log = logging.getLogger(__name__)
3940
floatX = config.floatX
@@ -815,8 +816,10 @@ def _combine_property(self, other, name, allow_duplicates=True):
815816
)
816817

817818
if not is_dataclass(self_prop):
819+
# TODO: This works right now because we are only passing <foo>_info info names into _combine_property
820+
# If we don't follow that schema moving forward this will break.
818821
raise TypeError(
819-
f"All component properties are expected to be dataclasses, but found {type(self_prop)}"
822+
f"Component properties are expected to be {snake_to_pascal(name)}, but found {type(self_prop)}"
820823
f"for property {name} of {self} and {type(other_prop)} for {other}'"
821824
)
822825

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import re
2+
3+
4+
def snake_to_pascal(s: str) -> str:
5+
return re.sub(r"(?:^|_)([a-z])", lambda m: m.group(1).upper(), s)

tests/statespace/core/test_properties.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_data_info_needs_exogenous_and_str():
6868
def test_coord_info_make_defaults_from_component_and_types():
6969
class DummyComponent:
7070
state_names = ["x1", "x2"]
71-
observed_state_names = ["x2"]
71+
observed_states = ["x2"]
7272
shock_names = ["eps1"]
7373

7474
ci = CoordInfo.default_coords_from_model(DummyComponent())
@@ -117,3 +117,26 @@ def test_info_is_iterable_and_unpackable():
117117

118118
a, b = info.items
119119
assert a.name == "p1" and b.name == "p2"
120+
121+
122+
def test_info_add_method():
123+
a_param = Parameter(name="a", shape=(1,), dims=("dim",))
124+
param_info = ParameterInfo(parameters=[a_param])
125+
126+
b_param = Parameter(name="b", shape=(1,), dims=("dim",))
127+
128+
new_param_info = param_info.add(new_item=b_param)
129+
130+
assert new_param_info.names == ("a", "b")
131+
132+
133+
def test_info_merge_method():
134+
a_param = Parameter(name="a", shape=(1,), dims=("dim",))
135+
a_param_info = ParameterInfo(parameters=[a_param])
136+
137+
b_param = Parameter(name="b", shape=(1,), dims=("dim",))
138+
b_param_info = ParameterInfo(parameters=[b_param])
139+
140+
new_param_info = a_param_info.merge(b_param_info)
141+
142+
assert new_param_info.names == ("a", "b")

0 commit comments

Comments
 (0)