Skip to content

[TorchToLinalg] Conversion fails to legalize 'torch.constant.int' / No conversion of tm_tensor.attention ops #4302

@DavidGinten

Description

@DavidGinten

Hello,

I run a sentence-transformer model through the Torch mlir to Linalg mlir conversion pipeline, but I cannot convert the tm_tensor.attention ops. I tried to decompose the attention ops with PyTorch's decomposition as suggested here #4279 and with that no attention ops appear anymore. However, setting the output type of fx.export_and_import to LINALG_ON_TENSORS yields this error: failed to legalize operation 'torch.constant.int'. More specifically, when first converting to TORCH and then running the following passes:
torch-mlir-opt -canonicalize -torch-backend-to-linalg-on-tensors-backend-pipeline \ model_torch.mlir > model_linalg.mlir

I get

error: failed to legalize operation 'torch.constant.int' <unknown>:0: note: see current operation: %120 = "torch.constant.int"() <{value = -1 : i64}> : () -> !torch.int (I tried different passes here)

How I export the model:

sentences = ["This is an example sentence", "Each sentence is converted"]
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

wrapped_model = Wrapper(model)

ep = torch.export.export(
    wrapped_model,
    (
        encoded_input["input_ids"],
        encoded_input["attention_mask"],
    )
)

ep = ep.run_decompositions()
m = fx.export_and_import(ep, output_type=OutputType.LINALG_ON_TENSORS, 
                            func_name = "transformer_model")

mlir_str = str(m)
with open("model_linalg.mlir", "w") as f:
    f.write(mlir_str)

I also tried #3461 (Then I don't need the PyTorch decomposition), but it doesn't get rid of the tm_tensor.attention ops.

Thanks for any help/advise in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions