1+ import typing as tp
2+
3+ from azcausal .core .effect import Effect
4+ from azcausal .core .error import (
5+ Bootstrap ,
6+ Error ,
7+ JackKnife ,
8+ )
9+ from azcausal .core .estimator import Estimator
10+ from azcausal .core .result import Result as _Result
11+ from azcausal .estimators .panel import (
12+ did ,
13+ sdid ,
14+ )
15+ from hypothesis import (
16+ given ,
17+ settings ,
18+ strategies as st ,
19+ )
20+ import numpy as np
21+
22+ from causal_validation .estimator import Result
23+ from causal_validation .estimator .utils import AZCausalWrapper
24+
25+ from causal_validation .testing import (
26+ TestConstants ,
27+ simulate_data ,
28+ )
29+
30+ MODELS = [did .DID (), sdid .SDID ()]
31+ MODEL_ERROR = [
32+ (did .DID (), None ),
33+ (sdid .SDID (), None ),
34+ (sdid .SDID (), Bootstrap ()),
35+ (sdid .SDID (), JackKnife ()),
36+ ]
37+
38+
39+ @given (
40+ model_error = st .sampled_from (MODEL_ERROR ),
41+ n_control = st .integers (min_value = 2 , max_value = 5 ),
42+ n_pre_treatment = st .integers (min_value = 1 , max_value = 50 ),
43+ n_post_treatment = st .integers (min_value = 1 , max_value = 50 ),
44+ seed = st .integers (min_value = 1 , max_value = 100 ),
45+ )
46+ @settings (max_examples = 10 )
47+ def test_call (
48+ model_error : tp .Union [Estimator , Error ],
49+ n_control : int ,
50+ n_pre_treatment : int ,
51+ n_post_treatment : int ,
52+ seed : int ,
53+ ):
54+ D = np .zeros ((n_pre_treatment + n_post_treatment , n_control + 1 ))
55+ D [n_pre_treatment :, - 1 ] = 1
56+ constants = TestConstants (
57+ TREATMENT_ASSIGNMENTS = D
58+ )
59+ data = simulate_data (global_mean = 10.0 , seed = seed , constants = constants )
60+ model = AZCausalWrapper (* model_error )
61+ result = model (data )
62+
63+ assert isinstance (result , Result )
64+ assert isinstance (result .effect , Effect )
65+ assert not np .isnan (result .effect .value )
66+ assert isinstance (model ._az_result , _Result )
67+ assert np .all (data .treated_unit_outputs == result .observed )
68+ assert (
69+ result .observed .shape == result .counterfactual .shape == result .synthetic .shape
70+ )
0 commit comments