@@ -55,6 +55,14 @@ class OpStat:
5555 count : int = 0
5656
5757
58+ def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
59+ attr_itr = node .target .split ("." )
60+ val = gm
61+ for a in attr_itr :
62+ val = getattr (val , a )
63+ return val
64+
65+
5866def collect_op_stats (model , input_dict ):
5967 # Use meta tensors as input to avoid actually running the model
6068 meta_input_dict = {}
@@ -77,14 +85,14 @@ def collect_op_stats(model, input_dict):
7785 op_name = node .op
7886 dtype = node_outputs [node .name ].dtype
7987 elif node .op in ["call_function" , "call_method" , "call_module" ]:
80- node_args = []
81- for arg in node .args :
82- node_args . append (
83- node_outputs [ arg . name ] if hasattr ( arg , "name" ) else arg
84- )
85- node_kwargs = {}
86- for k , v in node . kwargs . items ():
87- node_kwargs [ k ] = node_outputs [ v . name ] if hasattr ( v , "name" ) else v
88+ node_args = torch . fx . map_arg (
89+ node .args ,
90+ lambda n : node_outputs [ n . name ] if isinstance ( n , torch . fx . Node ) else n ,
91+ )
92+ node_kwargs = torch . fx . map_arg (
93+ node . kwargs ,
94+ lambda n : node_outputs [ n . name ] if isinstance ( n , torch . fx . Node ) else n ,
95+ )
8896
8997 if node .op == "call_module" :
9098 # classname of module
@@ -107,13 +115,17 @@ def collect_op_stats(model, input_dict):
107115 except Exception :
108116 print (f"dtype inference failed: op_name={ op_name } " )
109117 node_outputs [node .name ] = None
118+ elif node .op == "get_attr" :
119+ val = resolve_get_attr (traced , node )
120+ out = val .to (device = "meta" ) if isinstance (val , torch .Tensor ) else val
121+ node_outputs [node .name ] = out
122+ dtype = out .dtype if isinstance (out , torch .Tensor ) else None
110123 elif node .op == "output" :
111124 op_name = node .op
112- node_args = []
113- for arg in node .args :
114- node_args .append (
115- node_outputs [arg .name ] if hasattr (arg , "name" ) else arg
116- )
125+ node_args = torch .fx .map_arg (
126+ node .args ,
127+ lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
128+ )
117129 node_outputs [node .name ] = node_args [0 ] if len (node_args ) == 1 else node_args
118130 dtype = (
119131 node_args [0 ].dtype if isinstance (node_args [0 ], torch .Tensor ) else None
0 commit comments