Skip to content

Commit 99ea63c

Browse files
committed
Format changes
1 parent a669f7b commit 99ea63c

File tree

25 files changed

+264
-218
lines changed

25 files changed

+264
-218
lines changed

src/causal_validation/config.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,10 @@ class Config:
6767
def __post_init__(self):
6868
self.rng = np.random.RandomState(self.seed)
6969
if self.covariate_means is not None:
70-
assert self.covariate_means.shape == (
71-
self.n_covariates,
72-
)
70+
assert self.covariate_means.shape == (self.n_covariates,)
7371

7472
if self.covariate_stds is not None:
75-
assert self.covariate_stds.shape == (
76-
self.n_covariates,
77-
)
73+
assert self.covariate_stds.shape == (self.n_covariates,)
7874

7975
if (self.n_covariates is not None) & (self.covariate_means is None):
8076
self.covariate_means = self.rng.normal(
@@ -94,7 +90,9 @@ def __post_init__(self):
9490
)
9591

9692
n_units = self.treatment_assignments.shape[1]
97-
treated_units = [i for i in range(n_units) if any(self.treatment_assignments[:, i] != 0)]
93+
treated_units = [
94+
i for i in range(n_units) if any(self.treatment_assignments[:, i] != 0)
95+
]
9896
n_treated_units = len(treated_units)
9997
n_control_units = n_units - n_treated_units
10098

@@ -108,9 +106,4 @@ def __post_init__(self):
108106
]
109107
else:
110108
assert len(self.weights) == n_treated_units
111-
assert all(
112-
[
113-
len(w) == n_control_units
114-
for w in self.weights
115-
]
116-
)
109+
assert all([len(w) == n_control_units for w in self.weights])

src/causal_validation/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def _get_columns(self) -> tp.List[str]:
266266

267267
def _get_index(self, start_date: dt.date) -> DatetimeIndex:
268268
return pd.date_range(start=start_date, freq="D", periods=self.n_timepoints)
269-
269+
270270
def inflate(self, inflation_vals: Float[np.ndarray, "T N"]) -> Dataset:
271271
"""
272272
Inflate the outputs Y by inflation_vals that are multiplicative factors

src/causal_validation/effects.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ class StaticEffect(AbstractEffect, _StaticEffect):
4545
effect (float): Rate effect to be applied, i.e., 0.3 = 30% lift.
4646
name (str): Name for the effect. 'Static Effect' by default.
4747
"""
48+
4849
effect: float
4950
name: str = "Static Effect"
5051

5152
def get_effect(self, data: Dataset, **kwargs) -> Float[np.ndarray, "T N"]:
52-
return np.ones(data.D.shape) + data.D*self.effect
53+
return np.ones(data.D.shape) + data.D * self.effect
5354

5455

5556
@dataclass
@@ -64,6 +65,7 @@ class RandomEffect(AbstractEffect, _RandomEffect):
6465
stddev_effect (float): Rate effect std. dev. to be applied.
6566
name (str): Name for the effect. 'Random Effect' by default.
6667
"""
68+
6769
mean_effect: float
6870
stddev_effect: float
6971
name: str = "Random Effect"
@@ -72,11 +74,9 @@ def get_effect(
7274
self, data: Dataset, key: np.random.RandomState
7375
) -> Float[np.ndarray, "T N"]:
7476
effect_sample = key.normal(
75-
loc=self.mean_effect,
76-
scale=self.stddev_effect,
77-
size=data.D.shape
77+
loc=self.mean_effect, scale=self.stddev_effect, size=data.D.shape
7878
)
79-
return np.ones(data.D.shape) + data.D*effect_sample
79+
return np.ones(data.D.shape) + data.D * effect_sample
8080

8181

8282
# Placeholder for now.

src/causal_validation/estimator/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ class Result:
1111
effect: Effect
1212
counterfactual: Float[NPArray, "N 1"]
1313
synthetic: Float[NPArray, "N 1"]
14-
observed: Float[NPArray, "N 1"]
14+
observed: Float[NPArray, "N 1"]

src/causal_validation/estimator/utils.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,41 @@
33

44
from azcausal.core.error import Error
55
from azcausal.core.estimator import Estimator
6-
from azcausal.core.panel import Panel, CausalPanel
6+
from azcausal.core.panel import (
7+
CausalPanel,
8+
Panel,
9+
)
710
from azcausal.core.result import Result as _Result
11+
from azcausal.util import to_panels
812
from jaxtyping import Float
913
import pandas as pd
1014

1115
from causal_validation.data import Dataset
1216
from causal_validation.estimator import Result
1317
from causal_validation.types import NPArray
14-
from azcausal.util import to_panels
18+
1519

1620
def to_azcausal(dataset: Dataset) -> Panel:
1721
if dataset.n_treated_units != 1:
1822
raise ValueError("Only one treated unit is supported.")
1923
time_index = dataset.full_index
2024
unit_cols = dataset._get_columns()
21-
25+
2226
data = []
2327
for time_idx in range(dataset.n_timepoints):
2428
for unit_idx, unit in enumerate(unit_cols):
25-
data.append({
26-
'variable': unit,
27-
'time': time_index[time_idx],
28-
'value': dataset.Y[time_idx, unit_idx],
29-
'treated': int(dataset.D[time_idx, unit_idx])
30-
})
31-
32-
df_data = pd.DataFrame(data)
29+
data.append(
30+
{
31+
"variable": unit,
32+
"time": time_index[time_idx],
33+
"value": dataset.Y[time_idx, unit_idx],
34+
"treated": int(dataset.D[time_idx, unit_idx]),
35+
}
36+
)
37+
38+
df_data = pd.DataFrame(data)
3339
panels = to_panels(df_data, "time", "variable", ["value", "treated"])
34-
ctypes = dict(
35-
outcome="value", time="time", unit="variable", intervention="treated"
36-
)
40+
ctypes = dict(outcome="value", time="time", unit="variable", intervention="treated")
3741
panel = CausalPanel(panels).setup(**ctypes)
3842
return panel
3943

src/causal_validation/plotters.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,39 @@ def plot(
3232
Y_control = data.control_unit_outputs
3333
Y_treated = data.treated_unit_outputs
3434
idx = data.full_index
35-
35+
3636
if ax is None:
3737
_, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
38-
38+
3939
ax.plot(idx, Y_control, color=cols[0], label="Control", alpha=0.5)
40-
40+
4141
for i, unit_idx in enumerate(data.treated_unit_indices):
42-
unit_color = cols[1] if len(data.treated_unit_indices) == 1 else cols[1 + i % (len(cols) - 2)]
43-
unit_label = "Treated" if len(data.treated_unit_indices) == 1 else f"Treated {unit_idx}"
42+
unit_color = (
43+
cols[1]
44+
if len(data.treated_unit_indices) == 1
45+
else cols[1 + i % (len(cols) - 2)]
46+
)
47+
unit_label = (
48+
"Treated" if len(data.treated_unit_indices) == 1 else f"Treated {unit_idx}"
49+
)
4450
ax.plot(idx, Y_treated[:, i], color=unit_color, label=unit_label)
45-
51+
4652
treatment_date = data.treatment_date(unit_idx)
4753
if treatment_date is not None:
4854
line_color = cols[2] if len(data.treated_unit_indices) == 1 else unit_color
49-
line_label = "Intervention" if len(data.treated_unit_indices) == 1 else f"Intervention {unit_idx}"
50-
ax.axvline(x=treatment_date, color=line_color, label=line_label, linestyle="--", alpha=0.7)
51-
55+
line_label = (
56+
"Intervention"
57+
if len(data.treated_unit_indices) == 1
58+
else f"Intervention {unit_idx}"
59+
)
60+
ax.axvline(
61+
x=treatment_date,
62+
color=line_color,
63+
label=line_label,
64+
linestyle="--",
65+
alpha=0.7,
66+
)
67+
5268
ax.xaxis.set_major_formatter(
5369
mdates.ConciseDateFormatter(ax.xaxis.get_major_locator())
5470
)

src/causal_validation/simulate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def simulate(config: Config, key: tp.Optional[np.random.RandomState] = None) ->
2020
)
2121
return base_data
2222

23+
2324
def _simulate_with_independent_treated_units(
2425
config: Config, key: np.random.RandomState
2526
) -> Dataset:
2627
n_timepoints, n_units = config.treatment_assignments.shape
27-
28+
2829
Y = key.normal(
2930
loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units)
3031
)
@@ -49,11 +50,12 @@ def _simulate_with_independent_treated_units(
4950

5051
return data
5152

53+
5254
def _simulate_with_control_weighted_treated_units(
5355
config: Config, key: np.random.RandomState
5456
) -> Dataset:
5557
n_timepoints, n_units = config.treatment_assignments.shape
56-
58+
5759
Y = np.zeros((n_timepoints, n_units))
5860
if config.n_covariates is not None:
5961
X = np.zeros((n_timepoints, n_units, config.n_covariates))
@@ -71,9 +73,9 @@ def _simulate_with_control_weighted_treated_units(
7173
treated_unit_indices = data_void.treated_unit_indices
7274

7375
Y_control = key.normal(
74-
loc=config.global_mean,
75-
scale=config.global_scale,
76-
size=(n_timepoints, n_control_units)
76+
loc=config.global_mean,
77+
scale=config.global_scale,
78+
size=(n_timepoints, n_control_units),
7779
)
7880

7981
if config.n_covariates is not None:
@@ -100,4 +102,4 @@ def _simulate_with_control_weighted_treated_units(
100102
_start_date=config.start_date,
101103
)
102104

103-
return data
105+
return data

src/causal_validation/testing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from dataclasses import dataclass
2-
from jaxtyping import Float
32
import typing as tp
3+
4+
from jaxtyping import Float
45
import numpy as np
56

67
from causal_validation.config import Config
78
from causal_validation.data import Dataset
89
from causal_validation.simulate import simulate
9-
1010
from causal_validation.types import (
1111
Number,
1212
TreatedSimulationTypes,
@@ -26,7 +26,7 @@ class TestConstants:
2626

2727
def __post_init__(self):
2828
if self.TREATMENT_ASSIGNMENTS is None:
29-
D = np.zeros((10,5))
29+
D = np.zeros((10, 5))
3030
D[6:, 2] = 1
3131
D[8:, 3] = 1
3232
self.TREATMENT_ASSIGNMENTS = D

src/causal_validation/transforms/base.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ def apply_values(
6464
) -> Dataset:
6565
Y = deepcopy(data.Y)
6666
Y = Y + transform_vals
67-
return Dataset(
68-
Y,
69-
data.D,
70-
data.X,
71-
data._start_date,
72-
data._name
73-
)
67+
return Dataset(Y, data.D, data.X, data._start_date, data._name)
7468

7569

7670
@dataclass(kw_only=True)
@@ -82,13 +76,7 @@ def apply_values(
8276
) -> Dataset:
8377
Y = deepcopy(data.Y)
8478
Y = Y * transform_vals
85-
return Dataset(
86-
Y,
87-
data.D,
88-
data.X,
89-
data._start_date,
90-
data._name
91-
)
79+
return Dataset(Y, data.D, data.X, data._start_date, data._name)
9280

9381

9482
@dataclass(kw_only=True)
@@ -100,10 +88,4 @@ def apply_values(
10088
) -> Dataset:
10189
X = deepcopy(data.X)
10290
X = X + transform_vals
103-
return Dataset(
104-
data.Y,
105-
data.D,
106-
X,
107-
data._start_date,
108-
data._name
109-
)
91+
return Dataset(data.Y, data.D, X, data._start_date, data._name)

src/causal_validation/transforms/noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_values(self, data: Dataset) -> Float[np.ndarray, "T N"]:
3737
noise_treatment = self.noise_dist.get_value(
3838
n_units=data.n_treated_units, n_timepoints=data.n_timepoints
3939
)
40-
noise[:,data.treated_unit_indices] = noise_treatment
40+
noise[:, data.treated_unit_indices] = noise_treatment
4141
return noise
4242

4343

0 commit comments

Comments
 (0)