Skip to content
Merged
58 changes: 46 additions & 12 deletions source/tests/pd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def eval_model(
if spins is not None:
assert isinstance(spins, paddle.Tensor), err_msg
assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
if isinstance(atom_types, paddle.Tensor):
atom_types = (
atom_types.clone().detach().to(dtype=paddle.int32, device=DEVICE)
)
else:
atom_types = paddle.to_tensor(atom_types, dtype=paddle.int32, place=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -101,28 +106,57 @@ def eval_model(
else:
natoms = len(atom_types[0])

coord_input = paddle.to_tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
spin_input = None
if spins is not None:
spin_input = paddle.to_tensor(
spins.reshape([-1, natoms, 3]),
if isinstance(coords, paddle.Tensor):
coord_input = (
coords.reshape([-1, natoms, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
)
else:
coord_input = paddle.to_tensor(
coords.reshape([-1, natoms, 3]),
dtype=GLOBAL_PD_FLOAT_PRECISION,
place=DEVICE,
)
spin_input = None
if spins is not None:
if isinstance(spins, paddle.Tensor):
spin_input = (
spins.reshape([-1, natoms, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
)
else:
spin_input = paddle.to_tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PD_FLOAT_PRECISION,
place=DEVICE,
)
has_spin = getattr(model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
if isinstance(atom_types, paddle.Tensor):
type_input = atom_types.clone().detach().to(dtype=paddle.int64, device=DEVICE)
else:
type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = paddle.to_tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
if isinstance(cells, paddle.Tensor):
box_input = (
cells.reshape([-1, 3, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PD_FLOAT_PRECISION, device=DEVICE)
)
else:
box_input = paddle.to_tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION, place=DEVICE
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down
60 changes: 48 additions & 12 deletions source/tests/pt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def eval_model(
if spins is not None:
assert isinstance(spins, torch.Tensor), err_msg
assert isinstance(atom_types, torch.Tensor) or isinstance(atom_types, list)
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
if isinstance(atom_types, torch.Tensor):
atom_types = (
atom_types.clone().detach().to(dtype=torch.int32, device=DEVICE)
)
else:
atom_types = torch.tensor(atom_types, dtype=torch.int32, device=DEVICE)
elif isinstance(coords, np.ndarray):
if cells is not None:
assert isinstance(cells, np.ndarray), err_msg
Expand All @@ -101,28 +106,59 @@ def eval_model(
else:
natoms = len(atom_types[0])

coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
spin_input = None
if spins is not None:
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
if isinstance(coords, torch.Tensor):
coord_input = (
coords.reshape([-1, natoms, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
)
else:
coord_input = torch.tensor(
coords.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
spin_input = None
if spins is not None:
if isinstance(spins, torch.Tensor):
spin_input = (
spins.reshape([-1, natoms, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
)
else:
spin_input = torch.tensor(
spins.reshape([-1, natoms, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
has_spin = getattr(model, "has_spin", False)
if callable(has_spin):
has_spin = has_spin()
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
if isinstance(atom_types, torch.Tensor):
type_input = atom_types.clone().detach().to(dtype=torch.long, device=DEVICE)
else:
type_input = torch.tensor(atom_types, dtype=torch.long, device=DEVICE)
box_input = None
if cells is None:
pbc = False
else:
pbc = True
box_input = torch.tensor(
cells.reshape([-1, 3, 3]), dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
)
if isinstance(cells, torch.Tensor):
box_input = (
cells.reshape([-1, 3, 3])
.clone()
.detach()
.to(dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE)
)
else:
box_input = torch.tensor(
cells.reshape([-1, 3, 3]),
dtype=GLOBAL_PT_FLOAT_PRECISION,
device=DEVICE,
)
num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)

for ii in range(num_iter):
Expand Down
28 changes: 18 additions & 10 deletions source/tests/pt/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,18 @@ def test_calculator(self) -> None:
atomic_numbers = [1, 1, 1, 8, 8]
idx_perm = [1, 0, 4, 3, 2]

# Convert tensors to numpy for ASE compatibility
cell_np = cell.numpy()
coord_np = coord.numpy()

prec = 1e-10
low_prec = 1e-4

ase_atoms0 = Atoms(
numbers=atomic_numbers,
positions=coord,
positions=coord_np,
# positions=[tuple(item) for item in coordinate],
cell=cell,
cell=cell_np,
calculator=self.calculator,
pbc=True,
)
Expand All @@ -83,9 +87,9 @@ def test_calculator(self) -> None:

ase_atoms1 = Atoms(
numbers=[atomic_numbers[i] for i in idx_perm],
positions=coord[idx_perm, :],
positions=coord_np[idx_perm, :],
# positions=[tuple(item) for item in coordinate],
cell=cell,
cell=cell_np,
calculator=self.calculator,
pbc=True,
)
Expand Down Expand Up @@ -141,19 +145,23 @@ def test_calculator(self) -> None:
generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED)
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
coord = torch.matmul(coord, cell)
fparam = torch.IntTensor([1, 2])
aparam = torch.IntTensor([[1], [0], [2], [1], [0]])
fparam = torch.IntTensor([1, 2]).numpy()
aparam = torch.IntTensor([[1], [0], [2], [1], [0]]).numpy()
atomic_numbers = [1, 1, 1, 8, 8]
idx_perm = [1, 0, 4, 3, 2]

# Convert tensors to numpy for ASE compatibility
cell_np = cell.numpy()
coord_np = coord.numpy()

prec = 1e-10
low_prec = 1e-4

ase_atoms0 = Atoms(
numbers=atomic_numbers,
positions=coord,
positions=coord_np,
# positions=[tuple(item) for item in coordinate],
cell=cell,
cell=cell_np,
calculator=self.calculator,
pbc=True,
)
Expand All @@ -166,9 +174,9 @@ def test_calculator(self) -> None:

ase_atoms1 = Atoms(
numbers=[atomic_numbers[i] for i in idx_perm],
positions=coord[idx_perm, :],
positions=coord_np[idx_perm, :],
# positions=[tuple(item) for item in coordinate],
cell=cell,
cell=cell_np,
calculator=self.calculator,
pbc=True,
)
Expand Down