1+ from hypothesis import (
2+ given ,
3+ strategies as st ,
4+ )
15import numpy as np
2- from hypothesis import given , strategies as st
36
47from causal_validation .config import Config
58
69
710@given (
811 n_units = st .integers (min_value = 1 , max_value = 10 ),
912 n_pre = st .integers (min_value = 1 , max_value = 20 ),
10- n_post = st .integers (min_value = 1 , max_value = 20 )
13+ n_post = st .integers (min_value = 1 , max_value = 20 ),
1114)
1215def test_config_basic_initialization (n_units , n_pre , n_post ):
1316 cfg = Config (
1417 n_control_units = n_units ,
1518 n_pre_intervention_timepoints = n_pre ,
16- n_post_intervention_timepoints = n_post
19+ n_post_intervention_timepoints = n_post ,
1720 )
1821 assert cfg .n_control_units == n_units
1922 assert cfg .n_pre_intervention_timepoints == n_pre
@@ -29,7 +32,7 @@ def test_config_basic_initialization(n_units, n_pre, n_post):
2932 n_pre = st .integers (min_value = 1 , max_value = 10 ),
3033 n_post = st .integers (min_value = 1 , max_value = 10 ),
3134 n_covariates = st .integers (min_value = 1 , max_value = 3 ),
32- seed = st .integers (min_value = 1 , max_value = 1000 )
35+ seed = st .integers (min_value = 1 , max_value = 1000 ),
3336)
3437def test_config_with_covariates_auto_generation (
3538 n_units , n_pre , n_post , n_covariates , seed
@@ -39,7 +42,7 @@ def test_config_with_covariates_auto_generation(
3942 n_pre_intervention_timepoints = n_pre ,
4043 n_post_intervention_timepoints = n_post ,
4144 n_covariates = n_covariates ,
42- seed = seed
45+ seed = seed ,
4346 )
4447 assert cfg .n_covariates == n_covariates
4548 assert cfg .covariate_means .shape == (n_units , n_covariates )
@@ -50,7 +53,7 @@ def test_config_with_covariates_auto_generation(
5053
5154@given (
5255 n_units = st .integers (min_value = 1 , max_value = 3 ),
53- n_covariates = st .integers (min_value = 1 , max_value = 3 )
56+ n_covariates = st .integers (min_value = 1 , max_value = 3 ),
5457)
5558def test_config_with_explicit_covariate_means (n_units , n_covariates ):
5659 means = np .random .random ((n_units , n_covariates ))
@@ -59,14 +62,14 @@ def test_config_with_explicit_covariate_means(n_units, n_covariates):
5962 n_pre_intervention_timepoints = 10 ,
6063 n_post_intervention_timepoints = 5 ,
6164 n_covariates = n_covariates ,
62- covariate_means = means
65+ covariate_means = means ,
6366 )
6467 np .testing .assert_array_equal (cfg .covariate_means , means )
6568
6669
6770@given (
6871 n_units = st .integers (min_value = 1 , max_value = 3 ),
69- n_covariates = st .integers (min_value = 1 , max_value = 3 )
72+ n_covariates = st .integers (min_value = 1 , max_value = 3 ),
7073)
7174def test_config_with_explicit_covariate_stds (n_units , n_covariates ):
7275 stds = np .random .random ((n_units , n_covariates )) + 0.1
@@ -75,6 +78,6 @@ def test_config_with_explicit_covariate_stds(n_units, n_covariates):
7578 n_pre_intervention_timepoints = 10 ,
7679 n_post_intervention_timepoints = 5 ,
7780 n_covariates = n_covariates ,
78- covariate_stds = stds
81+ covariate_stds = stds ,
7982 )
8083 np .testing .assert_array_equal (cfg .covariate_stds , stds )
0 commit comments