From 498eb8aba50c54dab0a9b2121e724ba069393cae Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Wed, 11 Dec 2024 09:36:27 -0800 Subject: [PATCH] Remove plugin WAR due to mlir-tensorrt issue #915 (#438) --- tripy/tests/integration/test_plugin.py | 3 +-- tripy/tripy/frontend/trace/ops/plugin.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tripy/tests/integration/test_plugin.py b/tripy/tests/integration/test_plugin.py index 43c9c1934..116313b71 100644 --- a/tripy/tests/integration/test_plugin.py +++ b/tripy/tests/integration/test_plugin.py @@ -22,8 +22,7 @@ class TestPlugin: def test_gelu(self): - # TODO: We add `+ 1` as a hack to work around MLIR-TRT Issue #915. We should be able to remove it once fixed - inp = tp.iota((2, 2)) + 1 + inp = tp.iota((2, 2)) out = tp.plugin( "CustomGeluPluginDynamic", [inp], diff --git a/tripy/tripy/frontend/trace/ops/plugin.py b/tripy/tripy/frontend/trace/ops/plugin.py index 4651dc070..92683e73f 100644 --- a/tripy/tripy/frontend/trace/ops/plugin.py +++ b/tripy/tripy/frontend/trace/ops/plugin.py @@ -75,8 +75,7 @@ def plugin( :linenos: :caption: Example - # TODO: We add `+ 1` as a hack to work around MLIR-TRT Issue #915. We should be able to remove it once fixed # doc: omit - inp = tp.iota((2, 1, 4)) + 1 + inp = tp.iota((2, 1, 4)) out = tp.plugin( "CustomGeluPluginDynamic", [inp],