Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def __init__(self) -> None:
self._op_type = OpType.ReduceType


class MatmulWithAccOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType


class GetItemOp(Op):
def __init__(self) -> None:
super().__init__()
Expand Down
77 changes: 76 additions & 1 deletion frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from .. import DeviceType
from torch.fx.immutable_collections import immutable_list

classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp}
classicfuse_register = {
"transpose_matmul_fusion": TransposeMatmulFusedOp,
"residual_fusion": MatmulWithAccOp,
}

# TODO: classify op type for op fusion
# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
Expand Down Expand Up @@ -91,6 +94,77 @@ def transpose_matmul_fusion(
graph.delete_node(target, targets_parent)


def residual_fuse_check(graph: Graph):
for op in graph.body:
pattern = None
if isinstance(op, MatmulOp):
child_1op = [graph.node_table[str(i)] for i in op._children]
for reshape in child_1op:
if isinstance(reshape, (ViewOp, ReshapeOp)) and (
reshape.args[1] == immutable_list([1, 1, 1536])
):
child_2op = [
graph.node_table[str(i)] for i in reshape._children
]
for add in child_2op:
if isinstance(add, AddOp):
pattern = (reshape, add, "residual_fusion")
break
else:
continue
break
if pattern:
residual_fusion(
graph,
op,
pattern[0],
pattern[1],
pattern[2],
)


def residual_fusion(graph: Graph, node, reshape, add: Op, pattern: str):
fuse_op = classicfuse_register.get(pattern)()
fuse_op.name = "fused_" + node.name
graph.displace_node(node, fuse_op)

reshape._children.extend(add._children)
add_children = [
graph.node_table[child_name] for child_name in add._children
]
for child in add_children:
if add.name in child._parents:
parent_idx = child._parents.index(add.name)
child._parents[parent_idx] = reshape.name

if add.name in child.args:
arg_idx = child.args.index(add.name)
child.args[arg_idx] = reshape.name

if add.name in reshape._children:
reshape._children.pop(reshape._children.index(add.name))

residual_parents = [p for p in add._parents if p != reshape.name]
fuse_op._parents.extend(residual_parents)
fuse_op.args.extend(residual_parents)

fuse_op._parents = list(dict.fromkeys(fuse_op._parents))
original_args = node.args.copy()
residual_only = [p for p in fuse_op._parents if p not in original_args]
fuse_op._parents = original_args + residual_only

add._children.clear()

add_parents = []
for parent_name in add._parents:
parent = graph.node_table[parent_name]
if add.name in parent._children:
add_parents.append(parent)

if graph.check_delete_node(add) and add_parents:
graph.delete_node(add, add_parents)


def apply_classic_fusion(graph: Graph):
"""
Function to fuse some typical operations into one operation and fuse
Expand All @@ -106,6 +180,7 @@ def apply_classic_fusion(graph: Graph):
device = DeviceType.CPU
# Run the first round of op fusion
classic_fuse_check(graph)
residual_fuse_check(graph)
for op in graph.body:
if isinstance(op, PlaceholderOp):
continue
Expand Down
41 changes: 41 additions & 0 deletions frontend/Python/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,46 @@ def matmul_transpose_b_op(
return op


def matmul_bias_op(
node: MatmulWithAccOp,
symbol_table: Dict[Tuple[str, int], ir.Operation],
):
if len(node.args) < 3:
raise ValueError(
f"MatmulWithAccOp requires 3 arguments, got {len(node.args)}"
)

input_a = symbol_table.get((str(node.args[0]), 0))
input_b = symbol_table.get((str(node.args[1]), 0))
bias = symbol_table.get((str(node.args[2]), 0))

if input_a is None or input_b is None or bias is None:
return None

dtype = node.tensor_meta["dtype"]
output_shape = list(node.tensor_meta["shape"])
mlir_dtype = mlir_element_type_get(dtype)
tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
generic_map = ir.AffineMap.get_permutation([0, 1, 2])

# reshape bias to match matmul output
bias_reshaped = tosa.ReshapeOp(bias, output_shape).result

op = linalg.MatmulOp(
result_tensors=[tensor_type],
inputs=[input_a, input_b],
outputs=[bias_reshaped],
indexing_maps=[
generic_map.get_submap([0, 2]), # lhs: (m, k)
generic_map.get_submap([2, 1]), # rhs: (k, n)
generic_map.get_submap([0, 1]), # out: (m, n)
],
cast="cast_signed",
)
linalg.fill_builtin_region(op.operation)
return op


def transpose_op(
node: TransposeOp,
symbol_table: Dict[Tuple[str, int], ir.Operation],
Expand Down Expand Up @@ -3001,6 +3041,7 @@ def as_strided_op(
ops_registry = {
"MatmulOp": matmul_op,
"TransposeMatmulFusedOp": matmul_transpose_b_op,
"MatmulWithAccOp": matmul_bias_op,
"ArangeOp": arange_op,
"UnsqueezeOp": unsqueeze_op,
"ViewOp": view_op,
Expand Down