2929MAX_STRING_LENGTH = 20
3030DEFAULT_SEED = 123
3131NUM_NON_CONTROL_COLS = 2
32+ NUM_TREATED = 1
3233LARGE_N_POST = 5000
3334LARGE_N_PRE = 5000
3435
@@ -109,15 +110,15 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int):
109110 n_pre_treatment = st .integers (min_value = 1 , max_value = 50 ),
110111 n_post_treatment = st .integers (min_value = 1 , max_value = 50 ),
111112)
112- def test_to_df (n_control : int , n_pre_treatment : int , n_post_treatment : int ):
113+ def test_to_df_no_cov (n_control : int , n_pre_treatment : int , n_post_treatment : int ):
113114 constants = TestConstants (
114115 N_POST_TREATMENT = n_post_treatment ,
115116 N_PRE_TREATMENT = n_pre_treatment ,
116117 N_CONTROL = n_control ,
117118 )
118119 data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
119120
120- df = data .to_df ()
121+ df , _ = data .to_df ()
121122 assert isinstance (df , pd .DataFrame )
122123 assert df .shape == (
123124 n_pre_treatment + n_post_treatment ,
@@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
133134 assert isinstance (index , DatetimeIndex )
134135 assert index [0 ].strftime ("%Y-%m-%d" ) == data ._start_date .strftime ("%Y-%m-%d" )
135136
137+ @given (
138+ n_control = st .integers (min_value = 1 , max_value = 50 ),
139+ n_pre_treatment = st .integers (min_value = 1 , max_value = 50 ),
140+ n_post_treatment = st .integers (min_value = 1 , max_value = 50 ),
141+ n_covariates = st .integers (min_value = 1 , max_value = 50 ),
142+ )
143+ def test_to_df_with_cov (n_control : int ,
144+ n_pre_treatment : int ,
145+ n_post_treatment : int ,
146+ n_covariates :int ):
147+ constants = TestConstants (
148+ N_POST_TREATMENT = n_post_treatment ,
149+ N_PRE_TREATMENT = n_pre_treatment ,
150+ N_CONTROL = n_control ,
151+ N_COVARIATES = n_covariates ,
152+ )
153+ data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
154+
155+ df_outs , df_covs = data .to_df ()
156+ assert isinstance (df_outs , pd .DataFrame )
157+ assert df_outs .shape == (
158+ n_pre_treatment + n_post_treatment ,
159+ n_control + NUM_NON_CONTROL_COLS ,
160+ )
161+
162+ assert isinstance (df_covs , pd .DataFrame )
163+ assert df_covs .shape == (
164+ n_pre_treatment + n_post_treatment ,
165+ n_covariates * (n_control + NUM_TREATED )
166+ + NUM_NON_CONTROL_COLS - NUM_TREATED ,
167+ )
168+
169+ colnames = data ._get_columns ()
170+ assert isinstance (colnames , list )
171+ assert colnames [0 ] == "T"
172+ assert len (colnames ) == n_control + 1
173+
174+ index = data .full_index
175+ assert isinstance (index , DatetimeIndex )
176+ assert index [0 ].strftime ("%Y-%m-%d" ) == data ._start_date .strftime ("%Y-%m-%d" )
177+
136178
137179@given (
138180 n_control = st .integers (min_value = 2 , max_value = 50 ),
0 commit comments