@@ -60,21 +60,26 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6060 val = gm
6161 for a in attr_itr :
6262 val = getattr (val , a )
63- return val
63+ out = val .to (device = "meta" ) if isinstance (val , torch .Tensor ) else val
64+ return out
6465
6566
6667def collect_op_stats (model , input_dict ):
68+ # FX symbolic trace
69+ try :
70+ traced = torch .fx .symbolic_trace (model )
71+ # print(traced.graph)
72+ except Exception :
73+ print ("Failed to FX symbolic trace" )
74+ return None
75+
6776 # Use meta tensors as input to avoid actually running the model
6877 meta_input_dict = {}
6978 for name , x in input_dict .items ():
7079 meta_input_dict [name ] = (
7180 torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
7281 )
7382
74- # FX symbolic trace
75- traced = torch .fx .symbolic_trace (model )
76- # print(traced.graph)
77-
7883 node_outputs = {}
7984 op_stats = {}
8085 for node in traced .graph .nodes :
@@ -84,7 +89,7 @@ def collect_op_stats(model, input_dict):
8489 node_outputs [node .name ] = meta_input_dict [node .target ]
8590 op_name = node .op
8691 dtype = node_outputs [node .name ].dtype
87- elif node .op in ["call_function" , "call_method " , "call_module " ]:
92+ elif node .op in ["call_function" , "call_module " , "call_method " ]:
8893 node_args = torch .fx .map_arg (
8994 node .args ,
9095 lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
@@ -96,28 +101,32 @@ def collect_op_stats(model, input_dict):
96101
97102 if node .op == "call_module" :
98103 # classname of module
99- submod = dict ( traced .named_modules ())[ node .target ]
104+ submod = traced .get_submodule ( node .target )
100105 op_name = submod .__class__ .__name__
101- try :
102- out = submod (* node_args , ** node_kwargs )
103- node_outputs [node .name ] = out
104- dtype = out .dtype if isinstance (out , torch .Tensor ) else None
105- except Exception :
106- node_outputs [node .name ] = None
107- elif node .op in ["call_function" , "call_method" ]:
108- op_name = (
109- node .target .__name__ if node .op == "call_function" else node .target
106+ op_func = submod
107+ elif node .op == "call_function" :
108+ op_name = node .target .__name__
109+ op_func = node .target
110+ elif node .op == "call_method" :
111+ op_name = node .target
112+ self_obj = (
113+ node_outputs [node .args [0 ].name ]
114+ if isinstance (node .args [0 ], torch .fx .Node )
115+ else node .args [0 ]
110116 )
111- try :
112- out = node .target (* node_args , ** node_kwargs )
113- node_outputs [node .name ] = out
114- dtype = out .dtype if isinstance (out , torch .Tensor ) else None
115- except Exception :
116- print (f"dtype inference failed: op_name={ op_name } " )
117- node_outputs [node .name ] = None
117+ op_func = getattr (self_obj , node .target )
118+ node_args = node_args [1 :]
119+
120+ try :
121+ out = op_func (* node_args , ** node_kwargs )
122+ node_outputs [node .name ] = out
123+ dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124+ except Exception :
125+ print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
126+ node_outputs [node .name ] = None
118127 elif node .op == "get_attr" :
119- val = resolve_get_attr ( traced , node )
120- out = val . to ( device = "meta" ) if isinstance ( val , torch . Tensor ) else val
128+ op_name = node . op
129+ out = resolve_get_attr ( traced , node )
121130 node_outputs [node .name ] = out
122131 dtype = out .dtype if isinstance (out , torch .Tensor ) else None
123132 elif node .op == "output" :
@@ -156,18 +165,20 @@ def collect_model_stats(model_path, device, log_prompt):
156165 num_outputs = 0
157166 dtypes = set ()
158167 op_stats = collect_op_stats (model , input_dict )
159- for op_name , stat in op_stats .items ():
160- if op_name == "placeholder" :
161- num_inputs += stat .count
162- elif op_name == "output" :
163- num_outputs += stat .count
164- else :
165- num_ops += stat .count
166- for v in stat .dtype :
167- if v is not None :
168- dtypes .add (v )
168+ if op_stats is not None :
169+ for op_name , stat in op_stats .items ():
170+ if op_name == "placeholder" :
171+ num_inputs += stat .count
172+ elif op_name == "output" :
173+ num_outputs += stat .count
174+ else :
175+ num_ops += stat .count
176+ for v in stat .dtype :
177+ if v is not None :
178+ dtypes .add (v )
169179
170180 arg_types = get_argument_types (model_class , "forward" )
181+ num_inputs = len (arg_types ) if op_stats is None else num_inputs
171182 num_params = 0
172183 param_dtypes = set ()
173184 for name , arg_type in arg_types .items ():
0 commit comments