Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2383,26 +2383,26 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto context{builder.getContext()};
auto argBases{getBasesForArgs(args)};

mlir::vector::SplatOp splatOp{nullptr};
mlir::vector::BroadcastOp splatOp{nullptr};
mlir::Type retTy{nullptr};
switch (vop) {
case VecOp::Splat: {
assert(args.size() == 2);
auto vecTyInfo{getVecTypeFromFir(argBases[0])};

auto extractOp{genVecExtract(resultType, args)};
splatOp =
mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
vecTyInfo.toMlirVectorType(context));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, vecTyInfo.toMlirVectorType(context),
*(extractOp.getUnboxed()));
retTy = vecTyInfo.toFirVectorType();
break;
}
case VecOp::Splats: {
assert(args.size() == 1);
auto vecTyInfo{getVecTypeFromEle(argBases[0])};

splatOp = mlir::vector::SplatOp::create(
builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
retTy = vecTyInfo.toFirVectorType();
break;
}
Expand All @@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};

// the intrinsic always returns vector(integer(4))
splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
mlir::VectorType::get(4, eleTy));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, mlir::VectorType::get(4, eleTy), intOp);
retTy = fir::VectorType::get(4, eleTy);
break;
}
Expand Down Expand Up @@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};

auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
auto splatRes{
mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};

mlir::Value result{nullptr};
if (mlirTy != splatRes.getType()) {
Expand Down
4 changes: 1 addition & 3 deletions mlir/docs/Dialects/Vector.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
// Produces a vector<3x7x8xf32>
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
// Produces a vector<3x7x8xf32>
%c = vector.splat %1 : vector<3x7x8xf32>
%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
Expand Down Expand Up @@ -176,8 +176,6 @@ infrastructure can apply iteratively.
### Virtual Vector to Hardware Vector Lowering

For now, `VV -> HWV` are specified in C++ (see for instance the
[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
or the
[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).

Simple
Expand Down
47 changes: 0 additions & 47 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2881,53 +2881,6 @@ def Vector_PrintOp :
}];
}

//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

def Vector_SplatOp : Vector_Op<"splat", [
Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"::llvm::cast<VectorType>($_self).getElementType()">
]> {
let summary = "vector splat or broadcast operation";
let description = [{
Note: This operation is deprecated. Please use vector.broadcast.

Broadcast the operand to all elements of the result vector. The type of the
operand must match the element type of the vector type.

Example:

```mlir
%s = arith.constant 10.1 : f32
%t = vector.splat %s : vector<8x16xf32>
```

This operation is deprecated, the preferred representation of the above is:

```mlir
%s = arith.constant 10.1 : f32
%t = vector.broadcast %s : f32 to vector<8x16xf32>
```
}];

let arguments = (ins AnyType:$input);
let results = (outs AnyVectorOfAnyRank:$aggregate);

let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
[{ build($_builder, $_state, aggregateType, element); }]>];
let assemblyFormat = "$input attr-dict `:` type($aggregate)";

let hasFolder = 1;

// vector.splat is deprecated, and vector.broadcast should be used instead.
// Canonicalize vector.splat to vector.broadcast.
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// VectorScaleOp
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
.Case<vector::SplatOp>([&current](auto op) {
current = op.getInput();
return false;
})
.Default([](Operation *) { return false; });

if (!skipOp) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
/// %pad_1d = vector.splat %pad : vector<[4]xi32>
/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
Expand Down
22 changes: 4 additions & 18 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};

// Convert all `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to ArmSME via another pattern.
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
using Base::Base;

LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {

rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getInput());
return success();
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);
Expand Down
15 changes: 1 addition & 14 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2161,19 +2161,6 @@ class TransposeOpToMatrixTransposeOpLowering
}
};

/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
/// `vector.broadcast` through other patterns.
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};

} // namespace

void mlir::vector::populateVectorRankReducingFMAPattern(
Expand Down Expand Up @@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
Expand Down
23 changes: 4 additions & 19 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};

// Convert `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to SPIRV via other patterns.
struct VectorSplatToBroadcast final
: public OpConversionPattern<vector::SplatOp> {
using Base::Base;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};

struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
Expand Down Expand Up @@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
VectorShuffleOpConvert, VectorInterleaveOpConvert,
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorScalarBroadcastPattern, VectorLoadOpConverter,
VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));

// Make sure that the more specialized dot product pattern has higher benefit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
arith::ConstantOp, arith::SelectOp, vector::SplatOp,
vector::BroadcastOp>();
arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}

void EmulateUnsupportedFloatsPass::runOnOperation() {
Expand Down
52 changes: 6 additions & 46 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}

/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
/// 1s, are considered to be 'broadcastlike'.
/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
if (isa<BroadcastOp>(op))
return true;

auto shapeCast = dyn_cast<ShapeCastOp>(op);
Expand Down Expand Up @@ -3249,23 +3249,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};

/// Consider the defining operation `defOp` of `value`. If `defOp` is a
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
/// value that is splatted. Otherwise return null.
/// vector.broadcast with a scalar operand, return the scalar value that is
/// splatted. Otherwise return null.
///
/// Examples:
/// Example:
///
/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
Operation *defOp = value.getDefiningOp();
if (!defOp)
return {};

// Splat:
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
return splat.getInput();

auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);

// Not broadcast (and not splat):
Expand Down Expand Up @@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}

//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//

OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
return {};

// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(getType(), {constOperand});
}

// Canonicalizer for vector.splat. It always gets canonicalized to a
// vector.broadcast.
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
public:
using Base::Base;
LogicalResult matchAndRewrite(SplatOp splatOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getOperand());
return success();
}
};
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SplatToBroadcastPattern>(context);
}

void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}

Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,

Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
// TODO: add support to `vector.splat`.
// TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
Expand Down
32 changes: 3 additions & 29 deletions mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};

/// This pattern converts the SplatOp to work on a linearized vector.
/// Following,
/// vector.splat %value : vector<4x4xf32>
/// is converted to:
/// %out_1d = vector.splat %value : vector<16xf32>
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
using Base::Base;

LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}

LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
dstTy);
return success();
}
};

/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
Expand Down Expand Up @@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
LinearizeVectorStore, LinearizeVectorFromElements,
LinearizeVectorToElements>(typeConverter, patterns.getContext());
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
LinearizeVectorFromElements, LinearizeVectorToElements>(
typeConverter, patterns.getContext());
}

void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
Expand Down
Loading