@@ -240,6 +240,12 @@ def onehot_matmul_to_embedding(nnssa):
240240 print ('[Op Fusion] Node %s is removed.' % (inp_node .name ))
241241
242242
243+ def _search_nodes_by_type (gf , node_names , op_type ):
244+ for name in node_names :
245+ if gf [name ].op == op_type :
246+ return gf [name ]
247+
248+
243249def _match_layernorm_pattern (gf , entry_node ):
244250 """ Return the nodes that form the subgraph of a LayerNormalization layer
245251 """
@@ -248,7 +254,10 @@ def _axes_in_range(axes, rank):
248254
249255 try :
250256 params = {}
251- mean_1 , sqdiff_2 , mul_3 = [gf [x ] for x in entry_node .outputs ]
257+ mean_1 = _search_nodes_by_type (gf , entry_node .outputs , 'Mean' )
258+ sqdiff_2 = _search_nodes_by_type (gf , entry_node .outputs , 'SquaredDifference' )
259+ mul_3 = _search_nodes_by_type (gf , entry_node .outputs , 'Mul' )
260+
252261 if not (mean_1 .op == 'Mean' and sqdiff_2 .op == 'SquaredDifference' and
253262 mul_3 .op == 'Mul' ):
254263 return None
@@ -284,9 +293,11 @@ def _axes_in_range(axes, rank):
284293 return None
285294 const_11 = gf [mul_10 .inputs [1 ]]
286295 params ['gamma' ] = const_11 .value .val
287- if not (gf [ mul_10 . outputs [ 0 ]] == mul_3 and len (mul_10 .outputs ) == 2 ):
296+ if not (mul_3 . name in mul_10 . outputs and len (mul_10 .outputs ) == 2 ):
288297 return None
289- mul_12 = gf [mul_10 .outputs [1 ]]
298+ mul_12 = gf [mul_10 .outputs [1 ]] if gf [mul_10 .outputs [0 ]] == mul_3 else \
299+ gf [mul_10 .outputs [0 ]]
300+
290301 sub_13 = gf [mul_12 .outputs [0 ]]
291302 if not (mul_12 .op == 'Mul' and sub_13 .op == 'Sub' ):
292303 return None
@@ -303,7 +314,7 @@ def _axes_in_range(axes, rank):
303314 add_15 ]
304315
305316 return (layernorm_nodes , params )
306- except :
317+ except Exception as e :
307318 return None
308319
309320
@@ -357,7 +368,10 @@ def _match_gelu_pattern(gf, entry_node):
357368 try :
358369 if not len (entry_node .outputs ) == 3 :
359370 return None
360- pow_1 , add_2 , mul_3 = [gf [x ] for x in entry_node .outputs ]
371+ pow_1 = _search_nodes_by_type (gf , entry_node .outputs , 'Pow' )
372+ add_2 = _search_nodes_by_type (gf , entry_node .outputs , 'Add' )
373+ mul_3 = _search_nodes_by_type (gf , entry_node .outputs , 'Mul' )
374+
361375 if not (pow_1 .op == 'Pow' and add_2 .op == 'Add' and mul_3 .op == 'Mul' ):
362376 return None
363377 const_4 = gf [pow_1 .inputs [1 ]]
0 commit comments