@@ -199,7 +199,7 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int):
199199 assert reduced_data .Xte .shape == desired_shape_Xte
200200 assert reduced_data .ytr .shape == desired_shape_ytr
201201 assert reduced_data .yte .shape == desired_shape_yte
202-
202+
203203 assert reduced_data .counterfactual == data .counterfactual
204204 assert reduced_data .synthetic == data .synthetic
205205 assert reduced_data ._name == data ._name
@@ -305,19 +305,19 @@ def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control:
305305 N_CONTROL = n_control ,
306306 )
307307 data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
308-
308+
309309 assert data .counterfactual is None
310310 assert data .synthetic is None
311-
311+
312312 counterfactual_vals = np .random .randn (n_post , 1 )
313313 synthetic_vals = np .random .randn (n_post , 1 )
314-
314+
315315 data_with_attrs = Dataset (
316316 data .Xtr , data .Xte , data .ytr , data .yte , data ._start_date ,
317317 data .Ptr , data .Pte , data .Rtr , data .Rte ,
318318 counterfactual_vals , synthetic_vals , "test_dataset"
319319 )
320-
320+
321321 np .testing .assert_array_equal (data_with_attrs .counterfactual , counterfactual_vals )
322322 np .testing .assert_array_equal (data_with_attrs .synthetic , synthetic_vals )
323323 assert data_with_attrs .name == "test_dataset"
@@ -336,17 +336,17 @@ def test_inflate_method(n_pre: int, n_post: int, n_control: int):
336336 N_CONTROL = n_control ,
337337 )
338338 data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
339-
340- inflation_vals = np .ones ((n_post , 1 )) * 1.1
339+
340+ inflation_vals = np .ones ((n_post , 1 )) * 1.1
341341 inflated_data = data .inflate (inflation_vals )
342-
342+
343343 np .testing .assert_array_equal (inflated_data .Xtr , data .Xtr )
344344 np .testing .assert_array_equal (inflated_data .ytr , data .ytr )
345345 np .testing .assert_array_equal (inflated_data .Xte , data .Xte )
346-
346+
347347 expected_yte = data .yte * inflation_vals
348348 np .testing .assert_array_equal (inflated_data .yte , expected_yte )
349-
349+
350350 np .testing .assert_array_equal (inflated_data .counterfactual , data .yte )
351351
352352
@@ -363,12 +363,12 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int):
363363 N_CONTROL = n_control ,
364364 )
365365 data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
366-
366+
367367 control_units = data .control_units
368368 expected_control = np .vstack ([data .Xtr , data .Xte ])
369369 np .testing .assert_array_equal (control_units , expected_control )
370370 assert control_units .shape == (n_pre + n_post , n_control )
371-
371+
372372 treated_units = data .treated_units
373373 expected_treated = np .vstack ([data .ytr , data .yte ])
374374 np .testing .assert_array_equal (treated_units , expected_treated )
@@ -416,14 +416,16 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool):
416416 n_control = st .integers (min_value = 2 , max_value = 20 ),
417417)
418418@settings (max_examples = 5 )
419- def test_covariate_properties_without_covariates (n_pre : int , n_post : int , n_control : int ):
419+ def test_covariate_properties_without_covariates (
420+ n_pre : int , n_post : int , n_control : int
421+ ):
420422 constants = TestConstants (
421423 N_POST_TREATMENT = n_post ,
422424 N_PRE_TREATMENT = n_pre ,
423425 N_CONTROL = n_control ,
424426 )
425427 data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
426-
428+
427429 assert data .has_covariates is False
428430 assert data .control_covariates is None
429431 assert data .treated_covariates is None
@@ -437,79 +439,45 @@ def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_cont
437439 n_post = st .integers (min_value = 10 , max_value = 50 ),
438440 n_control = st .integers (min_value = 2 , max_value = 10 ),
439441 n_covariates = st .integers (min_value = 1 , max_value = 5 ),
440- Xtr = st .data (),
441- Xte = st .data (),
442- ytr = st .data (),
443- yte = st .data (),
444- Ptr = st .data (),
445- Pte = st .data (),
446- Rtr = st .data (),
447- Rte = st .data (),
442+ seed = st .integers (min_value = 1 , max_value = 10000 ),
448443)
449444@settings (max_examples = 5 )
450- def test_covariate_properties_with_covariates (n_pre : int ,
451- n_post : int ,
452- n_control : int ,
453- n_covariates : int ,
454- Xtr ,
455- Xte ,
456- ytr ,
457- yte ,
458- Ptr ,
459- Pte ,
460- Rtr ,
461- Rte ):
462-
463- Xtr = Xtr .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
464- min_size = n_pre * n_control , max_size = n_pre * n_control ))
465- Xtr = np .array (Xtr ).reshape (n_pre , n_control )
466-
467- Xte = Xte .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
468- min_size = n_post * n_control , max_size = n_post * n_control ))
469- Xte = np .array (Xte ).reshape (n_post , n_control )
470-
471- ytr = ytr .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
472- min_size = n_pre , max_size = n_pre ))
473- ytr = np .array (ytr ).reshape (n_pre , 1 )
474-
475- yte = yte .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
476- min_size = n_post , max_size = n_post ))
477- yte = np .array (yte ).reshape (n_post , 1 )
478-
479- Ptr = Ptr .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
480- min_size = n_pre * n_control * n_covariates , max_size = n_pre * n_control * n_covariates ))
481- Ptr = np .array (Ptr ).reshape (n_pre , n_control , n_covariates )
482-
483- Pte = Pte .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
484- min_size = n_post * n_control * n_covariates , max_size = n_post * n_control * n_covariates ))
485- Pte = np .array (Pte ).reshape (n_post , n_control , n_covariates )
486-
487- Rtr = Rtr .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
488- min_size = n_pre * n_covariates , max_size = n_pre * n_covariates ))
489- Rtr = np .array (Rtr ).reshape (n_pre , 1 , n_covariates )
490-
491- Rte = Rte .draw (st .lists (st .floats (min_value = - 10 , max_value = 10 ),
492- min_size = n_post * n_covariates , max_size = n_post * n_covariates ))
493- Rte = np .array (Rte ).reshape (n_post , 1 , n_covariates )
494-
445+ def test_covariate_properties_with_covariates (
446+ n_pre : int ,
447+ n_post : int ,
448+ n_control : int ,
449+ n_covariates : int ,
450+ seed : int ,
451+ ):
452+ rng = np .random .RandomState (seed )
453+
454+ Xtr = rng .uniform (- 10 , 10 , (n_pre , n_control ))
455+ Xte = rng .uniform (- 10 , 10 , (n_post , n_control ))
456+ ytr = rng .uniform (- 10 , 10 , (n_pre , 1 ))
457+ yte = rng .uniform (- 10 , 10 , (n_post , 1 ))
458+ Ptr = rng .uniform (- 10 , 10 , (n_pre , n_control , n_covariates ))
459+ Pte = rng .uniform (- 10 , 10 , (n_post , n_control , n_covariates ))
460+ Rtr = rng .uniform (- 10 , 10 , (n_pre , 1 , n_covariates ))
461+ Rte = rng .uniform (- 10 , 10 , (n_post , 1 , n_covariates ))
462+
495463 data = Dataset (Xtr , Xte , ytr , yte , dt .date (2023 , 1 , 1 ), Ptr , Pte , Rtr , Rte )
496-
464+
497465 assert data .n_covariates == n_covariates
498466 assert data .has_covariates is True
499-
467+
500468 control_covariates = data .control_covariates
501469 expected_control_cov = np .vstack ([Ptr , Pte ])
502470 np .testing .assert_array_equal (control_covariates , expected_control_cov )
503471 assert control_covariates .shape == (n_pre + n_post , n_control , n_covariates )
504-
472+
505473 treated_covariates = data .treated_covariates
506474 expected_treated_cov = np .vstack ([Rtr , Rte ])
507475 np .testing .assert_array_equal (treated_covariates , expected_treated_cov )
508476 assert treated_covariates .shape == (n_pre + n_post , 1 , n_covariates )
509-
477+
510478 pre_cov = data .pre_intervention_covariates
511479 assert pre_cov == (Ptr , Rtr )
512-
480+
513481 post_cov = data .post_intervention_covariates
514482 assert post_cov == (Pte , Rte )
515483
0 commit comments