Skip to content

Commit e6a16e2

Browse files
authored
Update test_compiler and range_decomposer_validator_backend (#416)
* Fix * Optimize typical_sequence_decomposer_test * change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler * update test_compiler and validator backend to support config and model_list
1 parent c86aceb commit e6a16e2

File tree

4 files changed

+88
-38
lines changed

4 files changed

+88
-38
lines changed

graph_net/torch/backend/range_decomposer_backend.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,6 @@
99
import graph_net
1010

1111

12-
def convert_to_dict(config_str):
13-
if config_str is None:
14-
return {}
15-
config_str = base64.b64decode(config_str).decode("utf-8")
16-
config = json.loads(config_str)
17-
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
18-
return config
19-
20-
2112
def encode_config(config: Dict[str, Any]) -> str:
2213
json_str = json.dumps(config)
2314
return base64.b64encode(json_str.encode("utf-8")).decode("utf-8")
@@ -34,7 +25,7 @@ def __init__(self):
3425
self.graph_net_root = Path(graph_net.__file__).parent
3526

3627
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
37-
config = convert_to_dict(self.config)
28+
config = self.config
3829
workspace_path = Path(config["workspace_path"])
3930
chain_style = config["chain_style"]
4031

graph_net/torch/backend/range_decomposer_validator_backend.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
import torch
22
import torch.nn as nn
33
import os
4-
import sys
5-
import inspect
64
import importlib.util
7-
import itertools
8-
from typing import List, Tuple, Dict, Any, Callable
5+
from typing import List
96

107

118
class ComposedModel(nn.Module):
@@ -14,18 +11,14 @@ def __init__(self, subgraph: List[nn.Module]):
1411
self.subgraphs = nn.ModuleList(subgraph)
1512

1613
def forward(self, **kwargs):
17-
subgraph_intput = {
18-
key.replace("L", "l_l", 1): value
19-
for key, value in kwargs.items()
20-
if key.startswith("L")
21-
}
22-
2314
output = None
24-
for subgraph in self.subgraphs:
15+
for i, subgraph in enumerate(self.subgraphs):
16+
print(f"{i=} subgraph begin")
2517
if output is None:
26-
output = subgraph(**subgraph_intput)
18+
output = subgraph(**kwargs)
2719
else:
2820
output = subgraph(*output)
21+
print(f"{i=} subgraph end")
2922

3023
return output
3124

@@ -43,10 +36,20 @@ def _load_model_instance(self, path: str, device: str) -> torch.nn.Module:
4336
instance = ModelClass().to(device)
4437
return instance
4538

39+
def _make_config(self, decomposed_root, decomposed_model_name_suffix="_decomposed"):
40+
return {
41+
"decomposed_root": decomposed_root,
42+
"decomposed_model_name_suffix": decomposed_model_name_suffix,
43+
}
44+
4645
def __call__(self, model: torch.nn.Module) -> torch.nn.Module:
46+
config = self._make_config(**self.config)
4747
model_file_path = model.__class__.__graph_net_file_path__
4848
model_dir = os.path.dirname(model_file_path)
49-
decomposed_parent_dir = model_dir + "_decomposed"
49+
model_name = os.path.basename(model_dir)
50+
decomposed_parent_dir = os.path.join(
51+
config["decomposed_root"], f"{model_name}_decomposed"
52+
)
5053
subgraph_paths = []
5154
for name in sorted(os.listdir(decomposed_parent_dir)):
5255
full_path = os.path.join(decomposed_parent_dir, name)

graph_net/torch/naive_graph_decomposer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def _make_config(
101101
post_extract_process_path=None,
102102
post_extract_process_class_name=None,
103103
post_extract_process_config=None,
104+
model_path_prefix="",
104105
**kwargs,
105106
):
106107
if post_extract_process_config is None:
@@ -119,9 +120,11 @@ def _make_config(
119120
"post_extract_process_path": post_extract_process_path,
120121
"post_extract_process_class_name": post_extract_process_class_name,
121122
"post_extract_process_config": post_extract_process_config,
123+
"model_path_prefix": model_path_prefix,
122124
}
123125

124-
def __call__(self, model_path):
126+
def __call__(self, rel_model_path):
127+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
125128
config = {
126129
k: v
127130
for k, v in self.config.items()
@@ -171,7 +174,9 @@ def __init__(
171174
dynamic=False,
172175
mut_graph_codes=[],
173176
placeholder_auto_rename=False,
174-
workspace_path=self.config["output_dir"],
177+
workspace_path=os.path.join(
178+
self.config["output_dir"], f"{parent_graph_name}_decomposed"
179+
),
175180
)
176181
self.filter = self.make_filter(self.config)
177182
self.post_extract_process = self.make_post_extract_process(self.config)

graph_net/torch/test_compiler.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
from . import utils
22
import argparse
33
import importlib.util
4-
import inspect
54
import torch
65
from pathlib import Path
7-
from typing import Type, Any, List, Dict, Callable
6+
from typing import Type
87
import sys
98
import os
109
import os.path
11-
from dataclasses import dataclass
12-
from contextlib import contextmanager
13-
import time
10+
import traceback
1411
import json
1512
import random
1613
import numpy as np
1714
import platform
15+
import base64
1816
from graph_net.torch.backend.graph_compiler_backend import GraphCompilerBackend
1917
from graph_net.torch.backend.tvm_backend import TvmBackend
2018
from graph_net.torch.backend.xla_backend import XlaBackend
@@ -27,7 +25,6 @@
2725
from graph_net.torch.backend.range_decomposer_validator_backend import (
2826
RangeDecomposerValidatorBackend,
2927
)
30-
from graph_net.test_compiler_util import generate_allclose_configs
3128
from graph_net import test_compiler_util
3229
from graph_net import path_utils
3330

@@ -68,7 +65,7 @@ def get_compile_framework_version(args):
6865
return torch.__version__
6966
elif args.compiler in ["tvm", "xla", "tensorrt", "bladedisc"]:
7067
# Assuming compiler object has a version attribute
71-
return f"{args.compiler.capitalize()} {compiler.version}"
68+
return f"{args.compiler.capitalize()} {args.compiler.version}"
7269
return "unknown"
7370

7471

@@ -94,11 +91,20 @@ def load_class_from_file(
9491
return model_class
9592

9693

94+
def convert_to_dict(config_str):
95+
if config_str is None:
96+
return {}
97+
config_str = base64.b64decode(config_str).decode("utf-8")
98+
config = json.loads(config_str)
99+
assert isinstance(config, dict), f"config should be a dict. {config_str=}"
100+
return config
101+
102+
97103
def get_compiler_backend(args) -> GraphCompilerBackend:
98104
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
99105
backend = registry_backend[args.compiler]
100106
if args.config is not None:
101-
backend.config = args.config
107+
backend.config = convert_to_dict(args.config)
102108
return backend
103109

104110

@@ -203,11 +209,13 @@ def test_single_model(args):
203209
runtime_seed = 1024
204210
eager_failure = False
205211
expected_out = None
206-
eager_types = []
207212
eager_time_stats = {}
208213

209214
try:
210-
eager_model_call = lambda: model(**input_dict)
215+
216+
def eager_model_call():
217+
return model(**input_dict)
218+
211219
expected_out, eager_time_stats = measure_performance(
212220
eager_model_call, args, compiler
213221
)
@@ -221,13 +229,15 @@ def test_single_model(args):
221229

222230
compiled_failure = False
223231
compiled_model = None
224-
compiled_types = []
225232
compiled_time_stats = {}
226233

227234
try:
228235
compiled_model = compiler(model)
229236
torch.manual_seed(runtime_seed)
230-
compiled_model_call = lambda: compiled_model(**input_dict)
237+
238+
def compiled_model_call():
239+
return compiled_model(**input_dict)
240+
231241
compiled_out, compiled_time_stats = measure_performance(
232242
compiled_model_call, args, compiler
233243
)
@@ -239,6 +249,8 @@ def test_single_model(args):
239249
except (TypeError, RuntimeError) as e:
240250
print(f"Compiled model execution failed: {str(e)}", file=sys.stderr)
241251
compiled_failure = True
252+
print("\n--- Full Traceback ---")
253+
traceback.print_exc()
242254

243255
if eager_failure:
244256
print(f"{args.log_prompt} [Result] status: failed", file=sys.stderr, flush=True)
@@ -381,6 +393,7 @@ def test_multi_models(args):
381393
f"--warmup {args.warmup}",
382394
f"--trials {args.trials}",
383395
f"--log-prompt {args.log_prompt}",
396+
f"--config {args.config}",
384397
]
385398
)
386399
cmd_ret = os.system(cmd)
@@ -398,7 +411,37 @@ def test_multi_models(args):
398411
print(f"- {model_path}", file=sys.stderr, flush=True)
399412

400413

414+
def test_multi_models_with_prefix(args):
415+
assert os.path.isdir(args.model_path_prefix)
416+
assert os.path.isfile(args.allow_list)
417+
test_samples = test_compiler_util.get_allow_samples(args.allow_list)
418+
py_module_name = os.path.splitext(os.path.basename(__file__))[0]
419+
for rel_model_path in test_samples:
420+
model_path = os.path.join(args.model_path_prefix, rel_model_path)
421+
if not os.path.exists(model_path):
422+
continue
423+
if not os.path.exists(os.path.join(model_path, "model.py")):
424+
continue
425+
cmd = " ".join(
426+
[
427+
sys.executable,
428+
f"-m graph_net.torch.{py_module_name}",
429+
f"--model-path {model_path}",
430+
f"--compiler {args.compiler}",
431+
f"--device {args.device}",
432+
f"--warmup {args.warmup}",
433+
f"--trials {args.trials}",
434+
f"--log-prompt {args.log_prompt}",
435+
f"--config {args.config}",
436+
]
437+
)
438+
os.system(cmd)
439+
440+
401441
def main(args):
442+
if args.model_path_prefix is not None:
443+
test_multi_models_with_prefix(args)
444+
return
402445
assert os.path.isdir(args.model_path)
403446

404447
initalize_seed = 123
@@ -415,7 +458,8 @@ def main(args):
415458
parser.add_argument(
416459
"--model-path",
417460
type=str,
418-
required=True,
461+
required=False,
462+
default=None,
419463
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
420464
)
421465
parser.add_argument(
@@ -452,6 +496,13 @@ def main(args):
452496
default=None,
453497
help="Path to samples list, each line contains a sample path",
454498
)
499+
parser.add_argument(
500+
"--model-path-prefix",
501+
type=str,
502+
required=False,
503+
default=None,
504+
help="Prefix path to model path list",
505+
)
455506
parser.add_argument(
456507
"--config",
457508
type=str,

0 commit comments

Comments
 (0)