@@ -55,6 +55,20 @@ class OpStat:
5555 count : int = 0
5656
5757
58+ def resolve_native_multi_head_attention (* args , ** kwargs ):
59+ query , key , value = args [0 ], args [1 ], args [2 ]
60+ seq_len , batch_size , embed_dim = query .shape
61+ attn_output = torch .empty (
62+ (seq_len , batch_size , embed_dim ), dtype = query .dtype , device = "meta"
63+ )
64+
65+ # seq_len_k = key.shape[0]
66+ # num_heads = args[4]
67+ # attn_output_weights = torch.empty((batch_size, num_heads, seq_len, seq_len_k),
68+ # dtype=query.dtype, device='meta')
69+ return attn_output # , attn_output_weights
70+
71+
5872def resolve_get_attr (gm : torch .fx .GraphModule , node : torch .fx .Node ):
5973 attr_itr = node .target .split ("." )
6074 val = gm
@@ -65,13 +79,13 @@ def resolve_get_attr(gm: torch.fx.GraphModule, node: torch.fx.Node):
6579
6680
6781def collect_op_stats (model , input_dict ):
68- # FX symbolic trace
6982 try :
83+ # FX symbolic trace
7084 traced = torch .fx .symbolic_trace (model )
7185 # print(traced.graph)
7286 except Exception :
7387 print ("Failed to FX symbolic trace" )
74- return None
88+ return False , None
7589
7690 # Use meta tensors as input to avoid actually running the model
7791 meta_input_dict = {}
@@ -80,8 +94,9 @@ def collect_op_stats(model, input_dict):
8094 torch .empty_like (x , device = "meta" ) if isinstance (x , torch .Tensor ) else x
8195 )
8296
83- node_outputs = {}
97+ is_complete = True
8498 op_stats = {}
99+ node_outputs = {}
85100 for node in traced .graph .nodes :
86101 op_name = None
87102 dtype = None
@@ -99,31 +114,35 @@ def collect_op_stats(model, input_dict):
99114 lambda n : node_outputs [n .name ] if isinstance (n , torch .fx .Node ) else n ,
100115 )
101116
102- if node .op == "call_module" :
103- # classname of module
104- submod = traced .get_submodule (node .target )
105- op_name = submod .__class__ .__name__
106- op_func = submod
107- elif node .op == "call_function" :
108- op_name = node .target .__name__
109- op_func = node .target
110- elif node .op == "call_method" :
111- op_name = node .target
112- self_obj = (
113- node_outputs [node .args [0 ].name ]
114- if isinstance (node .args [0 ], torch .fx .Node )
115- else node .args [0 ]
116- )
117- op_func = getattr (self_obj , node .target )
118- node_args = node_args [1 :]
119-
120117 try :
121- out = op_func (* node_args , ** node_kwargs )
118+ if node .op == "call_module" :
119+ # classname of module
120+ submod = traced .get_submodule (node .target )
121+ op_name = submod .__class__ .__name__
122+ op_func = submod
123+ elif node .op == "call_function" :
124+ op_name = node .target .__name__
125+ op_func = node .target
126+ elif node .op == "call_method" :
127+ op_name = node .target
128+ self_obj = (
129+ node_outputs [node .args [0 ].name ]
130+ if isinstance (node .args [0 ], torch .fx .Node )
131+ else node .args [0 ]
132+ )
133+ op_func = getattr (self_obj , node .target )
134+ node_args = node_args [1 :]
135+
136+ if op_name == "_native_multi_head_attention" :
137+ out = resolve_native_multi_head_attention (* node_args , ** node_kwargs )
138+ else :
139+ out = op_func (* node_args , ** node_kwargs )
122140 node_outputs [node .name ] = out
123141 dtype = out .dtype if isinstance (out , torch .Tensor ) else None
124142 except Exception :
125143 print (f"dtype inference failed: node.op={ node .op } , op_name={ op_name } " )
126144 node_outputs [node .name ] = None
145+ is_complete = False
127146 elif node .op == "get_attr" :
128147 op_name = node .op
129148 out = resolve_get_attr (traced , node )
@@ -149,11 +168,16 @@ def collect_op_stats(model, input_dict):
149168 else :
150169 op_stats [op_name ].dtype .add (dtype_str )
151170 op_stats [op_name ].count = op_stats [op_name ].count + 1
152- return op_stats
171+ return is_complete , op_stats
153172
154173
155174def collect_model_stats (model_path , device , log_prompt ):
156- print (f"Collect information for { model_path } " )
175+ if not hasattr (collect_model_stats , "_counter" ):
176+ collect_model_stats ._counter = 0
177+ else :
178+ collect_model_stats ._counter += 1
179+ print (f"[{ collect_model_stats ._counter } ] Collect information for { model_path } " )
180+
157181 model_class = load_class_from_file (
158182 os .path .join (model_path , "model.py" ), "GraphModule"
159183 )
@@ -164,7 +188,7 @@ def collect_model_stats(model_path, device, log_prompt):
164188 num_inputs = 0
165189 num_outputs = 0
166190 dtypes = set ()
167- op_stats = collect_op_stats (model , input_dict )
191+ is_complete , op_stats = collect_op_stats (model , input_dict )
168192 if op_stats is not None :
169193 for op_name , stat in op_stats .items ():
170194 if op_name == "placeholder" :
@@ -192,7 +216,7 @@ def collect_model_stats(model_path, device, log_prompt):
192216 dtypes_str = "[" + "," .join (dtypes ) + "]"
193217 param_dtypes_str = "[" + "," .join (param_dtypes ) + "]"
194218 print (
195- f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } " ,
219+ f"{ log_prompt } [ModelStats] model_path:{ model_path } num_inputs:{ num_inputs } num_outputs:{ num_outputs } num_ops:{ num_ops } num_params:{ num_params_in_billion } B param_dtypes:{ param_dtypes_str } op_dtypes:{ dtypes_str } is_complete: { is_complete } " ,
196220 file = sys .stderr ,
197221 flush = True ,
198222 )
0 commit comments