Skip to content

Commit 92e333f

Browse files
committed
made necessary changes to get the regression component test to pass using the new dataclasses API
1 parent c6a48fc commit 92e333f

File tree

6 files changed

+293
-93
lines changed

6 files changed

+293
-93
lines changed

pymc_extras/statespace/core/properties.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Generic, Self, TypeVar
44

55
from pymc_extras.statespace.core import PyMCStateSpace
6-
from pymc_extras.statespace.models.structural.core import Component
76
from pymc_extras.statespace.utils.constants import (
87
ALL_STATE_AUX_DIM,
98
ALL_STATE_DIM,
@@ -87,6 +86,21 @@ class ParameterInfo(Info[Parameter]):
8786
def __init__(self, parameters: list[Parameter]):
8887
super().__init__(items=tuple(parameters), key_field="name")
8988

89+
def add(self, parameter: Parameter) -> "ParameterInfo":
90+
# return a new ParameterInfo with parameter appended
91+
return ParameterInfo(parameters=[*list(self.items), parameter])
92+
93+
def merge(self, other: "ParameterInfo") -> "ParameterInfo":
94+
"""Combine parameters from two ParameterInfo objects."""
95+
if not isinstance(other, ParameterInfo):
96+
raise TypeError(f"Cannot merge {type(other).__name__} with ParameterInfo")
97+
98+
overlapping = set(self.names) & set(other.names)
99+
if overlapping:
100+
raise ValueError(f"Duplicate parameter names found: {overlapping}")
101+
102+
return ParameterInfo(parameters=list(self.items) + list(other.items))
103+
90104

91105
@dataclass(frozen=True)
92106
class Data(Property):
@@ -108,6 +122,21 @@ def needs_exogenous_data(self) -> bool:
108122
def __str__(self) -> str:
109123
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
110124

125+
def add(self, data: Data) -> "DataInfo":
126+
# return a new DataInfo with data appended
127+
return DataInfo(data=[*list(self.items), data])
128+
129+
def merge(self, other: "DataInfo") -> "DataInfo":
130+
"""Combine data from two DataInfo objects."""
131+
if not isinstance(other, DataInfo):
132+
raise TypeError(f"Cannot merge {type(other).__name__} with DataInfo")
133+
134+
overlapping = set(self.names) & set(other.names)
135+
if overlapping:
136+
raise ValueError(f"Duplicate data names found: {overlapping}")
137+
138+
return DataInfo(data=list(self.items) + list(other.items))
139+
111140

112141
@dataclass(frozen=True)
113142
class Coord(Property):
@@ -129,7 +158,11 @@ def __str__(self) -> str:
129158
return base
130159

131160
@classmethod
132-
def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self:
161+
def default_coords_from_model(
162+
cls, model: PyMCStateSpace
163+
) -> (
164+
Self
165+
): # TODO: Need to figure out how to include Component type was causing circular import issues
133166
states = tuple(model.state_names)
134167
obs_states = tuple(model.observed_state_names)
135168
shocks = tuple(model.shock_names)
@@ -149,6 +182,21 @@ def default_coords_from_model(cls, model: Component | PyMCStateSpace) -> Self:
149182
def to_dict(self):
150183
return {coord.dimension: coord.labels for coord in self.items if len(coord.labels) > 0}
151184

185+
def add(self, coord: Coord) -> "CoordInfo":
186+
# return a new CoordInfo with data appended
187+
return CoordInfo(coords=[*list(self.items), coord])
188+
189+
def merge(self, other: "CoordInfo") -> "CoordInfo":
190+
"""Combine data from two CoordInfo objects."""
191+
if not isinstance(other, CoordInfo):
192+
raise TypeError(f"Cannot merge {type(other).__name__} with CoordInfo")
193+
194+
overlapping = set(self.names) & set(other.names)
195+
if overlapping:
196+
raise ValueError(f"Duplicate coord names found: {overlapping}")
197+
198+
return CoordInfo(coords=list(self.items) + list(other.items))
199+
152200

153201
@dataclass(frozen=True)
154202
class State(Property):
@@ -171,6 +219,21 @@ def __str__(self) -> str:
171219
def observed_states(self) -> tuple[State, ...]:
172220
return tuple(s for s in self.items if s.observed)
173221

222+
def add(self, state: State) -> "StateInfo":
223+
# return a new StateInfo with state appended
224+
return StateInfo(states=[*list(self.items), state])
225+
226+
def merge(self, other: "StateInfo") -> "StateInfo":
227+
"""Combine states from two StateInfo objects."""
228+
if not isinstance(other, StateInfo):
229+
raise TypeError(f"Cannot merge {type(other).__name__} with StateInfo")
230+
231+
overlapping = set(self.names) & set(other.names)
232+
if overlapping:
233+
raise ValueError(f"Duplicate state names found: {overlapping}")
234+
235+
return StateInfo(states=list(self.items) + list(other.items))
236+
174237

175238
@dataclass(frozen=True)
176239
class Shock(Property):
@@ -181,3 +244,18 @@ class Shock(Property):
181244
class ShockInfo(Info[Shock]):
182245
def __init__(self, shocks: list[Shock]):
183246
super().__init__(items=tuple(shocks), key_field="name")
247+
248+
def add(self, shock: Shock) -> "ShockInfo":
249+
# return a new ShockInfo with shock appended
250+
return ShockInfo(shocks=[*list(self.items), shock])
251+
252+
def merge(self, other: "ShockInfo") -> "ShockInfo":
253+
"""Combine shocks from two ShockInfo objects."""
254+
if not isinstance(other, ShockInfo):
255+
raise TypeError(f"Cannot merge {type(other).__name__} with ShockInfo")
256+
257+
overlapping = set(self.names) & set(other.names)
258+
if overlapping:
259+
raise ValueError(f"Duplicate shock names found: {overlapping}")
260+
261+
return ShockInfo(shocks=list(self.items) + list(other.items))

pymc_extras/statespace/models/structural/components/regression.py

Lines changed: 102 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
from pytensor import tensor as pt
44

5+
from pymc_extras.statespace.core.properties import (
6+
Coord,
7+
CoordInfo,
8+
Data,
9+
DataInfo,
10+
Parameter,
11+
ParameterInfo,
12+
Shock,
13+
ShockInfo,
14+
State,
15+
StateInfo,
16+
)
517
from pymc_extras.statespace.models.structural.core import Component
618
from pymc_extras.statespace.utils.constants import TIME_DIM
719

@@ -194,64 +206,110 @@ def make_symbolic_graph(self) -> None:
194206
row_idx, col_idx = np.diag_indices(self.k_states)
195207
self.ssm["state_cov", row_idx, col_idx] = sigma_beta.ravel() ** 2
196208

197-
def populate_component_properties(self) -> None:
209+
def _set_parameters(self) -> None:
198210
k_endog = self.k_endog
199211
k_endog_effective = 1 if self.share_states else k_endog
212+
k_states = self.k_states // k_endog_effective
213+
214+
beta_param_name = f"beta_{self.name}"
215+
beta_param_shape = (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,)
216+
beta_param_dims = (
217+
(f"endog_{self.name}", f"state_{self.name}")
218+
if k_endog_effective > 1
219+
else (f"state_{self.name}",)
220+
)
221+
222+
beta_param_constraints = None
223+
beta_parameter = Parameter(
224+
name=beta_param_name,
225+
shape=beta_param_shape,
226+
dims=beta_param_dims,
227+
constraints=beta_param_constraints,
228+
)
200229

230+
if self.innovations:
231+
sigma_param_name = f"sigma_beta_{self.name}"
232+
sigma_param_dims = (f"state_{self.name}",)
233+
sigma_param_shape = (k_states,)
234+
sigma_param_constraints = "Positive"
235+
236+
sigma_parameter = Parameter(
237+
name=sigma_param_name,
238+
shape=sigma_param_shape,
239+
dims=sigma_param_dims,
240+
constraints=sigma_param_constraints,
241+
)
242+
243+
self.param_info = ParameterInfo(parameters=[beta_parameter, sigma_parameter])
244+
self.param_names = self.param_info.names
245+
else:
246+
self.param_info = ParameterInfo(parameters=[beta_parameter])
247+
self.param_names = self.param_info.names
248+
249+
def _set_data(self) -> None:
250+
k_endog = self.k_endog
251+
k_endog_effective = 1 if self.share_states else k_endog
201252
k_states = self.k_states // k_endog_effective
202253

254+
data_name = f"data_{self.name}"
255+
data_shape = (None, k_states)
256+
data_dims = (TIME_DIM, f"state_{self.name}")
257+
258+
data_prop = Data(name=data_name, shape=data_shape, dims=data_dims, is_exogenous=True)
259+
self.data_info = DataInfo(data=[data_prop])
260+
self.data_names = self.data_info.names
261+
262+
def _set_shocks(self) -> None:
203263
if self.share_states:
204-
self.shock_names = [f"{state_name}_shared" for state_name in self.state_names]
264+
shock_names = [f"{state_name}_shared" for state_name in self.state_names]
205265
else:
206-
self.shock_names = self.state_names
266+
shock_names = self.state_names
207267

208-
self.param_names = [f"beta_{self.name}"]
209-
self.data_names = [f"data_{self.name}"]
210-
self.param_dims = {
211-
f"beta_{self.name}": (f"endog_{self.name}", f"state_{self.name}")
212-
if k_endog_effective > 1
213-
else (f"state_{self.name}",)
214-
}
268+
self.shock_info = ShockInfo(shocks=[Shock(name=name) for name in shock_names])
269+
self.shock_names = self.shock_info.names
215270

216-
base_names = self.state_names
271+
def _set_states(self) -> None:
272+
self.base_names = self.state_names
217273

218274
if self.share_states:
219-
self.state_names = [f"{name}[{self.name}_shared]" for name in base_names]
275+
state_names = [f"{name}[{self.name}_shared]" for name in self.base_names]
276+
self.state_info = StateInfo(
277+
states=[State(name=name, observed=True, shared=True) for name in state_names]
278+
)
279+
self.state_names = self.state_info.names
220280
else:
221-
self.state_names = [
281+
state_names = [
222282
f"{name}[{obs_name}]"
223283
for obs_name in self.observed_state_names
224-
for name in base_names
284+
for name in self.base_names
225285
]
286+
self.state_info = StateInfo(
287+
states=[State(name=name, observed=True, shared=False) for name in state_names]
288+
)
289+
self.state_names = self.state_info.names
226290

227-
self.param_info = {
228-
f"beta_{self.name}": {
229-
"shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,),
230-
"constraints": None,
231-
"dims": (f"endog_{self.name}", f"state_{self.name}")
232-
if k_endog_effective > 1
233-
else (f"state_{self.name}",),
234-
},
235-
}
236-
237-
self.data_info = {
238-
f"data_{self.name}": {
239-
"shape": (None, k_states),
240-
"dims": (TIME_DIM, f"state_{self.name}"),
241-
},
242-
}
243-
self.coords = {
244-
f"state_{self.name}": base_names,
245-
f"endog_{self.name}": self.observed_state_names,
246-
}
291+
def _set_coords(self) -> None:
292+
regression_state_coord = Coord(
293+
dimension=f"state_{self.name}", labels=[state for state in self.base_names]
294+
)
295+
endogenous_state_coord = Coord(
296+
dimension=f"endog_{self.name}", labels=[state for state in self.observed_state_names]
297+
)
247298

248-
if self.innovations:
249-
self.param_names += [f"sigma_beta_{self.name}"]
250-
self.param_dims[f"sigma_beta_{self.name}"] = (f"state_{self.name}",)
251-
self.param_info[f"sigma_beta_{self.name}"] = {
252-
"shape": (k_states,),
253-
"constraints": "Positive",
254-
"dims": (f"state_{self.name}",)
255-
if k_endog_effective == 1
256-
else (f"endog_{self.name}", f"state_{self.name}"),
257-
}
299+
self.coords = CoordInfo(coords=[regression_state_coord, endogenous_state_coord])
300+
301+
def populate_component_properties(self) -> None:
302+
# Set parameter info
303+
self._set_parameters()
304+
305+
# Set data info
306+
self._set_data()
307+
308+
# Set shock info
309+
self._set_shocks()
310+
311+
# Set states info
312+
self._set_states()
313+
314+
# Set coordinates info
315+
self._set_coords()

0 commit comments

Comments
 (0)