From b3caf59894f04f3ae31ca9fc3f3d2804ed2d6ea5 Mon Sep 17 00:00:00 2001 From: Mehrdad Malekmohammadi Date: Mon, 17 Nov 2025 17:12:27 -0500 Subject: [PATCH] Remove references to the xdsl pass plugin in existing tests and docs --- .../doc/unified_compiler_cookbook.rst | 18 ++++---- .../python_compiler/visualization/draw.py | 3 +- tests/python_compiler/conftest.py | 2 +- .../mbqc/test_xdsl_outline_state_evolution.py | 43 ++++-------------- .../quantum/test_xdsl_cancel_inverses.py | 6 +-- .../test_xdsl_combine_global_phases.py | 4 +- .../test_xdsl_diagonalize_measurements.py | 16 ++----- .../test_xdsl_measurements_from_samples.py | 44 +++++-------------- .../quantum/test_xdsl_merge_rotations.py | 4 +- .../quantum/test_xdsl_split_non_commuting.py | 16 ++----- .../test_draw_python_compiler.py | 26 ++++------- .../visualization/test_mlir_graph.py | 13 +++--- 12 files changed, 52 insertions(+), 143 deletions(-) diff --git a/pennylane/compiler/python_compiler/doc/unified_compiler_cookbook.rst b/pennylane/compiler/python_compiler/doc/unified_compiler_cookbook.rst index b2f88b6060c..711ab87b012 100644 --- a/pennylane/compiler/python_compiler/doc/unified_compiler_cookbook.rst +++ b/pennylane/compiler/python_compiler/doc/unified_compiler_cookbook.rst @@ -664,10 +664,12 @@ currently rely on JAX’s API to lower to MLIR. This has the special effect of lowering to a specific dialect called StableHLO, which is used to represent all arithmetic operations present in the program. -Once lowered to MLIR, if the original ``qjit`` decorator specified the -xDSL pass plugin, we pass control over to the xDSL layer, which applies -all transforms that were requested by the user. We can request the use -of the xDSL plugin like so: +Once lowered to MLIR, if any xDSL registered passes are detected, we pass the control over to +the xDSL layer, which automatically detects and applies all xDSL transforms that were requested +by the user. + +However, if you want to manually trigger the xDSL layer without using any xDSL registered passes, +you can do so by specifying the ``pass_plugins`` parameter: .. code-block:: python @@ -1003,9 +1005,7 @@ currently accessible as qml.capture.enable() dev = qml.device("lightning.qubit", wires=1) - @qml.qjit( - pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()] - ) + @qml.qjit @my_pass @qml.qnode(dev) def circuit(x): @@ -1296,8 +1296,6 @@ will explain what is going on. .. code-block:: python - from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - def test_h_to_x_pass_integration(run_filecheck_qjit): """Test that Hadamard gets converted into PauliX.""" # The original program simply applies a Hadamard to a circuit @@ -1306,7 +1304,7 @@ will explain what is going on. # `compiler_transform`. To make sure that the xDSL API works # correctly, program capture must be enabled. # qml.capture.enable() - @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath]) + @qml.qjit @h_to_x_pass def circuit(): # CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit diff --git a/pennylane/compiler/python_compiler/visualization/draw.py b/pennylane/compiler/python_compiler/visualization/draw.py index ad3e16f0db9..00550785847 100644 --- a/pennylane/compiler/python_compiler/visualization/draw.py +++ b/pennylane/compiler/python_compiler/visualization/draw.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING from catalyst import qjit -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath from pennylane.tape import QuantumScript @@ -44,7 +43,7 @@ def _get_mlir_module(qnode: QNode, args, kwargs) -> ModuleOp: return qnode.mlir_module func = getattr(qnode, "user_function", qnode) - jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(func) + jitted_qnode = qjit(func) jitted_qnode.jit_compile(args, **kwargs) return jitted_qnode.mlir_module diff --git a/tests/python_compiler/conftest.py b/tests/python_compiler/conftest.py index 813e0956130..115fc3babc9 100644 --- a/tests/python_compiler/conftest.py +++ b/tests/python_compiler/conftest.py @@ -183,7 +183,7 @@ def test_qjit(self, run_filecheck_qjit): # Test that the merge_rotations_pass works as expected when used with `qjit` dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @merge_rotations_pass @qml.qnode(dev) def circuit(x: float, y: float): diff --git a/tests/python_compiler/transforms/mbqc/test_xdsl_outline_state_evolution.py b/tests/python_compiler/transforms/mbqc/test_xdsl_outline_state_evolution.py index 30ffca97f75..26ce78c4b2c 100644 --- a/tests/python_compiler/transforms/mbqc/test_xdsl_outline_state_evolution.py +++ b/tests/python_compiler/transforms/mbqc/test_xdsl_outline_state_evolution.py @@ -21,7 +21,6 @@ # pylint: disable=wrong-import-position from catalyst.ftqc import mbqc_pipeline -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( @@ -148,10 +147,7 @@ def test_multiple_func_w_qnode_attr(self, run_filecheck): def test_outline_state_evolution_no_error(self): """Test outline_state_evolution_pass does not raise error for circuit with classical operations only.""" - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @outline_state_evolution_pass def circuit(x, y): return x * y + 5 @@ -164,10 +160,7 @@ def test_outline_state_evolution_no_terminal_op_error(self): # TODOs: we can resolve this issue if the boundary op is inserted when the program is captured. dev = qml.device("null.qubit", wires=10) - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @outline_state_evolution_pass @qml.qnode(dev) def circuit(): @@ -183,10 +176,7 @@ def test_outline_state_evolution_pass_only(self, run_filecheck_qjit): """Test the outline_state_evolution_pass only.""" dev = qml.device("lightning.qubit", wires=1000) - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @outline_state_evolution_pass @qml.set_shots(1000) @qml.qnode(dev) @@ -221,11 +211,7 @@ def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_f """Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass on lightning.qubit.""" dev = qml.device("lightning.qubit", wires=1000) - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - pipelines=mbqc_pipeline(), - ) + @qml.qjit(target="mlir", pipelines=mbqc_pipeline()) @decompose_graph_state_pass @convert_to_mbqc_formalism_pass @outline_state_evolution_pass @@ -270,11 +256,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qji """Test if the outline_state_evolution_pass works with all mbqc transform pipeline on null.qubit.""" dev = qml.device("null.qubit", wires=1000) - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - pipelines=mbqc_pipeline(), - ) + @qml.qjit(target="mlir", pipelines=mbqc_pipeline()) @decompose_graph_state_pass @convert_to_mbqc_formalism_pass @measurements_from_samples_pass @@ -319,11 +301,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self): """Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc transform pipeline can be executed on null.qubit.""" dev = qml.device("null.qubit", wires=1000) - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - pipelines=mbqc_pipeline(), - ) + @qml.qjit(target="mlir", pipelines=mbqc_pipeline()) @decompose_graph_state_pass @convert_to_mbqc_formalism_pass @measurements_from_samples_pass @@ -362,10 +340,7 @@ def while_fn(i): i = i + 1 return i - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @outline_state_evolution_pass @qml.qnode(dev) def circuit(): @@ -376,9 +351,7 @@ def circuit(): res = circuit() - @qml.qjit( - target="mlir", - ) + @qml.qjit(target="mlir") @qml.qnode(dev) def circuit_ref(): for_fn() diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_cancel_inverses.py b/tests/python_compiler/transforms/quantum/test_xdsl_cancel_inverses.py index 09e195bcbc9..5cbc095ed0d 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_cancel_inverses.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_cancel_inverses.py @@ -21,8 +21,6 @@ pytest.importorskip("catalyst") # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( IterativeCancelInversesPass, @@ -197,7 +195,7 @@ def test_qjit(self, run_filecheck_qjit): """Test that the IterativeCancelInversesPass works correctly with qjit.""" dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @iterative_cancel_inverses_pass @qml.qnode(dev) def circuit(): @@ -215,7 +213,7 @@ def test_qjit_no_cancellation(self, run_filecheck_qjit): there are no operations that can be cancelled.""" dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @iterative_cancel_inverses_pass @qml.qnode(dev) def circuit(): diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_combine_global_phases.py b/tests/python_compiler/transforms/quantum/test_xdsl_combine_global_phases.py index f51d888c2f2..bcae6d1e0bc 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_combine_global_phases.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_combine_global_phases.py @@ -20,8 +20,6 @@ pytest.importorskip("catalyst") # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( CombineGlobalPhasesPass, @@ -226,7 +224,7 @@ def test_qjit(self, run_filecheck_qjit): """Test that the CombineGlobalPhasesPass works correctly with qjit.""" dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @combine_global_phases_pass @qml.qnode(dev) def circuit(x: float, y: float): diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_diagonalize_measurements.py b/tests/python_compiler/transforms/quantum/test_xdsl_diagonalize_measurements.py index 1185868f538..e82169543e0 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_diagonalize_measurements.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_diagonalize_measurements.py @@ -24,7 +24,6 @@ xdsl = pytest.importorskip("xdsl") catalyst = pytest.importorskip("catalyst") -from catalyst.passes import xdsl_plugin import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( @@ -345,7 +344,6 @@ def circuit_ref(phi): ), "Sanity check failed, is expected_res correct?" circuit_compiled = qml.qjit( diagonalize_final_measurements_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], ) assert np.allclose(expected_res(angle), circuit_compiled(angle)) @@ -376,10 +374,7 @@ def expected_res(x, y): assert np.allclose( expected_res(phi, theta), circuit_ref(phi, theta) ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - diagonalize_final_measurements_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(diagonalize_final_measurements_pass(circuit_ref)) assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) @@ -406,10 +401,7 @@ def expected_res(x, y): expected_res(phi, theta), circuit_ref(phi, theta) ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - diagonalize_final_measurements_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(diagonalize_final_measurements_pass(circuit_ref)) assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta)) @@ -420,7 +412,7 @@ def test_overlapping_observables_raises_error(self): dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @qml.qjit @diagonalize_final_measurements_pass @qml.qnode(dev) def circuit(x): @@ -439,7 +431,7 @@ def test_non_commuting_observables_raise_error(self): non-commuting observables.""" dev = qml.device("lightning.qubit", wires=1) - @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @qml.qjit @diagonalize_final_measurements_pass @qml.qnode(dev) def circuit(x): diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_measurements_from_samples.py b/tests/python_compiler/transforms/quantum/test_xdsl_measurements_from_samples.py index 2f7495eb54d..530fb474e76 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_measurements_from_samples.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_measurements_from_samples.py @@ -25,8 +25,6 @@ xdsl = pytest.importorskip("xdsl") catalyst = pytest.importorskip("catalyst") -from catalyst.passes import xdsl_plugin - import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( MeasurementsFromSamplesPass, @@ -551,10 +549,7 @@ def circuit_ref(): return mp(obs(wires=0)) assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert expected_res == circuit_compiled() @@ -583,10 +578,7 @@ def circuit_ref(): assert np.array_equal( expected_res, circuit_ref() ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert np.array_equal(expected_res, circuit_compiled()) @@ -621,10 +613,7 @@ def circuit_ref(): expected_res, circuit_ref() ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert np.array_equal(expected_res, _counts_catalyst_to_pl(*circuit_compiled())) @@ -651,10 +640,7 @@ def circuit_ref(): initial_op(wires=0) return qml.sample(wires=0) - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) expected_res = expected_res_base * np.ones(shape=(shots, 1), dtype=int) @@ -695,7 +681,6 @@ def circuit_ref(): assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" circuit_compiled = qml.qjit( measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], ) assert expected_res == circuit_compiled() @@ -736,10 +721,7 @@ def circuit_ref(): assert expected_res == circuit_ref(), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert expected_res == circuit_compiled() @@ -770,10 +752,7 @@ def circuit_ref(): assert np.array_equal( expected_res, circuit_ref() ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert np.array_equal(expected_res, circuit_compiled()) @@ -804,10 +783,7 @@ def circuit_ref(): assert np.array_equal( expected_res, circuit_ref() ), "Sanity check failed, is expected_res correct?" - circuit_compiled = qml.qjit( - measurements_from_samples_pass(circuit_ref), - pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()], - ) + circuit_compiled = qml.qjit(measurements_from_samples_pass(circuit_ref)) assert np.array_equal(expected_res, circuit_compiled()) @@ -820,7 +796,7 @@ def test_exec_expval_dynamic_shots(self): This use case is not currently supported. """ - @qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @qml.qjit def workload(shots): dev = qml.device("lightning.qubit", wires=1) @@ -838,7 +814,7 @@ def test_qjit_filecheck(self, run_filecheck_qjit): """Test that the measurements_from_samples_pass works correctly with qjit.""" dev = qml.device("lightning.qubit", wires=2) - @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @measurements_from_samples_pass @qml.qnode(dev, shots=25) def circuit(): @@ -857,7 +833,7 @@ def circuit(): def test_integrate_with_decompose(self): dev = qml.device("null.qubit", wires=4) - @qml.qjit(target="mlir", pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @measurements_from_samples_pass @partial( qml.transforms.decompose, diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_merge_rotations.py b/tests/python_compiler/transforms/quantum/test_xdsl_merge_rotations.py index ec5aa491b41..e3e8d24762e 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_merge_rotations.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_merge_rotations.py @@ -20,8 +20,6 @@ pytest.importorskip("catalyst") # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import MergeRotationsPass, merge_rotations_pass @@ -229,7 +227,7 @@ def test_qjit(self, run_filecheck_qjit): """Test that the MergeRotationsPass works correctly with qjit.""" dev = qml.device("lightning.qubit", wires=1) - @qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit(target="mlir") @merge_rotations_pass @qml.qnode(dev) def circuit(x: float, y: float): diff --git a/tests/python_compiler/transforms/quantum/test_xdsl_split_non_commuting.py b/tests/python_compiler/transforms/quantum/test_xdsl_split_non_commuting.py index 9f93174ec4f..f15c576c6ca 100644 --- a/tests/python_compiler/transforms/quantum/test_xdsl_split_non_commuting.py +++ b/tests/python_compiler/transforms/quantum/test_xdsl_split_non_commuting.py @@ -20,8 +20,6 @@ catalyst = pytest.importorskip("catalyst") # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( SplitNonCommutingPass, @@ -231,10 +229,7 @@ def _while_for(i): i = i + 1 return i - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @split_non_commuting_pass @qml.set_shots(10) @qml.qnode(dev) @@ -277,10 +272,7 @@ def while_fn(i): i = i + 1 return i - @qml.qjit( - target="mlir", - pass_plugins=[getXDSLPluginAbsolutePath()], - ) + @qml.qjit(target="mlir") @split_non_commuting_pass @qml.qnode(dev) def circuit(): @@ -295,9 +287,7 @@ def circuit(): res = circuit() - @qml.qjit( - target="mlir", - ) + @qml.qjit(target="mlir") @qml.qnode(dev) def circuit_ref(): for_fn() diff --git a/tests/python_compiler/visualization/test_draw_python_compiler.py b/tests/python_compiler/visualization/test_draw_python_compiler.py index de09bfef619..a3199ec2c4c 100644 --- a/tests/python_compiler/visualization/test_draw_python_compiler.py +++ b/tests/python_compiler/visualization/test_draw_python_compiler.py @@ -25,8 +25,6 @@ import jax # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import iterative_cancel_inverses_pass from pennylane.compiler.python_compiler.visualization import draw @@ -93,9 +91,7 @@ def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected): ) if qjit: - transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( - transforms_circuit - ) + transforms_circuit = qml.qjit(transforms_circuit) assert draw(transforms_circuit, level=level)() == expected @@ -128,9 +124,7 @@ def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expecte ) if qjit: - transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( - transforms_circuit - ) + transforms_circuit = qml.qjit(transforms_circuit) assert draw(transforms_circuit, level=level)() == expected @@ -162,9 +156,7 @@ def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, ex qml.transforms.merge_rotations(transforms_circuit) ) if qjit: - transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( - transforms_circuit - ) + transforms_circuit = qml.qjit(transforms_circuit) assert draw(transforms_circuit, level=level)() == expected @@ -208,9 +200,7 @@ def test_no_passes(self, transforms_circuit, level, qjit, expected): """Test that if no passes are applied, the circuit is still visualized.""" if qjit: - transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])( - transforms_circuit - ) + transforms_circuit = qml.qjit(transforms_circuit) assert draw(transforms_circuit, level=level)() == expected @@ -480,7 +470,7 @@ def circ(arg): def adjoint_op_not_implemented(self): """Test that NotImplementedError is raised when AdjointOp is used.""" - @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit(): qml.adjoint(qml.QubitUnitary)(jax.numpy.array([[0, 1], [1, 0]]), wires=[0]) @@ -492,7 +482,7 @@ def circuit(): def test_cond_not_implemented(self): """Test that NotImplementedError is raised when cond is used.""" - @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()]) + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(): m0 = qml.measure(0, reset=False, postselect=0) @@ -505,7 +495,7 @@ def circuit(): def test_for_loop_not_implemented(self): """Test that NotImplementedError is raised when for loop is used.""" - @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit(): for _ in range(3): @@ -518,7 +508,7 @@ def circuit(): def test_while_loop_not_implemented(self): """Test that NotImplementedError is raised when while loop is used.""" - @qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True) + @qml.qjit(autograph=True) @qml.qnode(qml.device("lightning.qubit", wires=1)) def circuit(): i = 0 diff --git a/tests/python_compiler/visualization/test_mlir_graph.py b/tests/python_compiler/visualization/test_mlir_graph.py index 5be78876b94..ec95391ce0e 100644 --- a/tests/python_compiler/visualization/test_mlir_graph.py +++ b/tests/python_compiler/visualization/test_mlir_graph.py @@ -22,10 +22,7 @@ pytest.importorskip("xdsl") pytest.importorskip("catalyst") - # pylint: disable=wrong-import-position -from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath - import pennylane as qml from pennylane.compiler.python_compiler.transforms import ( iterative_cancel_inverses_pass, @@ -70,7 +67,7 @@ def _(): return qml.state() if qjit: - _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + _ = qml.qjit(_) generate_mlir_graph(_)() assert collect_files(tmp_path) == {"QNode_level_0_no_transforms.svg"} @@ -90,7 +87,7 @@ def _(): return qml.state() if qjit: - _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + _ = qml.qjit(_) generate_mlir_graph(_)() assert_files( @@ -115,7 +112,7 @@ def _(x, y, w1, w2): return qml.state() if qjit: - _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + _ = qml.qjit(_) generate_mlir_graph(_)(0.1, 0.2, 0, 1) assert_files( @@ -140,7 +137,7 @@ def _(x, y, w1, w2): return qml.state() if qjit: - _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + _ = qml.qjit(_) generate_mlir_graph(_)(0.1, 0.2, 0, 1) assert_files( @@ -165,7 +162,7 @@ def _(x, y, w1, w2): return qml.state() if qjit: - _ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_) + _ = qml.qjit(_) generate_mlir_graph(_)(0.1, 0.2, 0, 1) assert_files(