Skip to content

Commit e3a13d9

Browse files
authored
Merge pull request #448 from slin07/bugfix/layernorm-fuse-py2
Python 2 GeLU and Layer Normalization Fusion
2 parents 74a54eb + 173a136 commit e3a13d9

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

coremltools/converters/nnssa/coreml/graph_pass/op_fusions.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ def onehot_matmul_to_embedding(nnssa):
240240
print('[Op Fusion] Node %s is removed.' %(inp_node.name))
241241

242242

243+
def _search_nodes_by_type(gf, node_names, op_type):
244+
for name in node_names:
245+
if gf[name].op == op_type:
246+
return gf[name]
247+
248+
243249
def _match_layernorm_pattern(gf, entry_node):
244250
""" Return the nodes that form the subgraph of a LayerNormalization layer
245251
"""
@@ -248,7 +254,10 @@ def _axes_in_range(axes, rank):
248254

249255
try:
250256
params = {}
251-
mean_1, sqdiff_2, mul_3 = [gf[x] for x in entry_node.outputs]
257+
mean_1 = _search_nodes_by_type(gf, entry_node.outputs, 'Mean')
258+
sqdiff_2 = _search_nodes_by_type(gf, entry_node.outputs, 'SquaredDifference')
259+
mul_3 = _search_nodes_by_type(gf, entry_node.outputs, 'Mul')
260+
252261
if not (mean_1.op == 'Mean' and sqdiff_2.op == 'SquaredDifference' and
253262
mul_3.op == 'Mul'):
254263
return None
@@ -284,9 +293,11 @@ def _axes_in_range(axes, rank):
284293
return None
285294
const_11 = gf[mul_10.inputs[1]]
286295
params['gamma'] = const_11.value.val
287-
if not (gf[mul_10.outputs[0]] == mul_3 and len(mul_10.outputs) == 2):
296+
if not (mul_3.name in mul_10.outputs and len(mul_10.outputs) == 2):
288297
return None
289-
mul_12 = gf[mul_10.outputs[1]]
298+
mul_12 = gf[mul_10.outputs[1]] if gf[mul_10.outputs[0]] == mul_3 else \
299+
gf[mul_10.outputs[0]]
300+
290301
sub_13 = gf[mul_12.outputs[0]]
291302
if not (mul_12.op == 'Mul' and sub_13.op == 'Sub'):
292303
return None
@@ -303,7 +314,7 @@ def _axes_in_range(axes, rank):
303314
add_15]
304315

305316
return (layernorm_nodes, params)
306-
except:
317+
except Exception as e:
307318
return None
308319

309320

@@ -357,7 +368,10 @@ def _match_gelu_pattern(gf, entry_node):
357368
try:
358369
if not len(entry_node.outputs) == 3:
359370
return None
360-
pow_1, add_2, mul_3 = [gf[x] for x in entry_node.outputs]
371+
pow_1 = _search_nodes_by_type(gf, entry_node.outputs, 'Pow')
372+
add_2 = _search_nodes_by_type(gf, entry_node.outputs, 'Add')
373+
mul_3 = _search_nodes_by_type(gf, entry_node.outputs, 'Mul')
374+
361375
if not (pow_1.op == 'Pow' and add_2.op == 'Add' and mul_3.op == 'Mul'):
362376
return None
363377
const_4 = gf[pow_1.inputs[1]]

coremltools/converters/nnssa/coreml/ssa_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1675,7 +1675,7 @@ def _convert_gelu(self, node):
16751675
name=node.name,
16761676
input_name=input_names[0],
16771677
output_name=node.name,
1678-
mode='EXACT')
1678+
mode='TANH_APPROXIMATION')
16791679

16801680
output_shape = self._get_tensor_shape_from_type(node.datatype)
16811681
shapes.propagate_single_layer(layer, self.tensor_shapes,

coremltools/converters/nnssa/frontend/tensorflow/load.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def load(tfgraph, resume_on_errors=False, **kwargs):
4040
ssa = graphdef_to_ssa(gd)
4141

4242
placeholder_shape = kwargs.get("inputs", {})
43+
4344
if len(placeholder_shape) > 0:
4445
graph = ssa.functions['main'].graph
4546
required_plhd_nodes = [node for node in graph if

0 commit comments

Comments
 (0)