Skip to content

Commit 47845a4

Browse files
committed
test:atom_polarizability also need to reshape to adapt current implementation
1 parent 0c87625 commit 47845a4

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

deepmd/utils/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,8 @@ def _load_single_data(
905905
# data should be 2D here: [natoms, ndof]
906906
data = data.reshape([natoms, -1])
907907
data = data[idx_map, :]
908+
else:
909+
data = data.reshape([ndof])
908910

909911
# Atomic: return [natoms, ndof] or flattened hessian above
910912
# Non-atomic: return [ndof]

source/tests/pt/test_loss_tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
DataRequirementItem,
2525
)
2626

27-
from ..seed import (
28-
GLOBAL_SEED,
29-
)
27+
# from ..seed import (
28+
# GLOBAL_SEED,
29+
# )
30+
31+
GLOBAL_SEED = 7
3032

3133
CUR_DIR = os.path.dirname(__file__)
3234

@@ -57,7 +59,7 @@ def get_single_batch(dataset, index=None):
5759
if key in np_batch.keys():
5860
np_batch[key] = np.expand_dims(np_batch[key], axis=0)
5961
pt_batch[key] = torch.as_tensor(np_batch[key], device=env.DEVICE)
60-
if key in ["coord", "atom_dipole"]:
62+
if key in ["coord", "atom_dipole", "atom_polarizability"]:
6163
np_batch[key] = np_batch[key].reshape(1, -1)
6264
np_batch["natoms"] = np_batch["natoms"][0]
6365
return np_batch, pt_batch

0 commit comments

Comments
 (0)