Skip to content

Conversation

@IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Oct 4, 2024

This was preventing dynamic dims in an ONNX model from being reified (causing the generation of tensor.casts and preventing fusion in iree):

%2 = torch.vtensor.literal(dense<[4, 256]> : tensor<2xsi64>) : !torch.vtensor<[2],si64>]
%7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
%8 = torch.aten.reshape %2, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
//... chain of foldable ops linking %2 to the `shape` operand of a `torch.aten.broadcast_to ... -> !torch.vtensor<[?,?],si64>`

@IanWood1 IanWood1 marked this pull request as ready for review October 7, 2024 15:49
@IanWood1 IanWood1 requested review from renxida and zjgarvey October 10, 2024 17:03
@zjgarvey
Copy link
Collaborator

zjgarvey commented Oct 10, 2024

Hey, @IanWood1 , I'm certainly not against this, but I'm not sure why this is causing issues.

func.func @reshape(%arg0 : !torch.vtensor<[2],si64>) -> !torch.vtensor<[2],si64> {
    %int2 = torch.constant.int 2
    %7 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
    %8 = torch.aten.reshape %arg0, %7 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
    return %8 : !torch.vtensor<[2],si64>
}

becomes

module {
  func.func @reshape(%arg0: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2],si64> {
    return %arg0 : !torch.vtensor<[2],si64>
  }
}

after applying the pass --torch-decompose-complex-ops.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my earlier comment. I'm happy to add this, since we should fold this kind of pattern before needing to decompose. I just want you to be sure that this was the problematic pattern, since I don't expect this to be the case.

@IanWood1
Copy link
Contributor Author

IanWood1 commented Oct 10, 2024

The problem arises when the op isn't folded before DropAbstractInterpCalculations because folding it is necessary to get the shape info. e.g. using --torch-lower-to-backend-contract --torch-backend-to-linalg-on-tensors-backend-pipeline

module {
  func.func @reshape(%arg0: !torch.vtensor<[2],si64>, %arg1: !torch.vtensor<[10,10],f16>) -> !torch.vtensor<[?,?],f16> {
    %int2 = torch.constant.int 2
    %int0 = torch.constant.int 0
    %0 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[10,10],f16> -> !torch.vtensor<[2],si64>
    %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
    %2 = torch.aten.reshape %0, %1 : !torch.vtensor<[2],si64>, !torch.list<int> -> !torch.vtensor<[2],si64>
    %3 = torch.aten.select.int %2, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.prim.ListConstruct %4, %4 : (!torch.int, !torch.int) -> !torch.list<int>
    %6 = torch.aten.broadcast_to %arg1, %5 : !torch.vtensor<[10,10],f16>, !torch.list<int> -> !torch.vtensor<[?,?],f16>
    return %6 : !torch.vtensor<[?,?],f16>
  }
}

With reshape folding, the entire body of the op gets folded vs:

module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @reshape(%arg0: tensor<2xi64>, %arg1: tensor<10x10xf16>) -> tensor<?x?xf16> {
    %cast = tensor.cast %arg1 : tensor<10x10xf16> to tensor<?x?xf16>
    return %cast : tensor<?x?xf16>
  }
}

Which Isn't problematic in this example, but can lead to dynamic shapes when they can be static. Maybe it makes sense to reify shapes after decomposing too (if its possible)

@IanWood1 IanWood1 merged commit 8787970 into llvm:main Oct 11, 2024
3 checks passed
@IanWood1 IanWood1 deleted the fold_reshape branch October 11, 2024 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants