2121
2222@dataclass
2323class Dataset :
24+ """A causal inference dataset containing pre/post intervention observations
25+ and optional associated covariates.
26+
27+ Attributes:
28+ Xtr: Pre-intervention control unit observations (N x D)
29+ Xte: Post-intervention control unit observations (M x D)
30+ ytr: Pre-intervention treated unit observations (N x 1)
31+ yte: Post-intervention treated unit observations (M x 1)
32+ _start_date: Start date for time indexing
33+ Ptr: Pre-intervention control unit covariates (N x D x F)
34+ Pte: Post-intervention control unit covariates (M x D x F)
35+ Rtr: Pre-intervention treated unit covariates (N x 1 x F)
36+ Rte: Post-intervention treated unit covariates (M x 1 x F)
37+ counterfactual: Optional counterfactual outcomes (M x 1)
38+ synthetic: Optional synthetic control outcomes (M x 1).
39+ This is weighted combination of control units
40+ minimizing a distance-based error w.r.t. the
41+ treated in pre-intervention period.
42+ _name: Optional name identifier for the dataset
43+ """
2444 Xtr : Float [np .ndarray , "N D" ]
2545 Xte : Float [np .ndarray , "M D" ]
2646 ytr : Float [np .ndarray , "N 1" ]
2747 yte : Float [np .ndarray , "M 1" ]
2848 _start_date : dt .date
49+ Ptr : tp .Optional [Float [np .ndarray , "N D F" ]] = None
50+ Pte : tp .Optional [Float [np .ndarray , "M D F" ]] = None
51+ Rtr : tp .Optional [Float [np .ndarray , "N 1 F" ]] = None
52+ Rte : tp .Optional [Float [np .ndarray , "M 1 F" ]] = None
2953 counterfactual : tp .Optional [Float [np .ndarray , "M 1" ]] = None
3054 synthetic : tp .Optional [Float [np .ndarray , "M 1" ]] = None
3155 _name : str = None
3256
57+ def __post_init__ (self ):
58+ covariates = [self .Ptr , self .Pte , self .Rtr , self .Rte ]
59+ self .has_covariates = all (cov is not None for cov in covariates )
60+ if not self .has_covariates :
61+ assert all (cov is None for cov in covariates )
62+
3363 def to_df (
3464 self , index_start : str = dt .date (year = 2023 , month = 1 , day = 1 )
3565 ) -> pd .DataFrame :
@@ -59,6 +89,13 @@ def n_units(self) -> int:
5989 def n_timepoints (self ) -> int :
6090 return self .n_post_intervention + self .n_pre_intervention
6191
92+ @property
93+ def n_covariates (self ) -> int :
94+ if self .has_covariates :
95+ return self .Ptr .shape [2 ]
96+ else :
97+ return 0
98+
6299 @property
63100 def control_units (self ) -> Float [np .ndarray , "{self.n_timepoints} {self.n_units}" ]:
64101 return np .vstack ([self .Xtr , self .Xte ])
@@ -67,6 +104,26 @@ def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units}
67104 def treated_units (self ) -> Float [np .ndarray , "{self.n_timepoints} 1" ]:
68105 return np .vstack ([self .ytr , self .yte ])
69106
107+ @property
108+ def control_covariates (
109+ self ,
110+ ) -> tp .Optional [
111+ Float [np .ndarray , "{self.n_timepoints} {self.n_units} {self.n_covariates}" ]
112+ ]:
113+ if self .has_covariates :
114+ return np .vstack ([self .Ptr , self .Pte ])
115+ else :
116+ return None
117+
118+ @property
119+ def treated_covariates (
120+ self ,
121+ ) -> tp .Optional [Float [np .ndarray , "{self.n_timepoints} 1 {self.n_covariates}" ]]:
122+ if self .has_covariates :
123+ return np .vstack ([self .Rtr , self .Rte ])
124+ else :
125+ return None
126+
70127 @property
71128 def pre_intervention_obs (
72129 self ,
@@ -79,6 +136,32 @@ def post_intervention_obs(
79136 ) -> tp .Tuple [Float [np .ndarray , "M D" ], Float [np .ndarray , "M 1" ]]:
80137 return self .Xte , self .yte
81138
139+ @property
140+ def pre_intervention_covariates (
141+ self ,
142+ ) -> tp .Optional [
143+ tp .Tuple [
144+ Float [np .ndarray , "N D F" ], Float [np .ndarray , "N 1 F" ],
145+ ]
146+ ]:
147+ if self .has_covariates :
148+ return self .Ptr , self .Rtr
149+ else :
150+ return None
151+
152+ @property
153+ def post_intervention_covariates (
154+ self ,
155+ ) -> tp .Optional [
156+ tp .Tuple [
157+ Float [np .ndarray , "M D F" ], Float [np .ndarray , "M 1 F" ],
158+ ]
159+ ]:
160+ if self .has_covariates :
161+ return self .Pte , self .Rte
162+ else :
163+ return None
164+
82165 @property
83166 def full_index (self ) -> DatetimeIndex :
84167 return self ._get_index (self ._start_date )
@@ -97,7 +180,12 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex:
97180 return self .full_index
98181
99182 def _get_columns (self ) -> tp .List [str ]:
100- colnames = ["T" ] + [f"C{ i } " for i in range (self .n_units )]
183+ if self .has_covariates :
184+ colnames = ["T" ] + [f"C{ i } " for i in range (self .n_units )] + [
185+ f"F{ i } " for i in range (self .n_covariates )
186+ ]
187+ else :
188+ colnames = ["T" ] + [f"C{ i } " for i in range (self .n_units )]
101189 return colnames
102190
103191 def _get_index (self , start_date : dt .date ) -> DatetimeIndex :
@@ -116,7 +204,10 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset:
116204 Xtr , ytr = [deepcopy (i ) for i in self .pre_intervention_obs ]
117205 Xte , yte = [deepcopy (i ) for i in self .post_intervention_obs ]
118206 inflated_yte = yte * inflation_vals
119- return Dataset (Xtr , Xte , ytr , inflated_yte , self ._start_date , yte )
207+ return Dataset (
208+ Xtr , Xte , ytr , inflated_yte , self ._start_date ,
209+ self .Ptr , self .Pte , self .Rtr , self .Rte , yte , self .synthetic , self ._name
210+ )
120211
121212 def __eq__ (self , other : Dataset ) -> bool :
122213 ytr = np .allclose (self .ytr , other .ytr )
@@ -151,14 +242,21 @@ def _slots(self) -> tp.Dict[str, int]:
151242 def drop_unit (self , idx : int ) -> Dataset :
152243 Xtr = np .delete (self .Xtr , [idx ], axis = 1 )
153244 Xte = np .delete (self .Xte , [idx ], axis = 1 )
245+ Ptr = np .delete (self .Ptr , [idx ], axis = 1 ) if self .Ptr is not None else None
246+ Pte = np .delete (self .Pte , [idx ], axis = 1 ) if self .Pte is not None else None
154247 return Dataset (
155248 Xtr ,
156249 Xte ,
157250 self .ytr ,
158251 self .yte ,
159252 self ._start_date ,
253+ Ptr ,
254+ Pte ,
255+ self .Rtr ,
256+ self .Rte ,
160257 self .counterfactual ,
161258 self .synthetic ,
259+ self ._name ,
162260 )
163261
164262 def to_placebo_data (self , to_treat_idx : int ) -> Dataset :
@@ -212,5 +310,7 @@ def reassign_treatment(
212310 Xtr = data .Xtr
213311 Xte = data .Xte
214312 return Dataset (
215- Xtr , Xte , ytr , yte , data ._start_date , data .counterfactual , data .synthetic
313+ Xtr , Xte , ytr , yte , data ._start_date ,
314+ data .Ptr , data .Pte , data .Rtr , data .Rte ,
315+ data .counterfactual , data .synthetic , data ._name
216316 )
0 commit comments