-
Notifications
You must be signed in to change notification settings - Fork 59
feat: visualize static operators and measurements as nodes #2218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 250 commits
6cda3bf
86b5662
9c26efd
81a9aa6
ab06276
3e4102b
194f14a
199ab70
e28a1a9
24d3dd3
685842c
986fb3f
07b7655
11f5864
8c64d81
606f5f5
edfcb93
0370886
9847b39
57ce573
805d21b
c0e0283
e1a6c64
2a960cf
a21e6dd
b06585b
f666a9f
e28b2b7
af5e52c
952fd7f
99f9fc6
492276d
8f2dc98
eac25d5
b06e92e
53497ba
83225b3
7d04249
fbb0c5a
941675f
bb9549d
2bad538
683768e
2adb3b2
f7bbb6b
f7f4b5a
a21f878
ae0ff32
26ae968
1d4f1f1
70ec17a
ed8bae0
25e8051
ffc9726
a243614
0659cdf
e58891a
6d1ab99
710e6a8
cb931d1
049a683
9ea917e
fd41bc5
fd26a37
1ad2ec0
7fdbac0
8fc4a30
bfd28b2
4753071
6200c76
1f90c6d
9aef399
e87dba2
14ed4bc
be1961c
c92736a
ca44e59
f84a6f1
3a8da55
26e95f9
d7f21bd
c9594ee
d76fcca
c0554c5
a80141c
529cace
4a172dd
a54de21
5a731c0
53f80b3
2bb5bde
434c500
cfa5f3b
a62f2b1
d9026ef
faf338f
5d03a81
56e4756
c95c45f
18aa30b
929bd23
709e961
b51bb7b
89f45f3
b45b2f0
2028fa7
4e994e6
e590847
eb6a0c3
7eb31b6
8e787c9
8b6287f
83ebb18
6d84d36
1d8ac14
3c42841
e319aae
1c9c857
093a5d5
2ca7bb8
9d724bf
61b70d9
3a2ab18
e9c69d8
f47b1a5
5e83fe4
d5a0611
98bf7ea
818448c
b924c43
089cb14
2a2d55c
220c6d5
f7261a9
a8cd49c
44bf9d3
70b0d93
d983e97
cb05921
1883fb3
5e2d456
36b5f7d
ee15b67
16e003d
23561f7
d5ccaf7
1dd9f38
6a9b01d
a2ac9ff
0dbe4fe
d92aea4
8afc67a
c0d4d67
9706967
5e62c31
1474a66
82430af
ced6f75
649aed2
9c74b6a
5641e68
4ffc9d3
22c8ae1
ab7c471
dd1df44
57409bd
9996684
303fb03
c262df9
05df88c
3175efa
194b357
506f839
f67802f
ad35efc
fac6150
9bdf8da
b9479b1
60e83c4
ebd108e
bfd1d10
aec9f83
c36a43e
6c1d825
7b56f89
4f83355
9f6a131
9855985
30db67e
2fe2074
d069fef
b65ae16
e29f546
b34d8e2
aa1d90c
d3fdfea
27ff971
1a0ac68
5e3a6d2
a25f7b4
02d4a55
7237a3f
6410aae
20efb54
81e9fe5
dcf59ae
7eb1da4
f7daa8a
80da43c
5fc9cf1
ef5550f
5476e6d
dfde76d
ecddff7
c2bd854
4176d03
f8b2922
89dad36
e58cc4e
1877cf9
8643a24
c35b8c1
5728859
5345dfc
04ee8e2
74e8358
d8c1b7a
bafd278
284f56b
8b2ec3c
488de58
490949b
3494f17
fee7e1b
e15bdc0
2071360
c86043f
5aeba5e
ea8c9b7
d224c0c
8392c2b
80798f9
cb1e531
1f48498
ac5ffe7
27cc772
90e6588
2983d67
24ffa25
c9ec867
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,18 @@ | |
|
|
||
| """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" | ||
|
|
||
| from functools import singledispatchmethod | ||
|
|
||
| from xdsl.dialects import builtin, func | ||
| from xdsl.ir import Block, Operation, Region | ||
|
|
||
| from catalyst.python_interface.dialects import catalyst, quantum | ||
| from functools import singledispatch, singledispatchmethod | ||
|
|
||
| from pennylane.measurements import MeasurementProcess | ||
| from pennylane.operation import Operator | ||
| from xdsl.dialects import builtin, func, scf | ||
| from xdsl.ir import Block, Operation, Region, SSAValue | ||
|
|
||
| from catalyst.python_interface.dialects import quantum | ||
| from catalyst.python_interface.inspection.xdsl_conversion import ( | ||
| xdsl_to_qml_measurement, | ||
| xdsl_to_qml_op, | ||
| ) | ||
| from catalyst.python_interface.visualization.dag_builder import DAGBuilder | ||
|
|
||
|
|
||
|
|
@@ -85,6 +91,206 @@ def _visit_block(self, block: Block) -> None: | |
| for op in block.ops: | ||
| self._visit_operation(op) | ||
|
|
||
| # =================== | ||
| # QUANTUM OPERATIONS | ||
| # =================== | ||
|
|
||
| @_visit_operation.register | ||
| def _unitary( | ||
andrijapau marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, | ||
| ) -> None: | ||
| """Generic handler for unitary gates.""" | ||
|
|
||
| # Create PennyLane instance | ||
| qml_op = xdsl_to_qml_op(op) | ||
|
|
||
| # Add node to current cluster | ||
| node_uid = f"node{self._node_uid_counter}" | ||
| self.dag_builder.add_node( | ||
| uid=node_uid, | ||
| label=get_label(qml_op), | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) | ||
| shape="record", | ||
| ) | ||
| self._node_uid_counter += 1 | ||
|
|
||
| @_visit_operation.register | ||
| def _projective_measure_op(self, op: quantum.MeasureOp) -> None: | ||
mehrdad2m marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Handler for the single-qubit projective measurement operation.""" | ||
|
|
||
| # Create PennyLane instance | ||
| meas = xdsl_to_qml_measurement(op) | ||
|
|
||
| # Add node to current cluster | ||
| node_uid = f"node{self._node_uid_counter}" | ||
| self.dag_builder.add_node( | ||
| uid=node_uid, | ||
| label=get_label(meas), | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) | ||
| shape="record", | ||
| ) | ||
| self._node_uid_counter += 1 | ||
|
|
||
| # ===================== | ||
| # QUANTUM MEASUREMENTS | ||
| # ===================== | ||
|
|
||
| @_visit_operation.register | ||
| def _state_op(self, op: quantum.StateOp) -> None: | ||
| """Handler for the terminal state measurement operation.""" | ||
|
|
||
| # Create PennyLane instance | ||
| meas = xdsl_to_qml_measurement(op) | ||
|
|
||
| # Add node to current cluster | ||
| node_uid = f"node{self._node_uid_counter}" | ||
| self.dag_builder.add_node( | ||
| uid=node_uid, | ||
| label=get_label(meas), | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| fillcolor="lightpink", | ||
| color="lightpink3", | ||
| ) | ||
| self._node_uid_counter += 1 | ||
|
|
||
| @_visit_operation.register | ||
| def _statistical_measurement_ops( | ||
andrijapau marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| op: quantum.ExpvalOp | quantum.VarianceOp, | ||
| ) -> None: | ||
| """Handler for statistical measurement operations.""" | ||
|
|
||
| # Create PennyLane instance | ||
| obs_op = op.obs.owner | ||
| meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) | ||
|
|
||
| # Add node to current cluster | ||
| node_uid = f"node{self._node_uid_counter}" | ||
| self.dag_builder.add_node( | ||
| uid=node_uid, | ||
| label=get_label(meas), | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| fillcolor="lightpink", | ||
| color="lightpink3", | ||
| ) | ||
| self._node_uid_counter += 1 | ||
|
|
||
| @_visit_operation.register | ||
| def _visit_sample_and_probs_ops( | ||
|
||
| self, | ||
| op: quantum.SampleOp | quantum.ProbsOp, | ||
| ) -> None: | ||
| """Handler for sample operations.""" | ||
|
|
||
| # Create PennyLane instance | ||
| obs_op = op.obs.owner | ||
|
|
||
| # TODO: This doesn't logically make sense, but quantum.compbasis | ||
| # is obs_op and function below just pulls out the static wires | ||
| wires = xdsl_to_qml_measurement(obs_op) | ||
| meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) | ||
|
|
||
| # Add node to current cluster | ||
| node_uid = f"node{self._node_uid_counter}" | ||
| self.dag_builder.add_node( | ||
| uid=node_uid, | ||
| label=get_label(meas), | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| fillcolor="lightpink", | ||
| color="lightpink3", | ||
| ) | ||
| self._node_uid_counter += 1 | ||
|
|
||
| # ============= | ||
| # CONTROL FLOW | ||
| # ============= | ||
|
|
||
| @_visit_operation.register | ||
| def _for_op(self, operation: scf.ForOp) -> None: | ||
| """Handle an xDSL ForOp operation.""" | ||
|
|
||
| uid = f"cluster{self._cluster_uid_counter}" | ||
| self.dag_builder.add_cluster( | ||
| uid, | ||
| node_label="for loop", | ||
| label="", | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| ) | ||
| self._cluster_uid_stack.append(uid) | ||
| self._cluster_uid_counter += 1 | ||
|
|
||
| self._visit_region(operation.regions[0]) | ||
|
|
||
| self._cluster_uid_stack.pop() | ||
|
|
||
| @_visit_operation.register | ||
| def _while_op(self, operation: scf.WhileOp) -> None: | ||
| """Handle an xDSL WhileOp operation.""" | ||
| uid = f"cluster{self._cluster_uid_counter}" | ||
| self.dag_builder.add_cluster( | ||
| uid, | ||
| node_label="while loop", | ||
| label="", | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| ) | ||
| self._cluster_uid_stack.append(uid) | ||
| self._cluster_uid_counter += 1 | ||
|
|
||
| for region in operation.regions: | ||
| self._visit_region(region) | ||
|
|
||
| self._cluster_uid_stack.pop() | ||
|
|
||
| @_visit_operation.register | ||
| def _if_op(self, operation: scf.IfOp): | ||
| """Handles the scf.IfOp operation.""" | ||
| flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) | ||
|
|
||
| uid = f"cluster{self._cluster_uid_counter}" | ||
| self.dag_builder.add_cluster( | ||
| uid, | ||
| node_label="", | ||
| label="conditional", | ||
| labeljust="l", | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| ) | ||
| self._cluster_uid_stack.append(uid) | ||
| self._cluster_uid_counter += 1 | ||
|
|
||
| # Loop through each branch and visualize as a cluster | ||
| num_regions = len(flattened_if_op) | ||
| for i, (condition_ssa, region) in enumerate(flattened_if_op): | ||
| node_label = "elif" | ||
| if i == 0: | ||
| node_label = "if" | ||
| elif i == num_regions - 1: | ||
| node_label = "else" | ||
|
|
||
| uid = f"cluster{self._cluster_uid_counter}" | ||
| self.dag_builder.add_cluster( | ||
| uid, | ||
| node_label=node_label, | ||
| label="", | ||
| style="dashed", | ||
| penwidth=1, | ||
| cluster_uid=self._cluster_uid_stack[-1], | ||
| ) | ||
| self._cluster_uid_stack.append(uid) | ||
| self._cluster_uid_counter += 1 | ||
|
|
||
| # Go recursively into the branch to process internals | ||
| self._visit_region(region) | ||
|
|
||
| # Pop branch cluster after processing to ensure | ||
| # logical branches are treated as 'parallel' | ||
| self._cluster_uid_stack.pop() | ||
|
|
||
| # Pop IfOp cluster before leaving this handler | ||
| self._cluster_uid_stack.pop() | ||
|
|
||
| # ============ | ||
| # DEVICE NODE | ||
| # ============ | ||
|
|
@@ -138,3 +344,51 @@ def _func_return(self, operation: func.ReturnOp) -> None: | |
| # If we hit a func.return operation we know we are leaving | ||
| # the FuncOp's scope and so we can pop the ID off the stack. | ||
| self._cluster_uid_stack.pop() | ||
|
|
||
|
|
||
| def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: | ||
| """Recursively flattens a nested IfOp (if/elif/else chains).""" | ||
|
|
||
| condition_ssa: SSAValue = op.operands[0] | ||
| then_region, else_region = op.regions | ||
|
|
||
| # Save condition SSA in case we want to visualize it eventually | ||
| flattened_op: list[tuple[SSAValue | None, Region]] = [(condition_ssa, then_region)] | ||
|
|
||
| # Peak into else region to see if there's another IfOp | ||
| else_block: Block = else_region.block | ||
| # Completely relies on the structure that the second last operation | ||
| # will be an IfOp (seems to hold true) | ||
| if isinstance(else_block.ops.last.prev_op, scf.IfOp): | ||
| # Recursively flatten any IfOps found in said block | ||
| nested_flattened_op = _flatten_if_op(else_block.ops.last.prev_op) | ||
| flattened_op.extend(nested_flattened_op) | ||
| return flattened_op | ||
|
|
||
| # No more nested IfOps, therefore append final region | ||
| # with no SSAValue | ||
| flattened_op.extend([(None, else_region)]) | ||
| return flattened_op | ||
|
|
||
|
|
||
| @singledispatch | ||
| def get_label(op: Operator | MeasurementProcess) -> str: | ||
| """Gets the appropriate label for a PennyLane object.""" | ||
| return str(op) | ||
|
|
||
|
|
||
| @get_label.register | ||
| def _operator(op: Operator) -> str: | ||
| """Returns the appropriate label for an xDSL operation.""" | ||
| wires = list(op.wires.labels) | ||
| if wires == []: | ||
| wires_str = "all" | ||
| else: | ||
| wires_str = f"[{', '.join(map(str, wires))}]" | ||
| return f"<name> {op.name}|<wire> {wires_str}" | ||
|
|
||
|
|
||
| @get_label.register | ||
| def _mp(mp: MeasurementProcess) -> str: | ||
| """Returns the appropriate label for an xDSL operation.""" | ||
| return str(mp) | ||
andrijapau marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.