Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
}
}

if (isa<AtenAllOp>(op)) {
if (isa<AtenAllOp, AtenAllDimOp>(op)) {
auto constAttr =
DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)});
return rewriter.create<stablehlo::ConstantOp>(op->getLoc(), constType,
Expand Down Expand Up @@ -166,7 +166,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
AtenLinalgVectorNormOp>(op)) {
result = rewriter.create<stablehlo::AddOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAllOp>(op)) {
} else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
result = rewriter.create<stablehlo::AndOp>(
op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument);
} else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
Expand Down Expand Up @@ -887,6 +887,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
patterns.add<ConvertAtenReduceOneDimOp<AtenOp>>(typeConverter, context, \
options)
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp);
INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp);
#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN

#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \
Expand Down
15 changes: 0 additions & 15 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,10 +815,6 @@
"RandnLikeDtypeModule_basic",
"RandnLikeModule_basic",
"RandnModule_basic",
"ReduceAllDimBool_basic",
"ReduceAllDimEmpty_basic",
"ReduceAllDimFloat_basic",
"ReduceAllDimInt_basic",
"ReduceProdDimIntFloatModule_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand All @@ -836,18 +832,7 @@
"ReplicationPad2dModule_top0",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarImplicitFloatModule_basic",
# need aten.all.dim lowering to stablehlo
"SafeSoftmaxModule_basic",
"SafeSoftmaxNonNoneDtypeModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,29 @@ def ReduceAllBoolModule_basic(module, tu: TestUtils):
# ==============================================================================


class ReduceAllDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.bool, True),
]
)
def forward(self, a):
return torch.ops.aten.all(a, 2)


@register_test_case(module_factory=lambda: ReduceAllDimModule())
def ReduceAllDimModule_basic(module, tu: TestUtils):
module.forward(tu.randint(16, 50, 256, high=2).to(torch.bool))


# ==============================================================================


class ReduceAnyFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading