Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
638 changes: 151 additions & 487 deletions graph_net/config/empty_cstr_torch_samples_list.txt

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions graph_net/config/tmp_torch_samples_list.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
samples/transformers-auto-model/microsoft_xclip-base-patch32-16-frames
68 changes: 45 additions & 23 deletions graph_net/model_path_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import traceback
import argparse
from graph_net.imp_util import load_module
import logging
import sys
import json
import base64
import subprocess

logging.basicConfig(
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"
Expand Down Expand Up @@ -37,33 +37,49 @@ def _get_handler(args):

def main(args):
handler = _get_handler(args)
for model_path in _get_model_paths(args):
print(f"{model_path=}")
if args.model_path is not None:
handle_model_path(handler, args.model_path)
elif args.use_subprocess:
handle_model_path_list_in_subprocess(args)
else:
handle_model_path_list_in_current_process(handler, args)


def handle_model_path_list_in_current_process(handler, args):
for model_path in _get_model_path_list(args):
try:
handler(model_path)
handle_model_path(handler, model_path)
except KeyboardInterrupt:
sys.exit(-1)
except Exception as e:
print("--- Concise Error Message ---")
print(e)
print("KeyboardInterrupt")
return

print("\n--- Full Traceback ---")
traceback.print_exc()

def handle_model_path_list_in_subprocess(args):
for model_path in _get_model_path_list(args):
cmd = f"{sys.executable} -m graph_net.model_path_handler --model-path {model_path} --handler-config {args.handler_config}"
try:
subprocess.Popen(cmd, shell=True).wait()
except KeyboardInterrupt:
print("KeyboardInterrupt")
return

def _get_model_paths(args):
assert args.model_path is not None or args.model_path_list is not None
if args.model_path is not None:
yield args.model_path
if args.model_path_list is not None:
with open(args.model_path_list) as f:
yield from (
clean_line
for line in f
for clean_line in [line.strip()]
if len(clean_line) > 0
if not clean_line.startswith("#")
)

def handle_model_path(handler, model_path):
print(f"{model_path=}", flush=True)
handler(model_path)


def _get_model_path_list(args):
assert args.model_path is None
assert args.model_path_list is not None
with open(args.model_path_list) as f:
yield from (
clean_line
for line in f
for clean_line in [line.strip()]
if len(clean_line) > 0
if not clean_line.startswith("#")
)


if __name__ == "__main__":
Expand All @@ -89,5 +105,11 @@ def _get_model_paths(args):
default=None,
help="handler configuration string",
)
parser.add_argument(
"--use-subprocess",
action="store_true",
default=False,
help="use subprocess",
)
args = parser.parse_args()
main(args=args)
2 changes: 1 addition & 1 deletion graph_net/test/decomposer_validator_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ echo "Results saved in: $FILE_PATH/ES_result.png"
echo ""
echo "IMPORTANT: Please verify if the curve in ES_result.png is a straight line"
echo "If the curve is NOT a straight line, please check the log file: $FILE_PATH/log.log"
echo "=================================================="
echo "=================================================="
5 changes: 2 additions & 3 deletions graph_net/test/naive_graph_decomposer_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ config_json_str=$(cat <<EOF
"handler_config": {
"output_dir": "/tmp/naive_decompose_workspace",
"split_positions": [8, 16, 32],
"group_head_and_tail": true,
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
"filter_config": {}
"chain_style": true,
"group_head_and_tail": true
}
}
EOF
Expand Down
5 changes: 5 additions & 0 deletions graph_net/tools/_get_in_tensor_symbolic_shapes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sympy
from pathlib import Path
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
import graph_net.graph_net_json_file_util as gn_json
Expand Down Expand Up @@ -27,6 +28,10 @@ def __call__(self, model_path):
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
str(input_tensor_cstr_filepath)
)
for shape, name in dyn_dim_cstrs.input_shapes:
if not any(isinstance(dim, sympy.Expr) for dim in shape):
continue
print(f"{shape=} {name=}")
input_shapes_str = str(dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str())
print(f"get-in-tensor-symbolic-shapes {input_shapes_str} {model_path}")

Expand Down
4 changes: 2 additions & 2 deletions graph_net/tools/batch_init_input_tensor_constraints.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ config_json_str=$(cat <<EOF
"non_batch_call_function_arange_plus_one_pass"
]
},
"limits_handled_models": 1,
"limits_handled_models": 999999,
"last_model_log_file": "/tmp/a.py"
}
}
EOF
)
CONFIG=$(echo $config_json_str | base64 -w 0)

python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG --use-subprocess
Empty file modified graph_net/tools/statisticize_in_tensor_symbolic_shapes.sh
100644 → 100755
Empty file.
Empty file modified graph_net/tools/update_sym_dim_reifier.sh
100644 → 100755
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_new_tuple_args(input_tensor_node, view_args):
input_tensor_node = node.args[0]
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
view_args = node.args[1:]
print(f"{view_args=}")
new_view_args = get_new_tuple_args(input_tensor_node, view_args)

# --- Rebuild the view node ---
Expand Down
18 changes: 16 additions & 2 deletions graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.fx as fx
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
import os
Expand All @@ -19,6 +18,21 @@ def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
return True
return False

def _node_need_rewrite(self, node) -> bool:
if not (node.op == "call_method"):
return False
if not (node.op == "expand"):
return False
input_tensor_node = node.args[0]
input_meta = input_tensor_node.meta.get("tensor_meta")
if input_meta is None:
return False
expand_args = node.args[1:]
input_shape = input_meta.shape
if not (len(expand_args) == len(input_shape)):
return False
return True

def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
"""
Fx Pass: Replaces hardcoded constants in 'expand' ops that match an input tensor dimension
Expand All @@ -31,7 +45,7 @@ def rewrite(self, traced_module: fx.GraphModule) -> fx.GraphModule:
val_map = {}

for node in traced_module.graph.nodes:
if node.op == "call_method" and node.target == "expand":
if self._node_need_rewrite(node):
# Get the input tensor node
input_tensor_node = node.args[0]
# Get the target shape arguments for expand (e.g., 1, 4, 6, 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def _node_need_rewrite(self, node) -> bool:
return False
if not (node.target == "view"):
return False
print(f"{self.dim=} {node.args[1:]=}")
if self.dim not in node.args[1:]:
return False
return True
Expand Down
156 changes: 146 additions & 10 deletions graph_net/torch/fx_graph_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,123 @@
import inspect


def _rename_placeholder(name):
class NamePatternMismatchDetector:
def __init__(self, names_from_signature, names_from_placeholder):
self.names_from_signature = names_from_signature
self.names_from_placeholder = names_from_placeholder

def __call__(self):
mut_pattern2replacement = {}
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_LayerNorm",
pattern_in_placeholder="modules_layer_norm",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_layer_norm",
pattern_in_placeholder="modules_LayerNorm",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_meta4D_layers",
pattern_in_placeholder="modules_meta4d_layers",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_meta4d_layers",
pattern_in_placeholder="modules_meta4D_layers",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_SelfAttention_modules",
pattern_in_placeholder="modules_self_attention_modules",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_self_attention_modules",
pattern_in_placeholder="modules_SelfAttention_modules",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_meta3D_layers",
pattern_in_placeholder="modules_meta3d_layers",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_meta3d_layers",
pattern_in_placeholder="modules_meta3D_layers",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_DenseReluDense_modules",
pattern_in_placeholder="modules_dense_relu_dense_modules",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_dense_relu_dense_modules",
pattern_in_placeholder="modules_DenseReluDense_modules",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_EncDecAttention_modules",
pattern_in_placeholder="modules_enc_dec_attention_modules",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_HashBucketCodepointEmbedder",
pattern_in_placeholder="modules_hash_bucket_codepoint_embedder",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="modules_MBconv",
pattern_in_placeholder="modules_mbconv",
)
self._detect_and_collect(
mut_pattern2replacement,
pattern_in_signature="_L_",
pattern_in_placeholder="_l_",
)
return mut_pattern2replacement

def _detect_and_collect(
self, mut_pattern2replacement, pattern_in_signature, pattern_in_placeholder
):
if not self._detect(pattern_in_signature, pattern_in_placeholder):
return
mut_pattern2replacement[pattern_in_placeholder] = pattern_in_signature

def _detect(self, pattern_in_signature, pattern_in_placeholder):
return self._check_pattern(
self.names_from_signature,
include_pattern=pattern_in_signature,
exclude_pattern=pattern_in_placeholder,
) and self._check_pattern(
self.names_from_placeholder,
include_pattern=pattern_in_placeholder,
exclude_pattern=pattern_in_signature,
)

def _check_pattern(self, names, include_pattern, exclude_pattern):
return any(include_pattern in name for name in names) and all(
exclude_pattern not in name for name in names
)


def _get_name_pattern2replacement(names_from_signature, names_from_placeholder):
dectector = NamePatternMismatchDetector(
names_from_signature, names_from_placeholder
)
return dectector()


def _rename_placeholder(name, pattern2replacement):
assert name[:2] == "L_" or name[:2] == "l_", f"{name=}"
name = name[2:]
if name[0] == "l":
name = "L" + name[1:]
name = name.replace(
"modules_layer_norm_parameters",
"modules_LayerNorm_parameters",
)
for pattern, replacement in pattern2replacement.items():
name = name.replace(pattern, replacement)
return name


Expand All @@ -27,11 +135,6 @@ def my_backend(gm, sample_inputs):

torch.compile(module, backend=my_backend)(*inputs)
assert traced_module is not None
for node in traced_module.graph.nodes:
if node.op != "placeholder":
continue
node.target = _rename_placeholder(node.target)
node.name = _rename_placeholder(node.name)

def get_input_names_from_signature():
return inspect.signature(module.forward).parameters
Expand All @@ -41,6 +144,17 @@ def get_input_names_from_placeholder():
node.name for node in traced_module.graph.nodes if node.op == "placeholder"
]

pattern2replacement = _get_name_pattern2replacement(
names_from_signature=get_input_names_from_signature(),
names_from_placeholder=get_input_names_from_placeholder(),
)

for node in traced_module.graph.nodes:
if node.op != "placeholder":
continue
node.target = _rename_placeholder(node.target, pattern2replacement)
node.name = _rename_placeholder(node.name, pattern2replacement)

def get_diff_input_names():
placeholder_names = set(get_input_names_from_placeholder())
return [
Expand Down Expand Up @@ -83,6 +197,28 @@ def get_zip_filter_names():
traced_module, get_input_names_from_signature()
)

def handle_underscore_suffix_difference():
zip_filter_names = get_zip_filter_names()
if not (len(zip_filter_names) > 0):
return
if not all((a == b or f"{a}_" == b) for _, a, b in zip_filter_names):
return
names = set(
name_in_placeholder
for _0, name_in_signature, name_in_placeholder in zip_filter_names
if f"{name_in_signature}_" == name_in_placeholder
)
for node in traced_module.graph.nodes:
if not (node.op == "placeholder"):
continue
if node.target not in names:
continue
node.target = node.target[:-1]
node.name = node.name[:-1]
traced_module.recompile()

handle_underscore_suffix_difference()

zip_filter_names = get_zip_filter_names()

def zip_filter_names_str():
Expand Down
Loading