Skip to content

Commit 7a44af6

Browse files
Covariate support for Dataset class. (#31)
* Covariate support for Dataset class
1 parent 7b4fb63 commit 7a44af6

File tree

7 files changed

+267
-19
lines changed

7 files changed

+267
-19
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ select = [
171171
"TID",
172172
"ISC",
173173
]
174-
ignore = ["F722"]
174+
ignore = ["F722", "PLW1641"]
175175

176176
[tool.ruff.format]
177177
quote-style = "double"

src/causal_validation/data.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,45 @@
2121

2222
@dataclass
2323
class Dataset:
24+
"""A causal inference dataset containing pre/post intervention observations
25+
and optional associated covariates.
26+
27+
Attributes:
28+
Xtr: Pre-intervention control unit observations (N x D)
29+
Xte: Post-intervention control unit observations (M x D)
30+
ytr: Pre-intervention treated unit observations (N x 1)
31+
yte: Post-intervention treated unit observations (M x 1)
32+
_start_date: Start date for time indexing
33+
Ptr: Pre-intervention control unit covariates (N x D x F)
34+
Pte: Post-intervention control unit covariates (M x D x F)
35+
Rtr: Pre-intervention treated unit covariates (N x 1 x F)
36+
Rte: Post-intervention treated unit covariates (M x 1 x F)
37+
counterfactual: Optional counterfactual outcomes (M x 1)
38+
synthetic: Optional synthetic control outcomes (M x 1).
39+
This is weighted combination of control units
40+
minimizing a distance-based error w.r.t. the
41+
treated in pre-intervention period.
42+
_name: Optional name identifier for the dataset
43+
"""
2444
Xtr: Float[np.ndarray, "N D"]
2545
Xte: Float[np.ndarray, "M D"]
2646
ytr: Float[np.ndarray, "N 1"]
2747
yte: Float[np.ndarray, "M 1"]
2848
_start_date: dt.date
49+
Ptr: tp.Optional[Float[np.ndarray, "N D F"]] = None
50+
Pte: tp.Optional[Float[np.ndarray, "M D F"]] = None
51+
Rtr: tp.Optional[Float[np.ndarray, "N 1 F"]] = None
52+
Rte: tp.Optional[Float[np.ndarray, "M 1 F"]] = None
2953
counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None
3054
synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None
3155
_name: str = None
3256

57+
def __post_init__(self):
58+
covariates = [self.Ptr, self.Pte, self.Rtr, self.Rte]
59+
self.has_covariates = all(cov is not None for cov in covariates)
60+
if not self.has_covariates:
61+
assert all(cov is None for cov in covariates)
62+
3363
def to_df(
3464
self, index_start: str = dt.date(year=2023, month=1, day=1)
3565
) -> pd.DataFrame:
@@ -59,6 +89,13 @@ def n_units(self) -> int:
5989
def n_timepoints(self) -> int:
6090
return self.n_post_intervention + self.n_pre_intervention
6191

92+
@property
93+
def n_covariates(self) -> int:
94+
if self.has_covariates:
95+
return self.Ptr.shape[2]
96+
else:
97+
return 0
98+
6299
@property
63100
def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units}"]:
64101
return np.vstack([self.Xtr, self.Xte])
@@ -67,6 +104,26 @@ def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units}
67104
def treated_units(self) -> Float[np.ndarray, "{self.n_timepoints} 1"]:
68105
return np.vstack([self.ytr, self.yte])
69106

107+
@property
108+
def control_covariates(
109+
self,
110+
) -> tp.Optional[
111+
Float[np.ndarray, "{self.n_timepoints} {self.n_units} {self.n_covariates}"]
112+
]:
113+
if self.has_covariates:
114+
return np.vstack([self.Ptr, self.Pte])
115+
else:
116+
return None
117+
118+
@property
119+
def treated_covariates(
120+
self,
121+
) -> tp.Optional[Float[np.ndarray, "{self.n_timepoints} 1 {self.n_covariates}"]]:
122+
if self.has_covariates:
123+
return np.vstack([self.Rtr, self.Rte])
124+
else:
125+
return None
126+
70127
@property
71128
def pre_intervention_obs(
72129
self,
@@ -79,6 +136,32 @@ def post_intervention_obs(
79136
) -> tp.Tuple[Float[np.ndarray, "M D"], Float[np.ndarray, "M 1"]]:
80137
return self.Xte, self.yte
81138

139+
@property
140+
def pre_intervention_covariates(
141+
self,
142+
) -> tp.Optional[
143+
tp.Tuple[
144+
Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"],
145+
]
146+
]:
147+
if self.has_covariates:
148+
return self.Ptr, self.Rtr
149+
else:
150+
return None
151+
152+
@property
153+
def post_intervention_covariates(
154+
self,
155+
) -> tp.Optional[
156+
tp.Tuple[
157+
Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"],
158+
]
159+
]:
160+
if self.has_covariates:
161+
return self.Pte, self.Rte
162+
else:
163+
return None
164+
82165
@property
83166
def full_index(self) -> DatetimeIndex:
84167
return self._get_index(self._start_date)
@@ -97,7 +180,12 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex:
97180
return self.full_index
98181

99182
def _get_columns(self) -> tp.List[str]:
100-
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
183+
if self.has_covariates:
184+
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [
185+
f"F{i}" for i in range(self.n_covariates)
186+
]
187+
else:
188+
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
101189
return colnames
102190

103191
def _get_index(self, start_date: dt.date) -> DatetimeIndex:
@@ -116,7 +204,10 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset:
116204
Xtr, ytr = [deepcopy(i) for i in self.pre_intervention_obs]
117205
Xte, yte = [deepcopy(i) for i in self.post_intervention_obs]
118206
inflated_yte = yte * inflation_vals
119-
return Dataset(Xtr, Xte, ytr, inflated_yte, self._start_date, yte)
207+
return Dataset(
208+
Xtr, Xte, ytr, inflated_yte, self._start_date,
209+
self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name
210+
)
120211

121212
def __eq__(self, other: Dataset) -> bool:
122213
ytr = np.allclose(self.ytr, other.ytr)
@@ -151,14 +242,21 @@ def _slots(self) -> tp.Dict[str, int]:
151242
def drop_unit(self, idx: int) -> Dataset:
152243
Xtr = np.delete(self.Xtr, [idx], axis=1)
153244
Xte = np.delete(self.Xte, [idx], axis=1)
245+
Ptr = np.delete(self.Ptr, [idx], axis=1) if self.Ptr is not None else None
246+
Pte = np.delete(self.Pte, [idx], axis=1) if self.Pte is not None else None
154247
return Dataset(
155248
Xtr,
156249
Xte,
157250
self.ytr,
158251
self.yte,
159252
self._start_date,
253+
Ptr,
254+
Pte,
255+
self.Rtr,
256+
self.Rte,
160257
self.counterfactual,
161258
self.synthetic,
259+
self._name,
162260
)
163261

164262
def to_placebo_data(self, to_treat_idx: int) -> Dataset:
@@ -212,5 +310,7 @@ def reassign_treatment(
212310
Xtr = data.Xtr
213311
Xte = data.Xte
214312
return Dataset(
215-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
313+
Xtr, Xte, ytr, yte, data._start_date,
314+
data.Ptr, data.Pte, data.Rtr, data.Rte,
315+
data.counterfactual, data.synthetic, data._name
216316
)

src/causal_validation/validation/placebo.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass
22
import typing as tp
33

4-
from azcausal.core.effect import Effect
54
import numpy as np
65
import pandas as pd
76
from pandera import (
@@ -11,14 +10,8 @@
1110
)
1211
from rich.progress import (
1312
Progress,
14-
ProgressBar,
15-
track,
1613
)
1714
from scipy.stats import ttest_1samp
18-
from tqdm import (
19-
tqdm,
20-
trange,
21-
)
2215

2316
from causal_validation.data import (
2417
Dataset,
@@ -108,7 +101,7 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult:
108101
"[blue]Datasets", total=n_datasets, visible=verbose
109102
)
110103
unit_task = progress.add_task(
111-
f"[green]Control Units",
104+
"[green]Control Units",
112105
total=n_control,
113106
visible=verbose,
114107
)

src/causal_validation/validation/rmspe.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@
22
import typing as tp
33

44
from jaxtyping import Float
5-
import numpy as np
65
import pandas as pd
76
from pandera import (
87
Check,
98
Column,
109
DataFrameSchema,
1110
)
12-
from rich import box
1311
from rich.progress import (
1412
Progress,
15-
ProgressBar,
16-
track,
1713
)
1814

1915
from causal_validation.validation.placebo import PlaceboTest
@@ -87,7 +83,7 @@ def execute(self, verbose: bool = True) -> RMSPETestResult:
8783
"[blue]Datasets", total=n_datasets, visible=verbose
8884
)
8985
unit_task = progress.add_task(
90-
f"[green]Treatment and Control Units",
86+
"[green]Treatment and Control Units",
9187
total=n_control + 1,
9288
visible=verbose,
9389
)

0 commit comments

Comments
 (0)