Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/script.h>
Expand Down Expand Up @@ -335,7 +334,7 @@ class DeepPotPT : public DeepPotBackend {
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int do_message_passing; // 1:dpa2 model 0:others
bool do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
c10::optional<torch::Tensor> mapping_tensor;
Expand Down
3 changes: 1 addition & 2 deletions source/api_cc/include/DeepSpinPT.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <torch/script.h>
Expand Down Expand Up @@ -257,7 +256,7 @@ class DeepSpinPT : public DeepSpinBackend {
NeighborListData nlist_data;
int max_num_neighbors;
int gpu_id;
int do_message_passing; // 1:dpa2 model 0:others
bool do_message_passing; // 1:dpa2 model 0:others
bool gpu_enabled;
at::Tensor firstneigh_tensor;
c10::optional<torch::Tensor> mapping_tensor;
Expand Down
5 changes: 2 additions & 3 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#ifdef BUILD_PYTORCH
#include "DeepPotPT.h"

Expand Down Expand Up @@ -171,7 +170,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing == 1) {
if (do_message_passing) {
int nswap = lmp_list.nswap;
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
Expand Down Expand Up @@ -234,7 +233,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1)
(do_message_passing)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
firstneigh_tensor, mapping_tensor, fparam_tensor,
Expand Down
5 changes: 2 additions & 3 deletions source/api_cc/src/DeepSpinPT.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#ifdef BUILD_PYTORCH
#include "DeepSpinPT.h"

Expand Down Expand Up @@ -179,7 +178,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing == 1) {
if (do_message_passing) {
int nswap = lmp_list.nswap;
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
Expand Down Expand Up @@ -234,7 +233,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1)
(do_message_passing)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
spin_wrapped_Tensor, firstneigh_tensor,
Expand Down
Loading