diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py index 8d6f156ed..c12dcc714 100644 --- a/onnxscript/rewriter/rules/fusion/_gqa.py +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -52,7 +52,7 @@ def pattern( _outputs=["attention_BHSDh"], ) - return attention_BHSDh + return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD def check( self, @@ -103,6 +103,7 @@ def rewrite( past_key_BHkvSpD, past_value_BHkvSpD, **original_attrs, + _outputs=3, )