Skip to content

Commit 5f74de5

Browse files
authored
[Stablehlo] support aten.all.dim (#3746)
1 parent eb4e59e commit 5f74de5

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

lib/Conversion/TorchToStablehlo/Reduction.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
110110
}
111111
}
112112

113-
if (isa<AtenAllOp>(op)) {
113+
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
114114
auto constAttr =
115115
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
116116
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
@@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
166166
AtenLinalgVectorNormOp>(op)) {
167167
result = rewriter.create<stablehlo::AddOp>(
168168
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
169-
} else if (isa<AtenAllOp>(op)) {
169+
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
170170
result = rewriter.create<stablehlo::AndOp>(
171171
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
172172
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
@@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
887887
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
888888
options)
889889
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
890+
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
890891
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN
891892

892893
#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -815,10 +815,6 @@
815815
"RandnLikeDtypeModule_basic",
816816
"RandnLikeModule_basic",
817817
"RandnModule_basic",
818-
"ReduceAllDimBool_basic",
819-
"ReduceAllDimEmpty_basic",
820-
"ReduceAllDimFloat_basic",
821-
"ReduceAllDimInt_basic",
822818
"ReduceProdDimIntFloatModule_basic",
823819
"ReflectionPad1dModule2dInput_Right",
824820
"ReflectionPad1dModule2dInput_basic",
@@ -836,18 +832,7 @@
836832
"ReplicationPad2dModule_top0",
837833
"RsubInt0d_NumToTensor_Module_basic",
838834
"ScalarImplicitFloatModule_basic",
839-
# need aten.all.dim lowering to stablehlo
840-
"SafeSoftmaxModule_basic",
841-
"SafeSoftmaxNonNoneDtypeModule_basic",
842835
# REMOVE WHEN ENABLE_GQA IS ADDED
843-
"ScaledDotProductAttentionBoolMaskModule_basic",
844-
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
845-
"ScaledDotProductAttentionDifferentCausalModule_basic",
846-
"ScaledDotProductAttentionDifferentModule_basic",
847-
"ScaledDotProductAttentionMaskModule_basic",
848-
"ScaledDotProductAttentionSameCausalModule_basic",
849-
"ScaledDotProductAttentionSameDynamicModule_basic",
850-
"ScaledDotProductAttentionSameModule_basic",
851836
"ScatterReduceFloatMaxModule",
852837
"ScatterReduceFloatMaxModuleIncludeSelf",
853838
"ScatterReduceFloatMeanModule",

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils):
170170
module.forward(tu.rand(3, 4, 5))
171171

172172

173+
class ReduceAllDimFloatModule(torch.nn.Module):
174+
def __init__(self):
175+
super().__init__()
176+
177+
@export
178+
@annotate_args(
179+
[
180+
None,
181+
([-1, -1, -1], torch.float32, True),
182+
]
183+
)
184+
def forward(self, a):
185+
return torch.ops.aten.all(a, dim=0)
186+
187+
188+
@register_test_case(module_factory=lambda: ReduceAllDimFloatModule())
189+
def ReduceAllDimFloatModule_basic(module, tu: TestUtils):
190+
module.forward(tu.rand(3, 4, 5))
191+
192+
173193
# ==============================================================================
174194

175195

0 commit comments

Comments
 (0)