Skip to content

Commit ec6e903

Browse files
chore(pt): Change the type of do_message_passing from int to bool in DeepPotPT and DeepSpinPT classes (#4391)
Fix #4366. * Update the type of `do_message_passing` to `bool` in the `DeepPotPT` class and `init` method in `source/api_cc/include/DeepPotPT.h` and `source/api_cc/src/DeepPotPT.cc` * Update the type of `do_message_passing` to `bool` in the `DeepSpinPT` class and `init` method in `source/api_cc/include/DeepSpinPT.h` and `source/api_cc/src/DeepSpinPT.cc` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced error handling for exceptions from the PyTorch library in both `DeepPotPT` and `DeepSpinPT` classes. - Simplified boolean checks for message passing in the `compute` methods of both classes. - **Bug Fixes** - Improved robustness in the `DeepPotPT` constructor to prevent resource leaks during initialization. - **Documentation** - Updated method signatures to reflect changes in parameter types and structures. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dbf450f commit ec6e903

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

source/api_cc/include/DeepPotPT.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ class DeepPotPT : public DeepPotBackend {
335335
NeighborListData nlist_data;
336336
int max_num_neighbors;
337337
int gpu_id;
338-
int do_message_passing; // 1:dpa2 model 0:others
338+
bool do_message_passing; // 1:dpa2 model 0:others
339339
bool gpu_enabled;
340340
at::Tensor firstneigh_tensor;
341341
c10::optional<torch::Tensor> mapping_tensor;

source/api_cc/include/DeepSpinPT.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class DeepSpinPT : public DeepSpinBackend {
257257
NeighborListData nlist_data;
258258
int max_num_neighbors;
259259
int gpu_id;
260-
int do_message_passing; // 1:dpa2 model 0:others
260+
bool do_message_passing; // 1:dpa2 model 0:others
261261
bool gpu_enabled;
262262
at::Tensor firstneigh_tensor;
263263
c10::optional<torch::Tensor> mapping_tensor;

source/api_cc/src/DeepPotPT.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
171171
nlist_data.copy_from_nlist(lmp_list);
172172
nlist_data.shuffle_exclude_empty(fwd_map);
173173
nlist_data.padding();
174-
if (do_message_passing == 1) {
174+
if (do_message_passing) {
175175
int nswap = lmp_list.nswap;
176176
torch::Tensor sendproc_tensor =
177177
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
@@ -234,7 +234,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
234234
.to(device);
235235
}
236236
c10::Dict<c10::IValue, c10::IValue> outputs =
237-
(do_message_passing == 1)
237+
(do_message_passing)
238238
? module
239239
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
240240
firstneigh_tensor, mapping_tensor, fparam_tensor,

source/api_cc/src/DeepSpinPT.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
179179
nlist_data.copy_from_nlist(lmp_list);
180180
nlist_data.shuffle_exclude_empty(fwd_map);
181181
nlist_data.padding();
182-
if (do_message_passing == 1) {
182+
if (do_message_passing) {
183183
int nswap = lmp_list.nswap;
184184
torch::Tensor sendproc_tensor =
185185
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
@@ -234,7 +234,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
234234
.to(device);
235235
}
236236
c10::Dict<c10::IValue, c10::IValue> outputs =
237-
(do_message_passing == 1)
237+
(do_message_passing)
238238
? module
239239
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
240240
spin_wrapped_Tensor, firstneigh_tensor,

0 commit comments

Comments
 (0)