Skip to content

Commit 82ef81a

Browse files
authored
Generate 9 samples for reifying one sample with symbolic dimensions (#418)
* support checking model redundancy * revert change of vision_model_test * reformat python code. * reformat bert_model_test.py and utils.py * minor fix * fix failed check by comparing directories after os.path.realpath() * fix bugs in check_validate.sh * set dynamic=False in single_device_runner.py * reset graph hash * add robustness code for generating input tensor constraints * Introduce input_tensor_constraints.py using shape propagation logic. * support dimension generalization for torch.Tensor.view and torch.Tensor.reshape * 1) support dimension generalization for torch.Tensor.expand(); 2) fix bugs in generalization for torch.Tensor.view and torch.Tensor.reshape * dimension_generalization_passes * Refactored DimensionGeneralizationPass.__init__ to accept argument dim_axes_pairs, enabling targeted configuration for specific use cases * save dimension generalization pass names into graph_net.json * Generalize sequence dimension * more dimension generalization passes for token dimension * refactor parse_sole_graph_module * Enhance the performance of the input_tensor_constraints.py file generation process. * minor fix * more dimension generalization pass * fix some hint bugs * generate input_tensor_constraints.py * statisticize in tensor symbolic shapes * remove empty samples * update symbolic dimension reifier in graph_net.json * generate 9 samples for reifying one sample with symbolic dimensions
1 parent 0b25342 commit 82ef81a

File tree

4 files changed

+134
-17
lines changed

4 files changed

+134
-17
lines changed

graph_net/dimension_generalizer.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from pathlib import Path
1212
from dataclasses import asdict
1313
import graph_net.graph_net_json_file_util as gn_json
14+
from collections import OrderedDict
15+
import copy
16+
from graph_net.hash_util import get_sha256_hash
1417

1518

1619
class ApplyDimGenPasses:
@@ -49,7 +52,12 @@ def __call__(self, rel_model_path):
4952
output_dir = Path(self.config["output_dir"])
5053
output_dir.mkdir(parents=True, exist_ok=True)
5154
generalized_model_path = output_dir / rel_model_path
52-
if self.config["resume"] and (generalized_model_path / "model.py").exists():
55+
if (
56+
self.config["resume"]
57+
and generalized_model_path.exists()
58+
and generalized_model_path.is_dir()
59+
and len(list(generalized_model_path.iterdir())) > 0
60+
):
5361
return
5462
tensor_metas = self._get_tensor_metas(model_path)
5563
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
@@ -64,6 +72,7 @@ def __call__(self, rel_model_path):
6472
)
6573
dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs)
6674
if len(dim_axes_pairs) == 0:
75+
print("No symbolic dims found. {model_path=}")
6776
return
6877

6978
def get_generalized():
@@ -74,19 +83,80 @@ def get_generalized():
7483
inputs=inputs,
7584
)
7685

77-
with get_generalized() as generalized_model_py_path:
78-
self._save_generalized_model_path(rel_model_path, generalized_model_py_path)
86+
with get_generalized() as tmp_model_py_path:
87+
from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
88+
triples = self._get_reified_tensor_metas(from_model_path, dyn_dim_cstrs)
89+
for symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs in triples:
90+
to_model_path = self._get_to_model_path(
91+
rel_model_path, symbol2example_value
92+
)
93+
print(f"{str(to_model_path)=}")
94+
self._copy_sample_model_path(from_model_path, to_model_path)
95+
self._save_generalized_model_path(to_model_path, tmp_model_py_path)
96+
self._save_tensor_metas_as_weight_meta(to_model_path, cur_tensor_metas)
97+
self._save_dyn_dim_cstrs(to_model_path, cur_dyn_dim_cstrs)
7998

8099
self._check_num_handled_models()
81100

82-
def _save_generalized_model_path(self, rel_model_path, generalized_model_py_path):
83-
from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
84-
to_model_path = Path(self.config["output_dir"]) / rel_model_path
85-
print(f"{str(to_model_path)=}")
101+
def _get_reified_tensor_metas(self, from_model_path, dyn_dim_cstrs):
102+
tensor_metas = self._get_tensor_metas(str(from_model_path))
103+
symbols, reified_dims = self._get_symbols_and_reified_dims(
104+
from_model_path, dyn_dim_cstrs
105+
)
106+
for dims in reified_dims:
107+
symbol2example_value = OrderedDict(list(zip(symbols, dims)))
108+
cur_dyn_dim_cstrs = copy.deepcopy(dyn_dim_cstrs)
109+
cur_tensor_metas = copy.deepcopy(tensor_metas)
110+
cur_dyn_dim_cstrs.update_symbol2example_value(symbol2example_value)
111+
update_tensor_metas_by_dyn_dim_cstr(cur_tensor_metas, cur_dyn_dim_cstrs)
112+
yield symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs
113+
114+
def _get_symbols_and_reified_dims(self, from_model_path, dyn_dim_cstrs):
115+
json_value = gn_json.read_json(str(from_model_path))
116+
reifier_name = json_value[gn_json.kSymbolicDimensionReifier]
117+
from graph_net.torch.sym_dim_reifiers.reifier_mgr import get_reifier
118+
119+
reifier_class = get_reifier(reifier_name)
120+
reifier_instance = reifier_class(str(from_model_path))
121+
assert reifier_instance.match
122+
symbols2reified_dims = reifier_instance.reify()
123+
assert len(symbols2reified_dims) == 1
124+
symbols, reified_dims = next(iter(symbols2reified_dims.items()))
125+
assert tuple(symbols) == tuple(dyn_dim_cstrs.symbols)
126+
assert all(len(symbols) == len(dims) for dims in reified_dims)
127+
return symbols, reified_dims
128+
129+
def _save_dyn_dim_cstrs(self, to_model_path, dyn_dim_cstrs):
130+
cstr_code = dyn_dim_cstrs.serialize_to_py_str()
131+
(to_model_path / "input_tensor_constraints.py").write_text(cstr_code)
132+
133+
def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas):
134+
weight_meta_code = "\n".join(
135+
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
136+
)
137+
(to_model_path / "weight_meta.py").write_text(weight_meta_code)
138+
139+
def _get_to_model_path(self, rel_model_path, symbol2example_value):
140+
sym_dim_str = "_".join(
141+
f"{sym_name}_{dim}"
142+
for symbol, dim in symbol2example_value.items()
143+
for sym_name in [symbol.name]
144+
)
145+
sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}"
146+
to_model_path = (
147+
Path(self.config["output_dir"]) / rel_model_path / sub_module_name
148+
)
149+
return to_model_path
150+
151+
def _copy_sample_model_path(self, from_model_path, to_model_path):
86152
to_model_path.mkdir(parents=True, exist_ok=True)
87153
shutil.copytree(Path(from_model_path), Path(to_model_path), dirs_exist_ok=True)
88-
generalized_model_py_code = Path(generalized_model_py_path).read_text()
154+
155+
def _save_generalized_model_path(self, to_model_path, tmp_model_py_path):
156+
generalized_model_py_code = Path(tmp_model_py_path).read_text()
89157
(to_model_path / "model.py").write_text(generalized_model_py_code)
158+
file_hash = get_sha256_hash(generalized_model_py_code)
159+
(to_model_path / "graph_hash.txt").write_text(file_hash)
90160

91161
def _get_dim_axes_pairs(self, dyn_dim_cstrs):
92162
sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes()

graph_net/hash_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import hashlib
2+
3+
4+
def get_sha256_hash(content):
5+
m = hashlib.sha256()
6+
m.update(content.encode())
7+
return m.hexdigest()

graph_net/tools/batch_apply_dim_gen_passes.sh renamed to graph_net/tools/apply_dim_gen_passes.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ config_json_str=$(cat <<EOF
1111
"handler_path": "$GRAPH_NET_ROOT/dimension_generalizer.py",
1212
"handler_class_name": "ApplyDimGenPasses",
1313
"handler_config": {
14-
"resume": true,
14+
"resume": false,
1515
"output_dir": "/tmp/dimension_generalized_samples",
1616
"model_path_prefix": "$GRAPH_NET_ROOT/../",
1717
"dimension_generalizer_filepath": "$GRAPH_NET_ROOT/torch/static_to_dynamic.py",
1818
"dimension_generalizer_class_name": "StaticToDynamic",
19-
"limits_handled_models": 9999999,
19+
"limits_handled_models": 10,
2020
"last_model_log_file": "/tmp/a.py"
2121
}
2222
}

graph_net/torch/sym_dim_reifiers/naive_cv_sym_dim_reifier.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def match(self) -> bool:
1919
return sym_shapes_str in self._get_map_cv_sym_shapes_str2reifier()
2020

2121
def reify(self):
22-
assert self.need_reify()
22+
assert self.match()
2323
sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()
2424
reifier = self._get_map_cv_sym_shapes_str2reifier()[sym_shapes_str]
2525
return reifier(self)
@@ -40,29 +40,69 @@ def _get_map_cv_sym_shapes_str2reifier(cls):
4040
return cls.g_cv_sym_shapes_str2reifier
4141

4242
def reify_s0_s1(self):
43+
S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1"))
4344
return {
44-
sympy.Symbol("S0"): [1, 32, 128],
45-
sympy.Symbol("S1"): [224, 256, 384],
45+
S0S1: [
46+
[1, 224],
47+
[1, 256],
48+
[1, 384],
49+
[32, 224],
50+
[32, 256],
51+
[32, 384],
52+
[128, 224],
53+
[128, 256],
54+
[128, 384],
55+
],
4656
}
4757

4858
def reify_vit_related_hw_s0(self):
4959
return {
50-
(sympy.Symbol("S0"),): [128, 192, 224, 256, 336, 384, 448, 512, 640],
60+
(sympy.Symbol("S0"),): [
61+
[128],
62+
[192],
63+
[224],
64+
[256],
65+
[336],
66+
[384],
67+
[448],
68+
[512],
69+
[640],
70+
],
5171
}
5272

5373
def reify_mmseg_related_batch_s0(self):
5474
return {
55-
(sympy.Symbol("S0"),): [1, 2, 4, 8, 12, 16, 24, 32, 64],
75+
(sympy.Symbol("S0"),): [[1], [2], [4], [8], [12], [16], [24], [32], [64]],
5676
}
5777

5878
def reify_timm_related_big_batch_s0(self):
5979
return {
60-
(sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512],
80+
(sympy.Symbol("S0"),): [
81+
[1],
82+
[4],
83+
[8],
84+
[16],
85+
[32],
86+
[64],
87+
[128],
88+
[256],
89+
[512],
90+
],
6191
}
6292

6393
def reify_mmpose_related_big_batch_s0(self):
6494
return {
65-
(sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512],
95+
(sympy.Symbol("S0"),): [
96+
[1],
97+
[4],
98+
[8],
99+
[16],
100+
[32],
101+
[64],
102+
[128],
103+
[256],
104+
[512],
105+
],
66106
}
67107

68108
def reify_mmpose_related_s0_s1_s2(self):

0 commit comments

Comments
 (0)