Skip to content

Commit 42d99b9

Browse files
author
ADchampion3
committed
pre-commit success
1 parent 8ba1b8f commit 42d99b9

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
import os
3+
import torch
4+
from torchvision import transforms
5+
import graph_net
6+
7+
8+
def extract_visio_graph(model_name: str, model_path: str):
9+
# Normalization parameters for ImageNet
10+
normalize = transforms.Normalize(
11+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
12+
)
13+
14+
# Create dummy input
15+
batch_size = 1
16+
height, width = 224, 224 # Standard ImageNet size
17+
num_channels = 3
18+
random_input = torch.rand(batch_size, num_channels, height, width)
19+
normalized_input = normalize(random_input)
20+
21+
# 使用get_model下载模型
22+
# all_models = list_models(module=torchvision.models)
23+
# if(model_path not in all_models):
24+
# print("不存在该模型, 请校验模型名称是否相同")
25+
# return
26+
# model = get_model(model_path, weights="DEFAULT")
27+
28+
# 使用torch.hub下载模型
29+
# 相关使用办法见https://docs.pytorch.org/docs/stable/hub.html
30+
torch.hub.set_dir("../../../test") # 缓存目录默认为$TORCH_HOME/hub 如果没有设置环境变量则为 ~/.cache
31+
endpoints = torch.hub.list("pytorch/vision")
32+
if model_path not in endpoints:
33+
print("Model not found")
34+
return
35+
model = torch.hub.load("pytorch/vision", model_path)
36+
37+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38+
model.to(device)
39+
normalized_input = normalized_input.to(device)
40+
41+
model = graph_net.torch.extract(name=model_name, dynamic=True)(model)
42+
43+
print("Running inference...")
44+
print("Input shape:", normalized_input.shape)
45+
output = model(normalized_input)
46+
print("Inference finished. Output shape:", output.shape)
47+
48+
49+
if __name__ == "__main__":
50+
# get parameters from command line
51+
workspace_default = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE", "../../workspace")
52+
53+
parser = argparse.ArgumentParser()
54+
parser.add_argument(
55+
"--model_name", type=str, default="resnet18"
56+
) # 模型名称(自定义,推荐与官网相同或者简写)
57+
parser.add_argument("--model_path", type=str, default="resnet18") # 官网定义模型名称
58+
parser.add_argument("--workspace", type=str, default=workspace_default)
59+
args = parser.parse_args()
60+
61+
os.environ["GRAPH_NET_EXTRACT_WORKSPACE"] = args.workspace
62+
63+
extract_visio_graph(args.model_name, args.model_path)

0 commit comments

Comments
 (0)