Skip to content

Commit 76cd66a

Browse files
committed
sync mapping interface with deepmodeling#4307
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 1e6f069 commit 76cd66a

File tree

6 files changed

+49
-3
lines changed

6 files changed

+49
-3
lines changed

source/api_c/include/c_api.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ extern "C" {
1212
/** C API version. Bumped whenever the API is changed.
1313
* @since API version 22
1414
*/
15-
#define DP_C_API_VERSION 24
15+
#define DP_C_API_VERSION 25
1616

1717
/**
1818
* @brief Neighbor list.
@@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
3131
int* ilist_,
3232
int* numneigh_,
3333
int** firstneigh_);
34-
/*
34+
/**
3535
* @brief Create a new neighbor list with communication capabilities.
3636
* @details This function extends DP_NewNlist by adding support for parallel
3737
* communication, allowing the neighbor list to be used in distributed
@@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
6868
int* recvproc,
6969
void* world);
7070

71-
/*
71+
/**
7272
* @brief Set mask for a neighbor list.
7373
*
7474
* @param nl Neighbor list.
@@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
7878
**/
7979
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);
8080

81+
/**
82+
* @brief Set mapping for a neighbor list.
83+
*
84+
* @param nl Neighbor list.
85+
* @param mapping mapping from all atoms to real atoms, in size nall.
86+
* @since API version 25
87+
*
88+
**/
89+
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
90+
8191
/**
8292
* @brief Delete a neighbor list.
8393
*

source/api_c/include/deepmd.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,11 @@ struct InputNlist {
863863
* @brief Set mask for this neighbor list.
864864
*/
865865
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
866+
/**
867+
* @brief Set mapping for this neighbor list.
868+
* @param mapping mapping from all atoms to real atoms, in size nall.
869+
*/
870+
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
866871
};
867872

868873
/**

source/api_c/src/c_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
4343
return new_nl;
4444
}
4545
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
46+
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
47+
nl->nl.set_mapping(mapping);
48+
}
4649
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }
4750

4851
// DP Base Model

source/lib/include/neighbor_list.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct InputNlist {
4444
void* world;
4545
/// mask to the neighbor index
4646
int mask = 0xFFFFFFFF;
47+
/// mapping from all atoms to real atoms, in the size of nall
48+
int* mapping = nullptr;
4749
InputNlist()
4850
: inum(0),
4951
ilist(NULL),
@@ -99,6 +101,10 @@ struct InputNlist {
99101
* @brief Set mask for this neighbor list.
100102
*/
101103
void set_mask(int mask_) { mask = mask_; };
104+
/**
105+
* @brief Set mapping for this neighbor list.
106+
*/
107+
void set_mapping(int* mapping_) { mapping = mapping_; };
102108
};
103109

104110
/**

source/lmp/fix_dplr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,14 @@ void FixDPLR::pre_force(int vflag) {
467467
int nghost = atom->nghost;
468468
int nall = nlocal + nghost;
469469

470+
// mapping (for DPA-2 JAX)
471+
std::vector<int> mapping_vec(nall, -1);
472+
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
473+
for (size_t ii = 0; ii < nall; ++ii) {
474+
mapping_vec[ii] = atom->map(atom->tag[ii]);
475+
}
476+
}
477+
470478
// if (eflag_atom) {
471479
// error->all(FLERR,"atomic energy calculation is not supported by this
472480
// fix\n");
@@ -499,6 +507,9 @@ void FixDPLR::pre_force(int vflag) {
499507
deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh,
500508
list->firstneigh);
501509
lmp_list.set_mask(NEIGHMASK);
510+
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
511+
lmp_list.set_mapping(mapping_vec.data());
512+
}
502513
// declear output
503514
vector<FLOAT_PREC> tensor;
504515
// compute

source/lmp/pair_deepmd.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,14 @@ void PairDeepMD::compute(int eflag, int vflag) {
155155
}
156156
}
157157

158+
// mapping (for DPA-2 JAX)
159+
std::vector<int> mapping_vec(nall, -1);
160+
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
161+
for (size_t ii = 0; ii < nall; ++ii) {
162+
mapping_vec[ii] = atom->map(atom->tag[ii]);
163+
}
164+
}
165+
158166
if (do_compute_aparam) {
159167
make_aparam_from_compute(daparam);
160168
} else if (aparam.size() > 0) {
@@ -198,6 +206,9 @@ void PairDeepMD::compute(int eflag, int vflag) {
198206
commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc,
199207
commdata_->recvproc, &world);
200208
lmp_list.set_mask(NEIGHMASK);
209+
if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) {
210+
lmp_list.set_mapping(mapping_vec.data());
211+
}
201212
deepmd_compat::InputNlist extend_lmp_list;
202213
if (single_model || multi_models_no_mod_devi) {
203214
// cvflag_atom is the right flag for the cvatom matrix

0 commit comments

Comments
 (0)