Skip to content

Commit 94f5410

Browse files
[LINALG] Add complex tensor support for create[Zero|One]InitTensor utility (#3777)
Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent d49eabb commit 94f5410

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

lib/Conversion/Utils/Utils.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,25 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
132132
Type elemTy) {
133133
Value initTensor =
134134
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
135-
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
136-
Value c0 =
137-
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
135+
136+
Type fillValElemTy = elemTy;
137+
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
138+
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
139+
140+
Value c0 = b.create<arith::ConstantOp>(loc, b.getZeroAttr(fillValElemTy));
138141
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
139142
}
140143

141144
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
142145
Type elemTy) {
143146
Value initTensor =
144147
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
145-
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
146-
Value c1 =
147-
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
148+
149+
Type fillValElemTy = elemTy;
150+
if (auto dtypeComplex = dyn_cast<mlir::ComplexType>(elemTy))
151+
fillValElemTy = cast<mlir::FloatType>(dtypeComplex.getElementType());
152+
153+
Value c1 = b.create<arith::ConstantOp>(loc, b.getOneAttr(fillValElemTy));
148154
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
149155
}
150156

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,7 @@
14231423
"SliceSizeTwoStepModule_basic",
14241424
"SliceStartEqEndModule_basic",
14251425
"SliceStaticModule_basic",
1426+
"SliceStaticComplexInputModule_basic",
14261427
"SliceWholeTensorModule_basic",
14271428
"SortIntListReverse_basic",
14281429
"SortIntList_basic",
@@ -2618,6 +2619,7 @@
26182619
"SliceCopyNegative_Module_basic",
26192620
"SliceCopyNonZeroDim_Module_basic",
26202621
"SliceCopy_Module_basic",
2622+
"SliceStaticComplexInputModule_basic",
26212623
"StdCorrectionLargeInputModule_basic",
26222624
"TupleModule_basic",
26232625
"VarCorrectionLargeInputModule_basic",
@@ -3778,6 +3780,7 @@
37783780
"SignAndLogarithmOfDeterminantModule_F32",
37793781
"SignAndLogarithmOfDeterminantBatchedModule_F32",
37803782
"SignAndLogarithmOfDeterminantDynamicModule_F32",
3783+
"SliceStaticComplexInputModule_basic",
37813784
"SliceCopyEndGreaterThanDimSize_Module_basic",
37823785
"SliceCopyNegative_Module_basic",
37833786
"SliceCopyNonZeroDim_Module_basic",
@@ -4714,6 +4717,7 @@
47144717
"SliceCopy_Module_basic",
47154718
"SliceEndSleStartModule_basic",
47164719
"SliceModule_basic",
4720+
"SliceStaticComplexInputModule_basic",
47174721
"SliceNegIdxModule_basic",
47184722
"SliceOutOfLowerBoundEndIndexModule_basic",
47194723
"SliceOutOfLowerBoundStartIndexModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils):
5858
# ==============================================================================
5959

6060

61+
class SliceStaticComplexInputModule(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
@export
66+
@annotate_args(
67+
[
68+
None,
69+
([6, 4, 7], torch.complex64, True),
70+
]
71+
)
72+
def forward(self, x):
73+
return x[0:5:1, 1:3:1, 2:4:1]
74+
75+
76+
@register_test_case(module_factory=lambda: SliceStaticComplexInputModule())
77+
def SliceStaticComplexInputModule_basic(module, tu: TestUtils):
78+
module.forward(tu.rand(6, 4, 7).to(torch.complex64))
79+
80+
81+
# ==============================================================================
82+
83+
6184
class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
6285
def __init__(self):
6386
super().__init__()

0 commit comments

Comments
 (0)