Skip to content

Commit 8ba2f69

Browse files
Generic dataset (#43)
* Refactor Dataset class to enhance generalisation * New config and simulator supporting a wide range of data generating scnearios. * Config and simulate tests * Delete weight files and classes * Add inflate method for Dataset class and adapt StaticEffect, RandomEffect for the new treatment assignment mechanism. * Testing inflate and effect * Adapting transforms to new dataset class * Test noise transformation * Testing periodic transform * Test trend transform * Adjusting validation components according to new dataset class * Plotter modification * Remove models.py and associated tests * test_models to test_az_causal * Format changes * Linting and documentation fix
1 parent 4c3e783 commit 8ba2f69

36 files changed

+1663
-1425
lines changed

docs/examples/azcausal.ipynb

Lines changed: 129 additions & 21 deletions
Large diffs are not rendered by default.

docs/examples/basic.ipynb

Lines changed: 155 additions & 37 deletions
Large diffs are not rendered by default.

docs/examples/placebo_test.ipynb

Lines changed: 182 additions & 22 deletions
Large diffs are not rendered by default.

docs/index.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ trend varies across each of the 10 control units.
1717
from causal_validation import Config, simulate
1818
from causal_validation.effects import StaticEffect
1919
from causal_validation.plotters import plot
20-
from causal_validation.transforms import Trend, Periodic
20+
from causal_validation.transforms import Trend
2121
from causal_validation.transforms.parameter import UnitVaryingParameter
22+
import numpy as np
2223
from scipy.stats import norm
2324

24-
cfg = Config(
25-
n_control_units=10,
26-
n_pre_intervention_timepoints=60,
27-
n_post_intervention_timepoints=30,
28-
)
25+
# Treatment assignment matrix
26+
D = np.zeros((90, 11)) # 90 time points, 11 units
27+
D[60:, -1] = 1 # Last unit treated after 60 time points
28+
29+
cfg = Config(treatment_assignments=D)
2930

3031
# Simulate the base observation
3132
base_data = simulate(cfg)
@@ -38,6 +39,8 @@ trended_data = trend_component(base_data)
3839
# Simulate a 5% lift in the treated unit's post-intervention data
3940
effect = StaticEffect(0.05)
4041
inflated_data = effect(trended_data)
42+
43+
plot(inflated_data)
4144
```
4245

4346
![Gaussian process posterior.](static/imgs/readme_fig.png)
@@ -50,6 +53,7 @@ combination with AZCausal by the following.
5053

5154
```python
5255
from azcausal.estimators.panel.sdid import SDID
56+
from causal_validation.estimator.utils import AZCausalWrapper
5357
from causal_validation.validation.placebo import PlaceboTest
5458

5559
model = AZCausalWrapper(model=SDID())

src/causal_validation/config.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import (
22
dataclass,
3-
field,
43
)
54
import datetime as dt
65
import typing as tp
@@ -11,84 +10,99 @@
1110

1211
from causal_validation.types import (
1312
Number,
14-
WeightTypes,
13+
TreatedSimulationTypes,
1514
)
16-
from causal_validation.weights import UniformWeights
17-
18-
19-
@dataclass(kw_only=True, frozen=True)
20-
class WeightConfig:
21-
weight_type: "WeightTypes" = field(default_factory=UniformWeights)
2215

2316

2417
@dataclass(kw_only=True)
2518
class Config:
2619
"""Configuration for causal data generation.
2720
2821
Args:
29-
n_control_units (int): Number of control units in the synthetic dataset.
30-
n_pre_intervention_timepoints (int): Number of time points before intervention.
31-
n_post_intervention_timepoints (int): Number of time points after intervention.
22+
treatment_assignments (Float[np.ndarray, "T N"]): Treatment assignments for T
23+
time steps and N units. Only supported with binary assignments.
24+
treated_simulation_type ("TreatedSimulationTypes"): Treated units can be
25+
simulated either "independent" of control units or "control-weighted",
26+
where waiting scheme is controlled by Dirichlet concentration parameter.
27+
Set to "control-weighted" by default.
28+
dirichlet_concentration (Number): Dirichlet parameters are set to a vector of
29+
dirichlet_concentration with length number of control units. This parameter
30+
controls how dense and sparse the generated weights are. Set to 1 by default
31+
and in effect only if treated_simulation_type is "control-weighted".
3232
n_covariates (Optional[int]): Number of covariates. Defaults to None.
33-
covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates
34-
D is n_control_units and K is n_covariates. Defaults to None. If it is set
35-
to None while n_covariates is provided, covariate_means will be generated
33+
covariate_means (Optional[np.ndarray]): Normal dist. mean values for covariates.
34+
The lenght must be n_covariates. Defaults to None. If it is set to
35+
None while n_covariates is provided, covariate_means will be generated
3636
randomly from Normal distribution.
37-
covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for
38-
covariates. D is n_control_units and K is n_covariates. Defaults to None.
39-
If it is set to None while n_covariates is provided, covariate_stds
40-
will be generated randomly from Half-Cauchy distribution.
37+
covariate_stds (Optional[np.ndarray]): Normal dist. std values for covariates.
38+
The lenght must be n_covariates. Defaults to None. If it is set to
39+
None while n_covariates is provided, covariate_stds will be generated
40+
randomly from Half-Cauchy distribution.
4141
covariate_coeffs (Optional[np.ndarray]): Linear regression
4242
coefficients to map covariates to output observations. K is n_covariates.
4343
Defaults to None.
4444
global_mean (Number): Global mean for data generation. Defaults to 20.0.
4545
global_scale (Number): Global scale for data generation. Defaults to 0.2.
4646
start_date (dt.date): Start date for time series. Defaults to 2023-01-01.
4747
seed (int): Random seed for reproducibility. Defaults to 123.
48-
weights_cfg (WeightConfig): Configuration for unit weights. Defaults to
49-
UniformWeights.
48+
weights (Optional[list[np.ndarray]]): Length num of treateds list of weights.
49+
Each element is length num of control, indicating how to weigh control
50+
units to generate treated.
5051
"""
5152

52-
n_control_units: int
53-
n_pre_intervention_timepoints: int
54-
n_post_intervention_timepoints: int
53+
treatment_assignments: Float[np.ndarray, "T N"]
54+
treated_simulation_type: "TreatedSimulationTypes" = "control-weighted"
55+
dirichlet_concentration: Number = 1.0
5556
n_covariates: tp.Optional[int] = None
56-
covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None
57-
covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None
57+
covariate_means: tp.Optional[np.ndarray] = None
58+
covariate_stds: tp.Optional[np.ndarray] = None
5859
covariate_coeffs: tp.Optional[np.ndarray] = None
5960
global_mean: Number = 20.0
6061
global_scale: Number = 0.2
6162
start_date: dt.date = dt.date(year=2023, month=1, day=1)
6263
seed: int = 123
63-
weights_cfg: WeightConfig = field(default_factory=WeightConfig)
64+
weights: tp.Optional[list[np.ndarray]] = None
6465

6566
def __post_init__(self):
6667
self.rng = np.random.RandomState(self.seed)
6768
if self.covariate_means is not None:
68-
assert self.covariate_means.shape == (
69-
self.n_control_units,
70-
self.n_covariates,
71-
)
69+
assert self.covariate_means.shape == (self.n_covariates,)
7270

7371
if self.covariate_stds is not None:
74-
assert self.covariate_stds.shape == (
75-
self.n_control_units,
76-
self.n_covariates,
77-
)
72+
assert self.covariate_stds.shape == (self.n_covariates,)
7873

7974
if (self.n_covariates is not None) & (self.covariate_means is None):
8075
self.covariate_means = self.rng.normal(
81-
loc=0.0, scale=5.0, size=(self.n_control_units, self.n_covariates)
76+
loc=0.0, scale=5.0, size=(self.n_covariates)
8277
)
8378

8479
if (self.n_covariates is not None) & (self.covariate_stds is None):
8580
self.covariate_stds = halfcauchy.rvs(
8681
scale=0.5,
87-
size=(self.n_control_units, self.n_covariates),
82+
size=(self.n_covariates),
8883
random_state=self.rng,
8984
)
9085

9186
if (self.n_covariates is not None) & (self.covariate_coeffs is None):
9287
self.covariate_coeffs = self.rng.normal(
9388
loc=0.0, scale=5.0, size=self.n_covariates
9489
)
90+
91+
n_units = self.treatment_assignments.shape[1]
92+
treated_units = [
93+
i for i in range(n_units) if any(self.treatment_assignments[:, i] != 0)
94+
]
95+
n_treated_units = len(treated_units)
96+
n_control_units = n_units - n_treated_units
97+
98+
if self.treated_simulation_type == "control-weighted":
99+
if self.weights is None:
100+
self.weights = [
101+
self.rng.dirichlet(
102+
self.dirichlet_concentration * np.ones(n_control_units)
103+
)
104+
for _ in range(n_treated_units)
105+
]
106+
else:
107+
assert len(self.weights) == n_treated_units
108+
assert all([len(w) == n_control_units for w in self.weights])

0 commit comments

Comments
 (0)