21
21
import matplotlib .pyplot as plt
22
22
import numpy as np
23
23
import pandas as pd
24
+ from sklearn import utils as sku
25
+ from torch import rand
24
26
25
27
from qolmat .benchmark import missing_patterns
26
28
from qolmat .utils import data
27
29
30
+ seed = 1234
31
+ rng = sku .check_random_state (seed )
32
+
28
33
# %%
29
34
# 1. Data
30
35
# ---------------------------------------------------------------
42
47
columns = ["TEMP" , "PRES" , "DEWP" , "RAIN" , "WSPM" ]
43
48
df_data = df_data [columns ]
44
49
45
- df = data .add_holes (df_data , ratio_masked = 0.2 , mean_size = 120 )
50
+ df = data .add_holes (df_data , ratio_masked = 0.2 , mean_size = 120 , random_state = rng )
46
51
cols_to_impute = df .columns
47
52
48
53
# %%
@@ -169,8 +174,8 @@ def plot_cdf(
169
174
axs [ind ].plot (sorted_data , cdf , c = "gray" , lw = 2 , label = "original" )
170
175
171
176
for df_mask , label , color in zip (list_df_mask , labels , colors ):
172
- array_mask = df_mask .copy ()
173
- array_mask [array_mask == True ] = np .nan
177
+ array_mask = df_mask .astype ( float ). copy ()
178
+ array_mask [df_mask ] = np .nan
174
179
hole_sizes_created = get_holes_sizes_column_wise (array_mask .to_numpy ())
175
180
176
181
for ind , (hole_created , col ) in enumerate (
@@ -197,7 +202,7 @@ def plot_cdf(
197
202
# Note this class is more suited for tabular datasets.
198
203
199
204
uniform_generator = missing_patterns .UniformHoleGenerator (
200
- n_splits = 1 , subset = df .columns , ratio_masked = 0.1
205
+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , random_state = rng
201
206
)
202
207
uniform_mask = uniform_generator .split (df )[0 ]
203
208
@@ -223,7 +228,7 @@ def plot_cdf(
223
228
# :class:`~qolmat.benchmark.missing_patterns.UniformHoleGenerator` class.
224
229
225
230
geometric_generator = missing_patterns .GeometricHoleGenerator (
226
- n_splits = 1 , subset = cols_to_impute , ratio_masked = 0.1
231
+ n_splits = 1 , subset = cols_to_impute , ratio_masked = 0.1 , random_state = rng
227
232
)
228
233
geometric_mask = geometric_generator .split (df )[0 ]
229
234
@@ -249,7 +254,7 @@ def plot_cdf(
249
254
# is learned on each group: here on each station.
250
255
251
256
empirical_generator = missing_patterns .EmpiricalHoleGenerator (
252
- n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,)
257
+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,), random_state = rng
253
258
)
254
259
empirical_mask = empirical_generator .split (df )[0 ]
255
260
@@ -274,7 +279,7 @@ def plot_cdf(
274
279
# :class:`~qolmat.benchmark.missing_patterns.MultiMarkovHoleGenerator` class.
275
280
276
281
multi_markov_generator = missing_patterns .MultiMarkovHoleGenerator (
277
- n_splits = 1 , subset = df .columns , ratio_masked = 0.1
282
+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , random_state = rng
278
283
)
279
284
multi_markov_mask = multi_markov_generator .split (df )[0 ]
280
285
@@ -297,7 +302,7 @@ def plot_cdf(
297
302
# :class:`~qolmat.benchmark.missing_patterns.GroupedHoleGenerator` class.
298
303
299
304
grouped_generator = missing_patterns .GroupedHoleGenerator (
300
- n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,)
305
+ n_splits = 1 , subset = df .columns , ratio_masked = 0.1 , groups = ("station" ,), random_state = rng
301
306
)
302
307
grouped_mask = grouped_generator .split (df )[0 ]
303
308
0 commit comments