Skip to content

Commit a83c98d

Browse files
authored
fix(pt): fix not used sys_probs (#4353)
`sys_probs` was not used in pt, because its priority was lower than that of `auto_prob`, while `auto_prob` always has its default values. See #4346 (reply in thread) . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new sampler selection function for improved data loading flexibility. - **Bug Fixes** - Streamlined logic for obtaining data samplers, enhancing maintainability. - **Tests** - Added end-to-end tests for sampler functionality, ensuring accuracy with system probabilities and automatic styles. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 320c7fd commit a83c98d

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

deepmd/pt/train/training.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from deepmd.pt.utils.dataloader import (
4949
BufferedIterator,
50-
get_weighted_sampler,
50+
get_sampler_from_params,
5151
)
5252
from deepmd.pt.utils.env import (
5353
DEVICE,
@@ -160,19 +160,7 @@ def get_opt_param(params):
160160

161161
def get_data_loader(_training_data, _validation_data, _training_params):
162162
def get_dataloader_and_buffer(_data, _params):
163-
if "auto_prob" in _training_params["training_data"]:
164-
_sampler = get_weighted_sampler(
165-
_data, _params["training_data"]["auto_prob"]
166-
)
167-
elif "sys_probs" in _training_params["training_data"]:
168-
_sampler = get_weighted_sampler(
169-
_data,
170-
_params["training_data"]["sys_probs"],
171-
sys_prob=True,
172-
)
173-
else:
174-
_sampler = get_weighted_sampler(_data, "prob_sys_size")
175-
163+
_sampler = get_sampler_from_params(_data, _params)
176164
if _sampler is None:
177165
log.warning(
178166
"Sampler not specified!"
@@ -193,14 +181,16 @@ def get_dataloader_and_buffer(_data, _params):
193181
return _dataloader, _data_buffered
194182

195183
training_dataloader, training_data_buffered = get_dataloader_and_buffer(
196-
_training_data, _training_params
184+
_training_data, _training_params["training_data"]
197185
)
198186

199187
if _validation_data is not None:
200188
(
201189
validation_dataloader,
202190
validation_data_buffered,
203-
) = get_dataloader_and_buffer(_validation_data, _training_params)
191+
) = get_dataloader_and_buffer(
192+
_validation_data, _training_params["validation_data"]
193+
)
204194
valid_numb_batch = _training_params["validation_data"].get(
205195
"numb_btch", 1
206196
)

deepmd/pt/utils/dataloader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,19 @@ def get_weighted_sampler(training_data, prob_style, sys_prob=False):
306306
with torch.device("cpu"):
307307
sampler = WeightedRandomSampler(probs, len_sampler, replacement=True)
308308
return sampler
309+
310+
311+
def get_sampler_from_params(_data, _params):
312+
if (
313+
"sys_probs" in _params and _params["sys_probs"] is not None
314+
): # use sys_probs first
315+
_sampler = get_weighted_sampler(
316+
_data,
317+
_params["sys_probs"],
318+
sys_prob=True,
319+
)
320+
elif "auto_prob" in _params:
321+
_sampler = get_weighted_sampler(_data, _params["auto_prob"])
322+
else:
323+
_sampler = get_weighted_sampler(_data, "prob_sys_size")
324+
return _sampler

source/tests/pt/test_sampler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from deepmd.pt.utils.dataloader import (
1616
DpLoaderSet,
17+
get_sampler_from_params,
1718
get_weighted_sampler,
1819
)
1920
from deepmd.tf.common import (
@@ -105,6 +106,27 @@ def test_sys_probs(self) -> None:
105106
dp_probs = np.array(self.dp_dataset.sys_probs)
106107
self.assertTrue(np.allclose(my_probs, dp_probs))
107108

109+
def test_sys_probs_end2end(self):
110+
sys_probs = [0.1, 0.4, 0.5]
111+
_params = {
112+
"sys_probs": sys_probs,
113+
"auto_prob": "prob_sys_size",
114+
} # use sys_probs first
115+
sampler = get_sampler_from_params(self.my_dataset, _params)
116+
my_probs = np.array(sampler.weights)
117+
self.dp_dataset.set_sys_probs(sys_probs=sys_probs)
118+
dp_probs = np.array(self.dp_dataset.sys_probs)
119+
self.assertTrue(np.allclose(my_probs, dp_probs))
120+
121+
def test_auto_prob_sys_size_ext_end2end(self):
122+
auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
123+
_params = {"sys_probs": None, "auto_prob": auto_prob_style} # use auto_prob
124+
sampler = get_sampler_from_params(self.my_dataset, _params)
125+
my_probs = np.array(sampler.weights)
126+
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
127+
dp_probs = np.array(self.dp_dataset.sys_probs)
128+
self.assertTrue(np.allclose(my_probs, dp_probs))
129+
108130

109131
if __name__ == "__main__":
110132
unittest.main()

0 commit comments

Comments
 (0)