Skip to content

Commit 47b76c8

Browse files
authored
fix(pt): fix precision (#4344)
Tried to implement the decorator as in #4343, but encountered JIT errors. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Enhanced precision handling across various descriptor classes and methods, ensuring consistent tensor operations. - Updated output formats in several classes to improve clarity and usability. - Introduced a new environment variable for stricter control over tensor precision handling. - Added a new parameter to the `DipoleFittingNet` class for excluding specific types. - **Bug Fixes** - Removed conditions that skipped tests for "float32" data type, allowing all tests to run consistently. - **Documentation** - Improved error messages for dimension mismatches and unsupported parameters, enhancing user understanding. - **Tests** - Adjusted test parameters for consistency in handling `fparam` and `aparam` across multiple test cases. - Simplified tensor handling in tests by removing unnecessary type conversions before compression. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 698b08d commit 47b76c8

19 files changed

+114
-64
lines changed

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
env,
2323
)
2424
from deepmd.pt.utils.env import (
25+
PRECISION_DICT,
2526
RESERVED_PRECISON_DICT,
2627
)
2728
from deepmd.pt.utils.tabulate import (
@@ -311,6 +312,7 @@ def __init__(
311312
use_tebd_bias=use_tebd_bias,
312313
type_map=type_map,
313314
)
315+
self.prec = PRECISION_DICT[precision]
314316
self.tebd_dim = tebd_dim
315317
self.concat_output_tebd = concat_output_tebd
316318
self.trainable = trainable
@@ -678,6 +680,8 @@ def forward(
678680
The smooth switch function. shape: nf x nloc x nnei
679681
680682
"""
683+
# cast the input to internal precsion
684+
extended_coord = extended_coord.to(dtype=self.prec)
681685
del mapping
682686
nframes, nloc, nnei = nlist.shape
683687
nall = extended_coord.view(nframes, -1).shape[1] // 3
@@ -693,7 +697,13 @@ def forward(
693697
if self.concat_output_tebd:
694698
g1 = torch.cat([g1, g1_inp], dim=-1)
695699

696-
return g1, rot_mat, g2, h2, sw
700+
return (
701+
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
702+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
703+
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if g2 is not None else None,
704+
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
705+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
706+
)
697707

698708
@classmethod
699709
def update_sel(

deepmd/pt/model/descriptor/dpa2.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from deepmd.pt.utils import (
2828
env,
2929
)
30+
from deepmd.pt.utils.env import (
31+
PRECISION_DICT,
32+
)
3033
from deepmd.pt.utils.nlist import (
3134
build_multiple_neighbor_list,
3235
get_multiple_nlist_key,
@@ -268,6 +271,7 @@ def init_subclass_params(sub_data, sub_class):
268271
)
269272
self.concat_output_tebd = concat_output_tebd
270273
self.precision = precision
274+
self.prec = PRECISION_DICT[self.precision]
271275
self.smooth = smooth
272276
self.exclude_types = exclude_types
273277
self.env_protection = env_protection
@@ -745,6 +749,9 @@ def forward(
745749
The smooth switch function. shape: nf x nloc x nnei
746750
747751
"""
752+
# cast the input to internal precsion
753+
extended_coord = extended_coord.to(dtype=self.prec)
754+
748755
use_three_body = self.use_three_body
749756
nframes, nloc, nnei = nlist.shape
750757
nall = extended_coord.view(nframes, -1).shape[1] // 3
@@ -810,7 +817,13 @@ def forward(
810817
)
811818
if self.concat_output_tebd:
812819
g1 = torch.cat([g1, g1_inp], dim=-1)
813-
return g1, rot_mat, g2, h2, sw
820+
return (
821+
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
822+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
823+
g2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
824+
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
825+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
826+
)
814827

815828
@classmethod
816829
def update_sel(

deepmd/pt/model/descriptor/repformers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from deepmd.pt.utils import (
2323
env,
2424
)
25+
from deepmd.pt.utils.env import (
26+
PRECISION_DICT,
27+
)
2528
from deepmd.pt.utils.env_mat_stat import (
2629
EnvMatStatSe,
2730
)
@@ -237,6 +240,7 @@ def __init__(
237240
self.reinit_exclude(exclude_types)
238241
self.env_protection = env_protection
239242
self.precision = precision
243+
self.prec = PRECISION_DICT[precision]
240244
self.trainable_ln = trainable_ln
241245
self.ln_eps = ln_eps
242246
self.epsilon = 1e-4
@@ -286,12 +290,8 @@ def __init__(
286290
self.layers = torch.nn.ModuleList(layers)
287291

288292
wanted_shape = (self.ntypes, self.nnei, 4)
289-
mean = torch.zeros(
290-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
291-
)
292-
stddev = torch.ones(
293-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
294-
)
293+
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
294+
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
295295
self.register_buffer("mean", mean)
296296
self.register_buffer("stddev", stddev)
297297
self.stats = None

deepmd/pt/model/descriptor/se_a.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
super().__init__()
119119
self.type_map = type_map
120120
self.compress = False
121+
self.prec = PRECISION_DICT[precision]
121122
self.sea = DescrptBlockSeA(
122123
rcut,
123124
rcut_smth,
@@ -337,7 +338,18 @@ def forward(
337338
The smooth switch function.
338339
339340
"""
340-
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)
341+
# cast the input to internal precsion
342+
coord_ext = coord_ext.to(dtype=self.prec)
343+
g1, rot_mat, g2, h2, sw = self.sea.forward(
344+
nlist, coord_ext, atype_ext, None, mapping
345+
)
346+
return (
347+
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
348+
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
349+
None,
350+
None,
351+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
352+
)
341353

342354
def set_stat_mean_and_stddev(
343355
self,
@@ -742,7 +754,6 @@ def forward(
742754
)
743755

744756
dmatrix = dmatrix.view(-1, self.nnei, 4)
745-
dmatrix = dmatrix.to(dtype=self.prec)
746757
nfnl = dmatrix.shape[0]
747758
# pre-allocate a shape to pass jit
748759
xyz_scatter = torch.zeros(
@@ -811,8 +822,8 @@ def forward(
811822
result = result.view(nf, nloc, self.filter_neuron[-1] * self.axis_neuron)
812823
rot_mat = rot_mat.view([nf, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
813824
return (
814-
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
815-
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
825+
result,
826+
rot_mat,
816827
None,
817828
None,
818829
sw,

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,8 @@ def __init__(
227227
)
228228

229229
wanted_shape = (self.ntypes, self.nnei, 4)
230-
mean = torch.zeros(
231-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
232-
)
233-
stddev = torch.ones(
234-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
235-
)
230+
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
231+
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
236232
self.register_buffer("mean", mean)
237233
self.register_buffer("stddev", stddev)
238234
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
@@ -568,8 +564,6 @@ def forward(
568564
# nfnl x nnei x ng
569565
# gg = gg_s * gg_t + gg_s
570566
gg_t = gg_t.reshape(-1, gg_t.size(-1))
571-
# Convert all tensors to the required precision at once
572-
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
573567
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
574568
self.compress_data[0].contiguous(),
575569
self.compress_info[0].cpu().contiguous(),

deepmd/pt/model/descriptor/se_r.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def forward(
456456
The smooth switch function.
457457
458458
"""
459+
# cast the input to internal precsion
460+
coord_ext = coord_ext.to(dtype=self.prec)
459461
del mapping, comm_dict
460462
nf = nlist.shape[0]
461463
nloc = nlist.shape[1]
@@ -474,7 +476,6 @@ def forward(
474476

475477
assert self.filter_layers is not None
476478
dmatrix = dmatrix.view(-1, self.nnei, 1)
477-
dmatrix = dmatrix.to(dtype=self.prec)
478479
nfnl = dmatrix.shape[0]
479480
# pre-allocate a shape to pass jit
480481
xyz_scatter = torch.zeros(
@@ -519,7 +520,7 @@ def forward(
519520
None,
520521
None,
521522
None,
522-
sw,
523+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
523524
)
524525

525526
def set_stat_mean_and_stddev(

deepmd/pt/model/descriptor/se_t.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
super().__init__()
155155
self.type_map = type_map
156156
self.compress = False
157+
self.prec = PRECISION_DICT[precision]
157158
self.seat = DescrptBlockSeT(
158159
rcut,
159160
rcut_smth,
@@ -373,7 +374,18 @@ def forward(
373374
The smooth switch function.
374375
375376
"""
376-
return self.seat.forward(nlist, coord_ext, atype_ext, None, mapping)
377+
# cast the input to internal precsion
378+
coord_ext = coord_ext.to(dtype=self.prec)
379+
g1, rot_mat, g2, h2, sw = self.seat.forward(
380+
nlist, coord_ext, atype_ext, None, mapping
381+
)
382+
return (
383+
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
384+
None,
385+
None,
386+
None,
387+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
388+
)
377389

378390
def set_stat_mean_and_stddev(
379391
self,
@@ -801,7 +813,6 @@ def forward(
801813
protection=self.env_protection,
802814
)
803815
dmatrix = dmatrix.view(-1, self.nnei, 4)
804-
dmatrix = dmatrix.to(dtype=self.prec)
805816
nfnl = dmatrix.shape[0]
806817
# pre-allocate a shape to pass jit
807818
result = torch.zeros(
@@ -832,8 +843,6 @@ def forward(
832843
env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j)
833844
if self.compress:
834845
ebd_env_ij = env_ij.view(-1, 1)
835-
ebd_env_ij = ebd_env_ij.to(dtype=self.prec)
836-
env_ij = env_ij.to(dtype=self.prec)
837846
res_ij = torch.ops.deepmd.tabulate_fusion_se_t(
838847
compress_data_ii.contiguous(),
839848
compress_info_ii.cpu().contiguous(),
@@ -853,7 +862,7 @@ def forward(
853862
# xyz_scatter /= (self.nnei * self.nnei)
854863
result = result.view(nf, nloc, self.filter_neuron[-1])
855864
return (
856-
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
865+
result,
857866
None,
858867
None,
859868
None,

deepmd/pt/model/descriptor/se_t_tebd.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
smooth=smooth,
162162
seed=child_seed(seed, 1),
163163
)
164+
self.prec = PRECISION_DICT[precision]
164165
self.use_econf_tebd = use_econf_tebd
165166
self.type_map = type_map
166167
self.smooth = smooth
@@ -441,12 +442,14 @@ def forward(
441442
The smooth switch function. shape: nf x nloc x nnei
442443
443444
"""
445+
# cast the input to internal precsion
446+
extended_coord = extended_coord.to(dtype=self.prec)
444447
del mapping
445448
nframes, nloc, nnei = nlist.shape
446449
nall = extended_coord.view(nframes, -1).shape[1] // 3
447450
g1_ext = self.type_embedding(extended_atype)
448451
g1_inp = g1_ext[:, :nloc, :]
449-
g1, g2, h2, rot_mat, sw = self.se_ttebd(
452+
g1, _, _, _, sw = self.se_ttebd(
450453
nlist,
451454
extended_coord,
452455
extended_atype,
@@ -456,7 +459,13 @@ def forward(
456459
if self.concat_output_tebd:
457460
g1 = torch.cat([g1, g1_inp], dim=-1)
458461

459-
return g1, rot_mat, g2, h2, sw
462+
return (
463+
g1.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
464+
None,
465+
None,
466+
None,
467+
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
468+
)
460469

461470
@classmethod
462471
def update_sel(
@@ -540,12 +549,8 @@ def __init__(
540549
self.reinit_exclude(exclude_types)
541550

542551
wanted_shape = (self.ntypes, self.nnei, 4)
543-
mean = torch.zeros(
544-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
545-
)
546-
stddev = torch.ones(
547-
wanted_shape, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
548-
)
552+
mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE)
553+
stddev = torch.ones(wanted_shape, dtype=self.prec, device=env.DEVICE)
549554
self.register_buffer("mean", mean)
550555
self.register_buffer("stddev", stddev)
551556
self.tebd_dim_input = self.tebd_dim * 2
@@ -849,7 +854,7 @@ def forward(
849854
# nf x nl x ng
850855
result = res_ij.view(nframes, nloc, self.filter_neuron[-1])
851856
return (
852-
result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
857+
result,
853858
None,
854859
None,
855860
None,

deepmd/pt/model/network/mlp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,8 @@ def forward(
200200
The output.
201201
"""
202202
ori_prec = xx.dtype
203-
xx = xx.to(self.prec)
203+
if not env.DP_DTYPE_PROMOTION_STRICT:
204+
xx = xx.to(self.prec)
204205
yy = (
205206
torch.matmul(xx, self.matrix) + self.bias
206207
if self.bias is not None
@@ -215,7 +216,8 @@ def forward(
215216
yy += torch.concat([xx, xx], dim=-1)
216217
else:
217218
yy = yy
218-
yy = yy.to(ori_prec)
219+
if not env.DP_DTYPE_PROMOTION_STRICT:
220+
yy = yy.to(ori_prec)
219221
return yy
220222

221223
def serialize(self) -> dict:

deepmd/pt/model/task/dipole.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def forward(
180180
):
181181
nframes, nloc, _ = descriptor.shape
182182
assert gr is not None, "Must provide the rotation matrix for dipole fitting."
183+
# cast the input to internal precsion
184+
gr = gr.to(self.prec)
183185
# (nframes, nloc, m1)
184186
out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[
185187
self.var_name

0 commit comments

Comments
 (0)