Skip to content

Commit a669f7b

Browse files
committed
test_models to test_az_causal
1 parent 978c819 commit a669f7b

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import typing as tp
2+
3+
from azcausal.core.effect import Effect
4+
from azcausal.core.error import (
5+
Bootstrap,
6+
Error,
7+
JackKnife,
8+
)
9+
from azcausal.core.estimator import Estimator
10+
from azcausal.core.result import Result as _Result
11+
from azcausal.estimators.panel import (
12+
did,
13+
sdid,
14+
)
15+
from hypothesis import (
16+
given,
17+
settings,
18+
strategies as st,
19+
)
20+
import numpy as np
21+
22+
from causal_validation.estimator import Result
23+
from causal_validation.estimator.utils import AZCausalWrapper
24+
25+
from causal_validation.testing import (
26+
TestConstants,
27+
simulate_data,
28+
)
29+
30+
MODELS = [did.DID(), sdid.SDID()]
31+
MODEL_ERROR = [
32+
(did.DID(), None),
33+
(sdid.SDID(), None),
34+
(sdid.SDID(), Bootstrap()),
35+
(sdid.SDID(), JackKnife()),
36+
]
37+
38+
39+
@given(
40+
model_error=st.sampled_from(MODEL_ERROR),
41+
n_control=st.integers(min_value=2, max_value=5),
42+
n_pre_treatment=st.integers(min_value=1, max_value=50),
43+
n_post_treatment=st.integers(min_value=1, max_value=50),
44+
seed=st.integers(min_value=1, max_value=100),
45+
)
46+
@settings(max_examples=10)
47+
def test_call(
48+
model_error: tp.Union[Estimator, Error],
49+
n_control: int,
50+
n_pre_treatment: int,
51+
n_post_treatment: int,
52+
seed: int,
53+
):
54+
D = np.zeros((n_pre_treatment + n_post_treatment, n_control + 1))
55+
D[n_pre_treatment:, -1] = 1
56+
constants = TestConstants(
57+
TREATMENT_ASSIGNMENTS=D
58+
)
59+
data = simulate_data(global_mean=10.0, seed=seed, constants=constants)
60+
model = AZCausalWrapper(*model_error)
61+
result = model(data)
62+
63+
assert isinstance(result, Result)
64+
assert isinstance(result.effect, Effect)
65+
assert not np.isnan(result.effect.value)
66+
assert isinstance(model._az_result, _Result)
67+
assert np.all(data.treated_unit_outputs == result.observed)
68+
assert (
69+
result.observed.shape == result.counterfactual.shape == result.synthetic.shape
70+
)

0 commit comments

Comments
 (0)