Skip to content

Commit 0a47c2e

Browse files
Patch non-repeative sampling "inv_pop_f" (#300)
using non-repeative sampling instead of the original repeative sampling when using `"candi_sel_prob": "inv_pop_f"`. `random.choices` -> `numpy.random.choice(replace=False)` The original behavior could take a large portion of repeated long-tail low-frequency smaples (the longer the tail, the worse the case), causing tens of percents of repeated downstream fp calculations, moreover amplifying the noise in labels from these high-force configurations. The non-repeated sampling re-nomalizes the prob after screening out each picked sample <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved candidate selection process to ensure unique selections when limiting the number of candidates, preventing duplicates in the output. * **Tests** * Updated tests to reflect changes in the candidate selection method and to ensure correct probability handling and uniqueness of selected candidates. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8aaa7ca commit 0a47c2e

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

dpgen2/exploration/report/report_adaptive_lower.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,13 @@ def _get_candidates_inv_pop_f(
446446
self.candi_picked = [(ii[0], ii[1]) for ii in self.candi]
447447
if max_nframes is not None and max_nframes < len(self.candi_picked):
448448
prob = self._choice_prob_inv_pop_f(self.candi_picked)
449-
ret = random.choices(
450-
self.candi_picked,
451-
weights=prob,
452-
k=max_nframes,
449+
indices = np.random.choice(
450+
len(self.candi_picked),
451+
size=max_nframes,
452+
replace=False,
453+
p=prob / np.sum(prob),
453454
)
455+
ret = [self.candi_picked[i] for i in indices]
454456
else:
455457
ret = self.candi_picked
456458
return ret

tests/exploration/test_report_adaptive_lower.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,30 +198,32 @@ def test_f_inv_pop(self):
198198
)
199199

200200
def faked_choices(
201-
candi,
202-
weights=None,
203-
k=0,
201+
a, # numb_candi
202+
size=None, # numb_select
203+
replace=False, # non-repeative sampling
204+
p=None, # normalized prob
204205
):
205206
# hist: 2bins, 0.1-0.4 5candi, 0.4-0.7 7candi
206207
# only return those with mdf 0.1-0.4
207-
self.assertEqual(len(weights), 12)
208-
self.assertEqual(len(candi), 12)
209-
ret = []
210-
for ii in range(len(candi)):
208+
candi = ter.candi_picked
209+
self.assertEqual(a, 12)
210+
self.assertEqual(len(p), 12)
211+
ret_indices = []
212+
for ii in range(a):
211213
tidx, fidx = candi[ii]
212214
this_mdf = md_f[tidx][fidx]
213215
if this_mdf < 0.4:
214-
self.assertAlmostEqual(weights[ii], 1.0 / 5.0)
215-
ret.append(candi[ii])
216+
self.assertAlmostEqual(p[ii], 0.1) # 1/5 / 2.0
217+
ret_indices.append(ii)
216218
else:
217-
self.assertAlmostEqual(weights[ii], 1.0 / 7.0)
218-
return ret
219+
self.assertAlmostEqual(p[ii], 1.0 / 14.0) # 1/7 / 2.0
220+
return ret_indices
219221

220222
ter.record(model_devi)
221223
self.assertEqual(ter.candi, expected_cand)
222224
self.assertEqual(ter.accur, expected_accu)
223225
self.assertEqual(set(ter.failed), expected_fail)
224-
with mock.patch("random.choices", faked_choices):
226+
with mock.patch("numpy.random.choice", faked_choices):
225227
picked = ter.get_candidate_ids(11)
226228
self.assertFalse(ter.converged([]))
227229
self.assertEqual(len(picked), 2)

0 commit comments

Comments
 (0)