Skip to content

Commit 93bf0b4

Browse files
Simulated covariates (#37)
* Covariates can be included in simulations.
1 parent 7a44af6 commit 93bf0b4

File tree

6 files changed

+376
-13
lines changed

6 files changed

+376
-13
lines changed

src/causal_validation/config.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
)
55
import datetime as dt
66

7+
from jaxtyping import Float
8+
import typing as tp
9+
710
import numpy as np
11+
from scipy.stats import halfcauchy
812

913
from causal_validation.types import (
1014
Number,
@@ -20,9 +24,38 @@ class WeightConfig:
2024

2125
@dataclass(kw_only=True)
2226
class Config:
27+
"""Configuration for causal data generation.
28+
29+
Args:
30+
n_control_units (int): Number of control units in the synthetic dataset.
31+
n_pre_intervention_timepoints (int): Number of time points before intervention.
32+
n_post_intervention_timepoints (int): Number of time points after intervention.
33+
n_covariates (Optional[int]): Number of covariates. Defaults to None.
34+
covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates
35+
D is n_control_units and K is n_covariates. Defaults to None. If it is set
36+
to None while n_covariates is provided, covariate_means will be generated
37+
randomly from Normal distribution.
38+
covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for
39+
covariates. D is n_control_units and K is n_covariates. Defaults to None.
40+
If it is set to None while n_covariates is provided, covariate_stds
41+
will be generated randomly from Half-Cauchy distribution.
42+
covariate_coeffs (Optional[np.ndarray]): Linear regression
43+
coefficients to map covariates to output observations. K is n_covariates.
44+
Defaults to None.
45+
global_mean (Number): Global mean for data generation. Defaults to 20.0.
46+
global_scale (Number): Global scale for data generation. Defaults to 0.2.
47+
start_date (dt.date): Start date for time series. Defaults to 2023-01-01.
48+
seed (int): Random seed for reproducibility. Defaults to 123.
49+
weights_cfg (WeightConfig): Configuration for unit weights. Defaults to
50+
UniformWeights.
51+
"""
2352
n_control_units: int
2453
n_pre_intervention_timepoints: int
2554
n_post_intervention_timepoints: int
55+
n_covariates: tp.Optional[int] = None
56+
covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None
57+
covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None
58+
covariate_coeffs: tp.Optional[np.ndarray] = None
2659
global_mean: Number = 20.0
2760
global_scale: Number = 0.2
2861
start_date: dt.date = dt.date(year=2023, month=1, day=1)
@@ -31,3 +64,29 @@ class Config:
3164

3265
def __post_init__(self):
3366
self.rng = np.random.RandomState(self.seed)
67+
if self.covariate_means is not None:
68+
assert self.covariate_means.shape == (self.n_control_units,
69+
self.n_covariates)
70+
71+
if self.covariate_stds is not None:
72+
assert self.covariate_stds.shape == (self.n_control_units,
73+
self.n_covariates)
74+
75+
if (self.n_covariates is not None) & (self.covariate_means is None):
76+
self.covariate_means = self.rng.normal(
77+
loc=0.0, scale=5.0, size=(self.n_control_units,
78+
self.n_covariates)
79+
)
80+
81+
if (self.n_covariates is not None) & (self.covariate_stds is None):
82+
self.covariate_stds = (
83+
halfcauchy.rvs(scale=0.5,
84+
size=(self.n_control_units,
85+
self.n_covariates),
86+
random_state=self.rng)
87+
)
88+
89+
if (self.n_covariates is not None) & (self.covariate_coeffs is None):
90+
self.covariate_coeffs = self.rng.normal(
91+
loc=0.0, scale=5.0, size=self.n_covariates
92+
)

src/causal_validation/simulate.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,40 @@ def _simulate_base_obs(
2929
obs = key.normal(
3030
loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units)
3131
)
32-
Xtr = obs[: config.n_pre_intervention_timepoints, :]
33-
Xte = obs[config.n_pre_intervention_timepoints :, :]
34-
ytr = weights.weight_obs(Xtr)
35-
yte = weights.weight_obs(Xte)
36-
data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date)
32+
33+
if config.n_covariates is not None:
34+
Xtr_ = obs[: config.n_pre_intervention_timepoints, :]
35+
Xte_ = obs[config.n_pre_intervention_timepoints :, :]
36+
37+
covariates = key.normal(
38+
loc=config.covariate_means,
39+
scale=config.covariate_stds,
40+
size=(n_timepoints, n_units, config.n_covariates)
41+
)
42+
43+
Ptr = covariates[:config.n_pre_intervention_timepoints, :, :]
44+
Pte = covariates[config.n_pre_intervention_timepoints:, :, :]
45+
46+
Xtr = Xtr_ + Ptr @ config.covariate_coeffs
47+
Xte = Xte_ + Pte @ config.covariate_coeffs
48+
49+
ytr = weights.weight_contr(Xtr)
50+
yte = weights.weight_contr(Xte)
51+
52+
Rtr = weights.weight_contr(Ptr)
53+
Rte = weights.weight_contr(Pte)
54+
55+
data = Dataset(
56+
Xtr, Xte, ytr, yte, _start_date=config.start_date,
57+
Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte
58+
)
59+
else:
60+
Xtr = obs[: config.n_pre_intervention_timepoints, :]
61+
Xte = obs[config.n_pre_intervention_timepoints :, :]
62+
63+
ytr = weights.weight_contr(Xtr)
64+
yte = weights.weight_contr(Xte)
65+
66+
data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date)
67+
3768
return data

src/causal_validation/weights.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@
1111
if tp.TYPE_CHECKING:
1212
from causal_validation.config import WeightConfig
1313

14+
# Constants for array dimensions
15+
_NDIM_2D = 2
16+
_NDIM_3D = 3
17+
1418

1519
@dataclass
1620
class AbstractWeights(BaseObject):
1721
name: str = "Abstract Weights"
1822

19-
def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
23+
def _get_weights(
24+
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
25+
) -> Float[np.ndarray, "D 1"]:
2026
raise NotImplementedError("Please implement `_get_weights` in all subclasses.")
2127

22-
def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
28+
def get_weights(
29+
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
30+
) -> Float[np.ndarray, "D 1"]:
2331
weights = self._get_weights(obs)
2432

2533
np.testing.assert_almost_equal(
@@ -28,21 +36,31 @@ def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]
2836
assert min(weights >= 0), "Weights should be non-negative"
2937
return weights
3038

31-
def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]:
32-
return self.weight_obs(obs)
39+
def __call__(
40+
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
41+
) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]:
42+
return self.weight_contr(obs)
3343

34-
def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]:
44+
def weight_contr(
45+
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
46+
) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]:
3547
weights = self.get_weights(obs)
3648

37-
weighted_obs = obs @ weights
49+
if obs.ndim == _NDIM_2D:
50+
weighted_obs = obs @ weights
51+
elif obs.ndim == _NDIM_3D:
52+
weighted_obs = np.einsum("n d k, d i -> n i k", obs, weights)
53+
3854
return weighted_obs
3955

4056

4157
@dataclass
4258
class UniformWeights(AbstractWeights):
4359
name: str = "Uniform Weights"
4460

45-
def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
61+
def _get_weights(
62+
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
63+
) -> Float[np.ndarray, "D 1"]:
4664
n_units = obs.shape[1]
4765
return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1)
4866

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import numpy as np
2+
from hypothesis import given, strategies as st
3+
4+
from causal_validation.config import Config
5+
6+
7+
@given(
8+
n_units=st.integers(min_value=1, max_value=10),
9+
n_pre=st.integers(min_value=1, max_value=20),
10+
n_post=st.integers(min_value=1, max_value=20)
11+
)
12+
def test_config_basic_initialization(n_units, n_pre, n_post):
13+
cfg = Config(
14+
n_control_units=n_units,
15+
n_pre_intervention_timepoints=n_pre,
16+
n_post_intervention_timepoints=n_post
17+
)
18+
assert cfg.n_control_units == n_units
19+
assert cfg.n_pre_intervention_timepoints == n_pre
20+
assert cfg.n_post_intervention_timepoints == n_post
21+
assert cfg.n_covariates is None
22+
assert cfg.covariate_means is None
23+
assert cfg.covariate_stds is None
24+
assert cfg.covariate_coeffs is None
25+
26+
27+
@given(
28+
n_units=st.integers(min_value=1, max_value=5),
29+
n_pre=st.integers(min_value=1, max_value=10),
30+
n_post=st.integers(min_value=1, max_value=10),
31+
n_covariates=st.integers(min_value=1, max_value=3),
32+
seed=st.integers(min_value=1, max_value=1000)
33+
)
34+
def test_config_with_covariates_auto_generation(
35+
n_units, n_pre, n_post, n_covariates, seed
36+
):
37+
cfg = Config(
38+
n_control_units=n_units,
39+
n_pre_intervention_timepoints=n_pre,
40+
n_post_intervention_timepoints=n_post,
41+
n_covariates=n_covariates,
42+
seed=seed
43+
)
44+
assert cfg.n_covariates == n_covariates
45+
assert cfg.covariate_means.shape == (n_units, n_covariates)
46+
assert cfg.covariate_stds.shape == (n_units, n_covariates)
47+
assert cfg.covariate_coeffs.shape == (n_covariates,)
48+
assert np.all(cfg.covariate_stds >= 0)
49+
50+
51+
@given(
52+
n_units=st.integers(min_value=1, max_value=3),
53+
n_covariates=st.integers(min_value=1, max_value=3)
54+
)
55+
def test_config_with_explicit_covariate_means(n_units, n_covariates):
56+
means = np.random.random((n_units, n_covariates))
57+
cfg = Config(
58+
n_control_units=n_units,
59+
n_pre_intervention_timepoints=10,
60+
n_post_intervention_timepoints=5,
61+
n_covariates=n_covariates,
62+
covariate_means=means
63+
)
64+
np.testing.assert_array_equal(cfg.covariate_means, means)
65+
66+
67+
@given(
68+
n_units=st.integers(min_value=1, max_value=3),
69+
n_covariates=st.integers(min_value=1, max_value=3)
70+
)
71+
def test_config_with_explicit_covariate_stds(n_units, n_covariates):
72+
stds = np.random.random((n_units, n_covariates)) + 0.1
73+
cfg = Config(
74+
n_control_units=n_units,
75+
n_pre_intervention_timepoints=10,
76+
n_post_intervention_timepoints=5,
77+
n_covariates=n_covariates,
78+
covariate_stds=stds
79+
)
80+
np.testing.assert_array_equal(cfg.covariate_stds, stds)

0 commit comments

Comments
 (0)