-
Notifications
You must be signed in to change notification settings - Fork 637
Open
Description
Background
I'm trying to write a custom backend for torch.compile
, where I want to use torch_mlir
to convert a GraphModule
to linalg-on-tensors
IR. My backend function is named mybackend
. However, I'm running into FakeTensor-related issues when calling torch_mlir.export_and_import
(or torch_mlir.export
), because the example_inputs
received by the backend are FakeTensors (not real tensors).
My workflow looks like this:
def mybackend(gm, example_inputs):
import torch_mlir
# example_inputs are FakeTensors here
mlir_module = torch_mlir.export_and_import(gm, *example_inputs, output_type="linalg-on-tensors")
# ...
return compiled_fn
When I run this with torch.compile(model, backend=mybackend)
, I get errors like:
fake mode(<torch.subclasses.fake_tensor.FakeTensorMode object at 0xffffdea906190>) from tracing context 0 doesn't match mode (<torch.subclasses.fake_tensor.FakeTensorMode object at 0xfffde1a25970) from fake tensor input 0
My questions:
- What is the recommended way to use torch_mlir inside a custom backend for torch.compile, given that example_inputs are FakeTensors?
- How can I convert FakeTensors to real tensors (with the correct shape/dtype) for use with torch_mlir?
- If I create real tensors from FakeTensors using
zeros_like
orrandn_like
, will that affect the correctness of the exported MLIR or model execution? - Besides
export_and_import
, are there other officially recommended or working ways to convert a GraphModule to linalg-on-tensors (especially in the context of a custom backend)? - Is there a better integration pattern for torch.compile + torch_mlir under these circumstances?
Minimal reproducible example
import torch
class M(torch.nn.Module):
def forward(self, x):
return x.view(3, 2)
model = M()
x = torch.randn(2, 3)
args = [x]
def mybackend(gm, example_inputs):
import torch_mlir
# At this point, example_inputs are FakeTensors
print(type(example_inputs[0])) # <class 'torch._subclasses.FakeTensor'>
# How to safely use torch_mlir.export_and_import here?
mlir_module = torch_mlir.export_and_import(gm, *example_inputs, output_type="linalg-on-tensors")
# ...
return lambda *a: torch.zeros((3, 2)) # dummy for illustration
model_compiled = torch.compile(model, backend=mybackend)
with torch.no_grad():
output = model_compiled(*args)
print(output)
Environment
- torch: 2.6.0
- torch-mlir: 7e7af67 (Date: Sat Aug 3 03:27:31 2024)
- Python: 3.9
Additional context
If there's an official or recommended pattern for using torch_mlir in a custom torch.compile backend (handling FakeTensor inputs), please let me know. Thanks!
Metadata
Metadata
Assignees
Labels
No labels