66#include < tensorflow/c/c_api.h>
77#include < tensorflow/c/eager/c_api.h>
88
9+ #include < array>
10+ #include < cstdint>
911#include < cstdio>
1012#include < iostream>
1113#include < numeric>
@@ -228,6 +230,13 @@ void deepmd::DeepPotJAX::init(const std::string& model,
228230 status = TF_NewStatus ();
229231
230232 sessionopts = TF_NewSessionOptions ();
233+ int num_intra_nthreads, num_inter_nthreads;
234+ get_env_nthreads (num_intra_nthreads, num_inter_nthreads);
235+ // https://github.com/Neargye/hello_tf_c_api/blob/51516101cf59408a6bb456f7e5f3c6628e327b3a/src/tf_utils.cpp#L400-L401
236+ std::array<std::uint8_t , 4 > config = {
237+ {0x10 , static_cast <std::uint8_t >(num_intra_nthreads), 0x28 ,
238+ static_cast <std::uint8_t >(num_inter_nthreads)}};
239+ TF_SetConfig (sessionopts, config.data (), config.size (), status);
231240 TF_Buffer* runopts = NULL ;
232241
233242 const char * tags = " serve" ;
@@ -250,8 +259,8 @@ void deepmd::DeepPotJAX::init(const std::string& model,
250259#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
251260 int gpu_num;
252261 DPGetDeviceCount (gpu_num); // check current device environment
253- DPErrcheck ( DPSetDevice ( gpu_rank % gpu_num));
254- if (gpu_num > 0 ) {
262+ if (gpu_num > 0 && gpu_rank >= 0 ) {
263+ DPErrcheck ( DPSetDevice (gpu_rank % gpu_num));
255264 device = " /gpu:" + std::to_string (gpu_rank % gpu_num);
256265 } else {
257266 device = " /cpu:0" ;
@@ -300,6 +309,153 @@ deepmd::DeepPotJAX::~DeepPotJAX() {
300309 }
301310}
302311
312+ template <typename VALUETYPE>
313+ void deepmd::DeepPotJAX::compute (std::vector<ENERGYTYPE>& ener,
314+ std::vector<VALUETYPE>& force_,
315+ std::vector<VALUETYPE>& virial,
316+ std::vector<VALUETYPE>& atom_energy_,
317+ std::vector<VALUETYPE>& atom_virial_,
318+ const std::vector<VALUETYPE>& dcoord,
319+ const std::vector<int >& datype,
320+ const std::vector<VALUETYPE>& box,
321+ const std::vector<VALUETYPE>& fparam,
322+ const std::vector<VALUETYPE>& aparam_,
323+ const bool atomic) {
324+ std::vector<VALUETYPE> coord, force, aparam, atom_energy, atom_virial;
325+ std::vector<double > ener_double, force_double, virial_double,
326+ atom_energy_double, atom_virial_double;
327+ std::vector<int > atype, fwd_map, bkw_map;
328+ int nghost_real, nall_real, nloc_real;
329+ int nall = datype.size ();
330+ // nlist passed to the model
331+ int nframes = nall > 0 ? (dcoord.size () / 3 / nall) : 1 ;
332+ int nghost = 0 ;
333+
334+ select_real_atoms_coord (coord, atype, aparam, nghost_real, fwd_map, bkw_map,
335+ nall_real, nloc_real, dcoord, datype, aparam_, nghost,
336+ ntypes, nframes, daparam, nall, false );
337+
338+ if (nloc_real == 0 ) {
339+ // no real atoms, fill 0 for all outputs
340+ // this can prevent a Xla error
341+ ener.resize (nframes, 0.0 );
342+ force_.resize (static_cast <size_t >(nframes) * nall * 3 , 0.0 );
343+ virial.resize (static_cast <size_t >(nframes) * 9 , 0.0 );
344+ atom_energy_.resize (static_cast <size_t >(nframes) * nall, 0.0 );
345+ atom_virial_.resize (static_cast <size_t >(nframes) * nall * 9 , 0.0 );
346+ return ;
347+ }
348+
349+ // cast coord, fparam, and aparam to double - I think it's useless to have a
350+ // float model interface
351+ std::vector<double > coord_double (coord.begin (), coord.end ());
352+ std::vector<double > box_double (box.begin (), box.end ());
353+ std::vector<double > fparam_double (fparam.begin (), fparam.end ());
354+ std::vector<double > aparam_double (aparam.begin (), aparam.end ());
355+
356+ TFE_Op* op;
357+ if (atomic) {
358+ op = get_func_op (ctx, " call_with_atomic_virial" , func_vector, device,
359+ status);
360+ } else {
361+ op = get_func_op (ctx, " call_without_atomic_virial" , func_vector, device,
362+ status);
363+ }
364+ std::vector<TFE_TensorHandle*> input_list (5 );
365+ std::vector<TF_Tensor*> data_tensor (5 );
366+ // coord
367+ std::vector<int64_t > coord_shape = {nframes, nloc_real, 3 };
368+ input_list[0 ] =
369+ add_input (op, coord_double, coord_shape, data_tensor[0 ], status);
370+ // atype
371+ std::vector<int64_t > atype_shape = {nframes, nloc_real};
372+ input_list[1 ] = add_input (op, atype, atype_shape, data_tensor[1 ], status);
373+ // box
374+ int box_size = box_double.size () > 0 ? 3 : 0 ;
375+ std::vector<int64_t > box_shape = {nframes, box_size, box_size};
376+ input_list[2 ] = add_input (op, box_double, box_shape, data_tensor[2 ], status);
377+ // fparam
378+ std::vector<int64_t > fparam_shape = {nframes, dfparam};
379+ input_list[3 ] =
380+ add_input (op, fparam_double, fparam_shape, data_tensor[3 ], status);
381+ // aparam
382+ std::vector<int64_t > aparam_shape = {nframes, nloc_real, daparam};
383+ input_list[4 ] =
384+ add_input (op, aparam_double, aparam_shape, data_tensor[4 ], status);
385+ // execute the function
386+ int nretvals = 6 ;
387+ TFE_TensorHandle* retvals[nretvals];
388+
389+ TFE_Execute (op, retvals, &nretvals, status);
390+ check_status (status);
391+
392+ // copy data
393+ // for atom virial, the order is:
394+ // Identity_15 energy -1, -1, 1
395+ // Identity_16 energy_derv_c -1, -1, 1, 9 (may pop)
396+ // Identity_17 energy_derv_c_redu -1, 1, 9
397+ // Identity_18 energy_derv_r -1, -1, 1, 3
398+ // Identity_19 energy_redu -1, 1
399+ // Identity_20 mask (int32) -1, -1
400+ //
401+ // for no atom virial, the order is:
402+ // Identity_15 energy -1, -1, 1
403+ // Identity_16 energy_derv_c -1, 1, 9
404+ // Identity_17 energy_derv_r -1, -1, 1, 3
405+ // Identity_18 energy_redu -1, 1
406+ // Identity_19 mask (int32) -1, -1
407+ //
408+ // it seems the order is the alphabet order?
409+ // not sure whether it is safe to assume the order
410+ if (atomic) {
411+ tensor_to_vector (ener_double, retvals[4 ], status);
412+ tensor_to_vector (force_double, retvals[3 ], status);
413+ tensor_to_vector (virial_double, retvals[2 ], status);
414+ tensor_to_vector (atom_energy_double, retvals[0 ], status);
415+ tensor_to_vector (atom_virial_double, retvals[1 ], status);
416+ } else {
417+ tensor_to_vector (ener_double, retvals[3 ], status);
418+ tensor_to_vector (force_double, retvals[2 ], status);
419+ tensor_to_vector (virial_double, retvals[1 ], status);
420+ tensor_to_vector (atom_energy_double, retvals[0 ], status);
421+ }
422+
423+ // cast back to VALUETYPE
424+ ener = std::vector<ENERGYTYPE>(ener_double.begin (), ener_double.end ());
425+ force = std::vector<VALUETYPE>(force_double.begin (), force_double.end ());
426+ virial = std::vector<VALUETYPE>(virial_double.begin (), virial_double.end ());
427+ atom_energy = std::vector<VALUETYPE>(atom_energy_double.begin (),
428+ atom_energy_double.end ());
429+ atom_virial = std::vector<VALUETYPE>(atom_virial_double.begin (),
430+ atom_virial_double.end ());
431+ force.resize (static_cast <size_t >(nframes) * nall_real * 3 );
432+ atom_virial.resize (static_cast <size_t >(nframes) * nall_real * 9 );
433+
434+ // nall atom_energy is required in the C++ API;
435+ // we always forget it!
436+ atom_energy.resize (static_cast <size_t >(nframes) * nall_real, 0.0 );
437+
438+ force_.resize (static_cast <size_t >(nframes) * fwd_map.size () * 3 );
439+ atom_energy_.resize (static_cast <size_t >(nframes) * fwd_map.size ());
440+ atom_virial_.resize (static_cast <size_t >(nframes) * fwd_map.size () * 9 );
441+ select_map<VALUETYPE>(force_, force, bkw_map, 3 , nframes, fwd_map.size (),
442+ nall_real);
443+ select_map<VALUETYPE>(atom_energy_, atom_energy, bkw_map, 1 , nframes,
444+ fwd_map.size (), nall_real);
445+ select_map<VALUETYPE>(atom_virial_, atom_virial, bkw_map, 9 , nframes,
446+ fwd_map.size (), nall_real);
447+
448+ // cleanup input_list, etc
449+ for (size_t i = 0 ; i < 5 ; i++) {
450+ TFE_DeleteTensorHandle (input_list[i]);
451+ TF_DeleteTensor (data_tensor[i]);
452+ }
453+ for (size_t i = 0 ; i < nretvals; i++) {
454+ TFE_DeleteTensorHandle (retvals[i]);
455+ }
456+ TFE_DeleteOp (op);
457+ }
458+
303459template <typename VALUETYPE>
304460void deepmd::DeepPotJAX::compute (std::vector<ENERGYTYPE>& ener,
305461 std::vector<VALUETYPE>& force_,
@@ -523,7 +679,8 @@ void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
523679 const std::vector<double >& fparam,
524680 const std::vector<double >& aparam,
525681 const bool atomic) {
526- throw deepmd::deepmd_exception (" not implemented" );
682+ compute (ener, force, virial, atom_energy, atom_virial, coord, atype, box,
683+ fparam, aparam, atomic);
527684}
528685void deepmd::DeepPotJAX::computew (std::vector<double >& ener,
529686 std::vector<float >& force,
@@ -536,7 +693,8 @@ void deepmd::DeepPotJAX::computew(std::vector<double>& ener,
536693 const std::vector<float >& fparam,
537694 const std::vector<float >& aparam,
538695 const bool atomic) {
539- throw deepmd::deepmd_exception (" not implemented" );
696+ compute (ener, force, virial, atom_energy, atom_virial, coord, atype, box,
697+ fparam, aparam, atomic);
540698}
541699void deepmd::DeepPotJAX::computew (std::vector<double >& ener,
542700 std::vector<double >& force,
0 commit comments