diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8b96d4e37b..3a7d714ea8 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -11,6 +11,7 @@ [(#2285)](https://github.com/PennyLaneAI/catalyst/pull/2285) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) [(#2218)](https://github.com/PennyLaneAI/catalyst/pull/2218) + [(#2260)](https://github.com/PennyLaneAI/catalyst/pull/2260) * Catalyst now features a unified compilation framework, which enables users and developers to design and implement compilation passes in Python in addition to C++, on the same Catalyst IR. The Python diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 271c354c44..7d11a5fae5 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,9 +14,15 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" +from collections import defaultdict from functools import singledispatch, singledispatchmethod -from pennylane.measurements import ExpectationMP, MeasurementProcess, ProbabilityMP, VarianceMP +from pennylane.measurements import ( + ExpectationMP, + MeasurementProcess, + ProbabilityMP, + VarianceMP, +) from pennylane.operation import Operator from xdsl.dialects import builtin, func, scf from xdsl.ir import Block, Operation, Region, SSAValue @@ -50,6 +56,11 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # Keep track of nesting clusters using a stack self._cluster_uid_stack: list[str] = [] + # Create a map of wire to node uid + # Keys represent static (int) or dynamic wires (str) + # Values represent the set of all node uids that are on that wire. + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) + # Use counter internally for UID self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 @@ -59,6 +70,7 @@ def _reset(self) -> None: self._cluster_uid_stack: list[str] = [] self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module. @@ -116,6 +128,17 @@ def _gate_op( ) self._node_uid_counter += 1 + # Search through previous ops found on current wires and connect + prev_node_uids: set[str] = set.union( + set(), *(self._wire_to_node_uids[wire] for wire in qml_op.wires) + ) + for prev_node_uid in prev_node_uids: + self.dag_builder.add_edge(prev_node_uid, node_uid) + + # Update affected wires to source from this node UID + for wire in qml_op.wires: + self._wire_to_node_uids[wire] = {node_uid} + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" @@ -134,70 +157,66 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: ) self._node_uid_counter += 1 + # Search through previous ops found on current wires and connect + prev_node_uids: set[str] = set.union( + set(), *(self._wire_to_node_uids[wire] for wire in meas.wires) + ) + for prev_node_uid in prev_node_uids: + self.dag_builder.add_edge(prev_node_uid, node_uid) + + # Update affected wires to source from this node UID + for wire in meas.wires: + self._wire_to_node_uids[wire] = {node_uid} + # ===================== # QUANTUM MEASUREMENTS # ===================== @_visit_operation.register - def _state_op(self, op: quantum.StateOp) -> None: - """Handler for the terminal state measurement operation.""" - - # Create PennyLane instance - meas = xdsl_to_qml_measurement(op) - - # Add node to current cluster - node_uid = f"node{self._node_uid_counter}" - self.dag_builder.add_node( - uid=node_uid, - label=get_label(meas), - cluster_uid=self._cluster_uid_stack[-1], - fillcolor="lightpink", - color="lightpink3", - # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) - shape="record", - ) - self._node_uid_counter += 1 - - @_visit_operation.register - def _expval_and_var_ops( + def _measurements( self, - op: quantum.ExpvalOp | quantum.VarianceOp, + op: ( + quantum.StateOp + | quantum.ExpvalOp + | quantum.VarianceOp + | quantum.SampleOp + | quantum.ProbsOp + ), ) -> None: - """Handler for statistical measurement operations.""" + """Handler for all quantum measurement operations.""" - # Create PennyLane instance - obs_op = op.obs.owner - meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + prev_wires = [] + meas = None - # Add node to current cluster - node_uid = f"node{self._node_uid_counter}" - self.dag_builder.add_node( - uid=node_uid, - label=get_label(meas), - cluster_uid=self._cluster_uid_stack[-1], - fillcolor="lightpink", - color="lightpink3", - # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) - shape="record", - ) - self._node_uid_counter += 1 + match op: + case quantum.StateOp(): + meas = xdsl_to_qml_measurement(op) + # NOTE: state can only handle all wires + prev_wires = self._wire_to_node_uids.keys() - @_visit_operation.register - def _sample_counts_probs_ops( - self, - op: quantum.SampleOp | quantum.ProbsOp, - ) -> None: - """Handler for sample operations.""" + case quantum.ExpvalOp() | quantum.VarianceOp(): + obs_op = op.obs.owner + meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + prev_wires = meas.wires.labels - # Create PennyLane instance - obs_op = op.obs.owner + case quantum.SampleOp() | quantum.ProbsOp(): + obs_op = op.obs.owner - # TODO: This doesn't logically make sense, but quantum.compbasis - # is obs_op and function below just pulls out the static wires - wires = xdsl_to_qml_measurement(obs_op) - meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + # TODO: This doesn't logically make sense, but quantum.compbasis + # is obs_op and function below just pulls out the static wires + wires = xdsl_to_qml_measurement(obs_op) + meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + + if wires == []: + # If no wires specified, connect to all seen current wires + prev_wires = self._wire_to_node_uids.keys() + else: + # Use the specific wires from the observable + prev_wires = wires + + case _: + return - # Add node to current cluster node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, @@ -210,6 +229,11 @@ def _sample_counts_probs_ops( ) self._node_uid_counter += 1 + for wire in prev_wires: + if wire in self._wire_to_node_uids: + for seen_node in self._wire_to_node_uids[wire]: + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") + # ============= # CONTROL FLOW # ============= @@ -263,6 +287,10 @@ def _if_op(self, operation: scf.IfOp): self._cluster_uid_stack.append(uid) self._cluster_uid_counter += 1 + # Save wires state before all of the branches + wire_map_before = self._wire_to_node_uids.copy() + region_wire_maps: list[dict[int | str, set[str]]] = [] + # Loop through each branch and visualize as a cluster flattened_if_op: list[Region] = _flatten_if_op(operation) num_regions = len(flattened_if_op) @@ -284,9 +312,16 @@ def _if_op(self, operation: scf.IfOp): self._cluster_uid_stack.append(uid) self._cluster_uid_counter += 1 + # Make fresh wire map before going into region + self._wire_to_node_uids = wire_map_before.copy() + # Go recursively into the branch to process internals self._visit_region(region) + # Update branch wire maps + if self._wire_to_node_uids != wire_map_before: + region_wire_maps.append(self._wire_to_node_uids) + # Pop branch cluster after processing to ensure # logical branches are treated as 'parallel' self._cluster_uid_stack.pop() @@ -294,6 +329,25 @@ def _if_op(self, operation: scf.IfOp): # Pop IfOp cluster before leaving this handler self._cluster_uid_stack.pop() + # Check what wires were affected + affected_wires: set[str | int] = set(wire_map_before.keys()) + for region_wire_map in region_wire_maps: + affected_wires.update(region_wire_map.keys()) + + # Update state to be the union of all branch wire maps + final_wire_map = defaultdict(set) + for wire in affected_wires: + all_nodes: set = set() + for region_wire_map in region_wire_maps: + if not wire in region_wire_map: + # IfOp region didn't apply anything on this wire + # so default to node before the IfOp + all_nodes.update(wire_map_before.get(wire, set())) + else: + all_nodes.update(region_wire_map.get(wire, set())) + final_wire_map[wire] = all_nodes + self._wire_to_node_uids = final_wire_map + # ============ # DEVICE NODE # ============ @@ -348,6 +402,9 @@ def _func_return(self, operation: func.ReturnOp) -> None: # the FuncOp's scope and so we can pop the ID off the stack. self._cluster_uid_stack.pop() + # Clear seen wires as we are exiting a FuncOp (qnode) + self._wire_to_node_uids = defaultdict(set) + def _flatten_if_op(op: scf.IfOp) -> list[Region]: """Recursively flattens a nested IfOp (if/elif/else chains).""" diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 4634eb7373..685ecd3798 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -45,7 +45,7 @@ class FakeDAGBuilder(DAGBuilder): def __init__(self): self._nodes = {} - self._edges = [] + self._edges = {} self._clusters = {} def add_node(self, uid, label, cluster_uid=None, **attrs) -> None: @@ -57,13 +57,11 @@ def add_node(self, uid, label, cluster_uid=None, **attrs) -> None: } def add_edge(self, from_uid: str, to_uid: str, **attrs) -> None: - self._edges.append( - { - "from_uid": from_uid, - "to_uid": to_uid, - "attrs": attrs, - } - ) + # O(1) look up + edge_key = (from_uid, to_uid) + self._edges[edge_key] = { + "attrs": attrs, + } def add_cluster( self, @@ -951,3 +949,342 @@ def my_circuit(): assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == get_label(op) + + +class TestOperatorConnectivity: + """Tests that operators are properly connected.""" + + def test_static_connection_within_cluster(self): + """Tests that connections can be made within the same cluster.""" + + dev = qml.device("null.qubit", wires=3) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.X(0) + qml.Z(1) + qml.Y(0) + qml.H(1) + qml.S(1) + qml.T(2) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "PauliY" in nodes["node2"]["label"] + assert "PauliZ" in nodes["node3"]["label"] + assert "Hadamard" in nodes["node4"]["label"] + assert "S" in nodes["node5"]["label"] + assert "T" in nodes["node6"]["label"] + + # Check edges + # X -> Y + # Z -> H -> S + # T + assert len(edges) == 3 + assert ("node1", "node2") in edges + assert ("node3", "node4") in edges + assert ("node4", "node5") in edges + + def test_static_connection_through_for_loop(self): + """Tests that connections can be made through a for loop cluster.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.X(0) + for i in range(3): + qml.Y(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "PauliY" in nodes["node2"]["label"] + + # Check edges + # for loop + # X ----------> Y + assert len(edges) == 1 + assert ("node1", "node2") in edges + + def test_static_connection_through_while_loop(self): + """Tests that connections can be made through a while loop cluster.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + counter = 0 + qml.X(0) + while counter < 5: + qml.Y(0) + counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "PauliY" in nodes["node2"]["label"] + + # Check edges + # while loop + # X ----------> Y + assert len(edges) == 1 + assert ("node1", "node2") in edges + + def test_static_connection_through_conditional(self): + """Tests that connections through conditionals make sense.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + qml.X(0) + qml.T(1) + if x == 1: + qml.RX(0, 0) + qml.S(1) + elif x == 2: + qml.RY(0, 0) + else: + qml.RZ(0, 0) + qml.H(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + # NOTE: depth first traversal hence T first then PauliX + assert "T" in nodes["node1"]["label"] + assert "PauliX" in nodes["node2"]["label"] + assert "RX" in nodes["node3"]["label"] + assert "S" in nodes["node4"]["label"] + assert "RY" in nodes["node5"]["label"] + assert "RZ" in nodes["node6"]["label"] + assert "Hadamard" in nodes["node7"]["label"] + + # Check all edges + assert len(edges) == 7 + assert ("node1", "node4") in edges + assert ("node2", "node3") in edges + assert ("node2", "node5") in edges + assert ("node2", "node6") in edges + assert ("node3", "node7") in edges + assert ("node5", "node7") in edges + assert ("node6", "node7") in edges + + def test_multi_wire_connectivity(self): + """Ensures that multi wire connectivity holds.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.RX(0.1, 0) + qml.RY(0.2, 1) + qml.RZ(0.3, 2) + qml.CNOT(wires=[0, 1]) + qml.Toffoli(wires=[1, 2, 0]) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + assert "RZ" in nodes["node1"]["label"] + assert "RX" in nodes["node2"]["label"] + assert "RY" in nodes["node3"]["label"] + assert "CNOT" in nodes["node4"]["label"] + assert "Toffoli" in nodes["node5"]["label"] + + # Check all edges + assert len(edges) == 4 + assert ("node3", "node4") in edges # RX -> CNOT + assert ("node2", "node4") in edges # RY -> CNOT + assert ("node1", "node5") in edges # RZ -> Toffoli + assert ("node4", "node5") in edges # CNOT -> Toffoli + + +class TestTerminalMeasurementConnectivity: + """Test that terminal measurements connect properly.""" + + @pytest.mark.parametrize("meas_fn", [qml.probs, qml.state]) + def test_connect_all_wires(self, meas_fn): + """Tests connection to terminal measurements that operate on all wires.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.X(0) + qml.T(1) + return meas_fn() + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "T" in nodes["node2"]["label"] + assert meas_fn.__name__ in nodes["node3"]["label"] + + # Check all edges + assert len(edges) == 2 + assert ("node1", "node3") in edges + assert ("node2", "node3") in edges + + def test_connect_specific_wires(self): + """Tests connection to terminal measurements that operate on specific wires.""" + + dev = qml.device("null.qubit", wires=5) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_workflow(): + qml.X(0) + qml.Y(1) + qml.Z(2) + qml.H(3) + return ( + qml.expval(qml.Z(0)), + qml.var(qml.Z(1)), + qml.probs(wires=[2]), + qml.sample(wires=[3]), + ) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "PauliY" in nodes["node2"]["label"] + assert "PauliZ" in nodes["node3"]["label"] + assert "Hadamard" in nodes["node4"]["label"] + assert "expval" in nodes["node5"]["label"] + assert "var" in nodes["node6"]["label"] + assert "probs" in nodes["node7"]["label"] + assert "sample" in nodes["node8"]["label"] + + # Check all edges + assert len(edges) == 4 + assert ("node1", "node5") in edges + assert ("node2", "node6") in edges + assert ("node3", "node7") in edges + assert ("node4", "node8") in edges + + def test_multi_wire_connectivity(self): + """Ensures that multi wire connectivity holds.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.X(0) + qml.Y(1) + return qml.probs(wires=[0, 1]) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + edges = utility.dag_builder.edges + nodes = utility.dag_builder.nodes + + # node0 -> NullQubit + + # Check all nodes + assert "PauliX" in nodes["node1"]["label"] + assert "PauliY" in nodes["node2"]["label"] + assert "probs" in nodes["node3"]["label"] + + # Check all edges + assert len(edges) == 2 + assert ("node1", "node3") in edges + assert ("node2", "node3") in edges + + def test_no_quantum_ops_before_measurement(self): + """Tests a workflow with no quantum operations.""" + + dev = qml.device("null.qubit", wires=2) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_empty_workflow(): + return qml.expval(qml.Z(0)) + + module = my_empty_workflow() + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + assert len(utility.dag_builder.edges) == 0