Skip to content

Maxunpool does not support overlapping Maxpool inputs #4308

@ziliangzl

Description

@ziliangzl

Steps to reproduce the issue

modify test MaxUnpool2dModule_basic above, set stride to 1 ,which is overlapping maxpool

class MaxUnpool2dModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args(
        [
            None,
            ([2, 2, 3, 7], torch.float32, True),
            ([2, 2, 3, 7], torch.int64, True),
        ]
    )
    def forward(self, x, indices):
        return torch.ops.aten.max_unpool2d(x, indices, (4, 8))


@register_test_case(module_factory=lambda: MaxUnpool2dModule())
def MaxUnpool2dModule_basic(module, tu: TestUtils):
    input = tu.rand(2, 2, 4, 8)
    pool = torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), return_indices=True)
    output, indices = pool(input)

    module.forward(output, indices)


class MaxUnpool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 2, 2, 4], torch.float32, True),
([2, 2, 2, 4], torch.int64, True),
]
)
def forward(self, x, indices):
return torch.ops.aten.max_unpool2d(x, indices, (4, 8))
@register_test_case(module_factory=lambda: MaxUnpool2dModule())
def MaxUnpool2dModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 8)
pool = torch.nn.MaxPool2d(kernel_size=(2, 2), return_indices=True)
output, indices = pool(input)
module.forward(output, indices)

run this test:

python -m e2e_testing.main -f 'MaxUnpool2dModule_basic' -v

output:

====================
LINALG Backend IR
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module attributes {torch.debug_module_name = "MaxUnpool2dModule"} {
  func.func @forward(%arg0: tensor<2x2x3x7xf32>, %arg1: tensor<2x2x3x7xi64>) -> tensor<2x2x4x8xf32> {
    %c8 = arith.constant 8 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c-1_i64 = arith.constant -1 : i64
    %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 0, -1, -3] {
    ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
      tensor.yield %cst : f32
    } : tensor<2x2x3x7xf32> to tensor<2x2x2x4xf32>
    %padded_0 = tensor.pad %arg1 low[0, 0, 0, 0] high[0, 0, -1, -3] {
    ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index):
      tensor.yield %c-1_i64 : i64
    } : tensor<2x2x3x7xi64> to tensor<2x2x2x4xi64>
    %0 = tensor.empty() : tensor<2x2x4x8xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%padded, %padded_0 : tensor<2x2x2x4xf32>, tensor<2x2x2x4xi64>) outs(%0 : tensor<2x2x4x8xf32>) {
    ^bb0(%in: f32, %in_1: i64, %out: f32):
      %2 = arith.index_cast %in_1 : i64 to index
      %3 = linalg.index 2 : index
      %4 = linalg.index 3 : index
      %5 = arith.muli %3, %c8 : index
      %6 = arith.addi %5, %4 : index
      %7 = arith.cmpi eq, %2, %6 : index
      %8 = arith.select %7, %in, %cst : f32
      linalg.yield %8 : f32
    } -> tensor<2x2x4x8xf32>
    return %1 : tensor<2x2x4x8xf32>
  }
}

TORCH_VERSION_FOR_COMPARISON = 2.9.0.dev20250820
FAIL - "MaxUnpool2dModule_basic"

Unexpected outcome summary: (linalg)

****** Failed tests - 1 tests
    FAIL - "MaxUnpool2dModule_basic"
        Compilation error: Traceback (most recent call last):
          File "/home/zzl/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 332, in compile_and_run_test
            compiled = config.compile(test.program_factory(), verbose=verbose)
          File "/home/zzl/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/configs/jit_importer_backend.py", line 70, in compile
            return self.backend.compile(module)
          File "/home/zzl/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 229, in compile
            run_pipeline_with_repro_report(
          File "/home/zzl/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 127, in run_pipeline_with_repro_report
            raise TorchMlirCompilerError(trimmed_message) from None
        torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering Linalg-on-Tensors IR to LLVM with RefBackend failed with the following diagnostics:
        error: slice along dimension 2 runs out-of-bounds: 2 >= 2
        note: see current operation: %4 = "memref.subview"(%3) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 2, 2, 3, 7>, static_strides = array<i64: 1, 1, 1, 1>}> : (memref<2x2x2x4xf32>) -> memref<2x2x3x7xf32, strided<[16, 8, 4, 1]>>


        python exception: Failure while executing pass pipeline

        For Torch-MLIR developers, the error can be reproduced with:
        $ torch-mlir-opt -pass-pipeline='builtin.module(func.func(linalg-generalize-named-ops),func.func(linalg-fuse-elementwise-ops),convert-shape-to-std,sparse-assembler{direct-out},sparsification-and-bufferization,sparse-storage-specifier-to-llvm,func.func(expand-realloc),func.func(refback-generalize-tensor-pad),func.func(refback-generalize-tensor-concat),func.func(tm-tensor-bufferize),one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map},refback-mlprogram-bufferize,func.func(buffer-deallocation-pipeline),inline,refback-munge-calling-conventions,func.func(tm-tensor-to-loops),func.func(refback-munge-memref-copy),func.func(convert-linalg-to-loops),func.func(lower-affine),convert-scf-to-cf,generate-runtime-verification,func.func(refback-expand-ops-for-llvm),func.func(arith-expand),func.func(convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,lower-affine,convert-bufferization-to-memref,finalize-memref-to-llvm,func.func(convert-arith-to-llvm),convert-vector-to-llvm,convert-func-to-llvm,convert-cf-to-llvm,convert-complex-to-llvm,reconcile-unrealized-casts)' /tmp/MaxUnpool2dModule.mlir
        Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.



Summary:
    Failed: 1

Reason

The current implementation doesn't handle overlapping scenarios - it can't even detect or throw errors for such cases.
see comments above:

// Max unpooling operation, takes result of max_pooling op and indices and
// tries to reconstructs original pooling input by filling out values by either
// values from self or zero.
// Upstream CPU implementation use parallel loop over the indices array to fill
// out tensor but such approach requires random access writes, which is tricky
// to represent in linalg.
// Instead we are using a different method: we are mapping each input/index
// value to multiple output values via affine maps in linalg.generic, then,
// inside the body of generic, we compute out index and compare it with expected
// index we got from input, returning either input or zero.
// This method only works if we have non-overlapping pooling windows.
// In case of overlap (e.g. kernel_size=2, stride=1) we need to map many-to-many
// input to output values and do a reduction. To construct such mapping we need
// to know original Kernel size, but it doesn't encoded in aten op. We cannot
// reconstruct kernel_size either as such reconstruction is ambiguous (e.g. for
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
// What worse, without knowing kernel size we cannot even reliably detect such
// cases and this conversion will just return invalid values.
class ConvertAtenMaxUnpool3dOp final
: public OpConversionPattern<AtenMaxUnpool3dOp> {

I'm open to contributing if others here determines that overlapping support is essential. That said, implementing this would introduce a performance overhead, since each output element would need to scan through all input elements. Given this trade-off, I’d appreciate hearing others' perspectives on whether the functionality justifies the potential performance impact.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions