Skip to content

Commit 8ba1b8f

Browse files
authored
Change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler (#413)
* Fix * Optimize typical_sequence_decomposer_test * change the entry of naive_graph_decomposer from graph_net.torch.run_model to graph_net.model_path_handler
1 parent d349727 commit 8ba1b8f

File tree

6 files changed

+223
-94
lines changed

6 files changed

+223
-94
lines changed

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,20 @@ os.path.dirname(graph_net.__file__))")
66
# input model path
77
MODEL_NAME=resnet18
88
MODEL_PATH_IN_SAMPLES=/timm/$MODEL_NAME
9-
decorator_config_json_str=$(cat <<EOF
9+
config_json_str=$(cat <<EOF
1010
{
11-
"decorator_path": "$GRAPH_NET_ROOT/torch/extractor.py",
12-
"decorator_config": {
13-
"name": "$MODEL_NAME",
14-
"custom_extractor_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
15-
"custom_extractor_config": {
16-
"output_dir": "/tmp/naive_decompose_workspace",
17-
"split_positions": [8, 16, 32],
18-
"group_head_and_tail": true,
19-
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
20-
"filter_config": {}
21-
}
11+
"handler_path": "$GRAPH_NET_ROOT/torch/naive_graph_decomposer.py",
12+
"handler_class_name": "NaiveDecomposerExtractor",
13+
"handler_config": {
14+
"output_dir": "/tmp/naive_decompose_workspace",
15+
"split_positions": [8, 16, 32],
16+
"group_head_and_tail": true,
17+
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
18+
"filter_config": {}
2219
}
2320
}
2421
EOF
2522
)
26-
DECORATOR_CONFIG=$(echo $decorator_config_json_str | base64 -w 0)
23+
CONFIG=$(echo $config_json_str | base64 -w 0)
2724

28-
python3 -m graph_net.torch.run_model --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --decorator-config=$DECORATOR_CONFIG
25+
python3 -m graph_net.model_path_handler --model-path $GRAPH_NET_ROOT/../samples/$MODEL_PATH_IN_SAMPLES --handler-config=$CONFIG
Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,60 @@
11
#!/bin/bash
22

33
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
4+
DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace
5+
6+
mkdir -p "$DECOMPOSE_PATH"
47

5-
MODEL1="$GRAPH_NET_ROOT/samples/torchvision/resnet18"
6-
MODEL2="$GRAPH_NET_ROOT/samples/torchvision/resnet34"
7-
MODEL_LIST_FILE=$(mktemp)
8-
echo "$MODEL1" > "$MODEL_LIST_FILE"
9-
echo "$MODEL2" >> "$MODEL_LIST_FILE"
8+
temp_model_list=$(mktemp)
9+
cat "$GRAPH_NET_ROOT/graph_net/config/torch_samples_list.txt" > "$temp_model_list"
1010

1111
python3 -m graph_net.torch.typical_sequence_split_points \
12-
--model-list "$MODEL_LIST_FILE" \
12+
--model-list "$temp_model_list" \
1313
--device "cuda" \
1414
--window-size 10 \
15-
--output-json "$GRAPH_NET_ROOT/split_results.json"
15+
--output-json "$DECOMPOSE_PATH/split_results.json"
1616

17-
rm -f "$MODEL_LIST_FILE"
17+
while IFS= read -r MODEL_PATH_IN_SAMPLES; do
18+
if [[ -n "$MODEL_PATH_IN_SAMPLES" ]]; then
19+
MODEL_FULL_PATH="$GRAPH_NET_ROOT/$MODEL_PATH_IN_SAMPLES"
20+
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
1821

22+
echo "== Decomposing $MODEL_PATH_IN_SAMPLES. =="
1923

20-
MODEL_PATH_IN_SAMPLES=/torchvision/resnet18
21-
MODEL_NAME=$(basename "$MODEL_PATH_IN_SAMPLES")
22-
23-
decomposer_config_json_str=$(cat <<EOF
24+
decomposer_config_json_str=$(cat <<EOF
2425
{
25-
"split_results_path": "$GRAPH_NET_ROOT/split_results.json",
26-
"workspace_path": "$GRAPH_NET_ROOT/decompose_workspace",
27-
"chain_style": "True"
26+
"split_results_path": "$DECOMPOSE_PATH/split_results.json",
27+
"workspace_path": "$DECOMPOSE_PATH",
28+
"chain_style": true,
29+
"target_model_name": "$MODEL_NAME"
2830
}
2931
EOF
30-
)
31-
DECOMPOSER_CONFIG=$(echo $decomposer_config_json_str | base64 -w 0)
32+
)
33+
DECOMPOSER_CONFIG=$(echo $decomposer_config_json_str | base64 -w 0)
3234

33-
python3 -m graph_net.torch.test_compiler --model-path $GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES --compiler range_decomposer --device cuda --config=$DECOMPOSER_CONFIG
35+
python3 -m graph_net.torch.test_compiler \
36+
--model-path "$MODEL_FULL_PATH" \
37+
--compiler range_decomposer \
38+
--device cuda \
39+
--config="$DECOMPOSER_CONFIG"
3440

41+
cp -r "$MODEL_FULL_PATH" "$DECOMPOSE_PATH/"
3542

36-
DECOMPOSE_PATH=$GRAPH_NET_ROOT/decompose_workspace
37-
cp -r "$GRAPH_NET_ROOT/samples/$MODEL_PATH_IN_SAMPLES" "$DECOMPOSE_PATH/"
43+
echo "== Validating $MODEL_PATH_IN_SAMPLES. =="
44+
45+
python3 -m graph_net.torch.test_compiler \
46+
--model-path "$DECOMPOSE_PATH/$MODEL_NAME" \
47+
--compiler range_decomposer_validator \
48+
--device cuda > "$DECOMPOSE_PATH/${MODEL_NAME}_validation.log" 2>&1
49+
50+
echo "== Finished processing $MODEL_PATH_IN_SAMPLES. =="
51+
fi
52+
done < $temp_model_list
53+
54+
rm -f "$temp_model_list"
3855

39-
python3 -m graph_net.torch.test_compiler \
40-
--model-path $DECOMPOSE_PATH/$MODEL_NAME \
41-
--compiler range_decomposer_validator \
42-
--device cuda > "$DECOMPOSE_PATH/log.log" 2>&1
56+
cat $DECOMPOSE_PATH/*_validation.log >> $DECOMPOSE_PATH/combined.log
4357

4458
python3 -m graph_net.plot_ESt \
45-
--benchmark-path $DECOMPOSE_PATH/log.log \
46-
--output-dir $DECOMPOSE_PATH \
59+
--benchmark-path "$DECOMPOSE_PATH/combined.log" \
60+
--output-dir "$DECOMPOSE_PATH"
Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,15 @@
1-
import logging
2-
import torch
31
import copy
42
import os
5-
import inspect
6-
from graph_net.tensor_meta import TensorMeta
73
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module
8-
from graph_net.imp_util import load_module
9-
from dataclasses import asdict
4+
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
105

116

127
def parse_immutable_model_path_into_sole_graph_module(model_path):
138
model_path = os.path.realpath(model_path)
149
if model_path not in g_model_path2graph_module:
15-
module = _get_torch_module(model_path)
16-
tensor_metas = _get_tensor_metas(model_path)
17-
logging.warning("before _create_inputs_by_metas")
18-
inputs = _create_inputs_by_metas(module, tensor_metas)
19-
logging.warning("after _create_inputs_by_metas")
20-
logging.warning("before parse_sole_graph_module")
10+
module, inputs = get_torch_module_and_inputs(model_path)
2111
g_model_path2graph_module[model_path] = parse_sole_graph_module(module, inputs)
22-
logging.warning("after parse_sole_graph_module")
2312
return copy.deepcopy(g_model_path2graph_module[model_path])
2413

2514

26-
def _get_torch_module(model_path):
27-
py_module = load_module(f"{model_path}/model.py")
28-
torch_module_cls = py_module.GraphModule
29-
return torch_module_cls()
30-
31-
32-
def _get_tensor_metas(model_path):
33-
make = TensorMeta.unserialize_from_py_file
34-
return [
35-
*make(os.path.join(model_path, "input_meta.py")),
36-
*make(os.path.join(model_path, "weight_meta.py")),
37-
]
38-
39-
40-
def _create_inputs_by_metas(module, tensor_metas):
41-
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
42-
from graph_net.torch.utils import get_dummy_named_tensors
43-
44-
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
45-
name2tensor = {k: v for k, v in named_tensors}
46-
return tuple(
47-
name2tensor[name] for name in inspect.signature(module.forward).parameters
48-
)
49-
50-
5115
g_model_path2graph_module = {}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import inspect
3+
from graph_net.tensor_meta import TensorMeta
4+
from graph_net.imp_util import load_module
5+
from dataclasses import asdict
6+
7+
8+
def get_torch_module_and_inputs(model_path):
9+
module = _get_torch_module(model_path)
10+
tensor_metas = _get_tensor_metas(model_path)
11+
inputs = _create_inputs_by_metas(module, tensor_metas)
12+
return module, inputs
13+
14+
15+
def _get_torch_module(model_path):
16+
py_module = load_module(f"{model_path}/model.py")
17+
torch_module_cls = py_module.GraphModule
18+
return torch_module_cls()
19+
20+
21+
def _get_tensor_metas(model_path):
22+
make = TensorMeta.unserialize_from_py_file
23+
return [
24+
*make(os.path.join(model_path, "input_meta.py")),
25+
*make(os.path.join(model_path, "weight_meta.py")),
26+
]
27+
28+
29+
def _create_inputs_by_metas(module, tensor_metas):
30+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
31+
from graph_net.torch.utils import get_dummy_named_tensors
32+
33+
named_tensors = get_dummy_named_tensors(tensor_meta_attrs_list)
34+
name2tensor = {k: v for k, v in named_tensors}
35+
return tuple(
36+
name2tensor[name] for name in inspect.signature(module.forward).parameters
37+
)

graph_net/torch/fx_graph_parse_util.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,45 @@ def get_zip_filter_names():
7676
if name_from_signature != name_from_placeholder
7777
)
7878

79+
if len(get_zip_filter_names()) > 0 and set(get_input_names_from_signature()) == set(
80+
get_input_names_from_placeholder()
81+
):
82+
traced_module = _reorder_placeholders(
83+
traced_module, get_input_names_from_signature()
84+
)
85+
7986
zip_filter_names = get_zip_filter_names()
8087

8188
def zip_filter_names_str():
8289
for triple in zip_filter_names:
8390
print(triple)
8491
return "<printed before>"
8592

93+
from pathlib import Path
94+
95+
Path("/tmp/a.py").write_text(traced_module.code)
8696
assert len(zip_filter_names) == 0, f"{zip_filter_names_str()=}"
8797
return traced_module
98+
99+
100+
def _reorder_placeholders(gm, sorted_names):
101+
sorted_names = list(sorted_names)
102+
name2placeholder = {
103+
node.name: node for node in gm.graph.nodes if node.op == "placeholder"
104+
}
105+
for i, current_placeholder_name in enumerate(sorted_names):
106+
if i == 0:
107+
continue
108+
prev_node = name2placeholder[sorted_names[i - 1]]
109+
current_node = name2placeholder[current_placeholder_name]
110+
with gm.graph.inserting_after(prev_node):
111+
new_node = gm.graph.placeholder(current_node.name)
112+
# force rename
113+
new_node.name = current_node.name
114+
new_node.target = current_node.target
115+
current_node.replace_all_uses_with(new_node)
116+
name2placeholder[current_placeholder_name] = new_node
117+
gm.graph.erase_node(current_node)
118+
119+
gm.recompile()
120+
return gm

0 commit comments

Comments
 (0)