Skip to content

Commit a44361c

Browse files
committed
[mlir][transform-dialect] add unittest of named_sequence build.
1 parent ca0bc78 commit a44361c

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

mlir/unittests/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_unittest(MLIRTransformDialectTests
2+
TransformNamedSequenceCreate.cpp
23
BuildOnlyExtensionTest.cpp
34
Preload.cpp
45
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
2+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
3+
#include "mlir/IR/Builders.h"
4+
#include "mlir/IR/BuiltinAttributes.h"
5+
#include "mlir/IR/BuiltinOps.h"
6+
#include "mlir/IR/MLIRContext.h"
7+
#include "gtest/gtest.h"
8+
9+
using namespace mlir;
10+
using namespace mlir::transform;
11+
12+
TEST(NamedSequenceOpTest, ArgAttrsAreHonoredByBuilder) {
13+
MLIRContext ctx;
14+
ctx.loadDialect<TransformDialect>();
15+
16+
OpBuilder builder(&ctx);
17+
auto module = ModuleOp::create(UnknownLoc::get(&ctx));
18+
builder.setInsertionPointToEnd(module.getBody());
19+
20+
Location loc = UnknownLoc::get(&ctx);
21+
22+
static constexpr StringLiteral kMainSequenceName = "__transform_main";
23+
24+
NamedSequenceOp seqOp = builder.create<NamedSequenceOp>(
25+
loc,
26+
/*sym_name=*/kMainSequenceName,
27+
/*rootType=*/builder.getType<AnyOpType>(),
28+
/*resultType=*/TypeRange{},
29+
[](OpBuilder &b, Location nested, Value rootH) {
30+
b.create<YieldOp>(nested, ValueRange());
31+
},
32+
/*args=*/ArrayRef<NamedAttribute>{},
33+
/*attrArgs=*/
34+
ArrayRef<DictionaryAttr>{
35+
builder.getDictionaryAttr(ArrayRef<NamedAttribute>{
36+
builder.getNamedAttr(TransformDialect::kArgConsumedAttrName,
37+
builder.getUnitAttr())})});
38+
39+
// 检查 body argument 上有没有 transform.consumed
40+
Block &body = seqOp.getBody().front();
41+
ASSERT_EQ(body.getNumArguments(), 1u);
42+
43+
StringAttr arg0Name = seqOp.getArgAttrsAttrName();
44+
EXPECT_TRUE(arg0Name);
45+
}

0 commit comments

Comments
 (0)