diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py index fb336022d..299c42c2f 100644 --- a/graph_net/dimension_generalizer.py +++ b/graph_net/dimension_generalizer.py @@ -11,6 +11,9 @@ from pathlib import Path from dataclasses import asdict import graph_net.graph_net_json_file_util as gn_json +from collections import OrderedDict +import copy +from graph_net.hash_util import get_sha256_hash class ApplyDimGenPasses: @@ -49,7 +52,12 @@ def __call__(self, rel_model_path): output_dir = Path(self.config["output_dir"]) output_dir.mkdir(parents=True, exist_ok=True) generalized_model_path = output_dir / rel_model_path - if self.config["resume"] and (generalized_model_path / "model.py").exists(): + if ( + self.config["resume"] + and generalized_model_path.exists() + and generalized_model_path.is_dir() + and len(list(generalized_model_path.iterdir())) > 0 + ): return tensor_metas = self._get_tensor_metas(model_path) tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas] @@ -64,6 +72,7 @@ def __call__(self, rel_model_path): ) dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs) if len(dim_axes_pairs) == 0: + print("No symbolic dims found. {model_path=}") return def get_generalized(): @@ -74,19 +83,80 @@ def get_generalized(): inputs=inputs, ) - with get_generalized() as generalized_model_py_path: - self._save_generalized_model_path(rel_model_path, generalized_model_py_path) + with get_generalized() as tmp_model_py_path: + from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + triples = self._get_reified_tensor_metas(from_model_path, dyn_dim_cstrs) + for symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs in triples: + to_model_path = self._get_to_model_path( + rel_model_path, symbol2example_value + ) + print(f"{str(to_model_path)=}") + self._copy_sample_model_path(from_model_path, to_model_path) + self._save_generalized_model_path(to_model_path, tmp_model_py_path) + self._save_tensor_metas_as_weight_meta(to_model_path, cur_tensor_metas) + self._save_dyn_dim_cstrs(to_model_path, cur_dyn_dim_cstrs) self._check_num_handled_models() - def _save_generalized_model_path(self, rel_model_path, generalized_model_py_path): - from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path - to_model_path = Path(self.config["output_dir"]) / rel_model_path - print(f"{str(to_model_path)=}") + def _get_reified_tensor_metas(self, from_model_path, dyn_dim_cstrs): + tensor_metas = self._get_tensor_metas(str(from_model_path)) + symbols, reified_dims = self._get_symbols_and_reified_dims( + from_model_path, dyn_dim_cstrs + ) + for dims in reified_dims: + symbol2example_value = OrderedDict(list(zip(symbols, dims))) + cur_dyn_dim_cstrs = copy.deepcopy(dyn_dim_cstrs) + cur_tensor_metas = copy.deepcopy(tensor_metas) + cur_dyn_dim_cstrs.update_symbol2example_value(symbol2example_value) + update_tensor_metas_by_dyn_dim_cstr(cur_tensor_metas, cur_dyn_dim_cstrs) + yield symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs + + def _get_symbols_and_reified_dims(self, from_model_path, dyn_dim_cstrs): + json_value = gn_json.read_json(str(from_model_path)) + reifier_name = json_value[gn_json.kSymbolicDimensionReifier] + from graph_net.torch.sym_dim_reifiers.reifier_mgr import get_reifier + + reifier_class = get_reifier(reifier_name) + reifier_instance = reifier_class(str(from_model_path)) + assert reifier_instance.match + symbols2reified_dims = reifier_instance.reify() + assert len(symbols2reified_dims) == 1 + symbols, reified_dims = next(iter(symbols2reified_dims.items())) + assert tuple(symbols) == tuple(dyn_dim_cstrs.symbols) + assert all(len(symbols) == len(dims) for dims in reified_dims) + return symbols, reified_dims + + def _save_dyn_dim_cstrs(self, to_model_path, dyn_dim_cstrs): + cstr_code = dyn_dim_cstrs.serialize_to_py_str() + (to_model_path / "input_tensor_constraints.py").write_text(cstr_code) + + def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas): + weight_meta_code = "\n".join( + tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas + ) + (to_model_path / "weight_meta.py").write_text(weight_meta_code) + + def _get_to_model_path(self, rel_model_path, symbol2example_value): + sym_dim_str = "_".join( + f"{sym_name}_{dim}" + for symbol, dim in symbol2example_value.items() + for sym_name in [symbol.name] + ) + sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}" + to_model_path = ( + Path(self.config["output_dir"]) / rel_model_path / sub_module_name + ) + return to_model_path + + def _copy_sample_model_path(self, from_model_path, to_model_path): to_model_path.mkdir(parents=True, exist_ok=True) shutil.copytree(Path(from_model_path), Path(to_model_path), dirs_exist_ok=True) - generalized_model_py_code = Path(generalized_model_py_path).read_text() + + def _save_generalized_model_path(self, to_model_path, tmp_model_py_path): + generalized_model_py_code = Path(tmp_model_py_path).read_text() (to_model_path / "model.py").write_text(generalized_model_py_code) + file_hash = get_sha256_hash(generalized_model_py_code) + (to_model_path / "graph_hash.txt").write_text(file_hash) def _get_dim_axes_pairs(self, dyn_dim_cstrs): sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes() diff --git a/graph_net/hash_util.py b/graph_net/hash_util.py new file mode 100644 index 000000000..d6b97c225 --- /dev/null +++ b/graph_net/hash_util.py @@ -0,0 +1,7 @@ +import hashlib + + +def get_sha256_hash(content): + m = hashlib.sha256() + m.update(content.encode()) + return m.hexdigest() diff --git a/graph_net/tools/batch_apply_dim_gen_passes.sh b/graph_net/tools/apply_dim_gen_passes.sh similarity index 93% rename from graph_net/tools/batch_apply_dim_gen_passes.sh rename to graph_net/tools/apply_dim_gen_passes.sh index 3ec4161db..a669f30ba 100755 --- a/graph_net/tools/batch_apply_dim_gen_passes.sh +++ b/graph_net/tools/apply_dim_gen_passes.sh @@ -11,12 +11,12 @@ config_json_str=$(cat < bool: return sym_shapes_str in self._get_map_cv_sym_shapes_str2reifier() def reify(self): - assert self.need_reify() + assert self.match() sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() reifier = self._get_map_cv_sym_shapes_str2reifier()[sym_shapes_str] return reifier(self) @@ -40,29 +40,69 @@ def _get_map_cv_sym_shapes_str2reifier(cls): return cls.g_cv_sym_shapes_str2reifier def reify_s0_s1(self): + S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1")) return { - sympy.Symbol("S0"): [1, 32, 128], - sympy.Symbol("S1"): [224, 256, 384], + S0S1: [ + [1, 224], + [1, 256], + [1, 384], + [32, 224], + [32, 256], + [32, 384], + [128, 224], + [128, 256], + [128, 384], + ], } def reify_vit_related_hw_s0(self): return { - (sympy.Symbol("S0"),): [128, 192, 224, 256, 336, 384, 448, 512, 640], + (sympy.Symbol("S0"),): [ + [128], + [192], + [224], + [256], + [336], + [384], + [448], + [512], + [640], + ], } def reify_mmseg_related_batch_s0(self): return { - (sympy.Symbol("S0"),): [1, 2, 4, 8, 12, 16, 24, 32, 64], + (sympy.Symbol("S0"),): [[1], [2], [4], [8], [12], [16], [24], [32], [64]], } def reify_timm_related_big_batch_s0(self): return { - (sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512], + (sympy.Symbol("S0"),): [ + [1], + [4], + [8], + [16], + [32], + [64], + [128], + [256], + [512], + ], } def reify_mmpose_related_big_batch_s0(self): return { - (sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512], + (sympy.Symbol("S0"),): [ + [1], + [4], + [8], + [16], + [32], + [64], + [128], + [256], + [512], + ], } def reify_mmpose_related_s0_s1_s2(self):