|
1 | 1 | from dataclasses import ( |
2 | 2 | dataclass, |
3 | | - field, |
4 | 3 | ) |
5 | 4 | import datetime as dt |
6 | 5 | import typing as tp |
|
11 | 10 |
|
12 | 11 | from causal_validation.types import ( |
13 | 12 | Number, |
14 | | - WeightTypes, |
| 13 | + TreatedSimulationTypes, |
15 | 14 | ) |
16 | | -from causal_validation.weights import UniformWeights |
17 | | - |
18 | | - |
19 | | -@dataclass(kw_only=True, frozen=True) |
20 | | -class WeightConfig: |
21 | | - weight_type: "WeightTypes" = field(default_factory=UniformWeights) |
22 | 15 |
|
23 | 16 |
|
24 | 17 | @dataclass(kw_only=True) |
25 | 18 | class Config: |
26 | 19 | """Configuration for causal data generation. |
27 | 20 |
|
28 | 21 | Args: |
29 | | - n_control_units (int): Number of control units in the synthetic dataset. |
30 | | - n_pre_intervention_timepoints (int): Number of time points before intervention. |
31 | | - n_post_intervention_timepoints (int): Number of time points after intervention. |
| 22 | + treatment_assignments (Float[np.ndarray, "T N"]): Treatment assignments for T |
| 23 | + time steps and N units. Only supported with binary assignments. |
| 24 | + treated_simulation_type ("TreatedSimulationTypes"): Treated units can be |
| 25 | + simulated either "independent" of control units or "control-weighted", |
| 26 | + where waiting scheme is controlled by Dirichlet concentration parameter. |
| 27 | + Set to "control-weighted" by default. |
| 28 | + dirichlet_concentration (Number): Dirichlet parameters are set to a vector of |
| 29 | + dirichlet_concentration with length number of control units. This parameter |
| 30 | + controls how dense and sparse the generated weights are. Set to 1 by default |
| 31 | + and in effect only if treated_simulation_type is "control-weighted". |
32 | 32 | n_covariates (Optional[int]): Number of covariates. Defaults to None. |
33 | | - covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates |
34 | | - D is n_control_units and K is n_covariates. Defaults to None. If it is set |
35 | | - to None while n_covariates is provided, covariate_means will be generated |
| 33 | + covariate_means (Optional[np.ndarray]): Normal dist. mean values for covariates. |
| 34 | + The lenght must be n_covariates. Defaults to None. If it is set to |
| 35 | + None while n_covariates is provided, covariate_means will be generated |
36 | 36 | randomly from Normal distribution. |
37 | | - covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for |
38 | | - covariates. D is n_control_units and K is n_covariates. Defaults to None. |
39 | | - If it is set to None while n_covariates is provided, covariate_stds |
40 | | - will be generated randomly from Half-Cauchy distribution. |
| 37 | + covariate_stds (Optional[np.ndarray]): Normal dist. std values for covariates. |
| 38 | + The lenght must be n_covariates. Defaults to None. If it is set to |
| 39 | + None while n_covariates is provided, covariate_stds will be generated |
| 40 | + randomly from Half-Cauchy distribution. |
41 | 41 | covariate_coeffs (Optional[np.ndarray]): Linear regression |
42 | 42 | coefficients to map covariates to output observations. K is n_covariates. |
43 | 43 | Defaults to None. |
44 | 44 | global_mean (Number): Global mean for data generation. Defaults to 20.0. |
45 | 45 | global_scale (Number): Global scale for data generation. Defaults to 0.2. |
46 | 46 | start_date (dt.date): Start date for time series. Defaults to 2023-01-01. |
47 | 47 | seed (int): Random seed for reproducibility. Defaults to 123. |
48 | | - weights_cfg (WeightConfig): Configuration for unit weights. Defaults to |
49 | | - UniformWeights. |
| 48 | + weights (Optional[list[np.ndarray]]): Length num of treateds list of weights. |
| 49 | + Each element is length num of control, indicating how to weigh control |
| 50 | + units to generate treated. |
50 | 51 | """ |
51 | 52 |
|
52 | | - n_control_units: int |
53 | | - n_pre_intervention_timepoints: int |
54 | | - n_post_intervention_timepoints: int |
| 53 | + treatment_assignments: Float[np.ndarray, "T N"] |
| 54 | + treated_simulation_type: "TreatedSimulationTypes" = "control-weighted" |
| 55 | + dirichlet_concentration: Number = 1.0 |
55 | 56 | n_covariates: tp.Optional[int] = None |
56 | | - covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None |
57 | | - covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None |
| 57 | + covariate_means: tp.Optional[np.ndarray] = None |
| 58 | + covariate_stds: tp.Optional[np.ndarray] = None |
58 | 59 | covariate_coeffs: tp.Optional[np.ndarray] = None |
59 | 60 | global_mean: Number = 20.0 |
60 | 61 | global_scale: Number = 0.2 |
61 | 62 | start_date: dt.date = dt.date(year=2023, month=1, day=1) |
62 | 63 | seed: int = 123 |
63 | | - weights_cfg: WeightConfig = field(default_factory=WeightConfig) |
| 64 | + weights: tp.Optional[list[np.ndarray]] = None |
64 | 65 |
|
65 | 66 | def __post_init__(self): |
66 | 67 | self.rng = np.random.RandomState(self.seed) |
67 | 68 | if self.covariate_means is not None: |
68 | | - assert self.covariate_means.shape == ( |
69 | | - self.n_control_units, |
70 | | - self.n_covariates, |
71 | | - ) |
| 69 | + assert self.covariate_means.shape == (self.n_covariates,) |
72 | 70 |
|
73 | 71 | if self.covariate_stds is not None: |
74 | | - assert self.covariate_stds.shape == ( |
75 | | - self.n_control_units, |
76 | | - self.n_covariates, |
77 | | - ) |
| 72 | + assert self.covariate_stds.shape == (self.n_covariates,) |
78 | 73 |
|
79 | 74 | if (self.n_covariates is not None) & (self.covariate_means is None): |
80 | 75 | self.covariate_means = self.rng.normal( |
81 | | - loc=0.0, scale=5.0, size=(self.n_control_units, self.n_covariates) |
| 76 | + loc=0.0, scale=5.0, size=(self.n_covariates) |
82 | 77 | ) |
83 | 78 |
|
84 | 79 | if (self.n_covariates is not None) & (self.covariate_stds is None): |
85 | 80 | self.covariate_stds = halfcauchy.rvs( |
86 | 81 | scale=0.5, |
87 | | - size=(self.n_control_units, self.n_covariates), |
| 82 | + size=(self.n_covariates), |
88 | 83 | random_state=self.rng, |
89 | 84 | ) |
90 | 85 |
|
91 | 86 | if (self.n_covariates is not None) & (self.covariate_coeffs is None): |
92 | 87 | self.covariate_coeffs = self.rng.normal( |
93 | 88 | loc=0.0, scale=5.0, size=self.n_covariates |
94 | 89 | ) |
| 90 | + |
| 91 | + n_units = self.treatment_assignments.shape[1] |
| 92 | + treated_units = [ |
| 93 | + i for i in range(n_units) if any(self.treatment_assignments[:, i] != 0) |
| 94 | + ] |
| 95 | + n_treated_units = len(treated_units) |
| 96 | + n_control_units = n_units - n_treated_units |
| 97 | + |
| 98 | + if self.treated_simulation_type == "control-weighted": |
| 99 | + if self.weights is None: |
| 100 | + self.weights = [ |
| 101 | + self.rng.dirichlet( |
| 102 | + self.dirichlet_concentration * np.ones(n_control_units) |
| 103 | + ) |
| 104 | + for _ in range(n_treated_units) |
| 105 | + ] |
| 106 | + else: |
| 107 | + assert len(self.weights) == n_treated_units |
| 108 | + assert all([len(w) == n_control_units for w in self.weights]) |
0 commit comments