-
Notifications
You must be signed in to change notification settings - Fork 639
Description
Hello,
I run a sentence-transformer model through the Torch mlir to Linalg mlir conversion pipeline, but I cannot convert the tm_tensor.attention ops. I tried to decompose the attention ops with PyTorch's decomposition as suggested here #4279 and with that no attention ops appear anymore. However, setting the output type of fx.export_and_import to LINALG_ON_TENSORS yields this error: failed to legalize operation 'torch.constant.int'. More specifically, when first converting to TORCH and then running the following passes:
torch-mlir-opt -canonicalize -torch-backend-to-linalg-on-tensors-backend-pipeline \ model_torch.mlir > model_linalg.mlir
I get
error: failed to legalize operation 'torch.constant.int' <unknown>:0: note: see current operation: %120 = "torch.constant.int"() <{value = -1 : i64}> : () -> !torch.int
(I tried different passes here)
How I export the model:
sentences = ["This is an example sentence", "Each sentence is converted"]
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids, attention_mask=attention_mask)
return outputs.last_hidden_state
wrapped_model = Wrapper(model)
ep = torch.export.export(
wrapped_model,
(
encoded_input["input_ids"],
encoded_input["attention_mask"],
)
)
ep = ep.run_decompositions()
m = fx.export_and_import(ep, output_type=OutputType.LINALG_ON_TENSORS,
func_name = "transformer_model")
mlir_str = str(m)
with open("model_linalg.mlir", "w") as f:
f.write(mlir_str)
I also tried #3461 (Then I don't need the PyTorch decomposition), but it doesn't get rid of the tm_tensor.attention ops.
Thanks for any help/advise in advance.