-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
Problem :
I am trying to compile a dummy example of a model whose computation graph depends on the input, but it fails with the error "error: failed to legalize operation 'torch.operator' that was explicitly marked illegal"
Steps to reproduce :
Run the following code :
import torch
import torch.nn as nn
import copy
class CondNetwork(nn.Module):
def __init__(self):
super(CondNetwork, self).__init__()
self.confidence_threshold = 2
def true_fn(self):
return torch.rand(10, dtype=torch.float32)
def false_fn(self):
return torch.rand(10, dtype=torch.float32)
def forward(self, x):
with torch.no_grad():
condition = x.sum() > self.confidence_threshold
return torch.cond(condition, self.true_fn, self.false_fn)
def model_export(model, device):
model.eval()
cond_model = copy.deepcopy(model).to(device)
x = torch.randn(1, 3, 32, 32).to(device)
with torch.no_grad():
cond_model.eval()
torch.onnx.export(cond_model, x, './conditional_model.onnx', verbose=True, dynamo=True, report= True)
###-- Main
def main():
model = CondNetwork()
model.cuda()
model_export(model, device="cuda")
if __name__ == '__main__':
main()After running this code run :
iree-import-onnx conditional_model.onnx -o conditional_model.mlir
Then, run :
iree-compile conditional_model.mlir -o conditional_model.vmfb
which results in the following error :
conditional_model.mlir:12:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
%5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32>
^
conditional_model.mlir:12:12: note: see current operation: %16 = "torch.operator"() <{name = "onnx.true_graph_0"}> : () -> !torch.vtensor<[10],f32>
Is there something wrong with my implementation or is the operation simply not supported ?
Additional informations :
Versions of the packages :
torch : 2.6.0
iree-turbine : 3.2.0
Associated IR (file conditional_model.mlir) :
module {
func.func @main_graph(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.opset_versions = {pkg.onnxscript.torch_lib.common = 1 : si64, pkg.torch.__subgraph__ = 1 : si64}, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.6.0+cu124"} {
%none = torch.constant.none
%0 = torch.operator "onnx.ReduceSum"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.noop_with_empty_axes = 0 : si64} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[],f32>
%1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<si64>} : () -> !torch.vtensor<[],si64>
%2 = torch.operator "onnx.Cast"(%1) {torch.onnx.to = 1 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],f32>
%3 = torch.operator "onnx.Greater"(%0, %2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[],i1>
%4 = torch.operator "onnx.If"(%3) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[10],f32> {
%5 = torch.operator "onnx.false_graph_0"() : () -> !torch.vtensor<[10],f32>
torch.operator_terminator %5 : !torch.vtensor<[10],f32>
}, {
%5 = torch.operator "onnx.true_graph_0"() : () -> !torch.vtensor<[10],f32>
torch.operator_terminator %5 : !torch.vtensor<[10],f32>
}
return %4 : !torch.vtensor<[10],f32>
}
}
{-#
dialect_resources: {
builtin: {
_: "0x080000000200000000000000"
}
}
#-}Metadata
Metadata
Assignees
Labels
No labels