-
Notifications
You must be signed in to change notification settings - Fork 575
feat(jax): SavedModel C++ interface (including DPA-2 supports) #4307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 7 commits
Commits
Show all changes
81 commits
Select commit
Hold shift + click to select a range
147400e
feat: saved model C++ interface
njzjz 8c6d522
model
njzjz 140f3e1
update test data
njzjz a0b8074
need CPU model
njzjz e6bf59f
skip memory check
njzjz 6c10e8e
fix
njzjz 2aa6deb
Apply suggestions from code review
njzjz f16dd92
Update source/api_cc/src/DeepPotJAX.cc
njzjz 297ae26
debug memory leak
njzjz e64e06a
add LAMMPS test
njzjz 8fefce8
fix memory leak in add_input
njzjz 261c7bd
pass reference
njzjz 4d5ccc5
delete function and retvals
njzjz d365bbc
Merge branch 'savedmodel-cxx-debug-mem' into savedmodel-cxx
njzjz 21fc045
no need to skip the test
njzjz 660171e
Merge branch 'devel' into savedmodel-cxx
njzjz d552821
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz 0461248
add limitation
njzjz f26f3fe
fix tf string parse
njzjz 713d065
Update source/api_cc/tests/test_deeppot_jax.cc
njzjz ccb182d
cast void*
njzjz 8ccead6
handle zero atom
njzjz 904042d
Merge branch 'devel' into jax-cxx-dpa2
njzjz 0f9d5c5
feat(jax): DPA-2 for LAMMPS
njzjz bad564b
use the cpu model
njzjz 2b165d7
fix function name
njzjz e717ba3
fix typos
njzjz f075075
nloc_real -> nall_real
njzjz 58dcf2b
document limation
njzjz d93d13a
Merge branch 'devel' into jax-cxx-dpa1
232f7cd
fix(tf): fix normalize when compressing a model converted from other …
ce9ee61
apply padding method
6b10eb7
update model
njzjz afc71cb
Merge commit 'ce9ee61e71b83d2c682522706f98955dfecea98a' into jax-cxx-…
njzjz e1a2b55
Merge remote-tracking branch 'origin/devel' into reformat-jax-cxx
njzjz 649f98e
update base class
njzjz 1cad0b2
perhaps PADDING_FACTOR doesn't need so much
njzjz 239d186
use max size
njzjz 37c8739
bump API version
njzjz b863c79
update model
njzjz 95ad9d0
update model
njzjz b6d039f
Revert "use max size"
njzjz 5e2ea67
test
njzjz 72a23d2
debug
njzjz edc4445
add all functions
njzjz b0808f1
Reapply "use max size"
njzjz 458be34
Revert "debug"
njzjz 87908c3
Revert "test"
njzjz 3a0ca2d
Revert "update model"
njzjz 1863b27
Revert "update model"
njzjz eb549e5
cast type
njzjz 8a154bd
update model
njzjz 4dab4fb
bugfix
njzjz c4f08c8
fix OOM issue
njzjz ef70135
no nlist interface
njzjz be02814
fix skip
njzjz 3c46f37
try to reduce memory
njzjz 49f57bc
fix skip tests
njzjz e8a99f4
also skip lammps dpa-2 tests for CUDA
njzjz 8f83a28
should be fw
njzjz 8c05d54
Revert "should be fw"
njzjz 88be054
Revert "try to reduce memory"
njzjz 5cfc83c
Revert "fix OOM issue"
njzjz 9af5267
set --clean-durations
njzjz 01567d6
Merge branch 'devel' into savedmodel-cxx
njzjz dc4a9d7
Merge remote-tracking branch 'origin/devel' into savedmodel-cxx
njzjz 1234489
add example
njzjz 86d1b7a
convert models at runtime
njzjz 93cc440
add script path
njzjz 546f7dc
revert strict=False
njzjz 0d51bcc
revert .gitignore
njzjz fc1f90d
prefer cuda's cudnn
njzjz 9447603
bump cuda version
njzjz 6d5b45a
debug
njzjz e569ed9
fix docker name
njzjz 1b3fd5e
set allow_growth to True
njzjz 9d95778
fix compile error
njzjz 09efdd3
fix typo
njzjz 39f357c
call TFE_ContextOptionsSetConfig
njzjz cfff834
fix config
njzjz ca02625
Revert "debug"
njzjz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,249 @@ | ||
| // SPDX-License-Identifier: LGPL-3.0-or-later | ||
| #pragma once | ||
|
|
||
| #include <tensorflow/c/c_api.h> | ||
| #include <tensorflow/c/eager/c_api.h> | ||
|
|
||
| #include "DeepPot.h" | ||
| #include "common.h" | ||
| #include "neighbor_list.h" | ||
|
|
||
| namespace deepmd { | ||
| /** | ||
| * @brief TensorFlow implementation for Deep Potential. | ||
| **/ | ||
| class DeepPotJAX : public DeepPotBase { | ||
| public: | ||
| /** | ||
| * @brief DP constructor without initialization. | ||
| **/ | ||
| DeepPotJAX(); | ||
| virtual ~DeepPotJAX(); | ||
| /** | ||
| * @brief DP constructor with initialization. | ||
| * @param[in] model The name of the frozen model file. | ||
| * @param[in] gpu_rank The GPU rank. Default is 0. | ||
| * @param[in] file_content The content of the model file. If it is not empty, | ||
| *DP will read from the string instead of the file. | ||
| **/ | ||
| DeepPotJAX(const std::string& model, | ||
| const int& gpu_rank = 0, | ||
| const std::string& file_content = ""); | ||
| /** | ||
| * @brief Initialize the DP. | ||
| * @param[in] model The name of the frozen model file. | ||
| * @param[in] gpu_rank The GPU rank. Default is 0. | ||
| * @param[in] file_content The content of the model file. If it is not empty, | ||
| *DP will read from the string instead of the file. | ||
| **/ | ||
| void init(const std::string& model, | ||
| const int& gpu_rank = 0, | ||
| const std::string& file_content = ""); | ||
| /** | ||
| * @brief Get the cutoff radius. | ||
| * @return The cutoff radius. | ||
| **/ | ||
| double cutoff() const { | ||
| assert(inited); | ||
| return rcut; | ||
| }; | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /** | ||
| * @brief Get the number of types. | ||
| * @return The number of types. | ||
| **/ | ||
| int numb_types() const { | ||
| assert(inited); | ||
| return ntypes; | ||
| }; | ||
| /** | ||
| * @brief Get the number of types with spin. | ||
| * @return The number of types with spin. | ||
| **/ | ||
| int numb_types_spin() const { | ||
| assert(inited); | ||
| return 0; | ||
| }; | ||
| /** | ||
| * @brief Get the dimension of the frame parameter. | ||
| * @return The dimension of the frame parameter. | ||
| **/ | ||
| int dim_fparam() const { | ||
| assert(inited); | ||
| return dfparam; | ||
| }; | ||
| /** | ||
| * @brief Get the dimension of the atomic parameter. | ||
| * @return The dimension of the atomic parameter. | ||
| **/ | ||
| int dim_aparam() const { | ||
| assert(inited); | ||
| return daparam; | ||
| }; | ||
| /** | ||
| * @brief Get the type map (element name of the atom types) of this model. | ||
| * @param[out] type_map The type map of this model. | ||
| **/ | ||
| void get_type_map(std::string& type_map); | ||
|
|
||
| /** | ||
| * @brief Get whether the atom dimension of aparam is nall instead of fparam. | ||
| * @param[out] aparam_nall whether the atom dimension of aparam is nall | ||
| *instead of fparam. | ||
| **/ | ||
| bool is_aparam_nall() const { | ||
| assert(inited); | ||
| return false; | ||
| }; | ||
|
|
||
| // forward to template class | ||
| void computew(std::vector<double>& ener, | ||
| std::vector<double>& force, | ||
| std::vector<double>& virial, | ||
| std::vector<double>& atom_energy, | ||
| std::vector<double>& atom_virial, | ||
| const std::vector<double>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<double>& box, | ||
| const std::vector<double>& fparam, | ||
| const std::vector<double>& aparam, | ||
| const bool atomic); | ||
| void computew(std::vector<double>& ener, | ||
| std::vector<float>& force, | ||
| std::vector<float>& virial, | ||
| std::vector<float>& atom_energy, | ||
| std::vector<float>& atom_virial, | ||
| const std::vector<float>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<float>& box, | ||
| const std::vector<float>& fparam, | ||
| const std::vector<float>& aparam, | ||
| const bool atomic); | ||
| void computew(std::vector<double>& ener, | ||
| std::vector<double>& force, | ||
| std::vector<double>& virial, | ||
| std::vector<double>& atom_energy, | ||
| std::vector<double>& atom_virial, | ||
| const std::vector<double>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<double>& box, | ||
| const int nghost, | ||
| const InputNlist& inlist, | ||
| const int& ago, | ||
| const std::vector<double>& fparam, | ||
| const std::vector<double>& aparam, | ||
| const bool atomic); | ||
| void computew(std::vector<double>& ener, | ||
| std::vector<float>& force, | ||
| std::vector<float>& virial, | ||
| std::vector<float>& atom_energy, | ||
| std::vector<float>& atom_virial, | ||
| const std::vector<float>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<float>& box, | ||
| const int nghost, | ||
| const InputNlist& inlist, | ||
| const int& ago, | ||
| const std::vector<float>& fparam, | ||
| const std::vector<float>& aparam, | ||
| const bool atomic); | ||
| void computew_mixed_type(std::vector<double>& ener, | ||
| std::vector<double>& force, | ||
| std::vector<double>& virial, | ||
| std::vector<double>& atom_energy, | ||
| std::vector<double>& atom_virial, | ||
| const int& nframes, | ||
| const std::vector<double>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<double>& box, | ||
| const std::vector<double>& fparam, | ||
| const std::vector<double>& aparam, | ||
| const bool atomic); | ||
| void computew_mixed_type(std::vector<double>& ener, | ||
| std::vector<float>& force, | ||
| std::vector<float>& virial, | ||
| std::vector<float>& atom_energy, | ||
| std::vector<float>& atom_virial, | ||
| const int& nframes, | ||
| const std::vector<float>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<float>& box, | ||
| const std::vector<float>& fparam, | ||
| const std::vector<float>& aparam, | ||
| const bool atomic); | ||
|
|
||
| private: | ||
| bool inited; | ||
| // device | ||
| std::string device; | ||
| // the cutoff radius | ||
| double rcut; | ||
| // the number of types | ||
| int ntypes; | ||
| // the dimension of the frame parameter | ||
| int dfparam; | ||
| // the dimension of the atomic parameter | ||
| int daparam; | ||
| // type map | ||
| std::string type_map; | ||
| // sel | ||
| std::vector<int64_t> sel; | ||
| // number of neighbors | ||
| int nnei; | ||
| /** TF C API objects. | ||
| * @{ | ||
| */ | ||
| TF_Graph* graph; | ||
| TF_Status* status; | ||
| TF_Session* session; | ||
| TF_SessionOptions* sessionopts; | ||
| TFE_ContextOptions* ctx_opts; | ||
| TFE_Context* ctx; | ||
| std::vector<TF_Function*> func_vector; | ||
| /** | ||
| * @} | ||
| */ | ||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // neighbor list data | ||
| NeighborListData nlist_data; | ||
| /** | ||
| * @brief Evaluate the energy, force, virial, atomic energy, and atomic virial | ||
| *by using this DP. | ||
| * @param[out] ener The system energy. | ||
| * @param[out] force The force on each atom. | ||
| * @param[out] virial The virial. | ||
| * @param[out] atom_energy The atomic energy. | ||
| * @param[out] atom_virial The atomic virial. | ||
| * @param[in] coord The coordinates of atoms. The array should be of size | ||
| *nframes x natoms x 3. | ||
| * @param[in] atype The atom types. The list should contain natoms ints. | ||
| * @param[in] box The cell of the region. The array should be of size nframes | ||
| *x 9. | ||
| * @param[in] nghost The number of ghost atoms. | ||
| * @param[in] lmp_list The input neighbour list. | ||
| * @param[in] ago Update the internal neighbour list if ago is 0. | ||
| * @param[in] fparam The frame parameter. The array can be of size : | ||
| * nframes x dim_fparam. | ||
| * dim_fparam. Then all frames are assumed to be provided with the same | ||
| *fparam. | ||
| * @param[in] aparam The atomic parameter The array can be of size : | ||
| * nframes x natoms x dim_aparam. | ||
| * natoms x dim_aparam. Then all frames are assumed to be provided with the | ||
| *same aparam. | ||
| * @param[in] atomic Whether to compute atomic energy and virial. | ||
| **/ | ||
| template <typename VALUETYPE> | ||
| void compute(std::vector<ENERGYTYPE>& ener, | ||
| std::vector<VALUETYPE>& force, | ||
| std::vector<VALUETYPE>& virial, | ||
| std::vector<VALUETYPE>& atom_energy, | ||
| std::vector<VALUETYPE>& atom_virial, | ||
| const std::vector<VALUETYPE>& coord, | ||
| const std::vector<int>& atype, | ||
| const std::vector<VALUETYPE>& box, | ||
| const int nghost, | ||
| const InputNlist& lmp_list, | ||
| const int& ago, | ||
| const std::vector<VALUETYPE>& fparam, | ||
| const std::vector<VALUETYPE>& aparam, | ||
| const bool atomic); | ||
| }; | ||
| } // namespace deepmd | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.