diff --git a/graph_net/torch/constraint_util.py b/graph_net/torch/constraint_util.py new file mode 100644 index 000000000..a4b29dacd --- /dev/null +++ b/graph_net/torch/constraint_util.py @@ -0,0 +1,54 @@ +import copy + + +def get_all_symbol_names(constraint_attrs_list): + unique_symbol_names = [] + for constraint_attrs in constraint_attrs_list: + for dim in constraint_attrs["shape"]: + if isinstance(dim, int): + continue + assert isinstance(dim, dict) + if dim["symbol_name"] in unique_symbol_names: + continue + unique_symbol_names.append(dim["symbol_name"]) + + return unique_symbol_names + + +def reify_symboli_dims(constraint_attrs_list, symbol_names): + def try_reify_dim(dim): + if isinstance(dim, int): + return dim + assert isinstance(dim, dict) + if dim["symbol_name"] not in symbol_names: + return dim + return dim["example_value"] + + constraint_attrs_list = copy.deepcopy(constraint_attrs_list) + for constraint_attrs in constraint_attrs_list: + constraint_attrs["shape"] = [ + try_reify_dim(dim) for dim in constraint_attrs["shape"] + ] + return constraint_attrs_list + + +def modify_dim_example_value(constraint_attrs_list, symbol_name, modifier): + def modify_dim(dim): + if isinstance(dim, int): + return + assert isinstance(dim, dict) + dim["example_value"] = modifier(dim["example_value"]) + + constraint_attrs_list = copy.deepcopy(constraint_attrs_list) + for constraint_attrs in constraint_attrs_list: + for dim in constraint_attrs["shape"]: + modify_dim(dim) + return constraint_attrs_list + + +def symbolic_dims_all_reified(constraint_attrs_list): + for constraint_attrs in constraint_attrs_list: + for dim in constraint_attrs["shape"]: + if not isinstance(dim, int): + return False + return True diff --git a/graph_net/torch/generate_constraint_proposal_file.py b/graph_net/torch/generate_constraint_proposal_file.py new file mode 100644 index 000000000..9cb7a1cd0 --- /dev/null +++ b/graph_net/torch/generate_constraint_proposal_file.py @@ -0,0 +1,140 @@ +from graph_net.torch import utils +import argparse +import torch +import logging +from pathlib import Path +from typing import Type, Any +import sys +from graph_net.torch.imp_util import load_class_from_file +import hashlib +from contextlib import contextmanager +import json +import inspect +import imp_util +import record_util +import copy + + +def main(args): + model_path = args.model_path + name2input_param_attrs = _get_name2input_param_attrs(model_path) + name_and_annotation_types = _get_name_and_annotation_types(model_path) + input_name_and_meta_attrs = _get_input_name_and_meta_attrs( + name2input_param_attrs, name_and_annotation_types + ) + input_name_and_constraint_attrs = _get_input_name_and_constraint_attrs( + input_name_and_meta_attrs + ) + _dump_input_name_and_constraint_attrs( + input_name_and_constraint_attrs, args.output_path + ) + + +def _dump_input_name_and_constraint_attrs(input_name_and_constraint_attrs, output_path): + py_code = record_util.serialize_to_py_code( + [attr for _, attr in input_name_and_constraint_attrs], + class_prefix="ProgramInputConstraint", + ) + print(f"{output_path=}") + with open(output_path, "w") as f: + f.write(py_code) + + +def _get_input_name_and_constraint_attrs(input_name_and_meta_attrs): + seq_no = 0 + dim2seq = {} + + def find_or_new_seq(dim): + nonlocal seq_no + nonlocal dim2seq + if dim in dim2seq: + return dim2seq[dim] + ret = seq_no + dim2seq[dim] = ret + seq_no += 1 + return ret + + def make_symoblic_shape(shape): + return type(shape)( + [ + symbolic_dim_desc + for dim in shape + for dim_seq_no in [find_or_new_seq(dim)] + for symbolic_dim_desc in [ + {"symbol_name": f"s{dim_seq_no}", "example_value": dim} + ] + ] + ) + + def make_constraint_attrs(attrs): + attrs = copy.deepcopy(attrs) + attrs["shape"] = make_symoblic_shape(attrs["shape"]) + return attrs + + return [ + (name, symbolic_attrs) + for name, attrs in input_name_and_meta_attrs + for symbolic_attrs in [make_constraint_attrs(attrs)] + ] + + +def _get_input_name_and_meta_attrs(name2input_param_attrs, name_and_annotation_types): + def constructed_from_self(name): + return name.find("self_") != -1 + + def is_tensor_type(annotation_type): + return annotation_type is torch.Tensor + + ret = [ + (name, meta_attr) + for name, annotation_type in name_and_annotation_types + if is_tensor_type(annotation_type) + if not constructed_from_self(name) + for meta_attr in [name2input_param_attrs[name]] + ] + assert len(ret) > 0 + return ret + + +def _get_name_and_annotation_types(model_path): + model_class = load_class_from_file( + f"{model_path}/model.py", class_name="GraphModule" + ) + annotations = inspect.getfullargspec(model_class.forward).annotations + return [(k, v) for k, v in annotations.items()] + + +def _get_name2input_param_attrs(model_path): + def get_classes(): + input_meta_file = f"{model_path}/input_meta.py" + for _, cls in imp_util.load_name_and_classes_from_file(input_meta_file): + yield cls + + weight_meta_file = f"{model_path}/weight_meta.py" + for _, cls in imp_util.load_name_and_classes_from_file(weight_meta_file): + yield cls + + return { + name: attr + for cls in get_classes() + for attr in [record_util.make_attrs_from_class(cls)] + for name in [attr["name"]] + } + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="generate constraint proposal file") + parser.add_argument( + "--model-path", + type=str, + required=True, + help="Path to folder e.g '../../samples/torch/resnet18'", + ) + parser.add_argument( + "--output-path", + type=str, + required=True, + help="output file path", + ) + args = parser.parse_args() + main(args=args) diff --git a/graph_net/torch/hash_util.py b/graph_net/torch/hash_util.py new file mode 100644 index 000000000..a8780470d --- /dev/null +++ b/graph_net/torch/hash_util.py @@ -0,0 +1,7 @@ +import hashlib + + +def get_sha_hash(content): + m = hashlib.sha256() + m.update(content.encode()) + return m.hexdigest() diff --git a/graph_net/torch/imp_util.py b/graph_net/torch/imp_util.py new file mode 100644 index 000000000..233c58c8e --- /dev/null +++ b/graph_net/torch/imp_util.py @@ -0,0 +1,17 @@ +import importlib.util +import inspect + + +def load_class_from_file(file_path: str, class_name: str): + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + model_class = getattr(unnamed, class_name, None) + return model_class + + +def load_name_and_classes_from_file(file_path): + spec = importlib.util.spec_from_file_location("unnamed", file_path) + unnamed = importlib.util.module_from_spec(spec) + spec.loader.exec_module(unnamed) + yield from inspect.getmembers(unnamed, inspect.isclass) diff --git a/graph_net/torch/record_util.py b/graph_net/torch/record_util.py new file mode 100644 index 000000000..c48882231 --- /dev/null +++ b/graph_net/torch/record_util.py @@ -0,0 +1,31 @@ +import hash_util + + +def make_attrs_from_class(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") and not callable(v) + } + + +def serialize_to_py_code(attrs, class_prefix): + assert isinstance(attrs, (tuple, list)) + + ret = "\n".join( + _serialize_one_attr_to_py_code(attr, class_prefix) for attr in attrs + ) + return ret + + +def _serialize_one_attr_to_py_code(attr, class_prefix): + hash_str = hash_util.get_sha_hash(str(attr)) + hash_str = hash_str[:32] + indent = " " * 4 + ret = "\n".join( + [ + f"class {class_prefix}{hash_str}:", + *[f"{indent}{name} = {repr(value)}" for name, value in attr.items()], + ] + ) + return f"{ret}\n\n" diff --git a/graph_net/torch/single_device_runner.py b/graph_net/torch/single_device_runner.py index 784e092d8..5afed3863 100644 --- a/graph_net/torch/single_device_runner.py +++ b/graph_net/torch/single_device_runner.py @@ -1,4 +1,4 @@ -from . import utils +from graph_net.torch import utils import argparse import importlib.util import inspect @@ -10,6 +10,10 @@ from graph_net.torch.extractor import extract import hashlib from contextlib import contextmanager +import json +import record_util +import imp_util +import os def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]: @@ -62,9 +66,14 @@ def main(args): kwargs = dict(name=args.extract_name, dynamic=False, **dump_graph_options) model = extract(**kwargs)(model) - inputs_params = utils.load_converted_from_text(f"{model_path}") + inputs_params = utils.make_input_and_param_tensors_from_model_path( + f"{model_path}" + ) params = inputs_params["weight_info"] - state_dict = {k: utils.replay_tensor(v) for k, v in params.items()} + shape_modifier = _get_shape_modifier(args) + state_dict = { + k: utils.replay_tensor(v, shape_modifier) for k, v in params.items() + } explain = torch._dynamo.explain(model)(**state_dict) if explain.graph_count != 1 or len(explain.break_reasons) != 0: @@ -76,10 +85,31 @@ def main(args): f"Graph extraction failed. The resulting graph is incomplete, broken into {explain.graph_count} subgraphs." ) - y = model(**state_dict)[0] + model(**state_dict) + + +def _get_shape_modifier(cli_args): + """ + yield shape modifier from shape_modifiers.json in directory cli_args.model_path + """ + if not cli_args.enable_shape_patch: + return lambda name, shape: shape + shape_patch_file_path = f"{cli_args.model_path}/shape_patch.py" + if not os.path.exists(shape_patch_file_path): + return lambda name, shape: shape + shape_modifier_data = [ + attrs + for name, cls in imp_util.load_name_and_classes_from_file(shape_patch_file_path) + for attrs in [record_util.make_attrs_from_class(cls)] + ] + assert isinstance(shape_modifier_data, list) + return _make_shape_modifier_impl(shape_modifier_data) - print(torch.argmin(y), torch.argmax(y)) - print(y.shape) + +def _make_shape_modifier_impl(shape_modifier_data): + name2new_shape = {attrs["name"]: attrs["shape"] for attrs in shape_modifier_data} + print(f"{name2new_shape=}") + return lambda name, shape: name2new_shape[name] if name in name2new_shape else shape if __name__ == "__main__": @@ -110,5 +140,12 @@ def main(args): default=None, help="Extracted graph's name", ) + parser.add_argument( + "--enable-shape-patch", + type=bool, + required=False, + default=False, + help="Enable extra inputs", + ) args = parser.parse_args() main(args=args) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index 5922991c5..d141647a0 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -63,7 +63,9 @@ def get_model(args, device): def get_input_dict(args): - inputs_params = utils.load_converted_from_text(f"{args.model_path}") + inputs_params = utils.make_input_and_param_tensors_from_model_path( + f"{args.model_path}" + ) params = inputs_params["weight_info"] for tensor_meta in params.values(): if hasattr(tensor_meta, "device"): diff --git a/graph_net/torch/utils.py b/graph_net/torch/utils.py index a0a05fc73..8bcd14d6f 100644 --- a/graph_net/torch/utils.py +++ b/graph_net/torch/utils.py @@ -10,6 +10,10 @@ import inspect import argparse import importlib +import inspect +import math +import graph_net.torch.imp_util as imp_util +import graph_net.torch.record_util as record_util kLiteralTensorSize = 64 @@ -196,6 +200,10 @@ def process_tensor_info(tensor_info, name_prefix="example_input"): def load_converted_from_text(file_path): + return make_input_and_param_tensors_from_model_path(file_path) + + +def make_input_and_param_tensors_from_model_path(file_path): input_info = list(convert_meta_classes_to_tensors(f"{file_path}/input_meta.py")) weight_info = { @@ -210,13 +218,14 @@ def load_converted_from_text(file_path): } +def convert_tensor_meta_file_to_attrs(file_path): + for name, cls in imp_util.load_name_and_classes_from_file(file_path): + attrs = record_util.make_attrs_from_class(cls) + yield attrs + + def convert_meta_classes_to_tensors(file_path): - for name, cls in _get_classes(file_path): - attrs = { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") and not callable(v) - } + for attrs in convert_tensor_meta_file_to_attrs(file_path): data_value = None data_type = getattr(torch, attrs.get("dtype", "torch.float").split(".")[-1]) shape = attrs.get("shape", []) @@ -248,21 +257,13 @@ def convert_meta_classes_to_tensors(file_path): } -def _get_classes(file_path): - spec = importlib.util.spec_from_file_location("unnamed", file_path) - unnamed = importlib.util.module_from_spec(spec) - spec.loader.exec_module(unnamed) - yield from inspect.getmembers(unnamed, inspect.isclass) - - -def extract_dynamic_shapes(example_inputs): - pass - - -def replay_tensor(info): +def replay_tensor(info, shape_modifier=None): device = info["info"]["device"] dtype = info["info"]["dtype"] shape = info["info"]["shape"] + if shape_modifier is None: + shape_modifier = lambda name, shape: shape + shape = shape_modifier(info["name"], shape) mean = info["info"]["mean"] std = info["info"]["std"] diff --git a/samples/torchvision/resnet18/shape_patch.py b/samples/torchvision/resnet18/shape_patch.py new file mode 100644 index 000000000..532aecb28 --- /dev/null +++ b/samples/torchvision/resnet18/shape_patch.py @@ -0,0 +1,8 @@ +class Program_weight_tensor_meta_L_x_: + name = "L_x_" + shape = [2, 3, 224, 224] + dtype = "torch.float32" + device = "cuda:0" + mean = 0.500 + std = 0.289 + data = None