11from . import utils
22import argparse
33import importlib .util
4- import inspect
54import torch
65from pathlib import Path
7- from typing import Type , Any , List , Dict , Callable
6+ from typing import Type
87import sys
98import os
109import os .path
11- from dataclasses import dataclass
12- from contextlib import contextmanager
13- import time
10+ import traceback
1411import json
1512import random
1613import numpy as np
1714import platform
15+ import base64
1816from graph_net .torch .backend .graph_compiler_backend import GraphCompilerBackend
1917from graph_net .torch .backend .tvm_backend import TvmBackend
2018from graph_net .torch .backend .xla_backend import XlaBackend
2725from graph_net .torch .backend .range_decomposer_validator_backend import (
2826 RangeDecomposerValidatorBackend ,
2927)
30- from graph_net .test_compiler_util import generate_allclose_configs
3128from graph_net import test_compiler_util
3229from graph_net import path_utils
3330
@@ -68,7 +65,7 @@ def get_compile_framework_version(args):
6865 return torch .__version__
6966 elif args .compiler in ["tvm" , "xla" , "tensorrt" , "bladedisc" ]:
7067 # Assuming compiler object has a version attribute
71- return f"{ args .compiler .capitalize ()} { compiler .version } "
68+ return f"{ args .compiler .capitalize ()} { args . compiler .version } "
7269 return "unknown"
7370
7471
@@ -94,11 +91,20 @@ def load_class_from_file(
9491 return model_class
9592
9693
94+ def convert_to_dict (config_str ):
95+ if config_str is None :
96+ return {}
97+ config_str = base64 .b64decode (config_str ).decode ("utf-8" )
98+ config = json .loads (config_str )
99+ assert isinstance (config , dict ), f"config should be a dict. { config_str = } "
100+ return config
101+
102+
97103def get_compiler_backend (args ) -> GraphCompilerBackend :
98104 assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
99105 backend = registry_backend [args .compiler ]
100106 if args .config is not None :
101- backend .config = args .config
107+ backend .config = convert_to_dict ( args .config )
102108 return backend
103109
104110
@@ -203,11 +209,13 @@ def test_single_model(args):
203209 runtime_seed = 1024
204210 eager_failure = False
205211 expected_out = None
206- eager_types = []
207212 eager_time_stats = {}
208213
209214 try :
210- eager_model_call = lambda : model (** input_dict )
215+
216+ def eager_model_call ():
217+ return model (** input_dict )
218+
211219 expected_out , eager_time_stats = measure_performance (
212220 eager_model_call , args , compiler
213221 )
@@ -221,13 +229,15 @@ def test_single_model(args):
221229
222230 compiled_failure = False
223231 compiled_model = None
224- compiled_types = []
225232 compiled_time_stats = {}
226233
227234 try :
228235 compiled_model = compiler (model )
229236 torch .manual_seed (runtime_seed )
230- compiled_model_call = lambda : compiled_model (** input_dict )
237+
238+ def compiled_model_call ():
239+ return compiled_model (** input_dict )
240+
231241 compiled_out , compiled_time_stats = measure_performance (
232242 compiled_model_call , args , compiler
233243 )
@@ -239,6 +249,8 @@ def test_single_model(args):
239249 except (TypeError , RuntimeError ) as e :
240250 print (f"Compiled model execution failed: { str (e )} " , file = sys .stderr )
241251 compiled_failure = True
252+ print ("\n --- Full Traceback ---" )
253+ traceback .print_exc ()
242254
243255 if eager_failure :
244256 print (f"{ args .log_prompt } [Result] status: failed" , file = sys .stderr , flush = True )
@@ -381,6 +393,7 @@ def test_multi_models(args):
381393 f"--warmup { args .warmup } " ,
382394 f"--trials { args .trials } " ,
383395 f"--log-prompt { args .log_prompt } " ,
396+ f"--config { args .config } " ,
384397 ]
385398 )
386399 cmd_ret = os .system (cmd )
@@ -398,7 +411,37 @@ def test_multi_models(args):
398411 print (f"- { model_path } " , file = sys .stderr , flush = True )
399412
400413
414+ def test_multi_models_with_prefix (args ):
415+ assert os .path .isdir (args .model_path_prefix )
416+ assert os .path .isfile (args .allow_list )
417+ test_samples = test_compiler_util .get_allow_samples (args .allow_list )
418+ py_module_name = os .path .splitext (os .path .basename (__file__ ))[0 ]
419+ for rel_model_path in test_samples :
420+ model_path = os .path .join (args .model_path_prefix , rel_model_path )
421+ if not os .path .exists (model_path ):
422+ continue
423+ if not os .path .exists (os .path .join (model_path , "model.py" )):
424+ continue
425+ cmd = " " .join (
426+ [
427+ sys .executable ,
428+ f"-m graph_net.torch.{ py_module_name } " ,
429+ f"--model-path { model_path } " ,
430+ f"--compiler { args .compiler } " ,
431+ f"--device { args .device } " ,
432+ f"--warmup { args .warmup } " ,
433+ f"--trials { args .trials } " ,
434+ f"--log-prompt { args .log_prompt } " ,
435+ f"--config { args .config } " ,
436+ ]
437+ )
438+ os .system (cmd )
439+
440+
401441def main (args ):
442+ if args .model_path_prefix is not None :
443+ test_multi_models_with_prefix (args )
444+ return
402445 assert os .path .isdir (args .model_path )
403446
404447 initalize_seed = 123
@@ -415,7 +458,8 @@ def main(args):
415458 parser .add_argument (
416459 "--model-path" ,
417460 type = str ,
418- required = True ,
461+ required = False ,
462+ default = None ,
419463 help = "Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model" ,
420464 )
421465 parser .add_argument (
@@ -452,6 +496,13 @@ def main(args):
452496 default = None ,
453497 help = "Path to samples list, each line contains a sample path" ,
454498 )
499+ parser .add_argument (
500+ "--model-path-prefix" ,
501+ type = str ,
502+ required = False ,
503+ default = None ,
504+ help = "Prefix path to model path list" ,
505+ )
455506 parser .add_argument (
456507 "--config" ,
457508 type = str ,
0 commit comments