diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index d7ba7d7de8..aab482a765 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,7 @@ [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2285)](https://github.com/PennyLaneAI/catalyst/pull/2285) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) + [(#2218)](https://github.com/PennyLaneAI/catalyst/pull/2218) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index c5128bcd6d..271c354c44 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,12 +14,18 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" -from functools import singledispatchmethod +from functools import singledispatch, singledispatchmethod +from pennylane.measurements import ExpectationMP, MeasurementProcess, ProbabilityMP, VarianceMP +from pennylane.operation import Operator from xdsl.dialects import builtin, func, scf from xdsl.ir import Block, Operation, Region, SSAValue from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.inspection.xdsl_conversion import ( + xdsl_to_qml_measurement, + xdsl_to_qml_op, +) from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -85,6 +91,125 @@ def _visit_block(self, block: Block) -> None: for op in block.ops: self._visit_operation(op) + # =================== + # QUANTUM OPERATIONS + # =================== + + @_visit_operation.register + def _gate_op( + self, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + ) -> None: + """Generic handler for unitary gates.""" + + # Create PennyLane instance + qml_op = xdsl_to_qml_op(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(qml_op), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + # ===================== + # QUANTUM MEASUREMENTS + # ===================== + + @_visit_operation.register + def _state_op(self, op: quantum.StateOp) -> None: + """Handler for the terminal state measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _expval_and_var_ops( + self, + op: quantum.ExpvalOp | quantum.VarianceOp, + ) -> None: + """Handler for statistical measurement operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _sample_counts_probs_ops( + self, + op: quantum.SampleOp | quantum.ProbsOp, + ) -> None: + """Handler for sample operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + + # TODO: This doesn't logically make sense, but quantum.compbasis + # is obs_op and function below just pulls out the static wires + wires = xdsl_to_qml_measurement(obs_op) + meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + # ============= # CONTROL FLOW # ============= @@ -256,3 +381,45 @@ def _flatten_if_op(op: scf.IfOp) -> list[Region]: # No more nested IfOps, therefore append final region flattened_op.append(else_region) return flattened_op + + +@singledispatch +def get_label(op: Operator | MeasurementProcess) -> str: + """Gets the appropriate label for a PennyLane object.""" + return str(op) + + +@get_label.register +def _operator(op: Operator) -> str: + """Returns the appropriate label for PennyLane Operator""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + # Using <...> lets us use ports (https://graphviz.org/doc/info/shapes.html#record) + return f" {op.name}| {wires_str}" + + +@get_label.register +def _meas(meas: MeasurementProcess) -> str: + """Returns the appropriate label for a PennyLane MeasurementProcess using match/case.""" + + wires_str = list(meas.wires.labels) + if not wires_str: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires_str))}]" + + base_name = meas._shortname + + match meas: + case ExpectationMP() | VarianceMP() | ProbabilityMP(): + if meas.obs is not None: + obs_name = meas.obs.name + base_name = f"{base_name}({obs_name})" + + case _: + pass + + return f" {base_name}| {wires_str}" diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 3954ca8d3c..d5706f9261 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -26,9 +26,11 @@ from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region +from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, + get_label, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -282,10 +284,11 @@ def my_workflow(): # Assert null qubit device node is inside my_qnode2 cluster assert graph_clusters["cluster2"]["label"] == "my_qnode2" - assert graph_nodes["node1"]["parent_cluster_uid"] == "cluster2" + # NOTE: node1 is the qml.H(0) in my_qnode1 + assert graph_nodes["node2"]["parent_cluster_uid"] == "cluster2" # Assert label is as expected - assert graph_nodes["node1"]["label"] == "LightningSimulator" + assert graph_nodes["node2"]["label"] == "LightningSimulator" class TestForOp: @@ -681,3 +684,290 @@ def my_workflow(x): assert clusters["cluster9"]["parent_cluster_uid"] == "cluster7" assert clusters["cluster10"]["label"] == "else" assert clusters["cluster10"]["parent_cluster_uid"] == "cluster7" + + +class TestGetLabel: + """Tests the get_label utility.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "op, label", + [ + (qml.H(0), " Hadamard| [0]"), + ( + qml.QubitUnitary([[0, 1], [1, 0]], 0), + " QubitUnitary| [0]", + ), + (qml.SWAP([0, 1]), " SWAP| [0, 1]"), + ], + ) + def test_standard_operator(self, op, label): + """Tests against an operator instance.""" + assert get_label(op) == label + + def test_global_phase_operator(self): + """Tests against a GlobalPhase operator instance.""" + assert get_label(qml.GlobalPhase(0.5)) == f" GlobalPhase| all" + + @pytest.mark.unit + @pytest.mark.parametrize( + "meas, label", + [ + (qml.state(), " state| all"), + (qml.expval(qml.Z(0)), " expval(PauliZ)| [0]"), + (qml.var(qml.Z(0)), " var(PauliZ)| [0]"), + (qml.probs(), " probs| all"), + (qml.probs(wires=0), " probs| [0]"), + (qml.probs(wires=[0, 1]), " probs| [0, 1]"), + (qml.sample(), " sample| all"), + (qml.sample(wires=0), " sample| [0]"), + (qml.sample(wires=[0, 1]), " sample| [0, 1]"), + ], + ) + def test_standard_measurement(self, meas, label): + """Tests against an operator instance.""" + assert get_label(meas) == label + + +class TestCreateStaticOperatorNodes: + """Tests that operators with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) + def test_custom_op(self, op): + """Tests that the CustomOp operation node can be created and visualized.""" + + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Make sure label has relevant info + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.GlobalPhase(0.5), + qml.GlobalPhase(0.5, wires=0), + qml.GlobalPhase(0.5, wires=[0, 1]), + ], + ) + def test_global_phase_op(self, op): + """Test that GlobalPhase can be handled.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Compiler throws out the wires and they get converted to wires=[] no matter what + assert nodes["node1"]["label"] == get_label(qml.GlobalPhase(0.5)) + + @pytest.mark.unit + def test_qubit_unitary_op(self): + """Test that QubitUnitary operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) + + @pytest.mark.unit + def test_multi_rz_op(self): + """Test that MultiRZ operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.MultiRZ(0.5, wires=[0]) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0])) + + @pytest.mark.unit + def test_projective_measurement_op(self): + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == f" MidMeasureMP| [0]" + + +class TestCreateStaticMeasurementNodes: + """Tests that measurements with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + def test_state_op(self): + """Test that qml.state can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.state() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.state()) + + @pytest.mark.unit + @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) + def test_expval_var_measurement_op(self, meas_fn): + """Test that statistical measurement operators can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas_fn(qml.Z(0)) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(meas_fn(qml.Z(0))) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + ], + ) + def test_probs_measurement_op(self, op): + """Tests that the probs measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_valid_sample_measurement_op(self, op): + """Tests that the sample measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op)