diff --git a/graph_net/torch/backend/range_decomposer_backend.py b/graph_net/torch/backend/range_decomposer_backend.py index 5a410c7b2..56a044cf8 100644 --- a/graph_net/torch/backend/range_decomposer_backend.py +++ b/graph_net/torch/backend/range_decomposer_backend.py @@ -9,15 +9,6 @@ import graph_net -def convert_to_dict(config_str): - if config_str is None: - return {} - config_str = base64.b64decode(config_str).decode("utf-8") - config = json.loads(config_str) - assert isinstance(config, dict), f"config should be a dict. {config_str=}" - return config - - def encode_config(config: Dict[str, Any]) -> str: json_str = json.dumps(config) return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") @@ -34,7 +25,7 @@ def __init__(self): self.graph_net_root = Path(graph_net.__file__).parent def __call__(self, model: torch.nn.Module) -> torch.nn.Module: - config = convert_to_dict(self.config) + config = self.config workspace_path = Path(config["workspace_path"]) chain_style = config["chain_style"] diff --git a/graph_net/torch/backend/range_decomposer_validator_backend.py b/graph_net/torch/backend/range_decomposer_validator_backend.py index b1f71cdc7..cc811f3e1 100644 --- a/graph_net/torch/backend/range_decomposer_validator_backend.py +++ b/graph_net/torch/backend/range_decomposer_validator_backend.py @@ -1,11 +1,8 @@ import torch import torch.nn as nn import os -import sys -import inspect import importlib.util -import itertools -from typing import List, Tuple, Dict, Any, Callable +from typing import List class ComposedModel(nn.Module): @@ -14,18 +11,14 @@ def __init__(self, subgraph: List[nn.Module]): self.subgraphs = nn.ModuleList(subgraph) def forward(self, **kwargs): - subgraph_intput = { - key.replace("L", "l_l", 1): value - for key, value in kwargs.items() - if key.startswith("L") - } - output = None - for subgraph in self.subgraphs: + for i, subgraph in enumerate(self.subgraphs): + print(f"{i=} subgraph begin") if output is None: - output = subgraph(**subgraph_intput) + output = subgraph(**kwargs) else: output = subgraph(*output) + print(f"{i=} subgraph end") return output @@ -43,10 +36,20 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module: instance = ModelClass().to(device) return instance + def _make_config(self, decomposed_root, decomposed_model_name_suffix="_decomposed"): + return { + "decomposed_root": decomposed_root, + "decomposed_model_name_suffix": decomposed_model_name_suffix, + } + def __call__(self, model: torch.nn.Module) -> torch.nn.Module: + config = self._make_config(**self.config) model_file_path = model.__class__.__graph_net_file_path__ model_dir = os.path.dirname(model_file_path) - decomposed_parent_dir = model_dir + "_decomposed" + model_name = os.path.basename(model_dir) + decomposed_parent_dir = os.path.join( + config["decomposed_root"], f"{model_name}_decomposed" + ) subgraph_paths = [] for name in sorted(os.listdir(decomposed_parent_dir)): full_path = os.path.join(decomposed_parent_dir, name) diff --git a/graph_net/torch/naive_graph_decomposer.py b/graph_net/torch/naive_graph_decomposer.py index b11970346..7a298c974 100644 --- a/graph_net/torch/naive_graph_decomposer.py +++ b/graph_net/torch/naive_graph_decomposer.py @@ -101,6 +101,7 @@ def _make_config( post_extract_process_path=None, post_extract_process_class_name=None, post_extract_process_config=None, + model_path_prefix="", **kwargs, ): if post_extract_process_config is None: @@ -119,9 +120,11 @@ def _make_config( "post_extract_process_path": post_extract_process_path, "post_extract_process_class_name": post_extract_process_class_name, "post_extract_process_config": post_extract_process_config, + "model_path_prefix": model_path_prefix, } - def __call__(self, model_path): + def __call__(self, rel_model_path): + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) config = { k: v for k, v in self.config.items() @@ -171,7 +174,9 @@ def __init__( dynamic=False, mut_graph_codes=[], placeholder_auto_rename=False, - workspace_path=self.config["output_dir"], + workspace_path=os.path.join( + self.config["output_dir"], f"{parent_graph_name}_decomposed" + ), ) self.filter = self.make_filter(self.config) self.post_extract_process = self.make_post_extract_process(self.config) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index ae91c633c..60eaa48b1 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -1,20 +1,18 @@ from . import utils import argparse import importlib.util -import inspect import torch from pathlib import Path -from typing import Type, Any, List, Dict, Callable +from typing import Type import sys import os import os.path -from dataclasses import dataclass -from contextlib import contextmanager -import time +import traceback import json import random import numpy as np import platform +import base64 from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend from graph_net.torch.backend.tvm_backend import TvmBackend from graph_net.torch.backend.xla_backend import XlaBackend @@ -27,7 +25,6 @@ from graph_net.torch.backend.range_decomposer_validator_backend import ( RangeDecomposerValidatorBackend, ) -from graph_net.test_compiler_util import generate_allclose_configs from graph_net import test_compiler_util from graph_net import path_utils @@ -68,7 +65,7 @@ def get_compile_framework_version(args): return torch.__version__ elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: # Assuming compiler object has a version attribute - return f"{args.compiler.capitalize()} {compiler.version}" + return f"{args.compiler.capitalize()} {args.compiler.version}" return "unknown" @@ -94,11 +91,20 @@ def load_class_from_file( return model_class +def convert_to_dict(config_str): + if config_str is None: + return {} + config_str = base64.b64decode(config_str).decode("utf-8") + config = json.loads(config_str) + assert isinstance(config, dict), f"config should be a dict. {config_str=}" + return config + + def get_compiler_backend(args) -> GraphCompilerBackend: assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}" backend = registry_backend[args.compiler] if args.config is not None: - backend.config = args.config + backend.config = convert_to_dict(args.config) return backend @@ -203,11 +209,13 @@ def test_single_model(args): runtime_seed = 1024 eager_failure = False expected_out = None - eager_types = [] eager_time_stats = {} try: - eager_model_call = lambda: model(**input_dict) + + def eager_model_call(): + return model(**input_dict) + expected_out, eager_time_stats = measure_performance( eager_model_call, args, compiler ) @@ -221,13 +229,15 @@ def test_single_model(args): compiled_failure = False compiled_model = None - compiled_types = [] compiled_time_stats = {} try: compiled_model = compiler(model) torch.manual_seed(runtime_seed) - compiled_model_call = lambda: compiled_model(**input_dict) + + def compiled_model_call(): + return compiled_model(**input_dict) + compiled_out, compiled_time_stats = measure_performance( compiled_model_call, args, compiler ) @@ -239,6 +249,8 @@ def test_single_model(args): except (TypeError, RuntimeError) as e: print(f"Compiled model execution failed: {str(e)}", file=sys.stderr) compiled_failure = True + print("\n--- Full Traceback ---") + traceback.print_exc() if eager_failure: print(f"{args.log_prompt} [Result] status: failed", file=sys.stderr, flush=True) @@ -381,6 +393,7 @@ def test_multi_models(args): f"--warmup {args.warmup}", f"--trials {args.trials}", f"--log-prompt {args.log_prompt}", + f"--config {args.config}", ] ) cmd_ret = os.system(cmd) @@ -398,7 +411,37 @@ def test_multi_models(args): print(f"- {model_path}", file=sys.stderr, flush=True) +def test_multi_models_with_prefix(args): + assert os.path.isdir(args.model_path_prefix) + assert os.path.isfile(args.allow_list) + test_samples = test_compiler_util.get_allow_samples(args.allow_list) + py_module_name = os.path.splitext(os.path.basename(__file__))[0] + for rel_model_path in test_samples: + model_path = os.path.join(args.model_path_prefix, rel_model_path) + if not os.path.exists(model_path): + continue + if not os.path.exists(os.path.join(model_path, "model.py")): + continue + cmd = " ".join( + [ + sys.executable, + f"-m graph_net.torch.{py_module_name}", + f"--model-path {model_path}", + f"--compiler {args.compiler}", + f"--device {args.device}", + f"--warmup {args.warmup}", + f"--trials {args.trials}", + f"--log-prompt {args.log_prompt}", + f"--config {args.config}", + ] + ) + os.system(cmd) + + def main(args): + if args.model_path_prefix is not None: + test_multi_models_with_prefix(args) + return assert os.path.isdir(args.model_path) initalize_seed = 123 @@ -415,7 +458,8 @@ def main(args): parser.add_argument( "--model-path", type=str, - required=True, + required=False, + default=None, help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model", ) parser.add_argument( @@ -452,6 +496,13 @@ def main(args): default=None, help="Path to samples list, each line contains a sample path", ) + parser.add_argument( + "--model-path-prefix", + type=str, + required=False, + default=None, + help="Prefix path to model path list", + ) parser.add_argument( "--config", type=str,