Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions pennylane/compiler/python_compiler/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from xdsl.dialects import builtin as xbuiltin
from xdsl.dialects import func as xfunc
from xdsl.ir import Dialect as xDialect
from xdsl.traits import SymbolOpInterface
from xdsl.traits import SymbolTable as xSymbolTable

from .parser import QuantumParser
Expand All @@ -49,14 +50,14 @@ def wrapper(*args, **kwargs) -> jModule:
return wrapper


def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover
def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str:
"""Create the generic textual representation for a jax.jitted function"""
lowered = func.lower(*args, **kwargs)
mod = lowered.compiler_ir()
return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True)


def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover
def generic_str(func: JaxJittedFunction) -> Callable[..., str]:
"""Returns a wrapper that creates the generic textual representation for a
jax.jitted function."""

Expand All @@ -69,53 +70,41 @@ def wrapper(*args, **kwargs) -> str:

def parse_generic_to_xdsl_module(
program: str, extra_dialects: Sequence[xDialect] | None = None
) -> xbuiltin.ModuleOp: # pragma: no cover
) -> xbuiltin.ModuleOp:
"""Parses a generic MLIR program string to an xDSL module."""
ctx = xContext(allow_unregistered=True)
parser = QuantumParser(ctx, program, extra_dialects=extra_dialects)
moduleOp: xbuiltin.ModuleOp = parser.parse_module()
return moduleOp


def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover
def parse_generic_to_mlir_module(program: str) -> jModule:
"""Parses a generic MLIR program string to an MLIR module."""
with jContext() as ctx:
ctx.allow_unregistered_dialects = True
jstablehlo.register_dialect(ctx) # pylint: disable=no-member
return jModule.parse(program)


def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover
"""Returns a wrapper that parses an MLIR program string located in the docstring
into an MLIR module."""
def mlir_from_docstring(func: Callable) -> jModule:
"""Returns an MLIR module using the docstring of the input callable as an MLIR string."""

@wraps(func)
def wrapper(*_, **__):
return parse_generic_to_mlir_module(func.__doc__)

return wrapper
return parse_generic_to_mlir_module(func.__doc__)


def _xdsl_module_inline(
func: JaxJittedFunction, *args, **kwargs
) -> xbuiltin.ModuleOp: # pragma: no cover
def _xdsl_module_inline(func: JaxJittedFunction, *args, **kwargs) -> xbuiltin.ModuleOp:
"""Get the xDSL module from a jax.jitted function"""
generic_repr = _generic_str_inline(func, *args, **kwargs)
return parse_generic_to_xdsl_module(generic_repr)


def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover
"""Returns a wrapper that parses an MLIR program string located in the docstring
into an xDSL module."""
def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp:
"""Returns an xDSL module using the docstring of the input callable as an MLIR string."""

@wraps(func)
def wrapper(*_, **__):
return parse_generic_to_xdsl_module(func.__doc__)

return wrapper
return parse_generic_to_xdsl_module(func.__doc__)


def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover
def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]:
"""Returns a wrapper that creates an xDSL module from a jax.jitted function."""

@wraps(func)
Expand All @@ -126,7 +115,7 @@ def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp:


def inline_module(
from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None
from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str | None = None
) -> None:
"""Inline the contents of one xDSL module into another xDSL module. The inlined body is appended
to the end of ``to_mod``.
Expand All @@ -140,7 +129,13 @@ def inline_module(
main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to)

for op in from_mod.body.ops:
xSymbolTable.insert_or_update(to_mod, op.clone())
clone = op.clone()
if op.has_trait(SymbolOpInterface):
# Do safe insertion for symbol op
xSymbolTable.insert_or_update(to_mod, clone)

else:
to_mod.regions[0].blocks[0].add_op(clone)


def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]:
Expand Down
3 changes: 1 addition & 2 deletions tests/python_compiler/dialects/test_transform_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None:
else:
print("hello world")

@xdsl_from_docstring
def program():
"""
builtin.module {
Expand All @@ -142,7 +141,7 @@ def program():
ctx.load_dialect(builtin.Builtin)
ctx.load_dialect(transform.Transform)

mod = program()
mod = xdsl_from_docstring(program)
pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),))
pipeline.apply(ctx, mod)

Expand Down
229 changes: 223 additions & 6 deletions tests/python_compiler/test_xdsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,30 @@

import pytest

pytestmark = pytest.mark.external
pytestmark = [pytest.mark.external, pytest.mark.capture]
xdsl = pytest.importorskip("xdsl")
jax = pytest.importorskip("jax")

# pylint: disable=wrong-import-position
from xdsl.dialects import arith, builtin, tensor, test
from jaxlib.mlir.ir import Module as jaxModule # pylint: disable=no-name-in-module
from xdsl.context import Context
from xdsl.dialects import arith, builtin, func, tensor, test

from pennylane.compiler.python_compiler.dialects.stablehlo import ConstantOp as hloConstantOp
import pennylane as qml
from pennylane.compiler.python_compiler import QuantumParser
from pennylane.compiler.python_compiler.conversion import (
generic_str,
inline_jit_to_module,
inline_module,
mlir_from_docstring,
mlir_module,
parse_generic_to_mlir_module,
parse_generic_to_xdsl_module,
xdsl_from_docstring,
xdsl_from_qjit,
xdsl_module,
)
from pennylane.compiler.python_compiler.dialects import stablehlo
from pennylane.compiler.python_compiler.utils import get_constant_from_ssa


Expand Down Expand Up @@ -58,7 +75,7 @@ def test_scalar_constant_arith(self, const, attr_type, dtype):
(-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())),
],
)
@pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp])
@pytest.mark.parametrize("constant_op", [arith.ConstantOp, stablehlo.ConstantOp])
def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op):
"""Test that constants created by ``stablehlo.constant`` are returned correctly."""
data = const
Expand Down Expand Up @@ -92,7 +109,7 @@ def test_tensor_constant_stablehlo(self):
type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),
data=(1.0, 2.0, 3.0),
)
val = hloConstantOp(value=dense_attr).results[0]
val = stablehlo.ConstantOp(value=dense_attr).results[0]

assert get_constant_from_ssa(val) is None

Expand All @@ -106,7 +123,7 @@ def test_extract_scalar_from_constant_tensor_stablehlo(self):
type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)),
data=(1.0, 2.0, 3.0),
)
tensor_ = hloConstantOp(value=dense_attr).results[0]
tensor_ = stablehlo.ConstantOp(value=dense_attr).results[0]
val = tensor.ExtractOp(
tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type()
).results[0]
Expand All @@ -116,5 +133,205 @@ def test_extract_scalar_from_constant_tensor_stablehlo(self):
assert get_constant_from_ssa(val) is None


class TestConversionUtils:
"""Unit tests for utilities for converting Python code to xDSL modules."""

def test_generic_str(self):
"""Test that the generic_str function works correctly."""

@jax.jit
def f(x):
return x + 1

gen_str = generic_str(f)(1)
context = Context()
module = QuantumParser(context, gen_str).parse_module()

assert len(module.regions[0].blocks[0].ops) == 1
func_op = module.regions[0].blocks[0].first_op
assert isinstance(func_op, func.FuncOp)

expected_op_names = ["stablehlo.constant", "stablehlo.add", "func.return"]
for op, expected_op_name in zip(func_op.body.ops, expected_op_names):
assert op.name == expected_op_name

def test_mlir_module(self):
"""Test that the mlir_module function works correctly."""

@jax.jit
def f(x):
return x + 1

mod = mlir_module(f)(1)
assert isinstance(mod, jaxModule)

def test_xdsl_module(self):
"""Test that the xdsl_module function works correctly."""

@jax.jit
def f(x):
return x + 1

mod = xdsl_module(f)(1)
assert isinstance(mod, builtin.ModuleOp)

assert len(mod.regions[0].blocks[0].ops) == 1
func_op = mod.regions[0].blocks[0].first_op
assert isinstance(func_op, func.FuncOp)

expected_op_names = ["stablehlo.constant", "stablehlo.add", "func.return"]
for op, expected_op_name in zip(func_op.body.ops, expected_op_names):
assert op.name == expected_op_name

def test_parse_generic_to_mlir_module(self):
"""Test that the parse_generic_to_mlir_module function works correctly."""
program_str = """
"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : i64}> : () -> i64
}) : () -> ()
"""

mod = parse_generic_to_mlir_module(program_str)
assert isinstance(mod, jaxModule)

def test_parse_generic_to_xdsl_module(self):
"""Test that the parse_generic_to_xdsl_module function works correctly."""
program_str = """
"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : i64}> : () -> i64
}) : () -> ()
"""

mod = parse_generic_to_xdsl_module(program_str)
assert isinstance(mod, builtin.ModuleOp)

assert len(mod.regions[0].blocks[0].ops) == 1
assert isinstance(mod.regions[0].blocks[0].first_op, arith.ConstantOp)

def test_mlir_from_docstring(self):
"""Test that the mlir_from_docstring function works correctly."""

def f():
"""
%0 = "arith.constant"() <{value = 0 : i64}> : () -> i64
"""

mod = mlir_from_docstring(f)
assert isinstance(mod, jaxModule)

def test_xdsl_from_docstring(self):
"""Test that the xdsl_from_docstring function works correctly."""

def f():
"""
%0 = "arith.constant"() <{value = 0 : i64}> : () -> i64
"""

mod = xdsl_from_docstring(f)
assert isinstance(mod, builtin.ModuleOp)

assert len(mod.regions[0].blocks[0].ops) == 1
assert isinstance(mod.regions[0].blocks[0].first_op, arith.ConstantOp)

def test_xdsl_from_qjit(self):
"""Test that the xdsl_from_qjit function works correctly."""

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit():
return qml.state()

mod = xdsl_from_qjit(circuit)()
assert isinstance(mod, builtin.ModuleOp)

nested_modules = []
for op in mod.body.ops:
if isinstance(op, builtin.ModuleOp):
nested_modules.append(op)

funcs = []
assert len(nested_modules) == 1
for op in nested_modules[0].body.ops:
if isinstance(op, func.FuncOp):
funcs.append(op)

assert len(funcs) == 1
# All qnodes have a UnitAttr attribute called qnode
assert funcs[0].attributes.get("qnode", None) is not None


class TestInliningUtils:
"""Unit tests for utilities for inlining operations into xDSL modules."""

@pytest.mark.parametrize("change_main_to", ["foo", None])
def test_inline_module(self, change_main_to):
"""Test that the inline_module function works correctly."""

mod1_main = func.FuncOp(name="main", function_type=((), ()))
mod1_func = func.FuncOp(name="not_main", function_type=((), ()))
mod1_ops = [mod1_main, mod1_func, test.TestPureOp()]
mod1 = builtin.ModuleOp(mod1_ops)

mod2_ops = [test.TestOp()]
mod2 = builtin.ModuleOp(mod2_ops)

inline_module(mod1, mod2, change_main_to=change_main_to)

assert len(mod2.ops) == 4
expected_mod2 = builtin.ModuleOp(ops=[op.clone() for op in mod2_ops + mod1_ops])
assert mod2.is_structurally_equivalent(expected_mod2)

# Check that mod1 is unchanged
expected_mod1 = builtin.ModuleOp(ops=[op.clone() for op in mod1_ops])
assert mod1.is_structurally_equivalent(expected_mod1)

expected_names = {"not_main", change_main_to or "main"}
actual_names = set()
for op in mod2.ops:
if isinstance(op, func.FuncOp):
actual_names.add(op.sym_name.data)

assert actual_names == expected_names

def test_inline_jit_to_module(self):
"""Test that the inline_jit_to_module function works correctly."""

@jax.jit
def f1(x):
return x

@jax.jit
def f2(x):
return f1(x)

mod = builtin.ModuleOp(ops=[])
# Mutate the module in-place
inline_jit_to_module(f2, mod)(1.5)

expected_func_names = {"f1", "f2"}
funcs = []
actual_func_names = set()
f2_func = None
assert len(mod.ops) == 2
for op in mod.body.ops:
assert isinstance(op, func.FuncOp)
funcs.append(op)
sym_name = op.sym_name.data
actual_func_names.add(sym_name)
if sym_name == "f2":
f2_func = op

assert actual_func_names == expected_func_names

# Check that f2 calls f1
call_op = None
for op in f2_func.body.ops:
if isinstance(op, func.CallOp):
call_op = op

assert call_op is not None
assert call_op.callee.root_reference.data == "f1"


if __name__ == "__main__":
pytest.main(["-x", __file__])