Skip to content

Commit 579de99

Browse files
committed
Linting and documentation fix
1 parent 99ea63c commit 579de99

File tree

10 files changed

+499
-125
lines changed

10 files changed

+499
-125
lines changed

docs/examples/azcausal.ipynb

Lines changed: 129 additions & 21 deletions
Large diffs are not rendered by default.

docs/examples/basic.ipynb

Lines changed: 155 additions & 37 deletions
Large diffs are not rendered by default.

docs/examples/placebo_test.ipynb

Lines changed: 182 additions & 22 deletions
Large diffs are not rendered by default.

docs/index.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ trend varies across each of the 10 control units.
1717
from causal_validation import Config, simulate
1818
from causal_validation.effects import StaticEffect
1919
from causal_validation.plotters import plot
20-
from causal_validation.transforms import Trend, Periodic
20+
from causal_validation.transforms import Trend
2121
from causal_validation.transforms.parameter import UnitVaryingParameter
22+
import numpy as np
2223
from scipy.stats import norm
2324

24-
cfg = Config(
25-
n_control_units=10,
26-
n_pre_intervention_timepoints=60,
27-
n_post_intervention_timepoints=30,
28-
)
25+
# Treatment assignment matrix
26+
D = np.zeros((90, 11)) # 90 time points, 11 units
27+
D[60:, -1] = 1 # Last unit treated after 60 time points
28+
29+
cfg = Config(treatment_assignments=D)
2930

3031
# Simulate the base observation
3132
base_data = simulate(cfg)
@@ -38,6 +39,8 @@ trended_data = trend_component(base_data)
3839
# Simulate a 5% lift in the treated unit's post-intervention data
3940
effect = StaticEffect(0.05)
4041
inflated_data = effect(trended_data)
42+
43+
plot(inflated_data)
4144
```
4245

4346
![Gaussian process posterior.](static/imgs/readme_fig.png)
@@ -50,6 +53,7 @@ combination with AZCausal by the following.
5053

5154
```python
5255
from azcausal.estimators.panel.sdid import SDID
56+
from causal_validation.estimator.utils import AZCausalWrapper
5357
from causal_validation.validation.placebo import PlaceboTest
5458

5559
model = AZCausalWrapper(model=SDID())

src/causal_validation/config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import (
22
dataclass,
3-
field,
43
)
54
import datetime as dt
65
import typing as tp

src/causal_validation/transforms/noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
class Noise(AdditiveOutputTransform):
2424
"""
2525
Transform the treated units by adding TimeAndUnitVaryingParameter noise terms
26-
sampled from a specified sampling distribution. By default, the sampling distribution
27-
is Normal with 0 loc and 0.1 scale.
26+
sampled from a specified sampling distribution. By default, the sampling
27+
distribution is Normal with 0 loc and 0.1 scale.
2828
"""
2929

3030
noise_dist: TimeAndUnitVaryingParameter = field(

tests/test_causal_validation/test_data.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,18 @@
1-
from copy import deepcopy
2-
import datetime as dt
31
import string
42
import typing as tp
53

6-
from azcausal.estimators.panel.did import DID
74
from hypothesis import (
85
given,
96
settings,
107
strategies as st,
118
)
129
import numpy as np
1310
import pandas as pd
14-
from pandas.core.indexes.datetimes import DatetimeIndex
15-
import pytest
1611

1712
from causal_validation.data import (
1813
Dataset,
1914
DatasetContainer,
2015
)
21-
from causal_validation.types import InterventionTypes
2216

2317

2418
@given(
@@ -58,8 +52,9 @@ def test_dataset(T: int, N: int, K: int, seed: int):
5852
== (T * np.ones(N) - data1.n_post_intervention).tolist()
5953
)
6054

61-
assert data1.n_treated_units == 2
62-
assert data1.n_control_units == N - 2
55+
N_TREATED = 2
56+
assert data1.n_treated_units == N_TREATED
57+
assert data1.n_control_units == N - N_TREATED
6358
assert data1.treated_unit_indices == [2, 3]
6459
assert data1.control_unit_indices == [0, 1] + list(range(4, N))
6560

@@ -92,7 +87,8 @@ def test_dataset_to_df(T: int, N: int, K: int, seed: int):
9287
df1 = data1.to_df()
9388
assert isinstance(df1, pd.DataFrame)
9489
assert df1.shape == (T, N * (K + 2))
95-
assert df1.columns.nlevels == 2
90+
EXPECTED_COLUMN_LEVELS = 2
91+
assert df1.columns.nlevels == EXPECTED_COLUMN_LEVELS
9692
assert df1.columns[0] == ("U0", "Y")
9793
assert df1.columns[1] == ("U0", "D")
9894
assert df1.columns[2] == ("U0", "X0")
@@ -114,7 +110,7 @@ def test_dataset_to_df(T: int, N: int, K: int, seed: int):
114110
df2 = data2.to_df()
115111
assert isinstance(df2, pd.DataFrame)
116112
assert df2.shape == (T, N * 2)
117-
assert df2.columns.nlevels == 2
113+
assert df2.columns.nlevels == EXPECTED_COLUMN_LEVELS
118114
assert df2.columns[0] == ("U0", "Y")
119115
assert df2.columns[1] == ("U0", "D")
120116
assert df2.columns[2] == ("U1", "Y")
@@ -198,13 +194,11 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool, T: int, N: int):
198194
K=st.integers(min_value=1, max_value=10),
199195
seed=st.integers(min_value=1, max_value=100),
200196
use_bernoulli=st.booleans(),
201-
include_X=st.booleans(),
202197
)
203198
@settings(max_examples=10)
204-
def test_inflate(
205-
T: int, N: int, K: int, seed: int, use_bernoulli: bool, include_X: bool
206-
):
199+
def test_inflate(T: int, N: int, K: int, seed: int, use_bernoulli: bool):
207200
rng = np.random.RandomState(seed)
201+
include_X = bool(rng.binomial(1, 0.5))
208202
Y = rng.randn(T, N)
209203
D = rng.binomial(1, 0.3, (T, N)) if use_bernoulli else np.abs(rng.randn(T, N))
210204
X = rng.randn(T, N, K) if include_X else None

tests/test_causal_validation/test_effect.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,12 @@
2020
min_value=-0.5, max_value=0.5, allow_nan=False, allow_infinity=False
2121
),
2222
seed=st.integers(min_value=1, max_value=100),
23-
use_bernoulli=st.booleans(),
24-
include_X=st.booleans(),
2523
)
2624
@settings(max_examples=10)
27-
def test_static_effect(
28-
T: int,
29-
N: int,
30-
K: int,
31-
effect: float,
32-
seed: int,
33-
use_bernoulli: bool,
34-
include_X: bool,
35-
):
25+
def test_static_effect(T: int, N: int, K: int, effect: float, seed: int):
3626
rng = np.random.RandomState(seed)
27+
use_bernoulli = bool(rng.binomial(1, 0.5))
28+
include_X = bool(rng.binomial(1, 0.5))
3729
Y = rng.randn(T, N)
3830
D = rng.binomial(1, 0.3, (T, N)) if use_bernoulli else np.abs(rng.randn(T, N))
3931
X = rng.randn(T, N, K) if include_X else None
@@ -53,7 +45,6 @@ def test_static_effect(
5345

5446

5547
@given(
56-
T=st.integers(min_value=2, max_value=50),
5748
N=st.integers(min_value=2, max_value=50),
5849
K=st.integers(min_value=1, max_value=10),
5950
mean_effect=st.floats(
@@ -63,21 +54,19 @@ def test_static_effect(
6354
min_value=0.01, max_value=0.2, allow_nan=False, allow_infinity=False
6455
),
6556
seed=st.integers(min_value=1, max_value=100),
66-
use_bernoulli=st.booleans(),
67-
include_X=st.booleans(),
6857
)
6958
@settings(max_examples=10)
7059
def test_random_effect(
71-
T: int,
7260
N: int,
7361
K: int,
7462
mean_effect: float,
7563
stddev_effect: float,
7664
seed: int,
77-
use_bernoulli: bool,
78-
include_X: bool,
7965
):
8066
rng = np.random.RandomState(seed)
67+
use_bernoulli = bool(rng.binomial(1, 0.5))
68+
include_X = bool(rng.binomial(1, 0.5))
69+
T = 50
8170
Y = rng.randn(T, N)
8271
D = rng.binomial(1, 0.3, (T, N)) if use_bernoulli else np.abs(rng.randn(T, N))
8372
X = rng.randn(T, N, K) if include_X else None

tests/test_causal_validation/test_simulate.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,10 @@ def test_simulate_control_weighted(n_units, n_timepoints, seed):
163163
)
164164
data = simulate(cfg)
165165

166+
N_TREATED = 2
166167
assert data.Y.shape == (n_timepoints, n_units)
167-
assert data.n_control_units == n_units - 2
168-
assert data.n_treated_units == 2
168+
assert data.n_control_units == n_units - N_TREATED
169+
assert data.n_treated_units == N_TREATED
169170
assert np.all(data.Y[:, :-2] @ cfg.weights[0] == data.Y[:, -2])
170171
assert np.all(data.Y[:, :-2] @ cfg.weights[1] == data.Y[:, -1])
171172

@@ -208,9 +209,10 @@ def test_simulate_control_weighted_with_covariates(
208209
)
209210
data = simulate(cfg)
210211

212+
N_TREATED = 2
211213
assert data.Y.shape == (n_timepoints, n_units)
212-
assert data.n_control_units == n_units - 2
213-
assert data.n_treated_units == 2
214+
assert data.n_control_units == n_units - N_TREATED
215+
assert data.n_treated_units == N_TREATED
214216
assert np.all(data.Y[:, :-2] @ cfg.weights[0] == data.Y[:, -2])
215217
assert np.all(data.Y[:, :-2] @ cfg.weights[1] == data.Y[:, -1])
216218
X_treated1 = np.einsum("ijk,j->ik", data.X[:, :-2, :], cfg.weights[0])

tests/test_causal_validation/test_transforms/test_periodic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_varying_parameters():
161161
data_slots = constants.DATA_SLOTS
162162
base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, constants=constants)
163163
base_data_transform = periodic_transform(base_data)
164-
for i, slot in enumerate(param_slots):
164+
for slot in param_slots:
165165
setattr(
166166
periodic_transform,
167167
slot,

0 commit comments

Comments
 (0)