From dc1d35b5ac51dbb979b91c6c5e581911ccec23aa Mon Sep 17 00:00:00 2001 From: JewelRoam <2752594773@qq.com> Date: Thu, 4 Dec 2025 16:46:13 +0800 Subject: [PATCH 1/4] Fix --- graph_net/torch/naive_graph_decomposer.py | 10 +++++--- .../torch/typical_sequence_split_points.py | 24 +++++++++---------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/graph_net/torch/naive_graph_decomposer.py b/graph_net/torch/naive_graph_decomposer.py index 23bb5ef60..9fef28ada 100644 --- a/graph_net/torch/naive_graph_decomposer.py +++ b/graph_net/torch/naive_graph_decomposer.py @@ -31,8 +31,11 @@ def make_config( filter_config=None, post_extract_process_path=None, post_extract_process_class_name=None, + post_extract_process_config=None, **kwargs, ): + if post_extract_process_config is None: + post_extract_process_config = {} for pos in split_positions: assert isinstance( pos, int @@ -46,6 +49,7 @@ def make_config( "filter_config": filter_config if filter_config is not None else {}, "post_extract_process_path": post_extract_process_path, "post_extract_process_class_name": post_extract_process_class_name, + "post_extract_process_config": post_extract_process_config, } def __call__(self, gm: torch.fx.GraphModule, sample_inputs): @@ -112,8 +116,8 @@ def make_filter(self, config): return module.GraphFilter(config["filter_config"]) def make_post_extract_process(self, config): - if config["post_extract_process_path"] is None: - return None + if config.get("post_extract_process_path") is None: + return lambda *args, **kwargs: None module = imp_util.load_module(config["post_extract_process_path"]) cls = getattr(module, config["post_extract_process_class_name"]) - return cls(config["post_extract_process_path"]) + return cls(config["post_extract_process_config"]) diff --git a/graph_net/torch/typical_sequence_split_points.py b/graph_net/torch/typical_sequence_split_points.py index 74dae9e28..9a8fb6445 100644 --- a/graph_net/torch/typical_sequence_split_points.py +++ b/graph_net/torch/typical_sequence_split_points.py @@ -2,7 +2,7 @@ import json import os from pathlib import Path -from typing import Any, Callable, Dict, List +from typing import Any, Dict, List import torch import torch.nn as nn @@ -252,7 +252,15 @@ def _print_analysis(self, name, path, splits, total_len, full_ops): print("\n") -def main(): +def main(args): + analyzer = SplitAnalyzer(window_size=args.window_size) + results = analyzer.analyze(args.model_list, args.device) + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": parser = argparse.ArgumentParser( description="Analyze graph and calculate split points." ) @@ -278,14 +286,4 @@ def main(): help="Path to save the analysis results in JSON format.", ) args = parser.parse_args() - - analyzer = SplitAnalyzer(window_size=args.window_size) - results = analyzer.analyze(args.model_list, args.device) - - if args.output_json: - with open(args.output_json, "w") as f: - json.dump(results, f, indent=4) - - -if __name__ == "__main__": - main() + main(args) From 0b1648ae13b282ab62a13cd46102cd95c75ac436 Mon Sep 17 00:00:00 2001 From: JewelRoam <2752594773@qq.com> Date: Thu, 4 Dec 2025 18:52:26 +0800 Subject: [PATCH 2/4] Optimize typical_sequence_decomposer_test --- .../test/typical_sequence_decomposer_test.sh | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/graph_net/test/typical_sequence_decomposer_test.sh b/graph_net/test/typical_sequence_decomposer_test.sh index c32507f4f..47be4c69c 100644 --- a/graph_net/test/typical_sequence_decomposer_test.sh +++ b/graph_net/test/typical_sequence_decomposer_test.sh @@ -1,46 +1,60 @@ #!/bin/bash GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace + +mkdir -p "$DECOMPOSE_PATH" -MODEL1="$GRAPH_NET_ROOT/samples/torchvision/resnet18" -MODEL2="$GRAPH_NET_ROOT/samples/torchvision/resnet34" -MODEL_LIST_FILE=$(mktemp) -echo "$MODEL1" > "$MODEL_LIST_FILE" -echo "$MODEL2" >> "$MODEL_LIST_FILE" +temp_model_list=$(mktemp) +cat "$GRAPH_NET_ROOT/graph_net/config/torch_samples_list.txt" > "$temp_model_list" python3 -m graph_net.torch.typical_sequence_split_points \ - --model-list "$MODEL_LIST_FILE" \ + --model-list "$temp_model_list" \ --device "cuda" \ --window-size 10 \ - --output-json "$GRAPH_NET_ROOT/split_results.json" + --output-json "$DECOMPOSE_PATH/split_results.json" -rm -f "$MODEL_LIST_FILE" +while IFS= read -r MODEL_PATH_IN_SAMPLES; do + if [[ -n "$MODEL_PATH_IN_SAMPLES" ]]; then + MODEL_FULL_PATH="$GRAPH_NET_ROOT/$MODEL_PATH_IN_SAMPLES" + MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES") + echo "== Decomposing $MODEL_PATH_IN_SAMPLES. ==" -MODEL_PATH_IN_SAMPLES=/torchvision/resnet18 -MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES") - -decomposer_config_json_str=$(cat < "$DECOMPOSE_PATH/${MODEL_NAME}_validation.log" 2>&1 + + echo "== Finished processing $MODEL_PATH_IN_SAMPLES. ==" + fi +done < $temp_model_list + +rm -f "$temp_model_list" -python3 -m graph_net.torch.test_compiler \ - --model-path $DECOMPOSE_PATH/$MODEL_NAME \ - --compiler range_decomposer_validator \ - --device cuda > "$DECOMPOSE_PATH/log.log" 2>&1 +cat $DECOMPOSE_PATH/*_validation.log >> $DECOMPOSE_PATH/combined.log python3 -m graph_net.plot_ESt \ - --benchmark-path $DECOMPOSE_PATH/log.log \ - --output-dir $DECOMPOSE_PATH \ \ No newline at end of file + --benchmark-path "$DECOMPOSE_PATH/combined.log" \ + --output-dir "$DECOMPOSE_PATH" \ No newline at end of file From 517b86e6ddfe5287417eb65c0021fbc0c9ac77ea Mon Sep 17 00:00:00 2001 From: JewelRoam <2752594773@qq.com> Date: Thu, 4 Dec 2025 22:43:23 +0800 Subject: [PATCH 3/4] change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler --- graph_net/test/naive_graph_decomposer_test.sh | 25 ++-- graph_net/torch/fx_graph_cache_util.py | 40 +----- graph_net/torch/fx_graph_module_util.py | 37 ++++++ graph_net/torch/fx_graph_parse_util.py | 33 +++++ graph_net/torch/naive_graph_decomposer.py | 116 +++++++++++++++--- 5 files changed, 183 insertions(+), 68 deletions(-) create mode 100644 graph_net/torch/fx_graph_module_util.py diff --git a/graph_net/test/naive_graph_decomposer_test.sh b/graph_net/test/naive_graph_decomposer_test.sh index ea9a80a59..cd23c9767 100755 --- a/graph_net/test/naive_graph_decomposer_test.sh +++ b/graph_net/test/naive_graph_decomposer_test.sh @@ -6,23 +6,20 @@ os.path.dirname(graph_net.__file__))") # input model path MODEL_NAME=resnet18 MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME -decorator_config_json_str=$(cat < 0 and set(get_input_names_from_signature()) == set( + get_input_names_from_placeholder() + ): + traced_module = _reorder_placeholders( + traced_module, get_input_names_from_signature() + ) + zip_filter_names = get_zip_filter_names() def zip_filter_names_str(): @@ -83,5 +90,31 @@ def zip_filter_names_str(): print(triple) return "" + from pathlib import Path + + Path("/tmp/a.py").write_text(traced_module.code) assert len(zip_filter_names) == 0, f"{zip_filter_names_str()=}" return traced_module + + +def _reorder_placeholders(gm, sorted_names): + sorted_names = list(sorted_names) + name2placeholder = { + node.name: node for node in gm.graph.nodes if node.op == "placeholder" + } + for i, current_placeholder_name in enumerate(sorted_names): + if i == 0: + continue + prev_node = name2placeholder[sorted_names[i - 1]] + current_node = name2placeholder[current_placeholder_name] + with gm.graph.inserting_after(prev_node): + new_node = gm.graph.placeholder(current_node.name) + # force rename + new_node.name = current_node.name + new_node.target = current_node.target + current_node.replace_all_uses_with(new_node) + name2placeholder[current_placeholder_name] = new_node + gm.graph.erase_node(current_node) + + gm.recompile() + return gm diff --git a/graph_net/torch/naive_graph_decomposer.py b/graph_net/torch/naive_graph_decomposer.py index 9fef28ada..b11970346 100644 --- a/graph_net/torch/naive_graph_decomposer.py +++ b/graph_net/torch/naive_graph_decomposer.py @@ -3,9 +3,15 @@ from graph_net.torch.decompose_util import convert_to_submodules_graph from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor import graph_net.imp_util as imp_util +from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs +from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module class GraphExtractor: + """ + Used by graph_net.torch.run_model + """ + def __init__( self, config: dict, @@ -66,29 +72,109 @@ def __call__(self, gm: torch.fx.GraphModule, sample_inputs): return rewrited_gm def get_naive_decomposer_extractor(self, submodule, seq_no): - return NaiveDecomposerExtractor(self, submodule, seq_no) + return NaiveDecomposerExtractorModule( + config=self.config, + parent_graph_name=self.name, + submodule=submodule, + seq_no=seq_no, + ) + + +class NaiveDecomposerExtractor: + """ + Used by graph_net.model_path_handler + """ + + def __init__(self, config: dict = None): + if config is None: + config = {} + self.config = self._make_config(**config) + + def _make_config( + self, + split_positions=(), + group_head_and_tail=False, + chain_style=False, + output_dir="./tmp/naive_decomposer_dir", + filter_path=None, + filter_config=None, + post_extract_process_path=None, + post_extract_process_class_name=None, + post_extract_process_config=None, + **kwargs, + ): + if post_extract_process_config is None: + post_extract_process_config = {} + for pos in split_positions: + assert isinstance( + pos, int + ), f"split_positions should be list of int, {split_positions=}" + return { + "split_positions": split_positions, + "group_head_and_tail": group_head_and_tail, + "chain_style": chain_style, + "output_dir": output_dir, + "filter_path": filter_path, + "filter_config": filter_config if filter_config is not None else {}, + "post_extract_process_path": post_extract_process_path, + "post_extract_process_class_name": post_extract_process_class_name, + "post_extract_process_config": post_extract_process_config, + } + + def __call__(self, model_path): + config = { + k: v + for k, v in self.config.items() + if k in {"split_positions", "group_head_and_tail", "chain_style"} + } + module, inputs = get_torch_module_and_inputs(model_path) + gm = parse_sole_graph_module(module, inputs) + rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( + gm, + submodule_hook=self.get_naive_decomposer_extractor(model_path), + **config, + ) + rewrited_gm(*inputs) + + def get_naive_decomposer_extractor(self, model_path): + def fn(submodule, seq_no): + return NaiveDecomposerExtractorModule( + config=self.config, + parent_graph_name=os.path.basename(model_path), + submodule=submodule, + seq_no=seq_no, + ) + + return fn -class NaiveDecomposerExtractor(torch.nn.Module): - def __init__(self, parent_graph_extractor, submodule, seq_no): +class NaiveDecomposerExtractorModule(torch.nn.Module): + def __init__( + self, + config: dict, + parent_graph_name: str, + submodule: torch.nn.Module, + seq_no: int, + ): super().__init__() - self.parent_graph_extractor = parent_graph_extractor + self.config = config self.submodule = submodule self.seq_no = seq_no self.extracted = False - name = f"{parent_graph_extractor.name}_{self.seq_no}" - self.model_name = name + if self.seq_no is None: + self.model_name = parent_graph_name + else: + submodule_name = f"{parent_graph_name}_{self.seq_no}" + self.model_name = submodule_name self.builtin_extractor = BuiltinGraphExtractor( - name=name, + name=submodule_name, dynamic=False, mut_graph_codes=[], - placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename, - workspace_path=self.parent_graph_extractor.config["output_dir"], - ) - self.filter = self.make_filter(self.parent_graph_extractor.config) - self.post_extract_process = self.make_post_extract_process( - self.parent_graph_extractor.config + placeholder_auto_rename=False, + workspace_path=self.config["output_dir"], ) + self.filter = self.make_filter(self.config) + self.post_extract_process = self.make_post_extract_process(self.config) def forward(self, *args): if not self.extracted: @@ -104,9 +190,7 @@ def need_extract(self, gm, sample_inputs): return self.filter(gm, sample_inputs) def _post_extract_process(self): - model_path = os.path.join( - self.parent_graph_extractor.config["output_dir"], self.model_name - ) + model_path = os.path.join(self.config["output_dir"], self.model_name) return self.post_extract_process(model_path) def make_filter(self, config): From 7ae9b9bac55862b0f26812516bb9e83a76c80d02 Mon Sep 17 00:00:00 2001 From: JewelRoam <2752594773@qq.com> Date: Fri, 5 Dec 2025 00:45:03 +0800 Subject: [PATCH 4/4] update test_compiler and validator backend to support config and model_list --- .../torch/backend/range_decomposer_backend.py | 11 +-- .../range_decomposer_validator_backend.py | 29 +++---- graph_net/torch/naive_graph_decomposer.py | 9 ++- graph_net/torch/test_compiler.py | 77 +++++++++++++++---- 4 files changed, 88 insertions(+), 38 deletions(-) diff --git a/graph_net/torch/backend/range_decomposer_backend.py b/graph_net/torch/backend/range_decomposer_backend.py index 5a410c7b2..56a044cf8 100644 --- a/graph_net/torch/backend/range_decomposer_backend.py +++ b/graph_net/torch/backend/range_decomposer_backend.py @@ -9,15 +9,6 @@ import graph_net -def convert_to_dict(config_str): - if config_str is None: - return {} - config_str = base64.b64decode(config_str).decode("utf-8") - config = json.loads(config_str) - assert isinstance(config, dict), f"config should be a dict. {config_str=}" - return config - - def encode_config(config: Dict[str, Any]) -> str: json_str = json.dumps(config) return base64.b64encode(json_str.encode("utf-8")).decode("utf-8") @@ -34,7 +25,7 @@ def __init__(self): self.graph_net_root = Path(graph_net.__file__).parent def __call__(self, model: torch.nn.Module) -> torch.nn.Module: - config = convert_to_dict(self.config) + config = self.config workspace_path = Path(config["workspace_path"]) chain_style = config["chain_style"] diff --git a/graph_net/torch/backend/range_decomposer_validator_backend.py b/graph_net/torch/backend/range_decomposer_validator_backend.py index b1f71cdc7..cc811f3e1 100644 --- a/graph_net/torch/backend/range_decomposer_validator_backend.py +++ b/graph_net/torch/backend/range_decomposer_validator_backend.py @@ -1,11 +1,8 @@ import torch import torch.nn as nn import os -import sys -import inspect import importlib.util -import itertools -from typing import List, Tuple, Dict, Any, Callable +from typing import List class ComposedModel(nn.Module): @@ -14,18 +11,14 @@ def __init__(self, subgraph: List[nn.Module]): self.subgraphs = nn.ModuleList(subgraph) def forward(self, **kwargs): - subgraph_intput = { - key.replace("L", "l_l", 1): value - for key, value in kwargs.items() - if key.startswith("L") - } - output = None - for subgraph in self.subgraphs: + for i, subgraph in enumerate(self.subgraphs): + print(f"{i=} subgraph begin") if output is None: - output = subgraph(**subgraph_intput) + output = subgraph(**kwargs) else: output = subgraph(*output) + print(f"{i=} subgraph end") return output @@ -43,10 +36,20 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module: instance = ModelClass().to(device) return instance + def _make_config(self, decomposed_root, decomposed_model_name_suffix="_decomposed"): + return { + "decomposed_root": decomposed_root, + "decomposed_model_name_suffix": decomposed_model_name_suffix, + } + def __call__(self, model: torch.nn.Module) -> torch.nn.Module: + config = self._make_config(**self.config) model_file_path = model.__class__.__graph_net_file_path__ model_dir = os.path.dirname(model_file_path) - decomposed_parent_dir = model_dir + "_decomposed" + model_name = os.path.basename(model_dir) + decomposed_parent_dir = os.path.join( + config["decomposed_root"], f"{model_name}_decomposed" + ) subgraph_paths = [] for name in sorted(os.listdir(decomposed_parent_dir)): full_path = os.path.join(decomposed_parent_dir, name) diff --git a/graph_net/torch/naive_graph_decomposer.py b/graph_net/torch/naive_graph_decomposer.py index b11970346..7a298c974 100644 --- a/graph_net/torch/naive_graph_decomposer.py +++ b/graph_net/torch/naive_graph_decomposer.py @@ -101,6 +101,7 @@ def _make_config( post_extract_process_path=None, post_extract_process_class_name=None, post_extract_process_config=None, + model_path_prefix="", **kwargs, ): if post_extract_process_config is None: @@ -119,9 +120,11 @@ def _make_config( "post_extract_process_path": post_extract_process_path, "post_extract_process_class_name": post_extract_process_class_name, "post_extract_process_config": post_extract_process_config, + "model_path_prefix": model_path_prefix, } - def __call__(self, model_path): + def __call__(self, rel_model_path): + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) config = { k: v for k, v in self.config.items() @@ -171,7 +174,9 @@ def __init__( dynamic=False, mut_graph_codes=[], placeholder_auto_rename=False, - workspace_path=self.config["output_dir"], + workspace_path=os.path.join( + self.config["output_dir"], f"{parent_graph_name}_decomposed" + ), ) self.filter = self.make_filter(self.config) self.post_extract_process = self.make_post_extract_process(self.config) diff --git a/graph_net/torch/test_compiler.py b/graph_net/torch/test_compiler.py index ae91c633c..60eaa48b1 100644 --- a/graph_net/torch/test_compiler.py +++ b/graph_net/torch/test_compiler.py @@ -1,20 +1,18 @@ from . import utils import argparse import importlib.util -import inspect import torch from pathlib import Path -from typing import Type, Any, List, Dict, Callable +from typing import Type import sys import os import os.path -from dataclasses import dataclass -from contextlib import contextmanager -import time +import traceback import json import random import numpy as np import platform +import base64 from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend from graph_net.torch.backend.tvm_backend import TvmBackend from graph_net.torch.backend.xla_backend import XlaBackend @@ -27,7 +25,6 @@ from graph_net.torch.backend.range_decomposer_validator_backend import ( RangeDecomposerValidatorBackend, ) -from graph_net.test_compiler_util import generate_allclose_configs from graph_net import test_compiler_util from graph_net import path_utils @@ -68,7 +65,7 @@ def get_compile_framework_version(args): return torch.__version__ elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]: # Assuming compiler object has a version attribute - return f"{args.compiler.capitalize()} {compiler.version}" + return f"{args.compiler.capitalize()} {args.compiler.version}" return "unknown" @@ -94,11 +91,20 @@ def load_class_from_file( return model_class +def convert_to_dict(config_str): + if config_str is None: + return {} + config_str = base64.b64decode(config_str).decode("utf-8") + config = json.loads(config_str) + assert isinstance(config, dict), f"config should be a dict. {config_str=}" + return config + + def get_compiler_backend(args) -> GraphCompilerBackend: assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}" backend = registry_backend[args.compiler] if args.config is not None: - backend.config = args.config + backend.config = convert_to_dict(args.config) return backend @@ -203,11 +209,13 @@ def test_single_model(args): runtime_seed = 1024 eager_failure = False expected_out = None - eager_types = [] eager_time_stats = {} try: - eager_model_call = lambda: model(**input_dict) + + def eager_model_call(): + return model(**input_dict) + expected_out, eager_time_stats = measure_performance( eager_model_call, args, compiler ) @@ -221,13 +229,15 @@ def test_single_model(args): compiled_failure = False compiled_model = None - compiled_types = [] compiled_time_stats = {} try: compiled_model = compiler(model) torch.manual_seed(runtime_seed) - compiled_model_call = lambda: compiled_model(**input_dict) + + def compiled_model_call(): + return compiled_model(**input_dict) + compiled_out, compiled_time_stats = measure_performance( compiled_model_call, args, compiler ) @@ -239,6 +249,8 @@ def test_single_model(args): except (TypeError, RuntimeError) as e: print(f"Compiled model execution failed: {str(e)}", file=sys.stderr) compiled_failure = True + print("\n--- Full Traceback ---") + traceback.print_exc() if eager_failure: print(f"{args.log_prompt} [Result] status: failed", file=sys.stderr, flush=True) @@ -381,6 +393,7 @@ def test_multi_models(args): f"--warmup {args.warmup}", f"--trials {args.trials}", f"--log-prompt {args.log_prompt}", + f"--config {args.config}", ] ) cmd_ret = os.system(cmd) @@ -398,7 +411,37 @@ def test_multi_models(args): print(f"- {model_path}", file=sys.stderr, flush=True) +def test_multi_models_with_prefix(args): + assert os.path.isdir(args.model_path_prefix) + assert os.path.isfile(args.allow_list) + test_samples = test_compiler_util.get_allow_samples(args.allow_list) + py_module_name = os.path.splitext(os.path.basename(__file__))[0] + for rel_model_path in test_samples: + model_path = os.path.join(args.model_path_prefix, rel_model_path) + if not os.path.exists(model_path): + continue + if not os.path.exists(os.path.join(model_path, "model.py")): + continue + cmd = " ".join( + [ + sys.executable, + f"-m graph_net.torch.{py_module_name}", + f"--model-path {model_path}", + f"--compiler {args.compiler}", + f"--device {args.device}", + f"--warmup {args.warmup}", + f"--trials {args.trials}", + f"--log-prompt {args.log_prompt}", + f"--config {args.config}", + ] + ) + os.system(cmd) + + def main(args): + if args.model_path_prefix is not None: + test_multi_models_with_prefix(args) + return assert os.path.isdir(args.model_path) initalize_seed = 123 @@ -415,7 +458,8 @@ def main(args): parser.add_argument( "--model-path", type=str, - required=True, + required=False, + default=None, help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model", ) parser.add_argument( @@ -452,6 +496,13 @@ def main(args): default=None, help="Path to samples list, each line contains a sample path", ) + parser.add_argument( + "--model-path-prefix", + type=str, + required=False, + default=None, + help="Prefix path to model path list", + ) parser.add_argument( "--config", type=str,