Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,12 @@ currently rely on JAX’s API to lower to MLIR. This has the special
effect of lowering to a specific dialect called StableHLO, which is used
to represent all arithmetic operations present in the program.

Once lowered to MLIR, if the original ``qjit`` decorator specified the
xDSL pass plugin, we pass control over to the xDSL layer, which applies
all transforms that were requested by the user. We can request the use
of the xDSL plugin like so:
Once lowered to MLIR, if any xDSL registered passes are detected, we pass the control over to
the xDSL layer, which automatically detects and applies all xDSL transforms that were requested
by the user.

However, if you want to manually trigger the xDSL layer without using any xDSL registered passes,
you can do so by specifying the ``pass_plugins`` parameter:

.. code-block:: python

Expand Down Expand Up @@ -1003,9 +1005,7 @@ currently accessible as
qml.capture.enable()
dev = qml.device("lightning.qubit", wires=1)

@qml.qjit(
pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()]
)
@qml.qjit
@my_pass
@qml.qnode(dev)
def circuit(x):
Expand Down Expand Up @@ -1296,8 +1296,6 @@ will explain what is going on.

.. code-block:: python

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

def test_h_to_x_pass_integration(run_filecheck_qjit):
"""Test that Hadamard gets converted into PauliX."""
# The original program simply applies a Hadamard to a circuit
Expand All @@ -1306,7 +1304,7 @@ will explain what is going on.
# `compiler_transform`. To make sure that the xDSL API works
# correctly, program capture must be enabled.
# qml.capture.enable()
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath])
@qml.qjit
@h_to_x_pass
def circuit():
# CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
Expand Down
3 changes: 1 addition & 2 deletions pennylane/compiler/python_compiler/visualization/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import TYPE_CHECKING

from catalyst import qjit
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

from pennylane.tape import QuantumScript

Expand All @@ -44,7 +43,7 @@ def _get_mlir_module(qnode: QNode, args, kwargs) -> ModuleOp:
return qnode.mlir_module

func = getattr(qnode, "user_function", qnode)
jitted_qnode = qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(func)
jitted_qnode = qjit(func)
jitted_qnode.jit_compile(args, **kwargs)
return jitted_qnode.mlir_module

Expand Down
2 changes: 1 addition & 1 deletion tests/python_compiler/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_qjit(self, run_filecheck_qjit):
# Test that the merge_rotations_pass works as expected when used with `qjit`
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@merge_rotations_pass
@qml.qnode(dev)
def circuit(x: float, y: float):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

# pylint: disable=wrong-import-position
from catalyst.ftqc import mbqc_pipeline
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import (
Expand Down Expand Up @@ -148,10 +147,7 @@ def test_multiple_func_w_qnode_attr(self, run_filecheck):
def test_outline_state_evolution_no_error(self):
"""Test outline_state_evolution_pass does not raise error for circuit with classical operations only."""

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
def circuit(x, y):
return x * y + 5
Expand All @@ -164,10 +160,7 @@ def test_outline_state_evolution_no_terminal_op_error(self):
# TODOs: we can resolve this issue if the boundary op is inserted when the program is captured.
dev = qml.device("null.qubit", wires=10)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.qnode(dev)
def circuit():
Expand All @@ -183,10 +176,7 @@ def test_outline_state_evolution_pass_only(self, run_filecheck_qjit):
"""Test the outline_state_evolution_pass only."""
dev = qml.device("lightning.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.set_shots(1000)
@qml.qnode(dev)
Expand Down Expand Up @@ -221,11 +211,7 @@ def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_f
"""Test if the outline_state_evolution_pass works with the convert-to-mbqc-formalism pass on lightning.qubit."""
dev = qml.device("lightning.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@outline_state_evolution_pass
Expand Down Expand Up @@ -270,11 +256,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qji
"""Test if the outline_state_evolution_pass works with all mbqc transform pipeline on null.qubit."""
dev = qml.device("null.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@measurements_from_samples_pass
Expand Down Expand Up @@ -319,11 +301,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self):
"""Test if a circuit can be transfored with the outline_state_evolution_pass and all mbqc transform pipeline can be executed on null.qubit."""
dev = qml.device("null.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@measurements_from_samples_pass
Expand Down Expand Up @@ -362,10 +340,7 @@ def while_fn(i):
i = i + 1
return i

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.qnode(dev)
def circuit():
Expand All @@ -376,9 +351,7 @@ def circuit():

res = circuit()

@qml.qjit(
target="mlir",
)
@qml.qjit(target="mlir")
@qml.qnode(dev)
def circuit_ref():
for_fn()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
pytest.importorskip("catalyst")

# pylint: disable=wrong-import-position
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import (
IterativeCancelInversesPass,
Expand Down Expand Up @@ -197,7 +195,7 @@ def test_qjit(self, run_filecheck_qjit):
"""Test that the IterativeCancelInversesPass works correctly with qjit."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@iterative_cancel_inverses_pass
@qml.qnode(dev)
def circuit():
Expand All @@ -215,7 +213,7 @@ def test_qjit_no_cancellation(self, run_filecheck_qjit):
there are no operations that can be cancelled."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@iterative_cancel_inverses_pass
@qml.qnode(dev)
def circuit():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
pytest.importorskip("catalyst")

# pylint: disable=wrong-import-position
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import (
CombineGlobalPhasesPass,
Expand Down Expand Up @@ -226,7 +224,7 @@ def test_qjit(self, run_filecheck_qjit):
"""Test that the CombineGlobalPhasesPass works correctly with qjit."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@combine_global_phases_pass
@qml.qnode(dev)
def circuit(x: float, y: float):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
xdsl = pytest.importorskip("xdsl")

catalyst = pytest.importorskip("catalyst")
from catalyst.passes import xdsl_plugin

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import (
Expand Down Expand Up @@ -345,7 +344,6 @@ def circuit_ref(phi):
), "Sanity check failed, is expected_res correct?"
circuit_compiled = qml.qjit(
diagonalize_final_measurements_pass(circuit_ref),
pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()],
)

assert np.allclose(expected_res(angle), circuit_compiled(angle))
Expand Down Expand Up @@ -376,10 +374,7 @@ def expected_res(x, y):
assert np.allclose(
expected_res(phi, theta), circuit_ref(phi, theta)
), "Sanity check failed, is expected_res correct?"
circuit_compiled = qml.qjit(
diagonalize_final_measurements_pass(circuit_ref),
pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()],
)
circuit_compiled = qml.qjit(diagonalize_final_measurements_pass(circuit_ref))

assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta))

Expand All @@ -406,10 +401,7 @@ def expected_res(x, y):
expected_res(phi, theta), circuit_ref(phi, theta)
), "Sanity check failed, is expected_res correct?"

circuit_compiled = qml.qjit(
diagonalize_final_measurements_pass(circuit_ref),
pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()],
)
circuit_compiled = qml.qjit(diagonalize_final_measurements_pass(circuit_ref))

assert np.allclose(expected_res(phi, theta), circuit_compiled(phi, theta))

Expand All @@ -420,7 +412,7 @@ def test_overlapping_observables_raises_error(self):

dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()])
@qml.qjit
@diagonalize_final_measurements_pass
@qml.qnode(dev)
def circuit(x):
Expand All @@ -439,7 +431,7 @@ def test_non_commuting_observables_raise_error(self):
non-commuting observables."""
dev = qml.device("lightning.qubit", wires=1)

@qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()])
@qml.qjit
@diagonalize_final_measurements_pass
@qml.qnode(dev)
def circuit(x):
Expand Down
Loading
Loading