Skip to content

Commit ef70135

Browse files
committed
no nlist interface
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent c4f08c8 commit ef70135

File tree

4 files changed

+323
-7
lines changed

4 files changed

+323
-7
lines changed

source/api_cc/include/DeepPotJAX.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class DeepPotJAX : public DeepPotBackend {
2222
/**
2323
* @brief DP constructor with initialization.
2424
* @param[in] model The name of the frozen model file.
25-
* @param[in] gpu_rank The GPU rank. Default is 0.
25+
* @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU.
2626
* @param[in] file_content The content of the model file. If it is not empty,
2727
*DP will read from the string instead of the file.
2828
**/
@@ -32,7 +32,7 @@ class DeepPotJAX : public DeepPotBackend {
3232
/**
3333
* @brief Initialize the DP.
3434
* @param[in] model The name of the frozen model file.
35-
* @param[in] gpu_rank The GPU rank. Default is 0.
35+
* @param[in] gpu_rank The GPU rank. Default is 0. If < 0, use CPU.
3636
* @param[in] file_content The content of the model file. If it is not empty,
3737
*DP will read from the string instead of the file.
3838
**/
@@ -208,6 +208,42 @@ class DeepPotJAX : public DeepPotBackend {
208208
*/
209209
// neighbor list data
210210
NeighborListData nlist_data;
211+
/**
212+
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
213+
*by using this DP.
214+
* @param[out] ener The system energy.
215+
* @param[out] force The force on each atom.
216+
* @param[out] virial The virial.
217+
* @param[out] atom_energy The atomic energy.
218+
* @param[out] atom_virial The atomic virial.
219+
* @param[in] coord The coordinates of atoms. The array should be of size
220+
*nframes x natoms x 3.
221+
* @param[in] atype The atom types. The list should contain natoms ints.
222+
* @param[in] box The cell of the region. The array should be of size nframes
223+
*x 9.
224+
* @param[in] fparam The frame parameter. The array can be of size :
225+
* nframes x dim_fparam.
226+
* dim_fparam. Then all frames are assumed to be provided with the same
227+
*fparam.
228+
* @param[in] aparam The atomic parameter The array can be of size :
229+
* nframes x natoms x dim_aparam.
230+
* natoms x dim_aparam. Then all frames are assumed to be provided with the
231+
*same aparam.
232+
* @param[in] atomic Whether to compute the atomic energy and virial.
233+
**/
234+
template <typename VALUETYPE>
235+
void compute(std::vector<ENERGYTYPE>& ener,
236+
std::vector<VALUETYPE>& force,
237+
std::vector<VALUETYPE>& virial,
238+
std::vector<VALUETYPE>& atom_energy,
239+
std::vector<VALUETYPE>& atom_virial,
240+
const std::vector<VALUETYPE>& coord,
241+
const std::vector<int>& atype,
242+
const std::vector<VALUETYPE>& box,
243+
const std::vector<VALUETYPE>& fparam,
244+
const std::vector<VALUETYPE>& aparam,
245+
const bool atomic);
246+
211247
/**
212248
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
213249
*by using this DP.

source/api_cc/src/DeepPotJAX.cc

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <tensorflow/c/c_api.h>
77
#include <tensorflow/c/eager/c_api.h>
88

9+
#include <array>
10+
#include <cstdint>
911
#include <cstdio>
1012
#include <iostream>
1113
#include <numeric>
@@ -228,6 +230,13 @@ void deepmd::DeepPotJAX::init(const std::string& model,
228230
status = TF_NewStatus();
229231

230232
sessionopts = TF_NewSessionOptions();
233+
int num_intra_nthreads, num_inter_nthreads;
234+
get_env_nthreads(num_intra_nthreads, num_inter_nthreads);
235+
// https://github.com/Neargye/hello_tf_c_api/blob/51516101cf59408a6bb456f7e5f3c6628e327b3a/src/tf_utils.cpp#L400-L401
236+
std::array<std::uint8_t, 4> config = {
237+
{0x10, static_cast<std::uint8_t>(num_intra_nthreads), 0x28,
238+
static_cast<std::uint8_t>(num_inter_nthreads)}};
239+
TF_SetConfig(sessionopts, config.data(), config.size(), status);
231240
TF_Buffer* runopts = NULL;
232241

233242
const char* tags = "serve";
@@ -250,8 +259,8 @@ void deepmd::DeepPotJAX::init(const std::string& model,
250259
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
251260
int gpu_num;
252261
DPGetDeviceCount(gpu_num); // check current device environment
253-
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
254-
if (gpu_num > 0) {
262+
if (gpu_num > 0 && gpu_rank >= 0) {
263+
DPErrcheck(DPSetDevice(gpu_rank % gpu_num));
255264
device = "/gpu:" + std::to_string(gpu_rank % gpu_num);
256265
} else {
257266
device = "/cpu:0";
@@ -300,6 +309,153 @@ deepmd::DeepPotJAX::~DeepPotJAX() {
300309
}
301310
}
302311

312+
template <typename VALUETYPE>
313+
void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
314+
std::vector<VALUETYPE>& force_,
315+
std::vector<VALUETYPE>& virial,
316+
std::vector<VALUETYPE>& atom_energy_,
317+
std::vector<VALUETYPE>& atom_virial_,
318+
const std::vector<VALUETYPE>& dcoord,
319+
const std::vector<int>& datype,
320+
const std::vector<VALUETYPE>& box,
321+
const std::vector<VALUETYPE>& fparam,
322+
const std::vector<VALUETYPE>& aparam_,
323+
const bool atomic) {
324+
std::vector<VALUETYPE> coord, force, aparam, atom_energy, atom_virial;
325+
std::vector<double> ener_double, force_double, virial_double,
326+
atom_energy_double, atom_virial_double;
327+
std::vector<int> atype, fwd_map, bkw_map;
328+
int nghost_real, nall_real, nloc_real;
329+
int nall = datype.size();
330+
// nlist passed to the model
331+
int nframes = nall > 0 ? (dcoord.size() / 3 / nall) : 1;
332+
int nghost = 0;
333+
334+
select_real_atoms_coord(coord, atype, aparam, nghost_real, fwd_map, bkw_map,
335+
nall_real, nloc_real, dcoord, datype, aparam_, nghost,
336+
ntypes, nframes, daparam, nall, false);
337+
338+
if (nloc_real == 0) {
339+
// no real atoms, fill 0 for all outputs
340+
// this can prevent a Xla error
341+
ener.resize(nframes, 0.0);
342+
force_.resize(static_cast<size_t>(nframes) * nall * 3, 0.0);
343+
virial.resize(static_cast<size_t>(nframes) * 9, 0.0);
344+
atom_energy_.resize(static_cast<size_t>(nframes) * nall, 0.0);
345+
atom_virial_.resize(static_cast<size_t>(nframes) * nall * 9, 0.0);
346+
return;
347+
}
348+
349+
// cast coord, fparam, and aparam to double - I think it's useless to have a
350+
// float model interface
351+
std::vector<double> coord_double(coord.begin(), coord.end());
352+
std::vector<double> box_double(box.begin(), box.end());
353+
std::vector<double> fparam_double(fparam.begin(), fparam.end());
354+
std::vector<double> aparam_double(aparam.begin(), aparam.end());
355+
356+
TFE_Op* op;
357+
if (atomic) {
358+
op = get_func_op(ctx, "call_with_atomic_virial", func_vector, device,
359+
status);
360+
} else {
361+
op = get_func_op(ctx, "call_without_atomic_virial", func_vector, device,
362+
status);
363+
}
364+
std::vector<TFE_TensorHandle*> input_list(5);
365+
std::vector<TF_Tensor*> data_tensor(5);
366+
// coord
367+
std::vector<int64_t> coord_shape = {nframes, nloc_real, 3};
368+
input_list[0] =
369+
add_input(op, coord_double, coord_shape, data_tensor[0], status);
370+
// atype
371+
std::vector<int64_t> atype_shape = {nframes, nloc_real};
372+
input_list[1] = add_input(op, atype, atype_shape, data_tensor[1], status);
373+
// box
374+
int box_size = box_double.size() > 0 ? 3 : 0;
375+
std::vector<int64_t> box_shape = {nframes, box_size, box_size};
376+
input_list[2] = add_input(op, box_double, box_shape, data_tensor[2], status);
377+
// fparam
378+
std::vector<int64_t> fparam_shape = {nframes, dfparam};
379+
input_list[3] =
380+
add_input(op, fparam_double, fparam_shape, data_tensor[3], status);
381+
// aparam
382+
std::vector<int64_t> aparam_shape = {nframes, nloc_real, daparam};
383+
input_list[4] =
384+
add_input(op, aparam_double, aparam_shape, data_tensor[4], status);
385+
// execute the function
386+
int nretvals = 6;
387+
TFE_TensorHandle* retvals[nretvals];
388+
389+
TFE_Execute(op, retvals, &nretvals, status);
390+
check_status(status);
391+
392+
// copy data
393+
// for atom virial, the order is:
394+
// Identity_15 energy -1, -1, 1
395+
// Identity_16 energy_derv_c -1, -1, 1, 9 (may pop)
396+
// Identity_17 energy_derv_c_redu -1, 1, 9
397+
// Identity_18 energy_derv_r -1, -1, 1, 3
398+
// Identity_19 energy_redu -1, 1
399+
// Identity_20 mask (int32) -1, -1
400+
//
401+
// for no atom virial, the order is:
402+
// Identity_15 energy -1, -1, 1
403+
// Identity_16 energy_derv_c -1, 1, 9
404+
// Identity_17 energy_derv_r -1, -1, 1, 3
405+
// Identity_18 energy_redu -1, 1
406+
// Identity_19 mask (int32) -1, -1
407+
//
408+
// it seems the order is the alphabet order?
409+
// not sure whether it is safe to assume the order
410+
if (atomic) {
411+
tensor_to_vector(ener_double, retvals[4], status);
412+
tensor_to_vector(force_double, retvals[3], status);
413+
tensor_to_vector(virial_double, retvals[2], status);
414+
tensor_to_vector(atom_energy_double, retvals[0], status);
415+
tensor_to_vector(atom_virial_double, retvals[1], status);
416+
} else {
417+
tensor_to_vector(ener_double, retvals[3], status);
418+
tensor_to_vector(force_double, retvals[2], status);
419+
tensor_to_vector(virial_double, retvals[1], status);
420+
tensor_to_vector(atom_energy_double, retvals[0], status);
421+
}
422+
423+
// cast back to VALUETYPE
424+
ener = std::vector<ENERGYTYPE>(ener_double.begin(), ener_double.end());
425+
force = std::vector<VALUETYPE>(force_double.begin(), force_double.end());
426+
virial = std::vector<VALUETYPE>(virial_double.begin(), virial_double.end());
427+
atom_energy = std::vector<VALUETYPE>(atom_energy_double.begin(),
428+
atom_energy_double.end());
429+
atom_virial = std::vector<VALUETYPE>(atom_virial_double.begin(),
430+
atom_virial_double.end());
431+
force.resize(static_cast<size_t>(nframes) * nall_real * 3);
432+
atom_virial.resize(static_cast<size_t>(nframes) * nall_real * 9);
433+
434+
// nall atom_energy is required in the C++ API;
435+
// we always forget it!
436+
atom_energy.resize(static_cast<size_t>(nframes) * nall_real, 0.0);
437+
438+
force_.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
439+
atom_energy_.resize(static_cast<size_t>(nframes) * fwd_map.size());
440+
atom_virial_.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
441+
select_map<VALUETYPE>(force_, force, bkw_map, 3, nframes, fwd_map.size(),
442+
nall_real);
443+
select_map<VALUETYPE>(atom_energy_, atom_energy, bkw_map, 1, nframes,
444+
fwd_map.size(), nall_real);
445+
select_map<VALUETYPE>(atom_virial_, atom_virial, bkw_map, 9, nframes,
446+
fwd_map.size(), nall_real);
447+
448+
// cleanup input_list, etc
449+
for (size_t i = 0; i < 5; i++) {
450+
TFE_DeleteTensorHandle(input_list[i]);
451+
TF_DeleteTensor(data_tensor[i]);
452+
}
453+
for (size_t i = 0; i < nretvals; i++) {
454+
TFE_DeleteTensorHandle(retvals[i]);
455+
}
456+
TFE_DeleteOp(op);
457+
}
458+
303459
template <typename VALUETYPE>
304460
void deepmd::DeepPotJAX::compute(std::vector<ENERGYTYPE>& ener,
305461
std::vector<VALUETYPE>& force_,
@@ -523,7 +679,8 @@ void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
523679
const std::vector<double>& fparam,
524680
const std::vector<double>& aparam,
525681
const bool atomic) {
526-
throw deepmd::deepmd_exception("not implemented");
682+
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
683+
fparam, aparam, atomic);
527684
}
528685
void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
529686
std::vector<float>& force,
@@ -536,7 +693,8 @@ void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
536693
const std::vector<float>& fparam,
537694
const std::vector<float>& aparam,
538695
const bool atomic) {
539-
throw deepmd::deepmd_exception("not implemented");
696+
compute(ener, force, virial, atom_energy, atom_virial, coord, atype, box,
697+
fparam, aparam, atomic);
540698
}
541699
void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
542700
std::vector<double>& force,

0 commit comments

Comments
 (0)