Skip to content

Commit 5ce48df

Browse files
[torch] Fix attention on linalg for dynamic shapes (#3714)
Current version does not work for a mixture of dynamic and static shaped batch dimensions. Rework to grab the correct dynamic shapes. --------- Co-authored-by: dan <[email protected]>
1 parent 3f46348 commit 5ce48df

File tree

3 files changed

+47
-26
lines changed

3 files changed

+47
-26
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,35 +1607,23 @@ class ConvertAtenScaledDotProductAttentionOp
16071607
op.getLoc(), "expected no attention mask when isCausal is true");
16081608
}
16091609

1610-
SmallVector<OpFoldResult> maskSizes;
1611-
1612-
if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
1613-
auto seqLenQ =
1614-
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
1615-
auto seqLenK =
1616-
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
1617-
maskSizes = {seqLenQ, seqLenK};
1618-
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
1619-
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
1620-
maskSizes.insert(maskSizes.begin(), batchSize);
1621-
}
1622-
} else { // Dynamic shape case: <?x?x...x?xf32> for example
1623-
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
1624-
Value batchSize =
1625-
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
1626-
maskSizes.push_back(batchSize);
1627-
}
1628-
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
1629-
queryTy.getRank() - 2);
1630-
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
1631-
keyTy.getRank() - 2);
1632-
maskSizes.push_back(seqLenQ);
1633-
maskSizes.push_back(seqLenK);
1610+
SmallVector<int64_t> maskStatic;
1611+
SmallVector<Value> maskDyn;
1612+
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
1613+
maskStatic.push_back(queryTy.getDimSize(i));
1614+
if (maskStatic.back() == ShapedType::kDynamic)
1615+
maskDyn.push_back(
1616+
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
16341617
}
16351618

1619+
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
1620+
if (maskStatic.back() == ShapedType::kDynamic)
1621+
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
1622+
keyTy.getRank() - 2));
1623+
16361624
Type maskType = getElementTypeOrSelf(queryTy);
1637-
Value emptyMask =
1638-
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
1625+
Value emptyMask = rewriter.create<tensor::EmptyOp>(
1626+
op.getLoc(), maskStatic, maskType, maskDyn);
16391627

16401628
Value zero = rewriter.create<arith::ConstantOp>(
16411629
op.getLoc(),

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
# WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE
3838
"ScaledDotProductAttentionBoolMaskModule_basic",
3939
"ScaledDotProductAttentionDifferentCausalModule_basic",
40+
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
4041
"ScaledDotProductAttentionDifferentModule_basic",
4142
"ScaledDotProductAttentionMaskModule_basic",
4243
"ScaledDotProductAttentionSameCausalModule_basic",
@@ -833,6 +834,7 @@
833834
"SafeSoftmaxNonNoneDtypeModule_basic",
834835
# REMOVE WHEN ENABLE_GQA IS ADDED
835836
"ScaledDotProductAttentionBoolMaskModule_basic",
837+
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
836838
"ScaledDotProductAttentionDifferentCausalModule_basic",
837839
"ScaledDotProductAttentionDifferentModule_basic",
838840
"ScaledDotProductAttentionMaskModule_basic",
@@ -3176,6 +3178,7 @@
31763178
# REMOVE WHEN ENABLE_GQA IS ADDED
31773179
"ScaledDotProductAttentionBoolMaskModule_basic",
31783180
"ScaledDotProductAttentionDifferentCausalModule_basic",
3181+
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
31793182
"ScaledDotProductAttentionSameCausalModule_basic",
31803183
"ScatterAddStaticModule_basic",
31813184
"TensorsConcatComplex128FloatModule_basic",
@@ -4679,6 +4682,7 @@
46794682
"ScalarImplicitIntModule_basic",
46804683
# REMOVE WHEN ENABLE_GQA IS ADDED
46814684
"ScaledDotProductAttentionBoolMaskModule_basic",
4685+
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
46824686
"ScaledDotProductAttentionSameCausalModule_basic",
46834687
"ScaledDotProductAttentionSameDynamicModule_basic",
46844688
"ScatterReduceFloatMaxModule",

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5370,6 +5370,35 @@ def forward(self, query, key, value):
53705370
@register_test_case(
53715371
module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule()
53725372
)
5373+
def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils):
5374+
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
5375+
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
5376+
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
5377+
module.forward(query, key, value)
5378+
5379+
5380+
class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module):
5381+
def __init__(self):
5382+
super().__init__()
5383+
5384+
@export
5385+
@annotate_args(
5386+
[
5387+
None,
5388+
([2, 3, -1, 16], torch.float32, True),
5389+
([2, 3, -1, 16], torch.float32, True),
5390+
([2, 3, -1, 20], torch.float32, True),
5391+
]
5392+
)
5393+
def forward(self, query, key, value):
5394+
return torch.ops.aten.scaled_dot_product_attention(
5395+
query, key, value, is_causal=True
5396+
)
5397+
5398+
5399+
@register_test_case(
5400+
module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule()
5401+
)
53735402
def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils):
53745403
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
53755404
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)

0 commit comments

Comments
 (0)