diff --git a/pennylane/compiler/python_compiler/conversion.py b/pennylane/compiler/python_compiler/conversion.py index 31ed5d49d13..663678906aa 100644 --- a/pennylane/compiler/python_compiler/conversion.py +++ b/pennylane/compiler/python_compiler/conversion.py @@ -27,6 +27,7 @@ from xdsl.dialects import builtin as xbuiltin from xdsl.dialects import func as xfunc from xdsl.ir import Dialect as xDialect +from xdsl.traits import SymbolOpInterface from xdsl.traits import SymbolTable as xSymbolTable from .parser import QuantumParser @@ -49,14 +50,14 @@ def wrapper(*args, **kwargs) -> jModule: return wrapper -def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: # pragma: no cover +def _generic_str_inline(func: JaxJittedFunction, *args, **kwargs) -> str: """Create the generic textual representation for a jax.jitted function""" lowered = func.lower(*args, **kwargs) mod = lowered.compiler_ir() return mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=True) -def generic_str(func: JaxJittedFunction) -> Callable[..., str]: # pragma: no cover +def generic_str(func: JaxJittedFunction) -> Callable[..., str]: """Returns a wrapper that creates the generic textual representation for a jax.jitted function.""" @@ -69,7 +70,7 @@ def wrapper(*args, **kwargs) -> str: def parse_generic_to_xdsl_module( program: str, extra_dialects: Sequence[xDialect] | None = None -) -> xbuiltin.ModuleOp: # pragma: no cover +) -> xbuiltin.ModuleOp: """Parses a generic MLIR program string to an xDSL module.""" ctx = xContext(allow_unregistered=True) parser = QuantumParser(ctx, program, extra_dialects=extra_dialects) @@ -77,7 +78,7 @@ def parse_generic_to_xdsl_module( return moduleOp -def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover +def parse_generic_to_mlir_module(program: str) -> jModule: """Parses a generic MLIR program string to an MLIR module.""" with jContext() as ctx: ctx.allow_unregistered_dialects = True @@ -85,37 +86,25 @@ def parse_generic_to_mlir_module(program: str) -> jModule: # pragma: no cover return jModule.parse(program) -def mlir_from_docstring(func: Callable) -> jModule: # pragma: no cover - """Returns a wrapper that parses an MLIR program string located in the docstring - into an MLIR module.""" +def mlir_from_docstring(func: Callable) -> jModule: + """Returns an MLIR module using the docstring of the input callable as an MLIR string.""" - @wraps(func) - def wrapper(*_, **__): - return parse_generic_to_mlir_module(func.__doc__) - - return wrapper + return parse_generic_to_mlir_module(func.__doc__) -def _xdsl_module_inline( - func: JaxJittedFunction, *args, **kwargs -) -> xbuiltin.ModuleOp: # pragma: no cover +def _xdsl_module_inline(func: JaxJittedFunction, *args, **kwargs) -> xbuiltin.ModuleOp: """Get the xDSL module from a jax.jitted function""" generic_repr = _generic_str_inline(func, *args, **kwargs) return parse_generic_to_xdsl_module(generic_repr) -def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: # pragma: no cover - """Returns a wrapper that parses an MLIR program string located in the docstring - into an xDSL module.""" +def xdsl_from_docstring(func: Callable) -> xbuiltin.ModuleOp: + """Returns an xDSL module using the docstring of the input callable as an MLIR string.""" - @wraps(func) - def wrapper(*_, **__): - return parse_generic_to_xdsl_module(func.__doc__) - - return wrapper + return parse_generic_to_xdsl_module(func.__doc__) -def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: # pragma: no cover +def xdsl_module(func: JaxJittedFunction) -> Callable[..., xbuiltin.ModuleOp]: """Returns a wrapper that creates an xDSL module from a jax.jitted function.""" @wraps(func) @@ -126,7 +115,7 @@ def wrapper(*args, **kwargs) -> xbuiltin.ModuleOp: def inline_module( - from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str = None + from_mod: xbuiltin.ModuleOp, to_mod: xbuiltin.ModuleOp, change_main_to: str | None = None ) -> None: """Inline the contents of one xDSL module into another xDSL module. The inlined body is appended to the end of ``to_mod``. @@ -140,7 +129,13 @@ def inline_module( main.properties["sym_name"] = xbuiltin.StringAttr(change_main_to) for op in from_mod.body.ops: - xSymbolTable.insert_or_update(to_mod, op.clone()) + clone = op.clone() + if op.has_trait(SymbolOpInterface): + # Do safe insertion for symbol op + xSymbolTable.insert_or_update(to_mod, clone) + + else: + to_mod.regions[0].blocks[0].add_op(clone) def inline_jit_to_module(func: JaxJittedFunction, mod: xbuiltin.ModuleOp) -> Callable[..., None]: diff --git a/tests/python_compiler/dialects/test_transform_dialect.py b/tests/python_compiler/dialects/test_transform_dialect.py index 138d91c5d9b..883185ef91e 100644 --- a/tests/python_compiler/dialects/test_transform_dialect.py +++ b/tests/python_compiler/dialects/test_transform_dialect.py @@ -125,7 +125,6 @@ def apply(self, _ctx: Context, _module: builtin.ModuleOp) -> None: else: print("hello world") - @xdsl_from_docstring def program(): """ builtin.module { @@ -142,7 +141,7 @@ def program(): ctx.load_dialect(builtin.Builtin) ctx.load_dialect(transform.Transform) - mod = program() + mod = xdsl_from_docstring(program) pipeline = xdsl.passes.PassPipeline((ApplyTransformSequence(),)) pipeline.apply(ctx, mod) diff --git a/tests/python_compiler/test_xdsl_utils.py b/tests/python_compiler/test_xdsl_utils.py index 0d33e2130e5..a7352d225f6 100644 --- a/tests/python_compiler/test_xdsl_utils.py +++ b/tests/python_compiler/test_xdsl_utils.py @@ -16,13 +16,30 @@ import pytest -pytestmark = pytest.mark.external +pytestmark = [pytest.mark.external, pytest.mark.capture] xdsl = pytest.importorskip("xdsl") +jax = pytest.importorskip("jax") # pylint: disable=wrong-import-position -from xdsl.dialects import arith, builtin, tensor, test +from jaxlib.mlir.ir import Module as jaxModule # pylint: disable=no-name-in-module +from xdsl.context import Context +from xdsl.dialects import arith, builtin, func, tensor, test -from pennylane.compiler.python_compiler.dialects.stablehlo import ConstantOp as hloConstantOp +import pennylane as qml +from pennylane.compiler.python_compiler import QuantumParser +from pennylane.compiler.python_compiler.conversion import ( + generic_str, + inline_jit_to_module, + inline_module, + mlir_from_docstring, + mlir_module, + parse_generic_to_mlir_module, + parse_generic_to_xdsl_module, + xdsl_from_docstring, + xdsl_from_qjit, + xdsl_module, +) +from pennylane.compiler.python_compiler.dialects import stablehlo from pennylane.compiler.python_compiler.utils import get_constant_from_ssa @@ -58,7 +75,7 @@ def test_scalar_constant_arith(self, const, attr_type, dtype): (-1.1 + 2.3j, builtin.ComplexType(builtin.Float64Type())), ], ) - @pytest.mark.parametrize("constant_op", [arith.ConstantOp, hloConstantOp]) + @pytest.mark.parametrize("constant_op", [arith.ConstantOp, stablehlo.ConstantOp]) def test_scalar_constant_extracted_from_rank0_tensor(self, const, elt_type, constant_op): """Test that constants created by ``stablehlo.constant`` are returned correctly.""" data = const @@ -92,7 +109,7 @@ def test_tensor_constant_stablehlo(self): type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), data=(1.0, 2.0, 3.0), ) - val = hloConstantOp(value=dense_attr).results[0] + val = stablehlo.ConstantOp(value=dense_attr).results[0] assert get_constant_from_ssa(val) is None @@ -106,7 +123,7 @@ def test_extract_scalar_from_constant_tensor_stablehlo(self): type=builtin.TensorType(element_type=builtin.Float64Type(), shape=(3,)), data=(1.0, 2.0, 3.0), ) - tensor_ = hloConstantOp(value=dense_attr).results[0] + tensor_ = stablehlo.ConstantOp(value=dense_attr).results[0] val = tensor.ExtractOp( tensor=tensor_, indices=[dummy_index], result_type=builtin.Float64Type() ).results[0] @@ -116,5 +133,205 @@ def test_extract_scalar_from_constant_tensor_stablehlo(self): assert get_constant_from_ssa(val) is None +class TestConversionUtils: + """Unit tests for utilities for converting Python code to xDSL modules.""" + + def test_generic_str(self): + """Test that the generic_str function works correctly.""" + + @jax.jit + def f(x): + return x + 1 + + gen_str = generic_str(f)(1) + context = Context() + module = QuantumParser(context, gen_str).parse_module() + + assert len(module.regions[0].blocks[0].ops) == 1 + func_op = module.regions[0].blocks[0].first_op + assert isinstance(func_op, func.FuncOp) + + expected_op_names = ["stablehlo.constant", "stablehlo.add", "func.return"] + for op, expected_op_name in zip(func_op.body.ops, expected_op_names): + assert op.name == expected_op_name + + def test_mlir_module(self): + """Test that the mlir_module function works correctly.""" + + @jax.jit + def f(x): + return x + 1 + + mod = mlir_module(f)(1) + assert isinstance(mod, jaxModule) + + def test_xdsl_module(self): + """Test that the xdsl_module function works correctly.""" + + @jax.jit + def f(x): + return x + 1 + + mod = xdsl_module(f)(1) + assert isinstance(mod, builtin.ModuleOp) + + assert len(mod.regions[0].blocks[0].ops) == 1 + func_op = mod.regions[0].blocks[0].first_op + assert isinstance(func_op, func.FuncOp) + + expected_op_names = ["stablehlo.constant", "stablehlo.add", "func.return"] + for op, expected_op_name in zip(func_op.body.ops, expected_op_names): + assert op.name == expected_op_name + + def test_parse_generic_to_mlir_module(self): + """Test that the parse_generic_to_mlir_module function works correctly.""" + program_str = """ + "builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + }) : () -> () + """ + + mod = parse_generic_to_mlir_module(program_str) + assert isinstance(mod, jaxModule) + + def test_parse_generic_to_xdsl_module(self): + """Test that the parse_generic_to_xdsl_module function works correctly.""" + program_str = """ + "builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + }) : () -> () + """ + + mod = parse_generic_to_xdsl_module(program_str) + assert isinstance(mod, builtin.ModuleOp) + + assert len(mod.regions[0].blocks[0].ops) == 1 + assert isinstance(mod.regions[0].blocks[0].first_op, arith.ConstantOp) + + def test_mlir_from_docstring(self): + """Test that the mlir_from_docstring function works correctly.""" + + def f(): + """ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + """ + + mod = mlir_from_docstring(f) + assert isinstance(mod, jaxModule) + + def test_xdsl_from_docstring(self): + """Test that the xdsl_from_docstring function works correctly.""" + + def f(): + """ + %0 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + """ + + mod = xdsl_from_docstring(f) + assert isinstance(mod, builtin.ModuleOp) + + assert len(mod.regions[0].blocks[0].ops) == 1 + assert isinstance(mod.regions[0].blocks[0].first_op, arith.ConstantOp) + + def test_xdsl_from_qjit(self): + """Test that the xdsl_from_qjit function works correctly.""" + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circuit(): + return qml.state() + + mod = xdsl_from_qjit(circuit)() + assert isinstance(mod, builtin.ModuleOp) + + nested_modules = [] + for op in mod.body.ops: + if isinstance(op, builtin.ModuleOp): + nested_modules.append(op) + + funcs = [] + assert len(nested_modules) == 1 + for op in nested_modules[0].body.ops: + if isinstance(op, func.FuncOp): + funcs.append(op) + + assert len(funcs) == 1 + # All qnodes have a UnitAttr attribute called qnode + assert funcs[0].attributes.get("qnode", None) is not None + + +class TestInliningUtils: + """Unit tests for utilities for inlining operations into xDSL modules.""" + + @pytest.mark.parametrize("change_main_to", ["foo", None]) + def test_inline_module(self, change_main_to): + """Test that the inline_module function works correctly.""" + + mod1_main = func.FuncOp(name="main", function_type=((), ())) + mod1_func = func.FuncOp(name="not_main", function_type=((), ())) + mod1_ops = [mod1_main, mod1_func, test.TestPureOp()] + mod1 = builtin.ModuleOp(mod1_ops) + + mod2_ops = [test.TestOp()] + mod2 = builtin.ModuleOp(mod2_ops) + + inline_module(mod1, mod2, change_main_to=change_main_to) + + assert len(mod2.ops) == 4 + expected_mod2 = builtin.ModuleOp(ops=[op.clone() for op in mod2_ops + mod1_ops]) + assert mod2.is_structurally_equivalent(expected_mod2) + + # Check that mod1 is unchanged + expected_mod1 = builtin.ModuleOp(ops=[op.clone() for op in mod1_ops]) + assert mod1.is_structurally_equivalent(expected_mod1) + + expected_names = {"not_main", change_main_to or "main"} + actual_names = set() + for op in mod2.ops: + if isinstance(op, func.FuncOp): + actual_names.add(op.sym_name.data) + + assert actual_names == expected_names + + def test_inline_jit_to_module(self): + """Test that the inline_jit_to_module function works correctly.""" + + @jax.jit + def f1(x): + return x + + @jax.jit + def f2(x): + return f1(x) + + mod = builtin.ModuleOp(ops=[]) + # Mutate the module in-place + inline_jit_to_module(f2, mod)(1.5) + + expected_func_names = {"f1", "f2"} + funcs = [] + actual_func_names = set() + f2_func = None + assert len(mod.ops) == 2 + for op in mod.body.ops: + assert isinstance(op, func.FuncOp) + funcs.append(op) + sym_name = op.sym_name.data + actual_func_names.add(sym_name) + if sym_name == "f2": + f2_func = op + + assert actual_func_names == expected_func_names + + # Check that f2 calls f1 + call_op = None + for op in f2_func.body.ops: + if isinstance(op, func.CallOp): + call_op = op + + assert call_op is not None + assert call_op.callee.root_reference.data == "f1" + + if __name__ == "__main__": pytest.main(["-x", __file__])