Skip to content

Commit 0dc4c3d

Browse files
committed
Fix linting errors in tests.
1 parent 6682af4 commit 0dc4c3d

File tree

3 files changed

+42
-74
lines changed

3 files changed

+42
-74
lines changed

tests/test_causal_validation/test_data.py

Lines changed: 40 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int):
199199
assert reduced_data.Xte.shape == desired_shape_Xte
200200
assert reduced_data.ytr.shape == desired_shape_ytr
201201
assert reduced_data.yte.shape == desired_shape_yte
202-
202+
203203
assert reduced_data.counterfactual == data.counterfactual
204204
assert reduced_data.synthetic == data.synthetic
205205
assert reduced_data._name == data._name
@@ -305,19 +305,19 @@ def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control:
305305
N_CONTROL=n_control,
306306
)
307307
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
308-
308+
309309
assert data.counterfactual is None
310310
assert data.synthetic is None
311-
311+
312312
counterfactual_vals = np.random.randn(n_post, 1)
313313
synthetic_vals = np.random.randn(n_post, 1)
314-
314+
315315
data_with_attrs = Dataset(
316316
data.Xtr, data.Xte, data.ytr, data.yte, data._start_date,
317317
data.Ptr, data.Pte, data.Rtr, data.Rte,
318318
counterfactual_vals, synthetic_vals, "test_dataset"
319319
)
320-
320+
321321
np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals)
322322
np.testing.assert_array_equal(data_with_attrs.synthetic, synthetic_vals)
323323
assert data_with_attrs.name == "test_dataset"
@@ -336,17 +336,17 @@ def test_inflate_method(n_pre: int, n_post: int, n_control: int):
336336
N_CONTROL=n_control,
337337
)
338338
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
339-
340-
inflation_vals = np.ones((n_post, 1)) * 1.1
339+
340+
inflation_vals = np.ones((n_post, 1)) * 1.1
341341
inflated_data = data.inflate(inflation_vals)
342-
342+
343343
np.testing.assert_array_equal(inflated_data.Xtr, data.Xtr)
344344
np.testing.assert_array_equal(inflated_data.ytr, data.ytr)
345345
np.testing.assert_array_equal(inflated_data.Xte, data.Xte)
346-
346+
347347
expected_yte = data.yte * inflation_vals
348348
np.testing.assert_array_equal(inflated_data.yte, expected_yte)
349-
349+
350350
np.testing.assert_array_equal(inflated_data.counterfactual, data.yte)
351351

352352

@@ -363,12 +363,12 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int):
363363
N_CONTROL=n_control,
364364
)
365365
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
366-
366+
367367
control_units = data.control_units
368368
expected_control = np.vstack([data.Xtr, data.Xte])
369369
np.testing.assert_array_equal(control_units, expected_control)
370370
assert control_units.shape == (n_pre + n_post, n_control)
371-
371+
372372
treated_units = data.treated_units
373373
expected_treated = np.vstack([data.ytr, data.yte])
374374
np.testing.assert_array_equal(treated_units, expected_treated)
@@ -416,14 +416,16 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool):
416416
n_control=st.integers(min_value=2, max_value=20),
417417
)
418418
@settings(max_examples=5)
419-
def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_control: int):
419+
def test_covariate_properties_without_covariates(
420+
n_pre: int, n_post: int, n_control: int
421+
):
420422
constants = TestConstants(
421423
N_POST_TREATMENT=n_post,
422424
N_PRE_TREATMENT=n_pre,
423425
N_CONTROL=n_control,
424426
)
425427
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
426-
428+
427429
assert data.has_covariates is False
428430
assert data.control_covariates is None
429431
assert data.treated_covariates is None
@@ -437,79 +439,45 @@ def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_cont
437439
n_post=st.integers(min_value=10, max_value=50),
438440
n_control=st.integers(min_value=2, max_value=10),
439441
n_covariates=st.integers(min_value=1, max_value=5),
440-
Xtr=st.data(),
441-
Xte=st.data(),
442-
ytr=st.data(),
443-
yte=st.data(),
444-
Ptr=st.data(),
445-
Pte=st.data(),
446-
Rtr=st.data(),
447-
Rte=st.data(),
442+
seed=st.integers(min_value=1, max_value=10000),
448443
)
449444
@settings(max_examples=5)
450-
def test_covariate_properties_with_covariates(n_pre: int,
451-
n_post: int,
452-
n_control: int,
453-
n_covariates: int,
454-
Xtr,
455-
Xte,
456-
ytr,
457-
yte,
458-
Ptr,
459-
Pte,
460-
Rtr,
461-
Rte):
462-
463-
Xtr = Xtr.draw(st.lists(st.floats(min_value=-10, max_value=10),
464-
min_size=n_pre*n_control, max_size=n_pre*n_control))
465-
Xtr = np.array(Xtr).reshape(n_pre, n_control)
466-
467-
Xte = Xte.draw(st.lists(st.floats(min_value=-10, max_value=10),
468-
min_size=n_post*n_control, max_size=n_post*n_control))
469-
Xte = np.array(Xte).reshape(n_post, n_control)
470-
471-
ytr = ytr.draw(st.lists(st.floats(min_value=-10, max_value=10),
472-
min_size=n_pre, max_size=n_pre))
473-
ytr = np.array(ytr).reshape(n_pre, 1)
474-
475-
yte = yte.draw(st.lists(st.floats(min_value=-10, max_value=10),
476-
min_size=n_post, max_size=n_post))
477-
yte = np.array(yte).reshape(n_post, 1)
478-
479-
Ptr = Ptr.draw(st.lists(st.floats(min_value=-10, max_value=10),
480-
min_size=n_pre*n_control*n_covariates, max_size=n_pre*n_control*n_covariates))
481-
Ptr = np.array(Ptr).reshape(n_pre, n_control, n_covariates)
482-
483-
Pte = Pte.draw(st.lists(st.floats(min_value=-10, max_value=10),
484-
min_size=n_post*n_control*n_covariates, max_size=n_post*n_control*n_covariates))
485-
Pte = np.array(Pte).reshape(n_post, n_control, n_covariates)
486-
487-
Rtr = Rtr.draw(st.lists(st.floats(min_value=-10, max_value=10),
488-
min_size=n_pre*n_covariates, max_size=n_pre*n_covariates))
489-
Rtr = np.array(Rtr).reshape(n_pre, 1, n_covariates)
490-
491-
Rte = Rte.draw(st.lists(st.floats(min_value=-10, max_value=10),
492-
min_size=n_post*n_covariates, max_size=n_post*n_covariates))
493-
Rte = np.array(Rte).reshape(n_post, 1, n_covariates)
494-
445+
def test_covariate_properties_with_covariates(
446+
n_pre: int,
447+
n_post: int,
448+
n_control: int,
449+
n_covariates: int,
450+
seed: int,
451+
):
452+
rng = np.random.RandomState(seed)
453+
454+
Xtr = rng.uniform(-10, 10, (n_pre, n_control))
455+
Xte = rng.uniform(-10, 10, (n_post, n_control))
456+
ytr = rng.uniform(-10, 10, (n_pre, 1))
457+
yte = rng.uniform(-10, 10, (n_post, 1))
458+
Ptr = rng.uniform(-10, 10, (n_pre, n_control, n_covariates))
459+
Pte = rng.uniform(-10, 10, (n_post, n_control, n_covariates))
460+
Rtr = rng.uniform(-10, 10, (n_pre, 1, n_covariates))
461+
Rte = rng.uniform(-10, 10, (n_post, 1, n_covariates))
462+
495463
data = Dataset(Xtr, Xte, ytr, yte, dt.date(2023, 1, 1), Ptr, Pte, Rtr, Rte)
496-
464+
497465
assert data.n_covariates == n_covariates
498466
assert data.has_covariates is True
499-
467+
500468
control_covariates = data.control_covariates
501469
expected_control_cov = np.vstack([Ptr, Pte])
502470
np.testing.assert_array_equal(control_covariates, expected_control_cov)
503471
assert control_covariates.shape == (n_pre + n_post, n_control, n_covariates)
504-
472+
505473
treated_covariates = data.treated_covariates
506474
expected_treated_cov = np.vstack([Rtr, Rte])
507475
np.testing.assert_array_equal(treated_covariates, expected_treated_cov)
508476
assert treated_covariates.shape == (n_pre + n_post, 1, n_covariates)
509-
477+
510478
pre_cov = data.pre_intervention_covariates
511479
assert pre_cov == (Ptr, Rtr)
512-
480+
513481
post_cov = data.post_intervention_covariates
514482
assert post_cov == (Pte, Rte)
515483

tests/test_causal_validation/test_validation/test_placebo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_schema_coerce():
3030
df = PlaceboSchema.example()
3131
cols = df.columns
3232
for col in cols:
33-
if not col in ["Model", "Dataset"]:
33+
if col not in ["Model", "Dataset"]:
3434
df[col] = np.ceil((df[col]))
3535
PlaceboSchema.validate(df)
3636

tests/test_causal_validation/test_validation/test_rmspe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_schema_coerce():
3535
df = RMSPESchema.example()
3636
cols = df.columns
3737
for col in cols:
38-
if not col in ["Model", "Dataset"]:
38+
if col not in ["Model", "Dataset"]:
3939
df[col] = np.ceil((df[col]))
4040
RMSPESchema.validate(df)
4141

0 commit comments

Comments
 (0)