Skip to content

torch.cond operator not supported on a simple example (error: failed to legalize operation 'torch.operator' that was explicitly marked illegal) #937

@JibAxelera

Description

@JibAxelera

Problem :

I am trying to compile a dummy example of a model whose computation graph depends on the input, but it fails with the error "error: failed to legalize operation 'torch.operator' that was explicitly marked illegal"

Steps to reproduce :

Run the following code :

import torch
import torch.nn as nn
import copy

class CondNetwork(nn.Module):

    def __init__(self):
        super(CondNetwork, self).__init__()

        self.confidence_threshold = 2

    def true_fn(self):
        return torch.rand(10, dtype=torch.float32)

    def false_fn(self):
        return torch.rand(10, dtype=torch.float32)

    def forward(self, x):

        with torch.no_grad():

            condition = x.sum() > self.confidence_threshold

            return torch.cond(condition, self.true_fn, self.false_fn)

def model_export(model, device):

    model.eval()

    cond_model = copy.deepcopy(model).to(device)

    x = torch.randn(1, 3, 32, 32).to(device)

    with torch.no_grad():
        cond_model.eval()
        torch.onnx.export(cond_model, x, './conditional_model.onnx', verbose=True, dynamo=True, report= True)

###-- Main
def main():

    model = CondNetwork()

    model.cuda()

    model_export(model, device="cuda")

if __name__ == '__main__':
    main()

After running this code run :

iree-import-onnx conditional_model.onnx -o conditional_model.mlir

Then, run :

iree-compile conditional_model.mlir -o conditional_model.vmfb

which results in the following error :

conditional_model.mlir:12:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
      %5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32> 
           ^
conditional_model.mlir:12:12: note: see current operation: %16 = "torch.operator"() <{name = "onnx.true_graph_0"}> : () -> !torch.vtensor<[10],f32>

Is there something wrong with my implementation or is the operation simply not supported ?

Additional informations :

Versions of the packages :
torch : 2.6.0
iree-turbine : 3.2.0

Associated IR (file conditional_model.mlir) :

module {
  func.func @main_graph(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.opset_versions = {pkg.onnxscript.torch_lib.common = 1 : si64, pkg.torch.__subgraph__ = 1 : si64}, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.6.0+cu124"} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.ReduceSum"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.noop_with_empty_axes = 0 : si64} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[],f32> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %2 = torch.operator "onnx.Cast"(%1) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],f32> 
    %3 = torch.operator "onnx.Greater"(%0, %2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[],i1> 
    %4 = torch.operator "onnx.If"(%3) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[10],f32> {
      %5 = torch.operator "onnx.false_graph_0"() : () -> !torch.vtensor<[10],f32> 
      torch.operator_terminator %5 : !torch.vtensor<[10],f32>
    }, {
      %5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32> 
      torch.operator_terminator %5 : !torch.vtensor<[10],f32>
    }
    return %4 : !torch.vtensor<[10],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _: "0x080000000200000000000000"
    }
  }
#-}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions