Skip to content

Commit 173a136

Browse files
committed
fix for different topology of gelu and layernorm fusion in Py2
1 parent 308ef9f commit 173a136

File tree

2 files changed

+20
-58
lines changed

2 files changed

+20
-58
lines changed

coremltools/converters/nnssa/coreml/graph_pass/op_fusions.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ def onehot_matmul_to_embedding(nnssa):
240240
print('[Op Fusion] Node %s is removed.' %(inp_node.name))
241241

242242

243+
def _search_nodes_by_type(gf, node_names, op_type):
244+
for name in node_names:
245+
if gf[name].op == op_type:
246+
return gf[name]
247+
248+
243249
def _match_layernorm_pattern(gf, entry_node):
244250
""" Return the nodes that form the subgraph of a LayerNormalization layer
245251
"""
@@ -248,7 +254,10 @@ def _axes_in_range(axes, rank):
248254

249255
try:
250256
params = {}
251-
mean_1, sqdiff_2, mul_3 = [gf[x] for x in entry_node.outputs]
257+
mean_1 = _search_nodes_by_type(gf, entry_node.outputs, 'Mean')
258+
sqdiff_2 = _search_nodes_by_type(gf, entry_node.outputs, 'SquaredDifference')
259+
mul_3 = _search_nodes_by_type(gf, entry_node.outputs, 'Mul')
260+
252261
if not (mean_1.op == 'Mean' and sqdiff_2.op == 'SquaredDifference' and
253262
mul_3.op == 'Mul'):
254263
return None
@@ -284,9 +293,11 @@ def _axes_in_range(axes, rank):
284293
return None
285294
const_11 = gf[mul_10.inputs[1]]
286295
params['gamma'] = const_11.value.val
287-
if not (gf[mul_10.outputs[0]] == mul_3 and len(mul_10.outputs) == 2):
296+
if not (mul_3.name in mul_10.outputs and len(mul_10.outputs) == 2):
288297
return None
289-
mul_12 = gf[mul_10.outputs[1]]
298+
mul_12 = gf[mul_10.outputs[1]] if gf[mul_10.outputs[0]] == mul_3 else \
299+
gf[mul_10.outputs[0]]
300+
290301
sub_13 = gf[mul_12.outputs[0]]
291302
if not (mul_12.op == 'Mul' and sub_13.op == 'Sub'):
292303
return None
@@ -303,7 +314,7 @@ def _axes_in_range(axes, rank):
303314
add_15]
304315

305316
return (layernorm_nodes, params)
306-
except:
317+
except Exception as e:
307318
return None
308319

309320

@@ -357,7 +368,10 @@ def _match_gelu_pattern(gf, entry_node):
357368
try:
358369
if not len(entry_node.outputs) == 3:
359370
return None
360-
pow_1, add_2, mul_3 = [gf[x] for x in entry_node.outputs]
371+
pow_1 = _search_nodes_by_type(gf, entry_node.outputs, 'Pow')
372+
add_2 = _search_nodes_by_type(gf, entry_node.outputs, 'Add')
373+
mul_3 = _search_nodes_by_type(gf, entry_node.outputs, 'Mul')
374+
361375
if not (pow_1.op == 'Pow' and add_2.op == 'Add' and mul_3.op == 'Mul'):
362376
return None
363377
const_4 = gf[pow_1.inputs[1]]

coremltools/converters/nnssa/coreml/ssa_converter.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,7 +1675,7 @@ def _convert_gelu(self, node):
16751675
name=node.name,
16761676
input_name=input_names[0],
16771677
output_name=node.name,
1678-
mode='EXACT')
1678+
mode='TANH_APPROXIMATION')
16791679

16801680
output_shape = self._get_tensor_shape_from_type(node.datatype)
16811681
shapes.propagate_single_layer(layer, self.tensor_shapes,
@@ -1786,58 +1786,6 @@ def _convert_resize_nearest_neighbor(self, node):
17861786
output_shapes=[output_shape])
17871787

17881788

1789-
def _convert_layer_normalization(self, node):
1790-
assert len(node.inputs) == 1
1791-
input_nodes, input_names, input_types = self._get_input_tensors(node)
1792-
input_name = input_names[0]
1793-
builder = self._get_builder()
1794-
gamma = node.attr['gamma']
1795-
beta = node.attr['beta']
1796-
axes = node.attr['axes']
1797-
epsilon = node.attr['epsilon']
1798-
input_shape = list(input_types[0].get_shape())
1799-
1800-
if (len(input_shape) in [2,3] and len(axes) == 1 and \
1801-
axes[0] == len(input_shape) - 1):
1802-
# Performance enhancement for some models with layer-norm
1803-
builder.add_reshape_static(name=input_name + '_reshape',
1804-
input_name=input_name,
1805-
output_name=input_name + '_reshape',
1806-
output_shape=input_shape + [1,1])
1807-
1808-
builder.add_mvn(name=input_name + '_mvn',
1809-
input_name=input_name + '_reshape',
1810-
output_name=input_name + '_mvn', across_channels=True,
1811-
normalize_variance=True, epsilon=epsilon)
1812-
1813-
builder.add_scale(name=node.name + '_5d',
1814-
input_name=input_name + '_mvn',
1815-
output_name=node.name + '_5d', W=gamma, b=beta, has_bias=True,
1816-
shape_scale=[len(gamma)], shape_bias=[len(beta)])
1817-
1818-
builder.add_reshape_static(name=node.name,
1819-
input_name=node.name + '_5d',
1820-
output_name=node.name,
1821-
output_shape=input_shape)
1822-
1823-
else:
1824-
# General implementation
1825-
input_shape = input_types[0].get_shape()
1826-
rdims = len(axes)
1827-
normalized_shape = node.datatype.get_shape()[-rdims:]
1828-
if gamma.shape != normalized_shape:
1829-
gamma = np.zeros(normalized_shape) + gamma
1830-
if beta.shape != normalized_shape:
1831-
beta = np.zeros(normalized_shape) + beta
1832-
1833-
builder.add_layer_normalization(node.name, input_name, node.name,
1834-
normalized_shape, gamma, beta, eps=1e-5)
1835-
1836-
self.tensor_shapes[node.name] = self._get_tensor_shape_from_type(
1837-
node.datatype)
1838-
1839-
1840-
18411789
def _convert_layer_normalization(self, node):
18421790
assert len(node.inputs) == 1
18431791
input_nodes, input_names, input_types = self._get_input_tensors(node)

0 commit comments

Comments
 (0)