diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index d6a6cbfcc8..d7ba7d7de8 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -9,6 +9,7 @@ [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2285)](https://github.com/PennyLaneAI/catalyst/pull/2285) + [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index ef13948b18..c5128bcd6d 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,10 +16,10 @@ from functools import singledispatchmethod -from xdsl.dialects import builtin, func -from xdsl.ir import Block, Operation, Region +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue -from catalyst.python_interface.dialects import catalyst, quantum +from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -85,6 +85,90 @@ def _visit_block(self, block: Block) -> None: for op in block.ops: self._visit_operation(op) + # ============= + # CONTROL FLOW + # ============= + + @_visit_operation.register + def _for_op(self, operation: scf.ForOp) -> None: + """Handle an xDSL ForOp operation.""" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="for loop", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + self._visit_region(operation.regions[0]) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _while_op(self, operation: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation.""" + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="while loop", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + for region in operation.regions: + self._visit_region(region) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _if_op(self, operation: scf.IfOp): + """Handles the scf.IfOp operation.""" + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="conditional", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Loop through each branch and visualize as a cluster + flattened_if_op: list[Region] = _flatten_if_op(operation) + num_regions = len(flattened_if_op) + for i, region in enumerate(flattened_if_op): + cluster_label = "elif" + if i == 0: + cluster_label = "if" + elif i == num_regions - 1: + cluster_label = "else" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label=cluster_label, + labeljust="l", + style="dashed", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Go recursively into the branch to process internals + self._visit_region(region) + + # Pop branch cluster after processing to ensure + # logical branches are treated as 'parallel' + self._cluster_uid_stack.pop() + + # Pop IfOp cluster before leaving this handler + self._cluster_uid_stack.pop() + # ============ # DEVICE NODE # ============ @@ -138,3 +222,37 @@ def _func_return(self, operation: func.ReturnOp) -> None: # If we hit a func.return operation we know we are leaving # the FuncOp's scope and so we can pop the ID off the stack. self._cluster_uid_stack.pop() + + +def _flatten_if_op(op: scf.IfOp) -> list[Region]: + """Recursively flattens a nested IfOp (if/elif/else chains).""" + + then_region, else_region = op.regions + + flattened_op: list[Region] = [then_region] + + # Check to see if there are any nested quantum operations in the else block + else_block: Block = else_region.block + has_quantum_ops = False + nested_if_op = None + for op in else_block.ops: + if isinstance(op, scf.IfOp): + nested_if_op = op + # No need to walk this op as this will be + # recursively handled down below + continue + for internal_op in op.walk(): + if type(internal_op) in quantum.Quantum.operations: + has_quantum_ops = True + # No need to check anything else + break + + if nested_if_op and not has_quantum_ops: + # Recursively flatten any IfOps found in said block + nested_flattened_op: list[Region] = _flatten_if_op(nested_if_op) + flattened_op.extend(nested_flattened_op) + return flattened_op + + # No more nested IfOps, therefore append final region + flattened_op.append(else_region) + return flattened_op diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d4cac366a2..3954ca8d3c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -65,14 +65,13 @@ def add_edge(self, from_uid: str, to_uid: str, **attrs) -> None: def add_cluster( self, uid, - node_label=None, + label=None, cluster_uid=None, **attrs, ) -> None: self._clusters[uid] = { "uid": uid, - "node_label": node_label, - "cluster_label": attrs.get("label"), + "label": label, "parent_cluster_uid": cluster_uid, "attrs": attrs, } @@ -155,11 +154,11 @@ def my_workflow(): # └── my_workflow # Check qjit is nested under graph - assert graph_clusters["cluster0"]["cluster_label"] == "qjit" + assert graph_clusters["cluster0"]["label"] == "qjit" assert graph_clusters["cluster0"]["parent_cluster_uid"] is None # Check that my_workflow is under qjit - assert graph_clusters["cluster1"]["cluster_label"] == "my_workflow" + assert graph_clusters["cluster1"]["label"] == "my_workflow" assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0" def test_nested_qnodes(self): @@ -195,14 +194,14 @@ def my_workflow(): # └── my_qnode2 # Check qjit is under graph - assert graph_clusters["cluster0"]["cluster_label"] == "qjit" + assert graph_clusters["cluster0"]["label"] == "qjit" assert graph_clusters["cluster0"]["parent_cluster_uid"] is None # Check both qnodes are under my_workflow - assert graph_clusters["cluster1"]["cluster_label"] == "my_qnode1" + assert graph_clusters["cluster1"]["label"] == "my_qnode1" assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0" - assert graph_clusters["cluster2"]["cluster_label"] == "my_qnode2" + assert graph_clusters["cluster2"]["label"] == "my_qnode2" assert graph_clusters["cluster2"]["parent_cluster_uid"] == "cluster0" @@ -234,7 +233,7 @@ def my_workflow(): # └── my_workflow: NullQubit # Assert device node is inside my_workflow cluster - assert graph_clusters["cluster1"]["cluster_label"] == "my_workflow" + assert graph_clusters["cluster1"]["label"] == "my_workflow" assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1" # Assert label is as expected @@ -275,15 +274,410 @@ def my_workflow(): # └── my_qnode2: LightningSimulator # Assert lightning.qubit device node is inside my_qnode1 cluster - assert graph_clusters["cluster1"]["cluster_label"] == "my_qnode1" + assert graph_clusters["cluster1"]["label"] == "my_qnode1" assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1" # Assert label is as expected assert graph_nodes["node0"]["label"] == "NullQubit" # Assert null qubit device node is inside my_qnode2 cluster - assert graph_clusters["cluster2"]["cluster_label"] == "my_qnode2" + assert graph_clusters["cluster2"]["label"] == "my_qnode2" assert graph_nodes["node1"]["parent_cluster_uid"] == "cluster2" # Assert label is as expected assert graph_nodes["node1"]["label"] == "LightningSimulator" + + +class TestForOp: + """Tests that the for loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Tests that the for loop cluster can be visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(3): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested for loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(0, 5, 2): + for j in range(1, 6, 2): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "for loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestWhileOp: + """Tests that the while loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the while loop is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + counter = 0 + while counter < 5: + qml.H(0) + counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested while loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + outer_counter = 0 + inner_counter = 0 + while outer_counter < 5: + while inner_counter < 6: + qml.H(0) + inner_counter += 1 + outer_counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "while loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestIfOp: + """Tests that the conditional control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 2: + qml.X(0) + else: + qml.Y(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within cluster1 (my_workflow) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within cluster2 (conditional) + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "else" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_if_elif_else_conditional(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + qml.Z(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within my_workflow + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within conditional + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_nested_conditionals(self): + """Tests that nested conditionals are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x, y): + if x == 1: + if y == 2: + qml.H(0) + else: + qml.Z(0) + qml.X(0) + else: + qml.Z(0) + + args = (1, 2) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # cluster2 -> conditional (1) + # cluster3 -> if + # cluster4 -> conditional () + # cluster5 -> if + # cluster6 -> else + # cluster7 -> else + + # Check first conditional is a cluster within my_workflow + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check 'if' cluster of first conditional has another conditional + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + # Second conditional + assert clusters["cluster4"]["label"] == "conditional" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster3" + # Check 'if' and 'else' in second conditional + assert clusters["cluster5"]["label"] == "if" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster4" + assert clusters["cluster6"]["label"] == "else" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4" + + # Check nested if / else is within the first if cluster + assert clusters["cluster7"]["label"] == "else" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2" + + def test_nested_conditionals_with_quantum_ops(self): + """Tests that nested conditionals are unflattend if quantum operations + are present""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + qml.Z(0) + if x == 3: + qml.RX(0, 0) + elif x == 4: + qml.RY(0, 0) + else: + qml.RZ(0, 0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # node0 -> NullQubit + # cluster2 -> conditional (1) + # cluster3 -> if + # node1 -> X(0) + # cluster4 -> elif + # node2 -> Y(0) + # cluster5 -> else + # node3 -> Z(0) + # cluster6 -> conditional (2) + # cluster7 -> if + # node4 -> RX(0,0) + # cluster8 -> elif + # node5 -> RY(0,0) + # cluster9 -> else + # node6 -> RZ(0,0) + + # check outer conditional (1) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + # Nested conditional (2) inside conditional (1) + assert clusters["cluster6"]["label"] == "conditional" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5" + assert clusters["cluster7"]["label"] == "if" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster6" + assert clusters["cluster8"]["label"] == "elif" + assert clusters["cluster8"]["parent_cluster_uid"] == "cluster6" + assert clusters["cluster9"]["label"] == "else" + assert clusters["cluster9"]["parent_cluster_uid"] == "cluster6" + + def test_nested_conditionals_with_nested_quantum_ops(self): + """Tests that nested conditionals are unflattend if quantum operations + are present but nested in other operations""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + for i in range(3): + qml.Z(0) + if x == 3: + qml.RX(0, 0) + elif x == 4: + qml.RY(0, 0) + else: + qml.RZ(0, 0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # node0 -> NullQubit + # cluster2 -> conditional (1) + # cluster3 -> if + # node1 -> X(0) + # cluster4 -> elif + # node2 -> Y(0) + # cluster5 -> else + # cluster6 -> for loop + # node3 -> Z(0) + # cluster7 -> conditional (2) + # cluster8 -> if + # node4 -> RX(0,0) + # cluster9 -> elif + # node5 -> RY(0,0) + # cluster10 -> else + # node6 -> RZ(0,0) + + # check outer conditional (1) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + # Nested conditional (2) inside conditional (1) + assert clusters["cluster6"]["label"] == "for loop" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5" + + assert clusters["cluster7"]["label"] == "conditional" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster5" + assert clusters["cluster8"]["label"] == "if" + assert clusters["cluster8"]["parent_cluster_uid"] == "cluster7" + assert clusters["cluster9"]["label"] == "elif" + assert clusters["cluster9"]["parent_cluster_uid"] == "cluster7" + assert clusters["cluster10"]["label"] == "else" + assert clusters["cluster10"]["parent_cluster_uid"] == "cluster7"