Skip to content

Commit 0c5ab07

Browse files
cherryWangYpre-commit-ci[bot]njzjz
authored
Add pt compress commad line (#4300)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced model compression functionality, allowing users to compress models directly via the command-line interface. - Added a new command option `"compress"` to trigger model compression. - Enhanced help messages and examples for the `"compress"` command to clarify usage with different backends. - Added a comprehensive JSON configuration file for model compression parameters. - Improved handling of compression parameters within descriptor classes for better organization and efficiency. - **Bug Fixes** - Improved error handling for unsupported file formats during model loading. - **Tests** - Introduced a suite of unit tests to evaluate the functionality of model compression, ensuring accuracy and performance across different configurations. - Enhanced tests for loading model parameters to ensure all required attributes are correctly handled. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent 3701566 commit 0c5ab07

File tree

25 files changed

+1003
-199
lines changed

25 files changed

+1003
-199
lines changed

deepmd/main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser:
424424
parser_compress = subparsers.add_parser(
425425
"compress",
426426
parents=[parser_log, parser_mpi_log],
427-
help="(Supported backend: TensorFlow) compress a model",
427+
help="Compress a model",
428428
formatter_class=RawTextArgumentDefaultsHelpFormatter,
429429
epilog=textwrap.dedent(
430430
"""\
431431
examples:
432432
dp compress
433-
dp compress -i graph.pb -o compressed.pb
433+
dp --tf compress -i frozen_model.pb -o compressed_model.pb
434+
dp --pt compress -i frozen_model.pth -o compressed_model.pth
434435
"""
435436
),
436437
)
437438
parser_compress.add_argument(
438439
"-i",
439440
"--input",
440-
default="frozen_model.pb",
441+
default="frozen_model",
441442
type=str,
442-
help="The original frozen model, which will be compressed by the code",
443+
help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
443444
)
444445
parser_compress.add_argument(
445446
"-o",
446447
"--output",
447-
default="frozen_model_compressed.pb",
448+
default="frozen_model_compressed",
448449
type=str,
449-
help="The compressed model",
450+
help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
450451
)
451452
parser_compress.add_argument(
452453
"-s",

deepmd/pt/entrypoints/compress.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
4+
import torch
5+
6+
from deepmd.pt.model.model import (
7+
get_model,
8+
)
9+
10+
11+
def enable_compression(
12+
input_file: str,
13+
output: str,
14+
stride: float = 0.01,
15+
extrapolate: int = 5,
16+
check_frequency: int = -1,
17+
):
18+
saved_model = torch.jit.load(input_file, map_location="cpu")
19+
model_def_script = json.loads(saved_model.model_def_script)
20+
model = get_model(model_def_script)
21+
model.load_state_dict(saved_model.state_dict())
22+
23+
model.enable_compression(
24+
extrapolate,
25+
stride,
26+
stride * 10,
27+
check_frequency,
28+
)
29+
30+
model = torch.jit.script(model)
31+
torch.jit.save(model, output)

deepmd/pt/entrypoints/main.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from deepmd.pt.cxx_op import (
3939
ENABLE_CUSTOMIZED_OP,
4040
)
41+
from deepmd.pt.entrypoints.compress import (
42+
enable_compression,
43+
)
4144
from deepmd.pt.infer import (
4245
inference,
4346
)
@@ -346,10 +349,14 @@ def train(
346349
# save min_nbor_dist
347350
if min_nbor_dist is not None:
348351
if not multi_task:
349-
trainer.model.min_nbor_dist = min_nbor_dist
352+
trainer.model.min_nbor_dist = torch.tensor(
353+
min_nbor_dist, dtype=torch.float64, device=DEVICE
354+
)
350355
else:
351356
for model_item in min_nbor_dist:
352-
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
357+
trainer.model[model_item].min_nbor_dist = torch.tensor(
358+
min_nbor_dist[model_item], dtype=torch.float64, device=DEVICE
359+
)
353360
trainer.run()
354361

355362

@@ -549,6 +556,16 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
549556
model_branch=FLAGS.model_branch,
550557
output=FLAGS.output,
551558
)
559+
elif FLAGS.command == "compress":
560+
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
561+
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
562+
enable_compression(
563+
input_file=FLAGS.input,
564+
output=FLAGS.output,
565+
stride=FLAGS.step,
566+
extrapolate=FLAGS.extrapolate,
567+
check_frequency=FLAGS.frequency,
568+
)
552569
else:
553570
raise RuntimeError(f"Invalid command {FLAGS.command}!")
554571

deepmd/pt/model/descriptor/se_a.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import torch
12+
import torch.nn as nn
1213

1314
from deepmd.dpmodel.utils.seed import (
1415
child_seed,
@@ -437,10 +438,6 @@ def update_sel(
437438
class DescrptBlockSeA(DescriptorBlock):
438439
ndescrpt: Final[int]
439440
__constants__: ClassVar[list] = ["ndescrpt"]
440-
lower: dict[str, int]
441-
upper: dict[str, int]
442-
table_data: dict[str, torch.Tensor]
443-
table_config: list[Union[int, float]]
444441

445442
def __init__(
446443
self,
@@ -500,13 +497,6 @@ def __init__(
500497
self.register_buffer("mean", mean)
501498
self.register_buffer("stddev", stddev)
502499

503-
# add for compression
504-
self.compress = False
505-
self.lower = {}
506-
self.upper = {}
507-
self.table_data = {}
508-
self.table_config = []
509-
510500
ndim = 1 if self.type_one_side else 2
511501
filter_layers = NetworkCollection(
512502
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
@@ -529,6 +519,21 @@ def __init__(
529519
for param in self.parameters():
530520
param.requires_grad = trainable
531521

522+
# add for compression
523+
self.compress = False
524+
self.compress_info = nn.ParameterList(
525+
[
526+
nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))
527+
for _ in range(len(self.filter_layers.networks))
528+
]
529+
)
530+
self.compress_data = nn.ParameterList(
531+
[
532+
nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))
533+
for _ in range(len(self.filter_layers.networks))
534+
]
535+
)
536+
532537
def get_rcut(self) -> float:
533538
"""Returns the cut-off radius."""
534539
return self.rcut
@@ -667,16 +672,39 @@ def reinit_exclude(
667672

668673
def enable_compression(
669674
self,
670-
table_data,
671-
table_config,
672-
lower,
673-
upper,
675+
table_data: dict[str, torch.Tensor],
676+
table_config: list[Union[int, float]],
677+
lower: dict[str, int],
678+
upper: dict[str, int],
674679
) -> None:
680+
for embedding_idx, ll in enumerate(self.filter_layers.networks):
681+
if self.type_one_side:
682+
ii = embedding_idx
683+
ti = -1
684+
else:
685+
# ti: center atom type, ii: neighbor type...
686+
ii = embedding_idx // self.ntypes
687+
ti = embedding_idx % self.ntypes
688+
if self.type_one_side:
689+
net = "filter_-1_net_" + str(ii)
690+
else:
691+
net = "filter_" + str(ti) + "_net_" + str(ii)
692+
info_ii = torch.as_tensor(
693+
[
694+
lower[net],
695+
upper[net],
696+
upper[net] * table_config[0],
697+
table_config[1],
698+
table_config[2],
699+
table_config[3],
700+
],
701+
dtype=self.prec,
702+
device="cpu",
703+
)
704+
tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec)
705+
self.compress_data[embedding_idx] = tensor_data_ii
706+
self.compress_info[embedding_idx] = info_ii
675707
self.compress = True
676-
self.table_data = table_data
677-
self.table_config = table_config
678-
self.lower = lower
679-
self.upper = upper
680708

681709
def forward(
682710
self,
@@ -724,7 +752,9 @@ def forward(
724752
)
725753
# nfnl x nnei
726754
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)
727-
for embedding_idx, ll in enumerate(self.filter_layers.networks):
755+
for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate(
756+
zip(self.filter_layers.networks, self.compress_data, self.compress_info)
757+
):
728758
if self.type_one_side:
729759
ii = embedding_idx
730760
ti = -1
@@ -751,23 +781,11 @@ def forward(
751781
ss = rr[:, :, :1]
752782

753783
if self.compress:
754-
if self.type_one_side:
755-
net = "filter_-1_net_" + str(ii)
756-
else:
757-
net = "filter_" + str(ti) + "_net_" + str(ii)
758-
info = [
759-
self.lower[net],
760-
self.upper[net],
761-
self.upper[net] * self.table_config[0],
762-
self.table_config[1],
763-
self.table_config[2],
764-
self.table_config[3],
765-
]
766784
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
767-
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)
785+
768786
gr = torch.ops.deepmd.tabulate_fusion_se_a(
769-
tensor_data.contiguous(),
770-
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
787+
compress_data_ii.contiguous(),
788+
compress_info_ii.cpu().contiguous(),
771789
ss.contiguous(),
772790
rr.contiguous(),
773791
self.filter_neuron[-1],

deepmd/pt/model/descriptor/se_atten.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,6 @@ def tabulate_fusion_se_atten(
7171

7272
@DescriptorBlock.register("se_atten")
7373
class DescrptBlockSeAtten(DescriptorBlock):
74-
lower: dict[str, int]
75-
upper: dict[str, int]
76-
table_data: dict[str, torch.Tensor]
77-
table_config: list[Union[int, float]]
78-
7974
def __init__(
8075
self,
8176
rcut: float,
@@ -202,14 +197,6 @@ def __init__(
202197
ln_eps = 1e-5
203198
self.ln_eps = ln_eps
204199

205-
# add for compression
206-
self.compress = False
207-
self.is_sorted = False
208-
self.lower = {}
209-
self.upper = {}
210-
self.table_data = {}
211-
self.table_config = []
212-
213200
if isinstance(sel, int):
214201
sel = [sel]
215202

@@ -282,6 +269,16 @@ def __init__(
282269
self.filter_layers_strip = filter_layers_strip
283270
self.stats = None
284271

272+
# add for compression
273+
self.compress = False
274+
self.is_sorted = False
275+
self.compress_info = nn.ParameterList(
276+
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
277+
)
278+
self.compress_data = nn.ParameterList(
279+
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
280+
)
281+
285282
def get_rcut(self) -> float:
286283
"""Returns the cut-off radius."""
287284
return self.rcut
@@ -431,11 +428,21 @@ def enable_compression(
431428
lower,
432429
upper,
433430
) -> None:
431+
net = "filter_net"
432+
self.compress_info[0] = torch.as_tensor(
433+
[
434+
lower[net],
435+
upper[net],
436+
upper[net] * table_config[0],
437+
table_config[1],
438+
table_config[2],
439+
table_config[3],
440+
],
441+
dtype=self.prec,
442+
device="cpu",
443+
)
444+
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
434445
self.compress = True
435-
self.table_data = table_data
436-
self.table_config = table_config
437-
self.lower = lower
438-
self.upper = upper
439446

440447
def forward(
441448
self,
@@ -544,15 +551,6 @@ def forward(
544551
xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg)
545552
elif self.tebd_input_mode in ["strip"]:
546553
if self.compress:
547-
net = "filter_net"
548-
info = [
549-
self.lower[net],
550-
self.upper[net],
551-
self.upper[net] * self.table_config[0],
552-
self.table_config[1],
553-
self.table_config[2],
554-
self.table_config[3],
555-
]
556554
ss = ss.reshape(-1, 1)
557555
# nfnl x nnei x ng
558556
# gg_s = self.filter_layers.networks[0](ss)
@@ -569,14 +567,12 @@ def forward(
569567
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
570568
# nfnl x nnei x ng
571569
# gg = gg_s * gg_t + gg_s
572-
tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec)
573-
info_tensor = torch.tensor(info, dtype=self.prec, device="cpu")
574570
gg_t = gg_t.reshape(-1, gg_t.size(-1))
575571
# Convert all tensors to the required precision at once
576572
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
577573
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
578-
tensor_data.contiguous(),
579-
info_tensor.contiguous(),
574+
self.compress_data[0].contiguous(),
575+
self.compress_info[0].cpu().contiguous(),
580576
ss.contiguous(),
581577
rr.contiguous(),
582578
gg_t.contiguous(),

0 commit comments

Comments
 (0)