Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ea77602
support checking model redundancy
lixinqi Jul 31, 2025
c3b3ea9
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Jul 31, 2025
c21cb49
revert change of vision_model_test
lixinqi Jul 31, 2025
ca9017f
reformat python code.
lixinqi Jul 31, 2025
52cc34d
reformat bert_model_test.py and utils.py
lixinqi Jul 31, 2025
d8c6213
minor fix
lixinqi Jul 31, 2025
6bd1370
fix failed check by comparing directories after os.path.realpath()
lixinqi Aug 4, 2025
5db0b63
merge paddle repo develop
lixinqi Aug 4, 2025
165ae4b
fix bugs in check_validate.sh
lixinqi Aug 4, 2025
6059328
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 4, 2025
d2d9e0b
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 5, 2025
2cfa175
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 5, 2025
3a75ddd
set dynamic=False in single_device_runner.py
lixinqi Aug 5, 2025
b394760
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 6, 2025
868b686
reset graph hash
lixinqi Aug 6, 2025
4c473cf
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 6, 2025
5530841
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 6, 2025
0caf96a
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 6, 2025
718cd39
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 6, 2025
d46c810
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 7, 2025
985d3cc
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 7, 2025
782858e
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 11, 2025
a7acb51
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 14, 2025
b49edb4
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 15, 2025
6be9482
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Aug 30, 2025
f8524bb
backup code for multi_dim_size
lixinqi Sep 6, 2025
7b950b9
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Sep 6, 2025
956fb0d
Merge branch 'develop' of github.com:PaddlePaddle/GraphNet into develop
lixinqi Sep 15, 2025
60fb21b
merge develop
lixinqi Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions graph_net/torch/constraint_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import copy


def get_all_symbol_names(constraint_attrs_list):
unique_symbol_names = []
for constraint_attrs in constraint_attrs_list:
for dim in constraint_attrs["shape"]:
if isinstance(dim, int):
continue
assert isinstance(dim, dict)
if dim["symbol_name"] in unique_symbol_names:
continue
unique_symbol_names.append(dim["symbol_name"])

return unique_symbol_names


def reify_symboli_dims(constraint_attrs_list, symbol_names):
def try_reify_dim(dim):
if isinstance(dim, int):
return dim
assert isinstance(dim, dict)
if dim["symbol_name"] not in symbol_names:
return dim
return dim["example_value"]

constraint_attrs_list = copy.deepcopy(constraint_attrs_list)
for constraint_attrs in constraint_attrs_list:
constraint_attrs["shape"] = [
try_reify_dim(dim) for dim in constraint_attrs["shape"]
]
return constraint_attrs_list


def modify_dim_example_value(constraint_attrs_list, symbol_name, modifier):
def modify_dim(dim):
if isinstance(dim, int):
return
assert isinstance(dim, dict)
dim["example_value"] = modifier(dim["example_value"])

constraint_attrs_list = copy.deepcopy(constraint_attrs_list)
for constraint_attrs in constraint_attrs_list:
for dim in constraint_attrs["shape"]:
modify_dim(dim)
return constraint_attrs_list


def symbolic_dims_all_reified(constraint_attrs_list):
for constraint_attrs in constraint_attrs_list:
for dim in constraint_attrs["shape"]:
if not isinstance(dim, int):
return False
return True
140 changes: 140 additions & 0 deletions graph_net/torch/generate_constraint_proposal_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from graph_net.torch import utils
import argparse
import torch
import logging
from pathlib import Path
from typing import Type, Any
import sys
from graph_net.torch.imp_util import load_class_from_file
import hashlib
from contextlib import contextmanager
import json
import inspect
import imp_util
import record_util
import copy


def main(args):
model_path = args.model_path
name2input_param_attrs = _get_name2input_param_attrs(model_path)
name_and_annotation_types = _get_name_and_annotation_types(model_path)
input_name_and_meta_attrs = _get_input_name_and_meta_attrs(
name2input_param_attrs, name_and_annotation_types
)
input_name_and_constraint_attrs = _get_input_name_and_constraint_attrs(
input_name_and_meta_attrs
)
_dump_input_name_and_constraint_attrs(
input_name_and_constraint_attrs, args.output_path
)


def _dump_input_name_and_constraint_attrs(input_name_and_constraint_attrs, output_path):
py_code = record_util.serialize_to_py_code(
[attr for _, attr in input_name_and_constraint_attrs],
class_prefix="ProgramInputConstraint",
)
print(f"{output_path=}")
with open(output_path, "w") as f:
f.write(py_code)


def _get_input_name_and_constraint_attrs(input_name_and_meta_attrs):
seq_no = 0
dim2seq = {}

def find_or_new_seq(dim):
nonlocal seq_no
nonlocal dim2seq
if dim in dim2seq:
return dim2seq[dim]
ret = seq_no
dim2seq[dim] = ret
seq_no += 1
return ret

def make_symoblic_shape(shape):
return type(shape)(
[
symbolic_dim_desc
for dim in shape
for dim_seq_no in [find_or_new_seq(dim)]
for symbolic_dim_desc in [
{"symbol_name": f"s{dim_seq_no}", "example_value": dim}
]
]
)

def make_constraint_attrs(attrs):
attrs = copy.deepcopy(attrs)
attrs["shape"] = make_symoblic_shape(attrs["shape"])
return attrs

return [
(name, symbolic_attrs)
for name, attrs in input_name_and_meta_attrs
for symbolic_attrs in [make_constraint_attrs(attrs)]
]


def _get_input_name_and_meta_attrs(name2input_param_attrs, name_and_annotation_types):
def constructed_from_self(name):
return name.find("self_") != -1

def is_tensor_type(annotation_type):
return annotation_type is torch.Tensor

ret = [
(name, meta_attr)
for name, annotation_type in name_and_annotation_types
if is_tensor_type(annotation_type)
if not constructed_from_self(name)
for meta_attr in [name2input_param_attrs[name]]
]
assert len(ret) > 0
return ret


def _get_name_and_annotation_types(model_path):
model_class = load_class_from_file(
f"{model_path}/model.py", class_name="GraphModule"
)
annotations = inspect.getfullargspec(model_class.forward).annotations
return [(k, v) for k, v in annotations.items()]


def _get_name2input_param_attrs(model_path):
def get_classes():
input_meta_file = f"{model_path}/input_meta.py"
for _, cls in imp_util.load_name_and_classes_from_file(input_meta_file):
yield cls

weight_meta_file = f"{model_path}/weight_meta.py"
for _, cls in imp_util.load_name_and_classes_from_file(weight_meta_file):
yield cls

return {
name: attr
for cls in get_classes()
for attr in [record_util.make_attrs_from_class(cls)]
for name in [attr["name"]]
}


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="generate constraint proposal file")
parser.add_argument(
"--model-path",
type=str,
required=True,
help="Path to folder e.g '../../samples/torch/resnet18'",
)
parser.add_argument(
"--output-path",
type=str,
required=True,
help="output file path",
)
args = parser.parse_args()
main(args=args)
7 changes: 7 additions & 0 deletions graph_net/torch/hash_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import hashlib


def get_sha_hash(content):
m = hashlib.sha256()
m.update(content.encode())
return m.hexdigest()
17 changes: 17 additions & 0 deletions graph_net/torch/imp_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import importlib.util
import inspect


def load_class_from_file(file_path: str, class_name: str):
spec = importlib.util.spec_from_file_location("unnamed", file_path)
unnamed = importlib.util.module_from_spec(spec)
spec.loader.exec_module(unnamed)
model_class = getattr(unnamed, class_name, None)
return model_class


def load_name_and_classes_from_file(file_path):
spec = importlib.util.spec_from_file_location("unnamed", file_path)
unnamed = importlib.util.module_from_spec(spec)
spec.loader.exec_module(unnamed)
yield from inspect.getmembers(unnamed, inspect.isclass)
31 changes: 31 additions & 0 deletions graph_net/torch/record_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import hash_util


def make_attrs_from_class(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__") and not callable(v)
}


def serialize_to_py_code(attrs, class_prefix):
assert isinstance(attrs, (tuple, list))

ret = "\n".join(
_serialize_one_attr_to_py_code(attr, class_prefix) for attr in attrs
)
return ret


def _serialize_one_attr_to_py_code(attr, class_prefix):
hash_str = hash_util.get_sha_hash(str(attr))
hash_str = hash_str[:32]
indent = " " * 4
ret = "\n".join(
[
f"class {class_prefix}{hash_str}:",
*[f"{indent}{name} = {repr(value)}" for name, value in attr.items()],
]
)
return f"{ret}\n\n"
49 changes: 43 additions & 6 deletions graph_net/torch/single_device_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import utils
from graph_net.torch import utils
import argparse
import importlib.util
import inspect
Expand All @@ -10,6 +10,10 @@
from graph_net.torch.extractor import extract
import hashlib
from contextlib import contextmanager
import json
import record_util
import imp_util
import os


def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
Expand Down Expand Up @@ -62,9 +66,14 @@ def main(args):
kwargs = dict(name=args.extract_name, dynamic=False, **dump_graph_options)
model = extract(**kwargs)(model)

inputs_params = utils.load_converted_from_text(f"{model_path}")
inputs_params = utils.make_input_and_param_tensors_from_model_path(
f"{model_path}"
)
params = inputs_params["weight_info"]
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
shape_modifier = _get_shape_modifier(args)
state_dict = {
k: utils.replay_tensor(v, shape_modifier) for k, v in params.items()
}

explain = torch._dynamo.explain(model)(**state_dict)
if explain.graph_count != 1 or len(explain.break_reasons) != 0:
Expand All @@ -76,10 +85,31 @@ def main(args):
f"Graph extraction failed. The resulting graph is incomplete, broken into {explain.graph_count} subgraphs."
)

y = model(**state_dict)[0]
model(**state_dict)


def _get_shape_modifier(cli_args):
"""
yield shape modifier from shape_modifiers.json in directory cli_args.model_path
"""
if not cli_args.enable_shape_patch:
return lambda name, shape: shape
shape_patch_file_path = f"{cli_args.model_path}/shape_patch.py"
if not os.path.exists(shape_patch_file_path):
return lambda name, shape: shape
shape_modifier_data = [
attrs
for name, cls in imp_util.load_name_and_classes_from_file(shape_patch_file_path)
for attrs in [record_util.make_attrs_from_class(cls)]
]
assert isinstance(shape_modifier_data, list)
return _make_shape_modifier_impl(shape_modifier_data)

print(torch.argmin(y), torch.argmax(y))
print(y.shape)

def _make_shape_modifier_impl(shape_modifier_data):
name2new_shape = {attrs["name"]: attrs["shape"] for attrs in shape_modifier_data}
print(f"{name2new_shape=}")
return lambda name, shape: name2new_shape[name] if name in name2new_shape else shape


if __name__ == "__main__":
Expand Down Expand Up @@ -110,5 +140,12 @@ def main(args):
default=None,
help="Extracted graph's name",
)
parser.add_argument(
"--enable-shape-patch",
type=bool,
required=False,
default=False,
help="Enable extra inputs",
)
args = parser.parse_args()
main(args=args)
4 changes: 3 additions & 1 deletion graph_net/torch/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_model(args, device):


def get_input_dict(args):
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
inputs_params = utils.make_input_and_param_tensors_from_model_path(
f"{args.model_path}"
)
params = inputs_params["weight_info"]
for tensor_meta in params.values():
if hasattr(tensor_meta, "device"):
Expand Down
Loading