diff --git a/tripy/tests/integration/test_plugin.py b/tripy/tests/integration/test_plugin.py index 43c9c1934..116313b71 100644 --- a/tripy/tests/integration/test_plugin.py +++ b/tripy/tests/integration/test_plugin.py @@ -22,8 +22,7 @@ class TestPlugin: def test_gelu(self): - # TODO: We add `+ 1` as a hack to work around MLIR-TRT Issue #915. We should be able to remove it once fixed - inp = tp.iota((2, 2)) + 1 + inp = tp.iota((2, 2)) out = tp.plugin( "CustomGeluPluginDynamic", [inp], diff --git a/tripy/tripy/frontend/trace/ops/plugin.py b/tripy/tripy/frontend/trace/ops/plugin.py index 4651dc070..92683e73f 100644 --- a/tripy/tripy/frontend/trace/ops/plugin.py +++ b/tripy/tripy/frontend/trace/ops/plugin.py @@ -75,8 +75,7 @@ def plugin( :linenos: :caption: Example - # TODO: We add `+ 1` as a hack to work around MLIR-TRT Issue #915. We should be able to remove it once fixed # doc: omit - inp = tp.iota((2, 1, 4)) + 1 + inp = tp.iota((2, 1, 4)) out = tp.plugin( "CustomGeluPluginDynamic", [inp],