Skip to content

Commit 73755b3

Browse files
Copilotnjzjzpre-commit-ci[bot]
authored
fix(pt,pd): remove redundant tensor handling to eliminate tensor construction warnings (#4907)
This PR fixes deprecation warnings that occur when `torch.tensor()` or `paddle.to_tensor()` is called on existing tensor objects: **PyTorch warning:** ``` UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). ``` **PaddlePaddle warning:** ``` UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach(), rather than paddle.to_tensor(sourceTensor). ``` ## Root Cause The warnings were being triggered in multiple locations: 1. **PyTorch**: Test cases were passing tensor objects directly to ASE calculators, which internally convert them using `torch.tensor()` 2. **PaddlePaddle**: Similar issues in `eval_model` function and `to_paddle_tensor` utility, plus a TypeError where `tensor.to()` method was incorrectly using `place=` instead of `device=` ## Solution **For PyTorch:** - Modified test cases to convert tensor inputs to numpy arrays before passing to ASE calculators - Removed redundant tensor handling in `to_torch_tensor` utility function since the non-numpy check already handles tensors by returning them as-is **For PaddlePaddle:** - Added proper type checking in `eval_model` function to handle existing tensors with `clone().detach()` - Removed redundant tensor handling in `to_paddle_tensor` utility function, applying the same optimization as PyTorch - Fixed TypeError by changing `place=` to `device=` in all `tensor.to()` method calls (PaddlePaddle's tensor `.to()` method expects `device=` parameter, while `paddle.to_tensor()` correctly uses `place=`) ## Changes Made 1. **`source/tests/pt/test_calculator.py`**: Fixed `TestCalculator` and `TestCalculatorWithFparamAparam` to convert PyTorch tensors to numpy arrays before passing to ASE calculator 2. **`deepmd/pt/utils/utils.py`**: Removed redundant tensor-specific handling in `to_torch_tensor` function 3. **`source/tests/pd/common.py`**: Updated `eval_model` function with type checking for PaddlePaddle tensors and fixed `tensor.to()` method calls to use `device=` instead of `place=` 4. **`deepmd/pd/utils/utils.py`**: Removed redundant tensor-specific handling in `to_paddle_tensor` function for consistency with PyTorch Both utility functions now use a simplified approach where the `if not isinstance(xx, np.ndarray): return xx` check handles all non-numpy inputs (including tensors) by returning them unchanged, eliminating the need for separate tensor-specific code paths. This change is backward compatible and maintains the same functionality while eliminating both deprecation warnings and TypeErrors, improving code consistency between PyTorch and PaddlePaddle backends. Fixes #3790. <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: njzjz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 727ec3c commit 73755b3

File tree

3 files changed

+112
-34
lines changed

3 files changed

+112
-34
lines changed

source/tests/pd/common.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def eval_model(
7979
if spins is not None:
8080
assert isinstance(spins, paddle.Tensor), err_msg
8181
assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
82-
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
82+
if isinstance(atom_types, paddle.Tensor):
83+
atom_types = (
84+
atom_types.clone().detach().to(dtype=paddle.int32, device=DEVICE)
85+
)
86+
else:
87+
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
8388
elif isinstance(coords, np.ndarray):
8489
if cells is not None:
8590
assert isinstance(cells, np.ndarray), err_msg
@@ -101,28 +106,57 @@ def eval_model(
101106
else:
102107
natoms = len(atom_types[0])
103108

104-
coord_input = paddle.to_tensor(
105-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
106-
)
107-
spin_input = None
108-
if spins is not None:
109-
spin_input = paddle.to_tensor(
110-
spins.reshape([-1, natoms, 3]),
109+
if isinstance(coords, paddle.Tensor):
110+
coord_input = (
111+
coords.reshape([-1, natoms, 3])
112+
.clone()
113+
.detach()
114+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
115+
)
116+
else:
117+
coord_input = paddle.to_tensor(
118+
coords.reshape([-1, natoms, 3]),
111119
dtype=GLOBAL_PD_FLOAT_PRECISION,
112120
place=DEVICE,
113121
)
122+
spin_input = None
123+
if spins is not None:
124+
if isinstance(spins, paddle.Tensor):
125+
spin_input = (
126+
spins.reshape([-1, natoms, 3])
127+
.clone()
128+
.detach()
129+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
130+
)
131+
else:
132+
spin_input = paddle.to_tensor(
133+
spins.reshape([-1, natoms, 3]),
134+
dtype=GLOBAL_PD_FLOAT_PRECISION,
135+
place=DEVICE,
136+
)
114137
has_spin = getattr(model, "has_spin", False)
115138
if callable(has_spin):
116139
has_spin = has_spin()
117-
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
140+
if isinstance(atom_types, paddle.Tensor):
141+
type_input = atom_types.clone().detach().to(dtype=paddle.int64, device=DEVICE)
142+
else:
143+
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
118144
box_input = None
119145
if cells is None:
120146
pbc = False
121147
else:
122148
pbc = True
123-
box_input = paddle.to_tensor(
124-
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
125-
)
149+
if isinstance(cells, paddle.Tensor):
150+
box_input = (
151+
cells.reshape([-1, 3, 3])
152+
.clone()
153+
.detach()
154+
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
155+
)
156+
else:
157+
box_input = paddle.to_tensor(
158+
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
159+
)
126160
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
127161

128162
for ii in range(num_iter):

source/tests/pt/common.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def eval_model(
7979
if spins is not None:
8080
assert isinstance(spins, torch.Tensor), err_msg
8181
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
82-
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
82+
if isinstance(atom_types, torch.Tensor):
83+
atom_types = (
84+
atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
85+
)
86+
else:
87+
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
8388
elif isinstance(coords, np.ndarray):
8489
if cells is not None:
8590
assert isinstance(cells, np.ndarray), err_msg
@@ -101,28 +106,59 @@ def eval_model(
101106
else:
102107
natoms = len(atom_types[0])
103108

104-
coord_input = torch.tensor(
105-
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
106-
)
107-
spin_input = None
108-
if spins is not None:
109-
spin_input = torch.tensor(
110-
spins.reshape([-1, natoms, 3]),
109+
if isinstance(coords, torch.Tensor):
110+
coord_input = (
111+
coords.reshape([-1, natoms, 3])
112+
.clone()
113+
.detach()
114+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
115+
)
116+
else:
117+
coord_input = torch.tensor(
118+
coords.reshape([-1, natoms, 3]),
111119
dtype=GLOBAL_PT_FLOAT_PRECISION,
112120
device=DEVICE,
113121
)
122+
spin_input = None
123+
if spins is not None:
124+
if isinstance(spins, torch.Tensor):
125+
spin_input = (
126+
spins.reshape([-1, natoms, 3])
127+
.clone()
128+
.detach()
129+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
130+
)
131+
else:
132+
spin_input = torch.tensor(
133+
spins.reshape([-1, natoms, 3]),
134+
dtype=GLOBAL_PT_FLOAT_PRECISION,
135+
device=DEVICE,
136+
)
114137
has_spin = getattr(model, "has_spin", False)
115138
if callable(has_spin):
116139
has_spin = has_spin()
117-
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
140+
if isinstance(atom_types, torch.Tensor):
141+
type_input = atom_types.clone().detach().to(dtype=torch.long, device=DEVICE)
142+
else:
143+
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
118144
box_input = None
119145
if cells is None:
120146
pbc = False
121147
else:
122148
pbc = True
123-
box_input = torch.tensor(
124-
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
125-
)
149+
if isinstance(cells, torch.Tensor):
150+
box_input = (
151+
cells.reshape([-1, 3, 3])
152+
.clone()
153+
.detach()
154+
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
155+
)
156+
else:
157+
box_input = torch.tensor(
158+
cells.reshape([-1, 3, 3]),
159+
dtype=GLOBAL_PT_FLOAT_PRECISION,
160+
device=DEVICE,
161+
)
126162
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
127163

128164
for ii in range(num_iter):

source/tests/pt/test_calculator.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,18 @@ def test_calculator(self) -> None:
6464
atomic_numbers = [1, 1, 1, 8, 8]
6565
idx_perm = [1, 0, 4, 3, 2]
6666

67+
# Convert tensors to numpy for ASE compatibility
68+
cell_np = cell.numpy()
69+
coord_np = coord.numpy()
70+
6771
prec = 1e-10
6872
low_prec = 1e-4
6973

7074
ase_atoms0 = Atoms(
7175
numbers=atomic_numbers,
72-
positions=coord,
76+
positions=coord_np,
7377
# positions=[tuple(item) for item in coordinate],
74-
cell=cell,
78+
cell=cell_np,
7579
calculator=self.calculator,
7680
pbc=True,
7781
)
@@ -83,9 +87,9 @@ def test_calculator(self) -> None:
8387

8488
ase_atoms1 = Atoms(
8589
numbers=[atomic_numbers[i] for i in idx_perm],
86-
positions=coord[idx_perm, :],
90+
positions=coord_np[idx_perm, :],
8791
# positions=[tuple(item) for item in coordinate],
88-
cell=cell,
92+
cell=cell_np,
8993
calculator=self.calculator,
9094
pbc=True,
9195
)
@@ -141,19 +145,23 @@ def test_calculator(self) -> None:
141145
generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED)
142146
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
143147
coord = torch.matmul(coord, cell)
144-
fparam = torch.IntTensor([1, 2])
145-
aparam = torch.IntTensor([[1], [0], [2], [1], [0]])
148+
fparam = torch.IntTensor([1, 2]).numpy()
149+
aparam = torch.IntTensor([[1], [0], [2], [1], [0]]).numpy()
146150
atomic_numbers = [1, 1, 1, 8, 8]
147151
idx_perm = [1, 0, 4, 3, 2]
148152

153+
# Convert tensors to numpy for ASE compatibility
154+
cell_np = cell.numpy()
155+
coord_np = coord.numpy()
156+
149157
prec = 1e-10
150158
low_prec = 1e-4
151159

152160
ase_atoms0 = Atoms(
153161
numbers=atomic_numbers,
154-
positions=coord,
162+
positions=coord_np,
155163
# positions=[tuple(item) for item in coordinate],
156-
cell=cell,
164+
cell=cell_np,
157165
calculator=self.calculator,
158166
pbc=True,
159167
)
@@ -166,9 +174,9 @@ def test_calculator(self) -> None:
166174

167175
ase_atoms1 = Atoms(
168176
numbers=[atomic_numbers[i] for i in idx_perm],
169-
positions=coord[idx_perm, :],
177+
positions=coord_np[idx_perm, :],
170178
# positions=[tuple(item) for item in coordinate],
171-
cell=cell,
179+
cell=cell_np,
172180
calculator=self.calculator,
173181
pbc=True,
174182
)

0 commit comments

Comments
 (0)