Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
07de1d0
[WIP] Initial implementation of attention masking in torch.ops.aten.s…
rohan-tan-bhowmik Aug 8, 2024
05b1065
(WIP) Added causal and boolean masking
rohan-tan-bhowmik Sep 3, 2024
dd290bf
(WIP) Added causal and boolean masking
rohan-tan-bhowmik Sep 3, 2024
a29fce4
(WIP) Added checks and tests
rohan-tan-bhowmik Sep 4, 2024
263b380
(WIP) Added checks and tests
rohan-tan-bhowmik Sep 4, 2024
8983210
Clean up and robustness work
rohan-tan-bhowmik Sep 6, 2024
2a112ec
Formatting
rohan-tan-bhowmik Sep 6, 2024
f6e873e
Merge branch 'llvm:main' into sdpa_mask
rohan-tan-bhowmik Sep 6, 2024
9d6f70d
Restored llvm-project
rohan-tan-bhowmik Sep 6, 2024
0e41ca3
Merge branch 'sdpa_mask' of https://github.com/rohan-tan-bhowmik/torc…
rohan-tan-bhowmik Sep 6, 2024
f067dd0
Restored llvm-project
rohan-tan-bhowmik Sep 6, 2024
f70166d
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 6, 2024
84cc5cb
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 6, 2024
1c0f138
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 6, 2024
fd981c4
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 7, 2024
ebd4c0c
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 8, 2024
daeff4a
Merge branch 'llvm:main' into sdpa_mask
rohan-tan-bhowmik Sep 8, 2024
3c269d4
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 9, 2024
0ea7cba
Merge branch 'sdpa_mask' of https://github.com/rohan-tan-bhowmik/torc…
rohan-tan-bhowmik Sep 9, 2024
67b318f
Merge branch 'llvm:main' into sdpa_mask
rohan-tan-bhowmik Sep 9, 2024
a743b2f
XFAIL on stable version, PASS on nightly version
rohan-tan-bhowmik Sep 9, 2024
2cab3c0
Merge branch 'sdpa_mask' of https://github.com/rohan-tan-bhowmik/torc…
rohan-tan-bhowmik Sep 9, 2024
fd94911
Review changes
rohan-tan-bhowmik Sep 9, 2024
75f1240
Review changes
rohan-tan-bhowmik Sep 9, 2024
d01f743
Review changes
rohan-tan-bhowmik Sep 9, 2024
f78f878
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
fc41aa3
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
04f9e4d
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
bf88ec5
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
1a9466e
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
421b74e
XFAIL Changes
rohan-tan-bhowmik Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
["generateScalarImplementation"]>]> {
let summary = "Attention operator";
let description = [{
This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes
the attention. Each of the inputs has shape BxNxd where B is the
of the batch dimension, N is the sequence length and d is head dimension.
Typically N >>> d. Mathematically, the attention is defined as
matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually,
this operator also performs scaling, masking and dropout, but we leave
that out of the current implementation.
This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an
optional mask(M) to compute the attention. These tensors must take on shapes
BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these
shapes, B represents the batch dimension, M represents sequence length, N
represents head dimension, and K1 and K2 are hidden dimensions.
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
has shape BxMxN. Usually, this operator also performs scaling, masking and
dropout, but we leave that out of the current implementation.
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Expand Down Expand Up @@ -287,6 +288,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
Value getValue() {
return getInputOperand(2)->get();
}
std::optional<Value> getAttnMask() {
if (getNumInputs() < 4) {
return std::nullopt;
}
return getInputOperand(3)->get();
}
Value getOutput() {
return getOutputOperand(0)->get();
}
Expand All @@ -299,6 +306,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
ShapedType getValueType() {
return cast<ShapedType>(getValue().getType());
}
std::optional<ShapedType> getAttnMaskType() {
if (getAttnMask()){
return cast<ShapedType>((*getAttnMask()).getType());
}
return std::nullopt;
}
ShapedType getOutputType() {
return cast<ShapedType>(getOutput().getType());
}
Expand All @@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
int64_t getValueRank() {
return getValueType().getRank();
}
std::optional<int64_t> getAttnMaskRank() {
if (getAttnMask()){
return (*getAttnMaskType()).getRank();
}
return std::nullopt;
}
int64_t getOutputRank() {
return getOutputType().getRank();
}
Expand Down
111 changes: 89 additions & 22 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,26 +1578,94 @@ class ConvertAtenScaledDotProductAttentionOp
LogicalResult
matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value mask = op.getAttnMask();

auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto mask = adaptor.getAttnMask();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

Value dropoutP = op.getDropoutP();
Value isCausal = op.getIsCausal();
Value scale = op.getScale();
Value enableGQA = op.getEnableGqa();
Type elementType =
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();

// Verify inputs (only support defaults)
if (!isa<Torch::NoneType>(mask.getType()))
return rewriter.notifyMatchFailure(op.getLoc(),
"attention masking not supported");
double dropout;
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
dropout > 0.0)
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");

bool causal;
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal)
return rewriter.notifyMatchFailure(
op.getLoc(), "causal attention masking not supported");
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
if (!isa<Torch::NoneType>(mask.getType())) {
return rewriter.notifyMatchFailure(
op.getLoc(), "expected no attention mask when isCausal is true");
}

SmallVector<OpFoldResult> maskSizes;

if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
}

Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);

Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
Value negInf = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));

mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
.getResult(0);

int64_t rank = cast<ShapedType>(queryTy).getRank();
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
auto genericOp = rewriter.create<linalg::GenericOp>(
op.getLoc(), mask.getType(), ValueRange{}, mask,
SmallVector<AffineMap>{maskMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
Value j = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 1);

Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, i, j);
Value select = b.create<arith::SelectOp>(loc, cond, zero, negInf);
b.create<linalg::YieldOp>(loc, select);
});
mask = genericOp.getResult(0);
}

if (!isa<Torch::NoneType>(scale.getType())) {
double scaleFloat;
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
Expand All @@ -1611,14 +1679,6 @@ class ConvertAtenScaledDotProductAttentionOp
return rewriter.notifyMatchFailure(
op.getLoc(), "grouped query attention not supported");

auto opTy = cast<ValueTensorType>(op.getType()).toBuiltinTensor();
auto query = adaptor.getQuery();
auto value = adaptor.getValue();
auto key = adaptor.getKey();
auto queryTy = cast<ShapedType>(query.getType());
auto valueTy = cast<ShapedType>(value.getType());
auto keyTy = cast<ShapedType>(key.getType());

if (queryTy.getRank() != valueTy.getRank() ||
queryTy.getRank() != keyTy.getRank())
return rewriter.notifyMatchFailure(op, "operand ranks do not match");
Expand Down Expand Up @@ -1659,6 +1719,9 @@ class ConvertAtenScaledDotProductAttentionOp
query = collapseBatch(query);
key = collapseBatch(key);
value = collapseBatch(value);
if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
mask = collapseBatch(mask);
}

SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
SmallVector<int64_t> valueSizes(
Expand All @@ -1672,13 +1735,17 @@ class ConvertAtenScaledDotProductAttentionOp
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
elementType);

SmallVector<Value> inputs = SmallVector<Value>{query, key, value};

if (!isa<mlir::torch::Torch::NoneType>(mask.getType())) {
inputs.push_back(mask);
}

// Overwrite with tm_tensor::attention
Value attention =
rewriter
.create<AttentionOp>(loc, outType,
SmallVector<Value>{query, key, value},
SmallVector<Value>{output})
.getResult()[0];
Value attention = rewriter
.create<AttentionOp>(loc, outType, inputs,
SmallVector<Value>{output})
.getResult()[0];

if (opTy != outType) {
attention = rewriter.create<tensor::ExpandShapeOp>(loc, opTy, attention,
Expand Down
94 changes: 87 additions & 7 deletions lib/Dialect/TMTensor/IR/TMTensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,49 @@ LogicalResult AttentionOp::verify() {
Operation *op = getOperation();
ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();

auto optionalMaskType = getAttnMaskType();
ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType();

ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> maskShape =
optionalMaskType ? maskType.getShape() : ArrayRef<int64_t>();

for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (keyShape[i] != queryShape[i])
if (keyShape[i] != queryShape[i]) {
return op->emitOpError("query and key batch mismatch");
}
}
if (keyShape.back() != queryShape.back())
if (keyShape.back() != queryShape.back()) {
return op->emitOpError("query and key head dimension mismatch");
}

for (int i = 0, s = queryShape.size() - 2; i < s; ++i) {
if (valueShape[i] != queryShape[i]) {
return op->emitOpError("query and value batch dimension mismatch");
}
}
if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) {
return op->emitOpError("key and value sequence length dimension mismatch");
}
if (optionalMaskType) {
for (int i = 0, s = maskShape.size() - 2; i < s; ++i) {
if (maskShape[i] != queryShape[i]) {
return op->emitOpError("query and mask batch dimension mismatch");
}
}
if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) {
return op->emitOpError(
"mask sequence length and query sequence length mismatch");
}
if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) {
return op->emitOpError(
"mask sequence lengt and key sequence length mismatch");
}
}
return success();
}

Expand Down Expand Up @@ -168,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
Value query = getQuery();
Value key = getKey();
Value value = getValue();

auto optionalMask = getAttnMask();
Value mask = optionalMask ? *optionalMask : Value();

Value output = getOutput();
auto queryType = cast<MemRefType>(query.getType());
auto keyType = cast<MemRefType>(key.getType());
auto valueType = cast<MemRefType>(value.getType());
auto maskType = mask ? cast<MemRefType>(mask.getType()) : MemRefType();
auto queryRank = queryType.getRank();
auto keyRank = keyType.getRank();
auto valueRank = valueType.getRank();
Expand All @@ -180,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,

Value zeroF = b.create<arith::ConstantOp>(loc, elementType,
b.getFloatAttr(elementType, 0.0));
Value negInfF = b.create<arith::ConstantOp>(
loc, elementType,
b.getFloatAttr(elementType, -std::numeric_limits<double>::infinity()));

// TODO: This needs to be fixed, it assumes everything is dynamic however if
// any shapes are static the `memref.alloc` generated is illegal.
Expand Down Expand Up @@ -214,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
/*transposed=*/true);

// weight = softmax(weight)
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value dim = weightDynSizes[weightRank - 1];
Value scaleFactor = b.create<math::SqrtOp>(
loc, b.create<arith::UIToFPOp>(
loc, elementType,
b.create<arith::IndexCastUIOp>(loc, b.getI32Type(),
queryDynSizes[queryRank - 1])));

// weight = (weight - max(weight)) / math.sqrt(querySizes[-1])
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});

// Apply mask to weights if mask is given
if (mask) {
b.create<scf::ParallelOp>(
loc, SmallVector<Value>(weightRank, zero), weightDynSizes,
SmallVector<Value>(weightRank, one),
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value weightValue = b.create<memref::LoadOp>(loc, weight, localIVs);
Value maskValue = b.create<memref::LoadOp>(loc, mask, localIVs);
if (maskType.getElementType().isInteger(1)) {
maskValue =
b.create<arith::SelectOp>(loc, maskValue, zeroF, negInfF);
}
Value maskedWeight =
b.create<arith::AddFOp>(loc, weightValue, maskValue);
b.create<memref::StoreOp>(loc, maskedWeight, weight, localIVs);
});
}

// calculate max(weight)
Value init = b.create<memref::LoadOp>(loc, weight,
SmallVector<Value>(weightRank, zero));
Expand Down Expand Up @@ -249,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
x = b.create<arith::SubFOp>(loc, x, globalMax);
x = b.create<arith::DivFOp>(loc, x, scaleFactor);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
});
// calculate exp(weight)
Expand Down Expand Up @@ -307,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
[&](OpBuilder &b, Location loc, ValueRange localIVs) {
SmallVector<Value> sumIVs(localIVs);
sumIVs.pop_back();

Value x = b.create<memref::LoadOp>(loc, weight, localIVs);
Value sum = b.create<memref::LoadOp>(loc, expWeightSum, sumIVs);
x = b.create<arith::DivFOp>(loc, x, sum);
b.create<memref::StoreOp>(loc, x, weight, localIVs);
Value divResult = b.create<arith::DivFOp>(loc, x, sum);

// Set to 0 if sum is 0 (can occur during boolean mask / large negative
// QK)
Value isSumZero =
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, sum, zeroF);
Value result =
b.create<arith::SelectOp>(loc, isSumZero, zeroF, divResult);

b.create<memref::StoreOp>(loc, result, weight, localIVs);
});

// output = weight @ value
Expand Down
Loading
Loading