Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,12 @@ currently rely on JAX’s API to lower to MLIR. This has the special
effect of lowering to a specific dialect called StableHLO, which is used
to represent all arithmetic operations present in the program.

Once lowered to MLIR, if the original ``qjit`` decorator specified the
xDSL pass plugin, we pass control over to the xDSL layer, which applies
all transforms that were requested by the user. We can request the use
of the xDSL plugin like so:
nce lowered to MLIR, if any xDSL registered passes are detected, we pass the control over to
the xDSL layer, which automatically detects and applies all xDSL transforms that were requested
by the user.

However, if you want to manually trigger the xDSL layer without using any xDSL registered passes,
you can do so by specifying the ``pass_plugins`` parameter:

.. code-block:: python

Expand Down Expand Up @@ -1003,9 +1005,7 @@ currently accessible as
qml.capture.enable()
dev = qml.device("lightning.qubit", wires=1)

@qml.qjit(
pass_plugins=[catalyst.passes.xdsl_plugin.getXDSLPluginAbsolutePath()]
)
@qml.qjit
@my_pass
@qml.qnode(dev)
def circuit(x):
Expand Down Expand Up @@ -1295,8 +1295,6 @@ will explain what is going on.

.. code-block:: python

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath

def test_h_to_x_pass_integration(run_filecheck_qjit):
"""Test that Hadamard gets converted into PauliX."""
# The original program simply applies a Hadamard to a circuit
Expand All @@ -1305,7 +1303,7 @@ will explain what is going on.
# `compiler_transform`. To make sure that the xDSL API works
# correctly, program capture must be enabled.
# qml.capture.enable()
@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath])
@qml.qjit
@h_to_x_pass
def circuit():
# CHECK: [[q0:%.+]] = "test.op"() : () -> !quantum.bit
Expand Down
2 changes: 1 addition & 1 deletion frontend/test/pytest/python_interface/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_qjit(self, run_filecheck_qjit):
# Test that the merge_rotations_pass works as expected when used with `qjit`
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@merge_rotations_pass
@qml.qnode(dev)
def circuit(x: float, y: float):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import jax
import pennylane as qml

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
from catalyst.python_interface.inspection import draw
from catalyst.python_interface.transforms import (
iterative_cancel_inverses_pass,
Expand Down Expand Up @@ -91,9 +90,7 @@ def test_multiple_levels_xdsl(self, transforms_circuit, level, qjit, expected):
)

if qjit:
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
transforms_circuit
)
transforms_circuit = qml.qjit(transforms_circuit)

assert draw(transforms_circuit, level=level)() == expected

Expand Down Expand Up @@ -127,9 +124,7 @@ def test_multiple_levels_catalyst(self, transforms_circuit, level, qjit, expecte
)

if qjit:
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
transforms_circuit
)
transforms_circuit = qml.qjit(transforms_circuit)

assert draw(transforms_circuit, level=level)() == expected

Expand Down Expand Up @@ -162,9 +157,7 @@ def test_multiple_levels_xdsl_catalyst(self, transforms_circuit, level, qjit, ex
qml.transforms.merge_rotations(transforms_circuit)
)
if qjit:
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
transforms_circuit
)
transforms_circuit = qml.qjit(transforms_circuit)

assert draw(transforms_circuit, level=level)() == expected

Expand Down Expand Up @@ -208,9 +201,7 @@ def test_no_passes(self, transforms_circuit, level, qjit, expected):
"""Test that if no passes are applied, the circuit is still visualized."""

if qjit:
transforms_circuit = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(
transforms_circuit
)
transforms_circuit = qml.qjit(transforms_circuit)

assert draw(transforms_circuit, level=level)() == expected

Expand Down Expand Up @@ -487,7 +478,7 @@ def circ(arg):
def adjoint_op_not_implemented(self):
"""Test that NotImplementedError is raised when AdjointOp is used."""

@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit():
qml.adjoint(qml.QubitUnitary)(jax.numpy.array([[0, 1], [1, 0]]), wires=[0])
Expand All @@ -499,7 +490,7 @@ def circuit():
def test_cond_not_implemented(self):
"""Test that NotImplementedError is raised when cond is used."""

@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit():
m0 = qml.measure(0, reset=False, postselect=0)
Expand All @@ -512,7 +503,7 @@ def circuit():
def test_for_loop_not_implemented(self):
"""Test that NotImplementedError is raised when for loop is used."""

@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True)
@qml.qjit(autograph=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit():
for _ in range(3):
Expand All @@ -525,7 +516,7 @@ def circuit():
def test_while_loop_not_implemented(self):
"""Test that NotImplementedError is raised when while loop is used."""

@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()], autograph=True)
@qml.qjit(autograph=True)
@qml.qnode(qml.device("lightning.qubit", wires=1))
def circuit():
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import pennylane as qml

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
from catalyst.python_interface.inspection import generate_mlir_graph
from catalyst.python_interface.transforms import (
iterative_cancel_inverses_pass,
Expand Down Expand Up @@ -68,7 +67,7 @@ def _():
return qml.state()

if qjit:
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
_ = qml.qjit(_)

generate_mlir_graph(_)()
assert collect_files(tmp_path) == {"QNode_level_0_no_transforms.svg"}
Expand All @@ -88,7 +87,7 @@ def _():
return qml.state()

if qjit:
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
_ = qml.qjit(_)

generate_mlir_graph(_)()
assert_files(
Expand All @@ -113,7 +112,7 @@ def _(x, y, w1, w2):
return qml.state()

if qjit:
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
_ = qml.qjit(_)

generate_mlir_graph(_)(0.1, 0.2, 0, 1)
assert_files(
Expand All @@ -138,7 +137,7 @@ def _(x, y, w1, w2):
return qml.state()

if qjit:
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
_ = qml.qjit(_)

generate_mlir_graph(_)(0.1, 0.2, 0, 1)
assert_files(
Expand All @@ -164,7 +163,7 @@ def _(x, y, w1, w2):
return qml.state()

if qjit:
_ = qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])(_)
_ = qml.qjit(_)

generate_mlir_graph(_)(0.1, 0.2, 0, 1)
assert_files(
Expand Down
10 changes: 5 additions & 5 deletions frontend/test/pytest/python_interface/test_unified_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def test_integration_catalyst_xdsl_pass_with_capture(self, capsys):

assert capture_enabled()

@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qjit
@hello_world_pass
@qml.qnode(qml.device("lightning.qubit", wires=2))
def f(x):
Expand All @@ -319,7 +319,7 @@ def test_integration_catalyst_xdsl_pass_no_capture(self, capsys):

assert not capture_enabled()

@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qjit
@apply_pass("hello-world")
@qml.qnode(qml.device("lightning.qubit", wires=2))
def f(x):
Expand All @@ -338,7 +338,7 @@ def test_integration_catalyst_mixed_passes_with_capture(self, capsys):

assert capture_enabled()

@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qjit
@hello_world_pass
@qml.transforms.cancel_inverses
@qml.qnode(qml.device("lightning.qubit", wires=2))
Expand All @@ -359,7 +359,7 @@ def test_integration_catalyst_mixed_passes_no_capture(self, capsys):

assert not capture_enabled()

@qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qjit
@apply_pass("hello-world")
@catalyst_cancel_inverses
@qml.qnode(qml.device("lightning.qubit", wires=2))
Expand Down Expand Up @@ -495,7 +495,7 @@ def print_between_passes(_, module, __, pass_level=0):
print("=== Between Pass ===")
print(module)

@qml.qjit(pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit
@iterative_cancel_inverses_pass
@merge_rotations_pass
@qml.qnode(qml.device("null.qubit", wires=2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pennylane.ftqc import RotXZX

from catalyst.ftqc import mbqc_pipeline
from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
from catalyst.python_interface.transforms import (
OutlineStateEvolutionPass,
convert_to_mbqc_formalism_pass,
Expand Down Expand Up @@ -147,10 +146,7 @@ def test_outline_state_evolution_no_error(self):
"""Test outline_state_evolution_pass does not raise error for circuit with classical
operations only."""

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
def circuit(x, y):
return x * y + 5
Expand All @@ -165,10 +161,7 @@ def test_outline_state_evolution_no_terminal_op_error(self):
# the program is captured.
dev = qml.device("null.qubit", wires=10)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.qnode(dev)
def circuit():
Expand All @@ -184,10 +177,7 @@ def test_outline_state_evolution_pass_only(self, run_filecheck_qjit):
"""Test the outline_state_evolution_pass only."""
dev = qml.device("lightning.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.set_shots(1000)
@qml.qnode(dev)
Expand Down Expand Up @@ -223,11 +213,7 @@ def test_outline_state_evolution_pass_with_convert_to_mbqc_formalism(self, run_f
on lightning.qubit."""
dev = qml.device("lightning.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@outline_state_evolution_pass
Expand Down Expand Up @@ -273,11 +259,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline(self, run_filecheck_qji
null.qubit."""
dev = qml.device("null.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@measurements_from_samples_pass
Expand Down Expand Up @@ -323,11 +305,7 @@ def test_outline_state_evolution_pass_with_mbqc_pipeline_run_on_nullqubit(self):
transform pipeline can be executed on null.qubit."""
dev = qml.device("null.qubit", wires=1000)

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
pipelines=mbqc_pipeline(),
)
@qml.qjit(target="mlir", pipelines=mbqc_pipeline())
@decompose_graph_state_pass
@convert_to_mbqc_formalism_pass
@measurements_from_samples_pass
Expand Down Expand Up @@ -367,10 +345,7 @@ def while_fn(i):
i = i + 1
return i

@qml.qjit(
target="mlir",
pass_plugins=[getXDSLPluginAbsolutePath()],
)
@qml.qjit(target="mlir")
@outline_state_evolution_pass
@qml.qnode(dev)
def circuit():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pennylane as qml

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
from catalyst.python_interface.transforms import (
IterativeCancelInversesPass,
iterative_cancel_inverses_pass,
Expand Down Expand Up @@ -194,7 +193,7 @@ def test_qjit(self, run_filecheck_qjit):
"""Test that the IterativeCancelInversesPass works correctly with qjit."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@iterative_cancel_inverses_pass
@qml.qnode(dev)
def circuit():
Expand All @@ -212,7 +211,7 @@ def test_qjit_no_cancellation(self, run_filecheck_qjit):
there are no operations that can be cancelled."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@iterative_cancel_inverses_pass
@qml.qnode(dev)
def circuit():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import pennylane as qml

from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath
from catalyst.python_interface.transforms import (
CombineGlobalPhasesPass,
combine_global_phases_pass,
Expand Down Expand Up @@ -224,7 +223,7 @@ def test_qjit(self, run_filecheck_qjit):
"""Test that the CombineGlobalPhasesPass works correctly with qjit."""
dev = qml.device("lightning.qubit", wires=2)

@qml.qjit(target="mlir", pass_plugins=[getXDSLPluginAbsolutePath()])
@qml.qjit(target="mlir")
@combine_global_phases_pass
@qml.qnode(dev)
def circuit(x: float, y: float):
Expand Down
Loading
Loading