Skip to content

[Question] How to convert GraphModule to linalg-on-tensors in custom backend (torch.compile + torch_mlir) ? #4300

@DaysChan

Description

@DaysChan

Background

I'm trying to write a custom backend for torch.compile, where I want to use torch_mlir to convert a GraphModule to linalg-on-tensors IR. My backend function is named mybackend. However, I'm running into FakeTensor-related issues when calling torch_mlir.export_and_import (or torch_mlir.export), because the example_inputs received by the backend are FakeTensors (not real tensors).

My workflow looks like this:

def mybackend(gm, example_inputs):
    import torch_mlir
    # example_inputs are FakeTensors here
    mlir_module = torch_mlir.export_and_import(gm, *example_inputs, output_type="linalg-on-tensors")
    # ...
    return compiled_fn

When I run this with torch.compile(model, backend=mybackend), I get errors like:

fake mode(<torch.subclasses.fake_tensor.FakeTensorMode object at 0xffffdea906190>) from tracing context 0 doesn't match mode (<torch.subclasses.fake_tensor.FakeTensorMode object at 0xfffde1a25970) from fake tensor input 0

My questions:

  1. What is the recommended way to use torch_mlir inside a custom backend for torch.compile, given that example_inputs are FakeTensors?
  2. How can I convert FakeTensors to real tensors (with the correct shape/dtype) for use with torch_mlir?
  3. If I create real tensors from FakeTensors using zeros_like or randn_like, will that affect the correctness of the exported MLIR or model execution?
  4. Besides export_and_import, are there other officially recommended or working ways to convert a GraphModule to linalg-on-tensors (especially in the context of a custom backend)?
  5. Is there a better integration pattern for torch.compile + torch_mlir under these circumstances?

Minimal reproducible example

import torch

class M(torch.nn.Module):
    def forward(self, x):
        return x.view(3, 2)

model = M()
x = torch.randn(2, 3)
args = [x]

def mybackend(gm, example_inputs):
    import torch_mlir
    # At this point, example_inputs are FakeTensors
    print(type(example_inputs[0]))  # <class 'torch._subclasses.FakeTensor'>
    # How to safely use torch_mlir.export_and_import here?
    mlir_module = torch_mlir.export_and_import(gm, *example_inputs, output_type="linalg-on-tensors")
    # ...
    return lambda *a: torch.zeros((3, 2))  # dummy for illustration

model_compiled = torch.compile(model, backend=mybackend)
with torch.no_grad():
    output = model_compiled(*args)
print(output)

Environment

  • torch: 2.6.0
  • torch-mlir: 7e7af67 (Date: Sat Aug 3 03:27:31 2024)
  • Python: 3.9

Additional context

If there's an official or recommended pattern for using torch_mlir in a custom torch.compile backend (handling FakeTensor inputs), please let me know. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions