Open
Description
Hello, I encountered this problem when converting GRU
Traceback (most recent call last):
File "/home/dgx/model_hub/test.py", line 25, in <module>
module = torchscript.compile(
^^^^^^^^^^^^^^^^^^^^
File "/root/anaconda3/envs/tm/lib/python3.11/site-packages/torch_mlir/torchscript.py", line 401, in compile
run_pipeline_with_repro_report(
File "/root/anaconda3/envs/tm/lib/python3.11/site-packages/torch_mlir/compiler_utils.py", line 78, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
note: see current operation: %21:2 = "torch.operator"(%9, %19, %20, %5, %0, %6, %4, %4, %5) <{name = "aten.gru.input"}> : (!torch.tensor<[1,10,32],f32>, !torch.tensor<[1,1,32],f32>, !torch.list<tensor>, !torch.bool, !torch.int, !torch.float, !torch.bool, !torch.bool, !torch.bool) -> (!torch.tensor<[1,10,32],f32>, !torch.tensor<[1,1,32],f32>)
python exception: Failure while executing pass pipeline
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.adaptive_avg_pool1d,aten.adaptive_avg_pool2d,aten.unflatten.int extra-library=})' /tmp/MyModel.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Following is the code I used
import torch
import torch.nn as nn
from torch_mlir import torchscript
class MyModel(nn.Module):
def __init__(self, embedding_dim):
super(MyModel, self).__init__()
self.gru_based_layer = nn.GRU(
embedding_dim * 2, embedding_dim * 2, batch_first=True
)
def forward(self, data):
output_based_gru, _ = self.gru_based_layer(data)
return output_based_gru
model = MyModel(16)
model.eval()
data = torch.randn(1, 10, 32)
res = model(data)
print(res)
module = torchscript.compile(
model,
data,
output_type="linalg-on-tensors",
use_tracing=True,
)
I saw #3447 mentioned support for GRU, but I encountered some problems.The torch-mlir version I use is 20240720.158
Metadata
Metadata
Assignees
Labels
No labels