diff --git a/graph_net/constraint_util.py b/graph_net/constraint_util.py index af456489d..d58bb8fc7 100644 --- a/graph_net/constraint_util.py +++ b/graph_net/constraint_util.py @@ -12,7 +12,6 @@ import tempfile import shutil from pathlib import Path -import json from dataclasses import asdict @@ -187,12 +186,14 @@ def _save_model_to_log_file(self, model_path): shutil.copy(Path(model_path) / "model.py", log_file) def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path): - from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses + from graph_net.graph_net_json_file_util import ( + kDimensionGeneralizationPasses, + update_json, + ) - graph_net_json_file_path = Path(f"{model_path}/graph_net.json") - graph_net_json = json.loads(graph_net_json_file_path.read_text()) - graph_net_json[kDimensionGeneralizationPasses] = list(dim_gen_pass_names) - graph_net_json_file_path.write_text(json.dumps(graph_net_json)) + update_json( + model_path, kDimensionGeneralizationPasses, list(dim_gen_pass_names) + ) def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path): cstr_code = dyn_dim_cstr.serialize_to_py_str() diff --git a/graph_net/dimension_generalizer.py b/graph_net/dimension_generalizer.py new file mode 100644 index 000000000..fb336022d --- /dev/null +++ b/graph_net/dimension_generalizer.py @@ -0,0 +1,176 @@ +import logging +from graph_net.dynamic_dim_constraints import DynamicDimConstraints +from graph_net.imp_util import load_module +from graph_net.tensor_meta import TensorMeta +import functools +import sys +import os +from contextlib import contextmanager +import tempfile +import shutil +from pathlib import Path +from dataclasses import asdict +import graph_net.graph_net_json_file_util as gn_json + + +class ApplyDimGenPasses: + def __init__(self, config=None): + if config is None: + config = {} + self.config = self._make_config(**config) + self.num_handled_models = 0 + + def _make_config( + self, + output_dir: str, + dimension_generalizer_filepath=None, + dimension_generalizer_class_name="StaticToDynamic", + dimension_generalizer_config=None, + model_path_prefix="", + resume=False, + last_model_log_file=None, + limits_handled_models=None, + ): + if dimension_generalizer_config is None: + dimension_generalizer_config = {} + return { + "resume": resume, + "output_dir": output_dir, + "model_path_prefix": model_path_prefix, + "dimension_generalizer_filepath": dimension_generalizer_filepath, + "dimension_generalizer_class_name": dimension_generalizer_class_name, + "dimension_generalizer_config": dimension_generalizer_config, + "last_model_log_file": last_model_log_file, + "limits_handled_models": limits_handled_models, + } + + def __call__(self, rel_model_path): + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) + output_dir = Path(self.config["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + generalized_model_path = output_dir / rel_model_path + if self.config["resume"] and (generalized_model_path / "model.py").exists(): + return + tensor_metas = self._get_tensor_metas(model_path) + tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas] + dim_gen_pass_names = self._get_dim_gen_pass_names(model_path) + dim_generalizer = self._get_dimension_generalizer(dim_gen_pass_names) + inputs = dim_generalizer.create_inputs_by_metas( + module=self._get_model(model_path), + tensor_meta_attrs_list=tensor_meta_attrs_list, + ) + dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file( + os.path.join(model_path, "input_tensor_constraints.py") + ) + dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs) + if len(dim_axes_pairs) == 0: + return + + def get_generalized(): + return self._get_generalized_model_py_file_path( + dim_generalizer=dim_generalizer, + dim_axes_pairs=dim_axes_pairs, + model_path=model_path, + inputs=inputs, + ) + + with get_generalized() as generalized_model_py_path: + self._save_generalized_model_path(rel_model_path, generalized_model_py_path) + + self._check_num_handled_models() + + def _save_generalized_model_path(self, rel_model_path, generalized_model_py_path): + from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path + to_model_path = Path(self.config["output_dir"]) / rel_model_path + print(f"{str(to_model_path)=}") + to_model_path.mkdir(parents=True, exist_ok=True) + shutil.copytree(Path(from_model_path), Path(to_model_path), dirs_exist_ok=True) + generalized_model_py_code = Path(generalized_model_py_path).read_text() + (to_model_path / "model.py").write_text(generalized_model_py_code) + + def _get_dim_axes_pairs(self, dyn_dim_cstrs): + sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes() + return [ + (dim, axes) + for symbol in dyn_dim_cstrs.symbols + for dim in [dyn_dim_cstrs.symbol2example_value[symbol]] + for axes in [ + [ + axis + for shape in sym_input_shapes + for axis, sym_or_dim in enumerate(shape) + if sym_or_dim == symbol + ] + ] + ] + + def _get_dim_gen_pass_names(self, model_path): + json_value = gn_json.read_json(model_path) + return json_value.get(gn_json.kDimensionGeneralizationPasses, []) + + def _check_num_handled_models(self): + self.num_handled_models += 1 + limits = self.config["limits_handled_models"] + if limits is None: + return + if self.num_handled_models < limits: + return + print("`num_handled_models` exceeds config `limits_handled_models`") + sys.exit(0) + + def _get_dimension_generalizer(self, dim_gen_pass_names): + assert self.config["dimension_generalizer_filepath"] is not None + decorator_cls = getattr( + load_module(self.config["dimension_generalizer_filepath"]), + self.config["dimension_generalizer_class_name"], + ) + config = {"pass_names": dim_gen_pass_names} + dim_generalizer = decorator_cls(config) + return dim_generalizer + + def _get_model(self, model_path): + py_module = load_module(os.path.join(model_path, "model.py")) + GraphModule = getattr(py_module, "GraphModule") + GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__ + return GraphModule() + + @contextmanager + def _get_generalized_model_py_file_path( + self, dim_generalizer, dim_axes_pairs, model_path, inputs + ): + model = self._get_model(model_path) + dim_gen_pass = dim_generalizer(model, dim_axes_pairs) + logging.warning("before need_rewrite") + need_rewrite = dim_gen_pass.need_rewrite(inputs) + logging.warning("after need_rewrite") + if not need_rewrite: + yield os.path.join(model_path, "model.py") + return + logging.warning("before rewrite") + graph_module = dim_gen_pass.rewrite(inputs) + logging.warning("after rewrite") + with tempfile.TemporaryDirectory() as tmp_dir: + shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True) + dim_gen_pass.save_graph_module(graph_module, tmp_dir) + yield os.path.join(tmp_dir, "model.py") + + def _get_tensor_metas(self, model_path): + make = TensorMeta.unserialize_from_py_file + return [ + *make(os.path.join(model_path, "input_meta.py")), + *make(os.path.join(model_path, "weight_meta.py")), + ] + + +def update_tensor_metas_by_dyn_dim_cstr( + tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints +): + input_shapes = dyn_dim_cstr.get_reified_input_shapes() + assert len(tensor_metas) == len(input_shapes) + for i, tensor_meta in enumerate(tensor_metas): + tensor_meta.shape = input_shapes[i] + if tensor_meta.data is not None: + assert isinstance(tensor_meta.data, (list, tuple)) + size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1) + doubled_data = [*tensor_meta.data, *tensor_meta.data] + tensor_meta.data = doubled_data[:size] diff --git a/graph_net/dynamic_dim_constraints.py b/graph_net/dynamic_dim_constraints.py index b415ee6ae..606b9c0a3 100644 --- a/graph_net/dynamic_dim_constraints.py +++ b/graph_net/dynamic_dim_constraints.py @@ -23,6 +23,21 @@ class DynamicDimConstraints: input_shapes: list[(tuple[sympy.Expr | int], str)] kInputShapes = "dynamic_dim_constraint_input_shapes" + def serialize_symbolic_input_shapes_to_str(self): + input_shapes = self.get_sorted_symbolic_input_shapes() + input_shapes_str = str(input_shapes).replace(" ", "") + return input_shapes_str + + def get_sorted_symbolic_input_shapes(self): + return sorted( + [ + tuple(shape) + for shape, name in self.input_shapes + if any(isinstance(dim, sympy.Expr) for dim in shape) + ], + key=str, + ) + @classmethod def make_by_named_inputs(cls, named_shapes): return cls( diff --git a/graph_net/graph_net_json_file_util.py b/graph_net/graph_net_json_file_util.py index 1f38a55ca..6627f5be4 100644 --- a/graph_net/graph_net_json_file_util.py +++ b/graph_net/graph_net_json_file_util.py @@ -1 +1,17 @@ +from pathlib import Path +import json + kDimensionGeneralizationPasses = "dimension_generalization_passes" +kSymbolicDimensionReifier = "symbolic_dimension_reifier" + + +def read_json(model_path): + graph_net_json_file_path = Path(f"{model_path}/graph_net.json") + return json.loads(graph_net_json_file_path.read_text()) + + +def update_json(model_path, field, value): + graph_net_json_file_path = Path(f"{model_path}/graph_net.json") + graph_net_json = json.loads(graph_net_json_file_path.read_text()) + graph_net_json[field] = value + graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4)) diff --git a/graph_net/tools/_get_in_tensor_symbolic_shapes.py b/graph_net/tools/_get_in_tensor_symbolic_shapes.py index bd53cebfd..d3e2223af 100644 --- a/graph_net/tools/_get_in_tensor_symbolic_shapes.py +++ b/graph_net/tools/_get_in_tensor_symbolic_shapes.py @@ -1,15 +1,16 @@ from pathlib import Path from graph_net.dynamic_dim_constraints import DynamicDimConstraints -import sympy +import graph_net.graph_net_json_file_util as gn_json class GetInTensorSymbolicShapes: def __init__(self, config): self.config = self.make_config(**config) - def make_config(self, model_path_prefix): + def make_config(self, model_path_prefix, ignore_reified=True): return { "model_path_prefix": model_path_prefix, + "ignore_reified": ignore_reified, } def __call__(self, model_path): @@ -18,17 +19,21 @@ def __call__(self, model_path): if not input_tensor_cstr_filepath.exists(): print(f"get-in-tensor-symbolic-shapes None {model_path}") return + if self.config["ignore_reified"] and self._found_reified_dims( + str(original_model_path) + ): + print(f"get-in-tensor-symbolic-shapes {model_path}") + return dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file( str(input_tensor_cstr_filepath) ) - dyn_dim_cstrs.symbol2example_value = {} - dyn_dim_cstrs.input_shapes = sorted( - [ - tuple(shape) - for shape, name in dyn_dim_cstrs.input_shapes - if any(isinstance(dim, sympy.Expr) for dim in shape) - ], - key=str, - ) - input_shapes_str = str(dyn_dim_cstrs.input_shapes).replace(" ", "") + input_shapes_str = str(dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str()) print(f"get-in-tensor-symbolic-shapes {input_shapes_str} {model_path}") + + def _found_reified_dims(self, model_path): + json = gn_json.read_json(model_path) + + if gn_json.kSymbolicDimensionReifier not in json: + return False + + return json[gn_json.kSymbolicDimensionReifier] is not None diff --git a/graph_net/tools/_update_sym_dim_reifier.py b/graph_net/tools/_update_sym_dim_reifier.py new file mode 100644 index 000000000..7eb59e3dc --- /dev/null +++ b/graph_net/tools/_update_sym_dim_reifier.py @@ -0,0 +1,58 @@ +from pathlib import Path +from graph_net.imp_util import load_module +import graph_net.graph_net_json_file_util as gn_json + + +class UpdateSymDimReifier: + def __init__(self, config): + self.config = self.make_config(**config) + + def make_config( + self, + model_path_prefix, + reifier_factory_path, + reifier_factory_class_name, + reifier_factory_config=None, + resume=True, + ): + if reifier_factory_config is None: + reifier_factory_config = {} + return { + "reifier_factory_path": reifier_factory_path, + "reifier_factory_class_name": reifier_factory_class_name, + "reifier_factory_config": reifier_factory_config, + "model_path_prefix": model_path_prefix, + "resume": resume, + } + + def __call__(self, model_path): + model_path_obj = Path(self.config["model_path_prefix"]) / model_path + model_path = str(model_path_obj) + input_tensor_cstr_filepath = model_path_obj / "input_tensor_constraints.py" + if not input_tensor_cstr_filepath.exists(): + return + if self.config["resume"] and self._found_reified_dims(model_path): + return + reifier_factory_class = self._get_reifier_factory_class() + reifier_factory_instance = reifier_factory_class( + config=self.config["reifier_factory_config"], model_path=model_path + ) + matched_reifier_name = reifier_factory_instance.get_matched_reifier_name() + if matched_reifier_name is None: + return + assert isinstance(matched_reifier_name, str), f"{type(matched_reifier_name)=}" + gn_json.update_json( + model_path, gn_json.kSymbolicDimensionReifier, matched_reifier_name + ) + + def _get_reifier_factory_class(self): + py_module = load_module(self.config["reifier_factory_path"]) + return getattr(py_module, self.config["reifier_factory_class_name"]) + + def _found_reified_dims(self, model_path): + json = gn_json.read_json(model_path) + + if gn_json.kSymbolicDimensionReifier not in json: + return False + + return json[gn_json.kSymbolicDimensionReifier] is not None diff --git a/graph_net/tools/batch_apply_dim_gen_passes.sh b/graph_net/tools/batch_apply_dim_gen_passes.sh new file mode 100755 index 000000000..3ec4161db --- /dev/null +++ b/graph_net/tools/batch_apply_dim_gen_passes.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( +os.path.dirname(graph_net.__file__))") + +# input model path +# model_runnable_predicator=ShapePropagatablePredicator +model_runnable_predicator=ModelRunnablePredicator +config_json_str=$(cat < bool: + return os.path.basename(__file__)[:-3] + + def match(self) -> bool: + if self.dyn_dim_cstrs is None: + return False + sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() + return sym_shapes_str in self._get_map_cv_sym_shapes_str2reifier() + + def reify(self): + assert self.need_reify() + sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() + reifier = self._get_map_cv_sym_shapes_str2reifier()[sym_shapes_str] + return reifier(self) + + @classmethod + def _get_map_cv_sym_shapes_str2reifier(cls): + if not hasattr(cls, "g_cv_sym_shapes_str2reifier"): + cls.g_cv_sym_shapes_str2reifier = { + "[(S0,3,S1,S1)]": cls.reify_s0_s1, + "[(1,3,S0,S0)]": cls.reify_vit_related_hw_s0, + "[(S0,3,512,512)]": cls.reify_mmseg_related_batch_s0, + "[(S0,3,224,224)]": cls.reify_timm_related_big_batch_s0, + "[(S0,3,256,192)]": cls.reify_mmpose_related_big_batch_s0, + "[(S0,3,256,256)]": cls.reify_mmpose_related_big_batch_s0, + "[(S0,3,S1,S2)]": cls.reify_mmpose_related_s0_s1_s2, + "[(1,S0,3,S1,S1)]": cls.reify_vivit_related_s0_s1, + } + return cls.g_cv_sym_shapes_str2reifier + + def reify_s0_s1(self): + return { + sympy.Symbol("S0"): [1, 32, 128], + sympy.Symbol("S1"): [224, 256, 384], + } + + def reify_vit_related_hw_s0(self): + return { + (sympy.Symbol("S0"),): [128, 192, 224, 256, 336, 384, 448, 512, 640], + } + + def reify_mmseg_related_batch_s0(self): + return { + (sympy.Symbol("S0"),): [1, 2, 4, 8, 12, 16, 24, 32, 64], + } + + def reify_timm_related_big_batch_s0(self): + return { + (sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512], + } + + def reify_mmpose_related_big_batch_s0(self): + return { + (sympy.Symbol("S0"),): [1, 4, 8, 16, 32, 64, 128, 256, 512], + } + + def reify_mmpose_related_s0_s1_s2(self): + S0S1S2 = (sympy.Symbol("S0"), sympy.Symbol("S1"), sympy.Symbol("S2")) + return { + S0S1S2: [ + [1, 256, 192], + [4, 128, 128], + [1, 384, 288], + [8, 256, 256], + [2, 512, 512], + [64, 96, 96], + [16, 256, 192], + [32, 192, 256], + [8, 480, 640], + ], + } + + def reify_vivit_related_s0_s1(self): + S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1")) + return { + S0S1: [ + [8, 112], + [16, 112], + [32, 112], + [8, 224], + [16, 224], + [64, 112], + [4, 448], + [8, 384], + [32, 224], + ], + } diff --git a/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py b/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py new file mode 100644 index 000000000..e706edb24 --- /dev/null +++ b/graph_net/torch/sym_dim_reifiers/naive_nlp_sym_dim_reifier.py @@ -0,0 +1,52 @@ +from graph_net.torch.sym_dim_reifiers.reify_util import get_dynamic_dim_constraints +from graph_net.torch.sym_dim_reifiers.reifier_base import ReifierBase +import os +import sympy + + +class ConcreteReifier(ReifierBase): + def __init__(self, model_path: str, **kwargs): + super().__init__(model_path) + self.dyn_dim_cstrs = get_dynamic_dim_constraints(model_path) + + def get_reifier_name(self) -> bool: + return os.path.basename(__file__)[:-3] + + def match(self) -> bool: + if self.dyn_dim_cstrs is None: + return False + sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() + return sym_shapes_str in self._get_map_nlp_sym_shapes_str2reifier() + + def reify(self): + assert self.need_reify() + sym_shapes_str = self.dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str() + reifier = self._get_map_nlp_sym_shapes_str2reifier()[sym_shapes_str] + return reifier(self) + + @classmethod + def _get_map_nlp_sym_shapes_str2reifier(cls): + if not hasattr(cls, "g_nlp_sym_shapes_str2reifier"): + cls.g_nlp_sym_shapes_str2reifier = { + "[(S0,1),(S0,S1),(S0,S1)]": cls.reify_batch_s0_seq_s1, + "[(S0,S1),(S0,S1),(S0,S1)]": cls.reify_batch_s0_seq_s1, + "[(S0,S1),(S0,S1)]": cls.reify_batch_s0_seq_s1, + "[(S0,S1,768)]": cls.reify_batch_s0_seq_s1, + } + return cls.g_nlp_sym_shapes_str2reifier + + def reify_batch_s0_seq_s1(self): + S0S1 = (sympy.Symbol("S0"), sympy.Symbol("S1")) + return { + S0S1: [ + [1, 64], + [1, 512], + [16, 128], + [32, 64], + [8, 256], + [4, 512], + [2, 1024], + [64, 128], + [128, 64], + ] + } diff --git a/graph_net/torch/sym_dim_reifiers/reifier_base.py b/graph_net/torch/sym_dim_reifiers/reifier_base.py new file mode 100644 index 000000000..5fea97563 --- /dev/null +++ b/graph_net/torch/sym_dim_reifiers/reifier_base.py @@ -0,0 +1,12 @@ +class ReifierBase: + def __init__(self, model_path: str): + self.model_path = model_path + + def get_reifier_name(self) -> bool: + raise NotImplementedError() + + def match(self) -> bool: + raise NotImplementedError() + + def reify(self): + raise NotImplementedError() diff --git a/graph_net/torch/sym_dim_reifiers/reifier_mgr.py b/graph_net/torch/sym_dim_reifiers/reifier_mgr.py new file mode 100644 index 000000000..57e06ca6a --- /dev/null +++ b/graph_net/torch/sym_dim_reifiers/reifier_mgr.py @@ -0,0 +1,12 @@ +from graph_net.imp_util import load_module +from graph_net.torch.sym_dim_reifiers.reifier_base import ReifierBase +import os + + +def get_reifier(reifier_name) -> ReifierBase: + import graph_net.torch.sym_dim_reifiers as sym_dim_reifiers + + py_module = load_module( + f"{os.path.dirname(sym_dim_reifiers.__file__)}/{reifier_name}.py" + ) + return py_module.ConcreteReifier diff --git a/graph_net/torch/sym_dim_reifiers/reify_util.py b/graph_net/torch/sym_dim_reifiers/reify_util.py new file mode 100644 index 000000000..a86e804a8 --- /dev/null +++ b/graph_net/torch/sym_dim_reifiers/reify_util.py @@ -0,0 +1,12 @@ +from pathlib import Path +from graph_net.dynamic_dim_constraints import DynamicDimConstraints + + +def get_dynamic_dim_constraints(model_path: str): + original_model_path = Path(model_path) + input_tensor_cstr_filepath = original_model_path / "input_tensor_constraints.py" + if not input_tensor_cstr_filepath.exists(): + return None + return DynamicDimConstraints.unserialize_from_py_file( + str(input_tensor_cstr_filepath) + )