Skip to content

Commit c598e01

Browse files
author
Prashant Kumar
committed
Add support for passing & returning memref of bool types
Support for passing memref of bool types as a function argument and return is added in ref-backend. Signed-off-by: Prashant Kumar <[email protected]>
1 parent 9958cf0 commit c598e01

File tree

4 files changed

+71
-4
lines changed

4 files changed

+71
-4
lines changed

e2e_testing/torchscript/basic.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,3 +1003,57 @@ def forward(self):
10031003
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
10041004
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
10051005
module.forward()
1006+
1007+
1008+
class BoolTensorReturnFalseModule(torch.nn.Module):
1009+
def __init__(self):
1010+
super().__init__()
1011+
1012+
@export
1013+
@annotate_args([
1014+
None,
1015+
([-1], torch.bool, True),
1016+
])
1017+
def forward(self, a):
1018+
return a
1019+
1020+
1021+
@register_test_case(module_factory=lambda: BoolTensorReturnFalseModule())
1022+
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
1023+
module.forward(torch.tensor([0, 0], dtype=torch.bool))
1024+
1025+
1026+
class BoolTensorReturnTrueModule(torch.nn.Module):
1027+
def __init__(self):
1028+
super().__init__()
1029+
1030+
@export
1031+
@annotate_args([
1032+
None,
1033+
([-1], torch.bool, True),
1034+
])
1035+
def forward(self, a):
1036+
return a
1037+
1038+
1039+
@register_test_case(module_factory=lambda: BoolTensorReturnTrueModule())
1040+
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
1041+
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))
1042+
1043+
1044+
class BoolTensorReturnMixedModule(torch.nn.Module):
1045+
def __init__(self):
1046+
super().__init__()
1047+
1048+
@export
1049+
@annotate_args([
1050+
None,
1051+
([-1, -1], torch.bool, True),
1052+
])
1053+
def forward(self, a):
1054+
return a
1055+
1056+
1057+
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
1058+
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
1059+
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))

e2e_testing/torchscript/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,7 @@
4040
"AddCMulModule_basic",
4141
"AddCDivModule_basic",
4242
"SqueezeModule_broadcast",
43+
"BoolTensorReturnFalseModule_basic",
44+
"BoolTensorReturnTrueModule_basic",
45+
"BoolTensorReturnMixedModule_basic",
4346
}

lib/RefBackend/RefBackend.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ static bool isArgMemRefTypeValid(Type type) {
5656
return true;
5757
if (integerTy.isSignlessInteger(32))
5858
return true;
59+
if (integerTy.isSignlessInteger(1))
60+
return true;
5961
}
6062
}
6163
return false;
@@ -128,7 +130,7 @@ static LogicalResult mungeFunction(
128130
auto type = arg.getType();
129131
if (!isArgMemRefTypeValid(type))
130132
return emitError(arg.getLoc(),
131-
"argument must be a memref of f32, f64, i32, i64");
133+
"argument must be a memref of f32, f64, i32, i64, i1");
132134
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
133135
arg.replaceAllUsesExcept(cast, cast);
134136
arg.setType(getAbiTypeForMemRef(type));
@@ -163,7 +165,7 @@ static LogicalResult mungeFunction(
163165
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
164166
op.emitError(
165167
"must have one return value of memref types or scalar types "
166-
"of i32, i64, f32, f64 or three return values of memref f32");
168+
"of i32, i64, f32, f64, i1, or three return values of memref f32");
167169
isSupported = false;
168170
}
169171

@@ -182,6 +184,7 @@ static LogicalResult mungeFunction(
182184

183185
static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
184186
std::set<std::string> funcNames;
187+
Type mri1 = UnrankedMemRefType::get(b.getI1Type(), 0);
185188
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
186189
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
187190
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
@@ -191,7 +194,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
191194
Type f64 = b.getF64Type();
192195

193196
SmallVector<TypeRange> supportedReturnTypes = {
194-
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
197+
mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
195198

196199
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
197200
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));

python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
def checkArgTypeIsSupported(ty):
27-
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
27+
SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_]
2828
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"
2929

3030

@@ -33,6 +33,10 @@ def __init__(self, module):
3333
self.ee = ExecutionEngine(module)
3434
self.result = None
3535

36+
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
37+
def consume_return_mri1(a):
38+
self.result = unranked_memref_to_numpy(a, np.bool_)
39+
3640
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
3741
def consume_return_mri32(a):
3842
self.result = unranked_memref_to_numpy(a, np.int32)
@@ -70,6 +74,9 @@ def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
7074
arg1,
7175
np.float32), unranked_memref_to_numpy(arg2, np.float32)
7276

77+
self.ee.register_runtime("refbackend_consume_func_return_mri1",
78+
consume_return_mri1)
79+
7380
self.ee.register_runtime("refbackend_consume_func_return_mri32",
7481
consume_return_mri32)
7582

0 commit comments

Comments
 (0)