Skip to content

Commit 7c36020

Browse files
committed
add split_last_iter_valid_ratio
Signed-off-by: zjgemi <[email protected]>
1 parent 5f263d3 commit 7c36020

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

dpgen2/op/run_dp_train.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import glob
22
import json
33
import logging
4+
import math
45
import os
6+
import random
57
import shutil
68
from pathlib import (
79
Path,
@@ -197,6 +199,10 @@ def execute(
197199
valid_data = ip["valid_data"]
198200
iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1])
199201
iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
202+
if config["split_last_iter_valid_ratio"] is not None:
203+
train_systems, valid_systems = split_valid(iter_data_new_exp, config["split_last_iter_valid_ratio"])
204+
iter_data_new_exp = train_systems
205+
valid_data = append_valid_data(config, valid_data, valid_systems)
200206
iter_data_exp = iter_data_old_exp + iter_data_new_exp
201207
work_dir = Path(task_name)
202208
init_model_with_finetune = config["init_model_with_finetune"]
@@ -517,6 +523,7 @@ def training_args():
517523
doc_head = "Head to use in the multitask training"
518524
doc_init_model_with_finetune = "Use finetune for init model"
519525
doc_train_args = "Extra arguments for dp train"
526+
doc_split_last_iter_valid_ratio = "Ratio of valid data if split data of last iter"
520527
return [
521528
Argument(
522529
"command",
@@ -618,6 +625,13 @@ def training_args():
618625
default="",
619626
doc=doc_train_args,
620627
),
628+
Argument(
629+
"split_last_iter_valid_ratio",
630+
float,
631+
optional=True,
632+
default=None,
633+
doc=doc_split_last_iter_valid_ratio,
634+
),
621635
]
622636

623637
@staticmethod
@@ -672,4 +686,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys):
672686
return all_sys_dirs
673687

674688

689+
def split_valid(systems: List[str], valid_ratio: float):
690+
train_systems = []
691+
valid_systems = []
692+
for system in systems:
693+
d = dpdata.MultiSystems()
694+
mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0
695+
if mixed_type:
696+
d.load_systems_from_file(system, fmt="deepmd/npy/mixed")
697+
else:
698+
k = dpdata.LabeledSystem(system, fmt="deepmd/npy")
699+
d.append(k)
700+
701+
train_multi_systems = dpdata.MultiSystems()
702+
valid_multi_systems = dpdata.MultiSystems()
703+
for s in d:
704+
nvalid = math.floor(len(s)*valid_ratio)
705+
if random.random() < len(s)*valid_ratio - nvalid:
706+
nvalid += 1
707+
valid_indices = random.sample(range(len(s)), nvalid)
708+
train_indices = list(set(range(len(s))).difference(valid_indices))
709+
if len(valid_indices) > 0:
710+
valid_multi_systems.append(s.sub_system(valid_indices))
711+
if len(train_indices) > 0:
712+
train_multi_systems.append(s.sub_system(train_indices))
713+
714+
if len(train_multi_systems) > 0:
715+
target = "train_data/" + system
716+
if mixed_type:
717+
# The multisystem is loaded from one dir, thus we can safely keep one dir
718+
train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target)
719+
fs = os.listdir("%s.tmp" % target)
720+
assert len(fs) == 1
721+
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
722+
os.rmdir("%s.tmp" % target)
723+
else:
724+
train_multi_systems[0].to_deepmd_npy(target)
725+
train_systems.append(target)
726+
727+
if len(valid_multi_systems) > 0:
728+
target = "valid_data/" + system
729+
if mixed_type:
730+
# The multisystem is loaded from one dir, thus we can safely keep one dir
731+
valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target)
732+
fs = os.listdir("%s.tmp" % target)
733+
assert len(fs) == 1
734+
os.rename(os.path.join("%s.tmp" % target, fs[0]), target)
735+
os.rmdir("%s.tmp" % target)
736+
else:
737+
valid_multi_systems[0].to_deepmd_npy(target)
738+
valid_systems.append(target)
739+
740+
return train_systems, valid_systems
741+
742+
743+
def append_valid_data(config, valid_data, valid_systems):
744+
if not valid_systems:
745+
return valid_data
746+
if config["multitask"]:
747+
head = config["head"]
748+
if not valid_data:
749+
valid_data = {}
750+
if head not in valid_data:
751+
valid_data[head] = []
752+
valid_data[head] += valid_systems
753+
else:
754+
if not valid_data:
755+
valid_data = []
756+
valid_data += valid_systems
757+
return valid_data
758+
759+
675760
config_args = RunDPTrain.training_args

0 commit comments

Comments
 (0)