|
1 | 1 | import glob |
2 | 2 | import json |
3 | 3 | import logging |
| 4 | +import math |
4 | 5 | import os |
| 6 | +import random |
5 | 7 | import shutil |
6 | 8 | from pathlib import ( |
7 | 9 | Path, |
@@ -197,6 +199,10 @@ def execute( |
197 | 199 | valid_data = ip["valid_data"] |
198 | 200 | iter_data_old_exp = _expand_all_multi_sys_to_sys(iter_data[:-1]) |
199 | 201 | iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:]) |
| 202 | + if config["split_last_iter_valid_ratio"] is not None: |
| 203 | + train_systems, valid_systems = split_valid(iter_data_new_exp, config["split_last_iter_valid_ratio"]) |
| 204 | + iter_data_new_exp = train_systems |
| 205 | + valid_data = append_valid_data(config, valid_data, valid_systems) |
200 | 206 | iter_data_exp = iter_data_old_exp + iter_data_new_exp |
201 | 207 | work_dir = Path(task_name) |
202 | 208 | init_model_with_finetune = config["init_model_with_finetune"] |
@@ -517,6 +523,7 @@ def training_args(): |
517 | 523 | doc_head = "Head to use in the multitask training" |
518 | 524 | doc_init_model_with_finetune = "Use finetune for init model" |
519 | 525 | doc_train_args = "Extra arguments for dp train" |
| 526 | + doc_split_last_iter_valid_ratio = "Ratio of valid data if split data of last iter" |
520 | 527 | return [ |
521 | 528 | Argument( |
522 | 529 | "command", |
@@ -618,6 +625,13 @@ def training_args(): |
618 | 625 | default="", |
619 | 626 | doc=doc_train_args, |
620 | 627 | ), |
| 628 | + Argument( |
| 629 | + "split_last_iter_valid_ratio", |
| 630 | + float, |
| 631 | + optional=True, |
| 632 | + default=None, |
| 633 | + doc=doc_split_last_iter_valid_ratio, |
| 634 | + ), |
621 | 635 | ] |
622 | 636 |
|
623 | 637 | @staticmethod |
@@ -672,4 +686,75 @@ def _expand_all_multi_sys_to_sys(list_multi_sys): |
672 | 686 | return all_sys_dirs |
673 | 687 |
|
674 | 688 |
|
| 689 | +def split_valid(systems: List[str], valid_ratio: float): |
| 690 | + train_systems = [] |
| 691 | + valid_systems = [] |
| 692 | + for system in systems: |
| 693 | + d = dpdata.MultiSystems() |
| 694 | + mixed_type = len(glob.glob("%s/*/real_atom_types.npy" % system)) > 0 |
| 695 | + if mixed_type: |
| 696 | + d.load_systems_from_file(system, fmt="deepmd/npy/mixed") |
| 697 | + else: |
| 698 | + k = dpdata.LabeledSystem(system, fmt="deepmd/npy") |
| 699 | + d.append(k) |
| 700 | + |
| 701 | + train_multi_systems = dpdata.MultiSystems() |
| 702 | + valid_multi_systems = dpdata.MultiSystems() |
| 703 | + for s in d: |
| 704 | + nvalid = math.floor(len(s)*valid_ratio) |
| 705 | + if random.random() < len(s)*valid_ratio - nvalid: |
| 706 | + nvalid += 1 |
| 707 | + valid_indices = random.sample(range(len(s)), nvalid) |
| 708 | + train_indices = list(set(range(len(s))).difference(valid_indices)) |
| 709 | + if len(valid_indices) > 0: |
| 710 | + valid_multi_systems.append(s.sub_system(valid_indices)) |
| 711 | + if len(train_indices) > 0: |
| 712 | + train_multi_systems.append(s.sub_system(train_indices)) |
| 713 | + |
| 714 | + if len(train_multi_systems) > 0: |
| 715 | + target = "train_data/" + system |
| 716 | + if mixed_type: |
| 717 | + # The multisystem is loaded from one dir, thus we can safely keep one dir |
| 718 | + train_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) |
| 719 | + fs = os.listdir("%s.tmp" % target) |
| 720 | + assert len(fs) == 1 |
| 721 | + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) |
| 722 | + os.rmdir("%s.tmp" % target) |
| 723 | + else: |
| 724 | + train_multi_systems[0].to_deepmd_npy(target) |
| 725 | + train_systems.append(target) |
| 726 | + |
| 727 | + if len(valid_multi_systems) > 0: |
| 728 | + target = "valid_data/" + system |
| 729 | + if mixed_type: |
| 730 | + # The multisystem is loaded from one dir, thus we can safely keep one dir |
| 731 | + valid_multi_systems.to_deepmd_npy_mixed("%s.tmp" % target) |
| 732 | + fs = os.listdir("%s.tmp" % target) |
| 733 | + assert len(fs) == 1 |
| 734 | + os.rename(os.path.join("%s.tmp" % target, fs[0]), target) |
| 735 | + os.rmdir("%s.tmp" % target) |
| 736 | + else: |
| 737 | + valid_multi_systems[0].to_deepmd_npy(target) |
| 738 | + valid_systems.append(target) |
| 739 | + |
| 740 | + return train_systems, valid_systems |
| 741 | + |
| 742 | + |
| 743 | +def append_valid_data(config, valid_data, valid_systems): |
| 744 | + if not valid_systems: |
| 745 | + return valid_data |
| 746 | + if config["multitask"]: |
| 747 | + head = config["head"] |
| 748 | + if not valid_data: |
| 749 | + valid_data = {} |
| 750 | + if head not in valid_data: |
| 751 | + valid_data[head] = [] |
| 752 | + valid_data[head] += valid_systems |
| 753 | + else: |
| 754 | + if not valid_data: |
| 755 | + valid_data = [] |
| 756 | + valid_data += valid_systems |
| 757 | + return valid_data |
| 758 | + |
| 759 | + |
675 | 760 | config_args = RunDPTrain.training_args |
0 commit comments