diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index debad683e8..7b78e8ffc3 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -2,6 +2,16 @@

New features since last release

+* Compiled programs can be visualized. + [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) + [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) + [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) + [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) + [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) + [(#2285)](https://github.com/PennyLaneAI/catalyst/pull/2285) + [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) + [(#2218)](https://github.com/PennyLaneAI/catalyst/pull/2218) + * Catalyst now features a unified compilation framework, which enables users and developers to design and implement compilation passes in Python in addition to C++, on the same Catalyst IR. The Python interface relies on the xDSL library to represent and manipulate programs (analogous to the MLIR library diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py new file mode 100644 index 0000000000..271c354c44 --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -0,0 +1,425 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" + +from functools import singledispatch, singledispatchmethod + +from pennylane.measurements import ExpectationMP, MeasurementProcess, ProbabilityMP, VarianceMP +from pennylane.operation import Operator +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.inspection.xdsl_conversion import ( + xdsl_to_qml_measurement, + xdsl_to_qml_op, +) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class ConstructCircuitDAG: + """Utility tool following the director pattern to build a DAG representation of a compiled quantum program. + + This tool traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) + of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module. + + **Example** + + >>> builder = PyDotDAGBuilder() + >>> director = ConstructCircuitDAG(builder) + >>> director.construct(module) + >>> director.dag_builder.to_string() + ... + """ + + def __init__(self, dag_builder: DAGBuilder) -> None: + self.dag_builder: DAGBuilder = dag_builder + + # Keep track of nesting clusters using a stack + self._cluster_uid_stack: list[str] = [] + + # Use counter internally for UID + self._node_uid_counter: int = 0 + self._cluster_uid_counter: int = 0 + + def _reset(self) -> None: + """Resets the instance.""" + self._cluster_uid_stack: list[str] = [] + self._node_uid_counter: int = 0 + self._cluster_uid_counter: int = 0 + + def construct(self, module: builtin.ModuleOp) -> None: + """Constructs the DAG from the module. + + Args: + module (xdsl.builtin.ModuleOp): The module containing the quantum program to visualize. + + """ + self._reset() + for op in module.ops: + self._visit_operation(op) + + # ============= + # IR TRAVERSAL + # ============= + + @singledispatchmethod + def _visit_operation(self, operation: Operation) -> None: + """Visit an xDSL Operation. Default to visiting each region contained in the operation.""" + for region in operation.regions: + self._visit_region(region) + + def _visit_region(self, region: Region) -> None: + """Visit an xDSL Region operation.""" + for block in region.blocks: + self._visit_block(block) + + def _visit_block(self, block: Block) -> None: + """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" + for op in block.ops: + self._visit_operation(op) + + # =================== + # QUANTUM OPERATIONS + # =================== + + @_visit_operation.register + def _gate_op( + self, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + ) -> None: + """Generic handler for unitary gates.""" + + # Create PennyLane instance + qml_op = xdsl_to_qml_op(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(qml_op), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + # ===================== + # QUANTUM MEASUREMENTS + # ===================== + + @_visit_operation.register + def _state_op(self, op: quantum.StateOp) -> None: + """Handler for the terminal state measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _expval_and_var_ops( + self, + op: quantum.ExpvalOp | quantum.VarianceOp, + ) -> None: + """Handler for statistical measurement operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _sample_counts_probs_ops( + self, + op: quantum.SampleOp | quantum.ProbsOp, + ) -> None: + """Handler for sample operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + + # TODO: This doesn't logically make sense, but quantum.compbasis + # is obs_op and function below just pulls out the static wires + wires = xdsl_to_qml_measurement(obs_op) + meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + # ============= + # CONTROL FLOW + # ============= + + @_visit_operation.register + def _for_op(self, operation: scf.ForOp) -> None: + """Handle an xDSL ForOp operation.""" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="for loop", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + self._visit_region(operation.regions[0]) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _while_op(self, operation: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation.""" + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="while loop", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + for region in operation.regions: + self._visit_region(region) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _if_op(self, operation: scf.IfOp): + """Handles the scf.IfOp operation.""" + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label="conditional", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Loop through each branch and visualize as a cluster + flattened_if_op: list[Region] = _flatten_if_op(operation) + num_regions = len(flattened_if_op) + for i, region in enumerate(flattened_if_op): + cluster_label = "elif" + if i == 0: + cluster_label = "if" + elif i == num_regions - 1: + cluster_label = "else" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + label=cluster_label, + labeljust="l", + style="dashed", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Go recursively into the branch to process internals + self._visit_region(region) + + # Pop branch cluster after processing to ensure + # logical branches are treated as 'parallel' + self._cluster_uid_stack.pop() + + # Pop IfOp cluster before leaving this handler + self._cluster_uid_stack.pop() + + # ============ + # DEVICE NODE + # ============ + + @_visit_operation.register + def _device_init(self, operation: quantum.DeviceInitOp) -> None: + """Handles the initialization of a quantum device.""" + node_id = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + node_id, + label=operation.device_name.data, + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="grey", + color="black", + penwidth=2, + shape="rectangle", + ) + self._node_uid_counter += 1 + + # ======================= + # FuncOp NESTING UTILITY + # ======================= + + @_visit_operation.register + def _func_op(self, operation: func.FuncOp) -> None: + """Visit a FuncOp Operation.""" + + label = operation.sym_name.data + if "jit_" in operation.sym_name.data: + label = "qjit" + + uid = f"cluster{self._cluster_uid_counter}" + parent_cluster_uid = None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1] + self.dag_builder.add_cluster( + uid, + label=label, + cluster_uid=parent_cluster_uid, + ) + self._cluster_uid_counter += 1 + self._cluster_uid_stack.append(uid) + + self._visit_block(operation.regions[0].blocks[0]) + + @_visit_operation.register + def _func_return(self, operation: func.ReturnOp) -> None: + """Handle func.return to exit FuncOp's cluster scope.""" + + # NOTE: Skip first cluster as it is the "base" of the graph diagram. + # In our case, it is the `qjit` bounding box. + if len(self._cluster_uid_stack) > 1: + # If we hit a func.return operation we know we are leaving + # the FuncOp's scope and so we can pop the ID off the stack. + self._cluster_uid_stack.pop() + + +def _flatten_if_op(op: scf.IfOp) -> list[Region]: + """Recursively flattens a nested IfOp (if/elif/else chains).""" + + then_region, else_region = op.regions + + flattened_op: list[Region] = [then_region] + + # Check to see if there are any nested quantum operations in the else block + else_block: Block = else_region.block + has_quantum_ops = False + nested_if_op = None + for op in else_block.ops: + if isinstance(op, scf.IfOp): + nested_if_op = op + # No need to walk this op as this will be + # recursively handled down below + continue + for internal_op in op.walk(): + if type(internal_op) in quantum.Quantum.operations: + has_quantum_ops = True + # No need to check anything else + break + + if nested_if_op and not has_quantum_ops: + # Recursively flatten any IfOps found in said block + nested_flattened_op: list[Region] = _flatten_if_op(nested_if_op) + flattened_op.extend(nested_flattened_op) + return flattened_op + + # No more nested IfOps, therefore append final region + flattened_op.append(else_region) + return flattened_op + + +@singledispatch +def get_label(op: Operator | MeasurementProcess) -> str: + """Gets the appropriate label for a PennyLane object.""" + return str(op) + + +@get_label.register +def _operator(op: Operator) -> str: + """Returns the appropriate label for PennyLane Operator""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + # Using <...> lets us use ports (https://graphviz.org/doc/info/shapes.html#record) + return f" {op.name}| {wires_str}" + + +@get_label.register +def _meas(meas: MeasurementProcess) -> str: + """Returns the appropriate label for a PennyLane MeasurementProcess using match/case.""" + + wires_str = list(meas.wires.labels) + if not wires_str: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires_str))}]" + + base_name = meas._shortname + + match meas: + case ExpectationMP() | VarianceMP() | ProbabilityMP(): + if meas.obs is not None: + obs_name = meas.obs.name + base_name = f"{base_name}({obs_name})" + + case _: + pass + + return f" {base_name}| {wires_str}" diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py new file mode 100644 index 0000000000..ae8f18e77e --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -0,0 +1,141 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the DAGBuilder abstract base class.""" + +from abc import ABC, abstractmethod +from typing import Any, TypeAlias + +ClusterUID: TypeAlias = str +NodeUID: TypeAlias = str + + +class DAGBuilder(ABC): + """An abstract base class for building Directed Acyclic Graphs (DAGs). + + This class provides a simple interface with three core methods (`add_node`, `add_edge` and `add_cluster`). + You can override these methods to implement any backend, like `pydot` or `graphviz` or even `matplotlib`. + + Outputting your graph can be done by overriding `to_file` and `to_string`. + """ + + @abstractmethod + def add_node( + self, + uid: NodeUID, + label: str, + *, + cluster_uid: ClusterUID | None = None, + **attrs: Any, + ) -> None: + """Add a single node to the graph. + + Args: + uid (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_uid (str | None): Optional unique ID of the cluster this node belongs to. If `None`, this node gets + added on the base graph. + **attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_edge(self, from_uid: NodeUID, to_uid: NodeUID, **attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_uid (str): The unique ID of the source node. + to_uid (str): The unique ID of the destination node. + **attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_cluster( + self, + uid: ClusterUID, + *, + label: str | None = None, + cluster_uid: ClusterUID | None = None, + **attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + uid (str): Unique cluster ID to identify this cluster. + label (str | None): Optional text to display as a label on the cluster when rendered. + cluster_uid (str | None): Optional unique ID of the cluster this cluster belongs to. If `None`, the cluster will be + placed on the base graph. + **attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @property + @abstractmethod + def nodes(self) -> dict[NodeUID, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. + """ + raise NotImplementedError + + @property + @abstractmethod + def edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. + """ + raise NotImplementedError + + @property + @abstractmethod + def clusters(self) -> dict[ClusterUID, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. + """ + raise NotImplementedError + + @abstractmethod + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + The implementation should ideally infer the output format + (e.g., 'png', 'svg') from this filename's extension. + + Args: + output_filename (str): Desired filename for the graph. + + """ + raise NotImplementedError + + @abstractmethod + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py new file mode 100644 index 0000000000..908cc5fc9d --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -0,0 +1,289 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the PyDotDAGBuilder subclass of DAGBuilder.""" + +import pathlib +from collections import ChainMap +from typing import Any + +from .dag_builder import DAGBuilder + +has_pydot = True +try: + import pydot + from pydot import Cluster, Dot, Edge, Graph, Node, Subgraph +except ImportError: + has_pydot = False + + +class PyDotDAGBuilder(DAGBuilder): + """A Directed Acyclic Graph builder for the PyDot backend. + + Args: + attrs (dict | None): User default attributes to be used for all elements (nodes, edges, clusters) in the graph. + node_attrs (dict | None): User default attributes for a node. + edge_attrs (dict | None): User default attributes for an edge. + cluster_attrs (dict | None): User default attributes for a cluster. + + Example: + >>> builder = PyDotDAGBuilder() + >>> builder.add_node("n0", "node 0") + >>> builder.add_cluster("c0") + >>> builder.add_node("n1", "node 1", cluster_uid="c0") + >>> print(builder.to_string()) + strict digraph G { + rankdir=TB; + compound=true; + n0 [label="node 0", fontname=Helvetica, penwidth=3, shape=ellipse, style=filled, fillcolor=lightblue, color=lightblue4]; + subgraph cluster_c0 { + fontname=Helvetica; + penwidth=2; + shape=rectangle; + style=solid; + } + + n1 [label="node 1", fontname=Helvetica, penwidth=3, shape=ellipse, style=filled, fillcolor=lightblue, color=lightblue4, cluster_uid=c0]; + } + + """ + + def __init__( + self, + attrs: dict | None = None, + node_attrs: dict | None = None, + edge_attrs: dict | None = None, + cluster_attrs: dict | None = None, + ) -> None: + # Initialize the pydot graph: + # - graph_type="digraph": Create a directed graph (edges have arrows). + # - rankdir="TB": Set layout direction from Top to Bottom. + # - compound="true": Allow edges to connect directly to clusters/subgraphs. + # - strict=True: Prevent duplicate edges (e.g., A -> B added twice). + # - splines="ortho": Edges connecting clusters are orthogonal + self.graph: Dot = Dot( + graph_type="digraph", rankdir="TB", compound="true", strict=True, splines="ortho" + ) + + # Use internal cache that maps cluster ID to actual pydot (Dot or Cluster) object + # NOTE: This is needed so we don't need to traverse the graph to find the relevant + # cluster object to modify + self._subgraph_cache: dict[str, Graph] = {} + + # Internal state for graph structure + self._nodes: dict[str, dict[str, Any]] = {} + self._edges: list[dict[str, Any]] = [] + self._clusters: dict[str, dict[str, Any]] = {} + + _default_attrs: dict = ( + {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + ) + self._default_node_attrs: dict = ( + { + **_default_attrs, + "shape": "ellipse", + "style": "filled", + "fillcolor": "lightblue", + "color": "lightblue4", + "penwidth": 3, + } + if node_attrs is None + else node_attrs + ) + self._default_edge_attrs: dict = ( + { + "color": "lightblue4", + "penwidth": 3, + } + if edge_attrs is None + else edge_attrs + ) + self._default_cluster_attrs: dict = ( + { + **_default_attrs, + "shape": "rectangle", + "style": "solid", + } + if cluster_attrs is None + else cluster_attrs + ) + + def add_node( + self, + uid: str, + label: str, + cluster_uid: str | None = None, + **attrs: Any, + ) -> None: + """Add a single node to the graph. + + Args: + uid (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_uid (str | None): Optional unique ID of the cluster this node belongs to. + **attrs (Any): Any additional styling keyword arguments. + + Raises: + ValueError: Node ID is already present in the graph. + + """ + if uid in self.nodes: + raise ValueError(f"Node ID {uid} already present in graph.") + + # Use ChainMap so you don't need to construct a new dictionary + node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) + node = Node(uid, label=label, **node_attrs) + + # Add node to cluster + if cluster_uid is None: + self.graph.add_node(node) + else: + self._subgraph_cache[cluster_uid].add_node(node) + + self._nodes[uid] = { + "uid": uid, + "label": label, + "cluster_uid": cluster_uid, + "attrs": dict(node_attrs), + } + + def add_edge(self, from_uid: str, to_uid: str, **attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_uid (str): The unique ID of the source node. + to_uid (str): The unique ID of the destination node. + **attrs (Any): Any additional styling keyword arguments. + + Raises: + ValueError: Source and destination have the same ID + ValueError: Source is not found in the graph. + ValueError: Destination is not found in the graph. + + """ + if from_uid == to_uid: + raise ValueError("Edges must connect two unique IDs.") + if from_uid not in self.nodes: + raise ValueError("Source is not found in the graph.") + if to_uid not in self.nodes: + raise ValueError("Destination is not found in the graph.") + + # Use ChainMap so you don't need to construct a new dictionary + edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs) + edge = Edge(from_uid, to_uid, **edge_attrs) + + self.graph.add_edge(edge) + + self._edges.append( + {"from_uid": from_uid, "to_uid": to_uid, "attrs": dict(edge_attrs)} + ) + + def add_cluster( + self, + uid: str, + label: str | None = None, + cluster_uid: str | None = None, + **attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + uid (str): Unique cluster ID to identify this cluster. + label (str | None): Optional text to display as a label on the cluster when rendered. + cluster_uid (str | None): Optional unique ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. + **attrs (Any): Any additional styling keyword arguments. + + Raises: + ValueError: Cluster ID is already present in the graph. + """ + if uid in self.clusters: + raise ValueError(f"Cluster ID {uid} already present in graph.") + + # Use ChainMap so you don't need to construct a new dictionary + cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) + cluster = Cluster(uid, label=label, **cluster_attrs) + + # Record new cluster + self._subgraph_cache[uid] = cluster + + # Add node to cluster + if cluster_uid is None: + self.graph.add_subgraph(cluster) + else: + self._subgraph_cache[cluster_uid].add_subgraph(cluster) + + self._clusters[uid] = { + "uid": uid, + "label": label, + "cluster_uid": cluster_uid, + "attrs": dict(cluster_attrs), + } + + @property + def nodes(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. + """ + return self._nodes + + @property + def edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. + """ + return self._edges + + @property + def clusters(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. + """ + return self._clusters + + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + This method will infer the file's format (e.g., 'png', 'svg') from this filename's extension. + If no extension is provided, the 'png' format will be the default. + + Args: + output_filename (str): Desired filename for the graph. File extension can be included + and if no file extension is provided, it will default to a `.png` file. + + """ + output_filename_path: pathlib.Path = pathlib.Path(output_filename) + if not output_filename_path.suffix: + output_filename_path = output_filename_path.with_suffix(".png") + + format = output_filename_path.suffix[1:].lower() + + self.graph.write(str(output_filename_path), format=format) + + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + return self.graph.to_string() diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py new file mode 100644 index 0000000000..d39bee45f1 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -0,0 +1,974 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ConstructCircuitDAG utility.""" + +from unittest.mock import Mock + +import pytest + +pytestmark = pytest.mark.xdsl +xdsl = pytest.importorskip("xdsl") + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +import pennylane as qml +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region + +from catalyst import measure +from catalyst.python_interface.conversion import xdsl_from_qjit +from catalyst.python_interface.visualization.construct_circuit_dag import ( + ConstructCircuitDAG, + get_label, +) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class FakeDAGBuilder(DAGBuilder): + """ + A concrete implementation of DAGBuilder used ONLY for testing. + It stores all graph manipulation calls in data structures + for easy assertion of the final graph state. + """ + + def __init__(self): + self._nodes = {} + self._edges = [] + self._clusters = {} + + def add_node(self, uid, label, cluster_uid=None, **attrs) -> None: + self._nodes[uid] = { + "uid": uid, + "label": label, + "parent_cluster_uid": cluster_uid, + "attrs": attrs, + } + + def add_edge(self, from_uid: str, to_uid: str, **attrs) -> None: + self._edges.append( + { + "from_uid": from_uid, + "to_uid": to_uid, + "attrs": attrs, + } + ) + + def add_cluster( + self, + uid, + label=None, + cluster_uid=None, + **attrs, + ) -> None: + self._clusters[uid] = { + "uid": uid, + "label": label, + "parent_cluster_uid": cluster_uid, + "attrs": attrs, + } + + @property + def nodes(self): + return self._nodes + + @property + def edges(self): + return self._edges + + @property + def clusters(self): + return self._clusters + + def to_file(self, output_filename): + pass + + def to_string(self) -> str: + return "graph" + + +@pytest.mark.unit +def test_dependency_injection(): + """Tests that relevant dependencies are injected.""" + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + assert utility.dag_builder is mock_dag_builder + + +@pytest.mark.unit +def test_does_not_mutate_module(): + """Test that the module is not mutated.""" + + # Create module + op = test.TestOp() + block = Block(ops=[op]) + region = Region(blocks=[block]) + container_op = test.TestOp(regions=[region]) + module_op = ModuleOp(ops=[container_op]) + + # Save state before + module_op_str_before = str(module_op) + + # Process module + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + utility.construct(module_op) + + # Ensure not mutated + assert str(module_op) == module_op_str_before + + +@pytest.mark.unit +class TestFuncOpVisualization: + """Tests the visualization of FuncOps with bounding boxes""" + + def test_standard_qnode(self): + """Tests that a standard QJIT'd QNode is visualized correctly""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + graph_clusters = utility.dag_builder.clusters + + # Check nesting is correct + # graph + # └── qjit + # └── my_workflow + + # Check qjit is nested under graph + assert graph_clusters["cluster0"]["label"] == "qjit" + assert graph_clusters["cluster0"]["parent_cluster_uid"] is None + + # Check that my_workflow is under qjit + assert graph_clusters["cluster1"]["label"] == "my_workflow" + assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0" + + def test_nested_qnodes(self): + """Tests that nested QJIT'd QNodes are visualized correctly""" + + dev = qml.device("null.qubit", wires=1) + + @qml.qnode(dev) + def my_qnode2(): + qml.X(0) + + @qml.qnode(dev) + def my_qnode1(): + qml.H(0) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + def my_workflow(): + my_qnode1() + my_qnode2() + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + graph_clusters = utility.dag_builder.clusters + + # Check nesting is correct + # graph + # └── qjit + # ├── my_qnode1 + # └── my_qnode2 + + # Check qjit is under graph + assert graph_clusters["cluster0"]["label"] == "qjit" + assert graph_clusters["cluster0"]["parent_cluster_uid"] is None + + # Check both qnodes are under my_workflow + assert graph_clusters["cluster1"]["label"] == "my_qnode1" + assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0" + + assert graph_clusters["cluster2"]["label"] == "my_qnode2" + assert graph_clusters["cluster2"]["parent_cluster_uid"] == "cluster0" + + +class TestDeviceNode: + """Tests that the device node is correctly visualized.""" + + def test_standard_qnode(self): + """Tests that a standard setup works.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + graph_nodes = utility.dag_builder.nodes + graph_clusters = utility.dag_builder.clusters + + # Check nesting is correct + # graph + # └── qjit + # └── my_workflow: NullQubit + + # Assert device node is inside my_workflow cluster + assert graph_clusters["cluster1"]["label"] == "my_workflow" + assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1" + + # Assert label is as expected + assert graph_nodes["node0"]["label"] == "NullQubit" + + def test_nested_qnodes(self): + """Tests that nested QJIT'd QNodes are visualized correctly""" + + dev1 = qml.device("null.qubit", wires=1) + dev2 = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev2) + def my_qnode2(): + qml.X(0) + + @qml.qnode(dev1) + def my_qnode1(): + qml.H(0) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + def my_workflow(): + my_qnode1() + my_qnode2() + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + graph_nodes = utility.dag_builder.nodes + graph_clusters = utility.dag_builder.clusters + + # Check nesting is correct + # graph + # └── qjit + # ├── my_qnode1: NullQubit + # └── my_qnode2: LightningSimulator + + # Assert lightning.qubit device node is inside my_qnode1 cluster + assert graph_clusters["cluster1"]["label"] == "my_qnode1" + assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1" + + # Assert label is as expected + assert graph_nodes["node0"]["label"] == "NullQubit" + + # Assert null qubit device node is inside my_qnode2 cluster + assert graph_clusters["cluster2"]["label"] == "my_qnode2" + # NOTE: node1 is the qml.H(0) in my_qnode1 + assert graph_nodes["node2"]["parent_cluster_uid"] == "cluster2" + + # Assert label is as expected + assert graph_nodes["node2"]["label"] == "LightningSimulator" + + +class TestForOp: + """Tests that the for loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Tests that the for loop cluster can be visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(3): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested for loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(0, 5, 2): + for j in range(1, 6, 2): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "for loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestWhileOp: + """Tests that the while loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the while loop is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + counter = 0 + while counter < 5: + qml.H(0) + counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested while loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + outer_counter = 0 + inner_counter = 0 + while outer_counter < 5: + while inner_counter < 6: + qml.H(0) + inner_counter += 1 + outer_counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "while loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestIfOp: + """Tests that the conditional control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 2: + qml.X(0) + else: + qml.Y(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within cluster1 (my_workflow) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within cluster2 (conditional) + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "else" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_if_elif_else_conditional(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + qml.Z(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within my_workflow + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within conditional + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_nested_conditionals(self): + """Tests that nested conditionals are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x, y): + if x == 1: + if y == 2: + qml.H(0) + else: + qml.Z(0) + qml.X(0) + else: + qml.Z(0) + + args = (1, 2) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # cluster2 -> conditional (1) + # cluster3 -> if + # cluster4 -> conditional () + # cluster5 -> if + # cluster6 -> else + # cluster7 -> else + + # Check first conditional is a cluster within my_workflow + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check 'if' cluster of first conditional has another conditional + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + # Second conditional + assert clusters["cluster4"]["label"] == "conditional" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster3" + # Check 'if' and 'else' in second conditional + assert clusters["cluster5"]["label"] == "if" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster4" + assert clusters["cluster6"]["label"] == "else" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4" + + # Check nested if / else is within the first if cluster + assert clusters["cluster7"]["label"] == "else" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2" + + def test_nested_conditionals_with_quantum_ops(self): + """Tests that nested conditionals are unflattend if quantum operations + are present""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + qml.Z(0) + if x == 3: + qml.RX(0, 0) + elif x == 4: + qml.RY(0, 0) + else: + qml.RZ(0, 0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # node0 -> NullQubit + # cluster2 -> conditional (1) + # cluster3 -> if + # node1 -> X(0) + # cluster4 -> elif + # node2 -> Y(0) + # cluster5 -> else + # node3 -> Z(0) + # cluster6 -> conditional (2) + # cluster7 -> if + # node4 -> RX(0,0) + # cluster8 -> elif + # node5 -> RY(0,0) + # cluster9 -> else + # node6 -> RZ(0,0) + + # check outer conditional (1) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + # Nested conditional (2) inside conditional (1) + assert clusters["cluster6"]["label"] == "conditional" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5" + assert clusters["cluster7"]["label"] == "if" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster6" + assert clusters["cluster8"]["label"] == "elif" + assert clusters["cluster8"]["parent_cluster_uid"] == "cluster6" + assert clusters["cluster9"]["label"] == "else" + assert clusters["cluster9"]["parent_cluster_uid"] == "cluster6" + + def test_nested_conditionals_with_nested_quantum_ops(self): + """Tests that nested conditionals are unflattend if quantum operations + are present but nested in other operations""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + for i in range(3): + qml.Z(0) + if x == 3: + qml.RX(0, 0) + elif x == 4: + qml.RY(0, 0) + else: + qml.RZ(0, 0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # node0 -> NullQubit + # cluster2 -> conditional (1) + # cluster3 -> if + # node1 -> X(0) + # cluster4 -> elif + # node2 -> Y(0) + # cluster5 -> else + # cluster6 -> for loop + # node3 -> Z(0) + # cluster7 -> conditional (2) + # cluster8 -> if + # node4 -> RX(0,0) + # cluster9 -> elif + # node5 -> RY(0,0) + # cluster10 -> else + # node6 -> RZ(0,0) + + # check outer conditional (1) + assert clusters["cluster2"]["label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + # Nested conditional (2) inside conditional (1) + assert clusters["cluster6"]["label"] == "for loop" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5" + + assert clusters["cluster7"]["label"] == "conditional" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster5" + assert clusters["cluster8"]["label"] == "if" + assert clusters["cluster8"]["parent_cluster_uid"] == "cluster7" + assert clusters["cluster9"]["label"] == "elif" + assert clusters["cluster9"]["parent_cluster_uid"] == "cluster7" + assert clusters["cluster10"]["label"] == "else" + assert clusters["cluster10"]["parent_cluster_uid"] == "cluster7" + + +class TestGetLabel: + """Tests the get_label utility.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "op, label", + [ + (qml.H(0), " Hadamard| [0]"), + ( + qml.QubitUnitary([[0, 1], [1, 0]], 0), + " QubitUnitary| [0]", + ), + (qml.SWAP([0, 1]), " SWAP| [0, 1]"), + ], + ) + def test_standard_operator(self, op, label): + """Tests against an operator instance.""" + assert get_label(op) == label + + def test_global_phase_operator(self): + """Tests against a GlobalPhase operator instance.""" + assert get_label(qml.GlobalPhase(0.5)) == f" GlobalPhase| all" + + @pytest.mark.unit + @pytest.mark.parametrize( + "meas, label", + [ + (qml.state(), " state| all"), + (qml.expval(qml.Z(0)), " expval(PauliZ)| [0]"), + (qml.var(qml.Z(0)), " var(PauliZ)| [0]"), + (qml.probs(), " probs| all"), + (qml.probs(wires=0), " probs| [0]"), + (qml.probs(wires=[0, 1]), " probs| [0, 1]"), + (qml.sample(), " sample| all"), + (qml.sample(wires=0), " sample| [0]"), + (qml.sample(wires=[0, 1]), " sample| [0, 1]"), + ], + ) + def test_standard_measurement(self, meas, label): + """Tests against an operator instance.""" + assert get_label(meas) == label + + +class TestCreateStaticOperatorNodes: + """Tests that operators with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) + def test_custom_op(self, op): + """Tests that the CustomOp operation node can be created and visualized.""" + + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Make sure label has relevant info + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.GlobalPhase(0.5), + qml.GlobalPhase(0.5, wires=0), + qml.GlobalPhase(0.5, wires=[0, 1]), + ], + ) + def test_global_phase_op(self, op): + """Test that GlobalPhase can be handled.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Compiler throws out the wires and they get converted to wires=[] no matter what + assert nodes["node1"]["label"] == get_label(qml.GlobalPhase(0.5)) + + @pytest.mark.unit + def test_qubit_unitary_op(self): + """Test that QubitUnitary operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) + + @pytest.mark.unit + def test_multi_rz_op(self): + """Test that MultiRZ operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.MultiRZ(0.5, wires=[0]) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0])) + + @pytest.mark.unit + def test_projective_measurement_op(self): + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == f" MidMeasureMP| [0]" + + +class TestCreateStaticMeasurementNodes: + """Tests that measurements with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + def test_state_op(self): + """Test that qml.state can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.state() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.state()) + + @pytest.mark.unit + @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) + def test_expval_var_measurement_op(self, meas_fn): + """Test that statistical measurement operators can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas_fn(qml.Z(0)) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(meas_fn(qml.Z(0))) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + ], + ) + def test_probs_measurement_op(self, op): + """Tests that the probs measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_valid_sample_measurement_op(self, op): + """Tests that the sample measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op) diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py new file mode 100644 index 0000000000..d6de2d4011 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -0,0 +1,113 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the DAGBuilder abstract base class.""" + +from typing import Any + +import pytest + +pytestmark = pytest.mark.xdsl +xdsl = pytest.importorskip("xdsl") + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +def test_concrete_implementation_works(): + """Unit test for concrete implementation of abc.""" + + # pylint: disable=unused-argument + class ConcreteDAGBuilder(DAGBuilder): + """Concrete subclass of an ABC for testing purposes.""" + + def add_node( + self, + uid: str, + label: str, + cluster_id: str | None = None, + **attrs: Any, + ) -> None: + return + + def add_edge(self, from_uid: str, to_uid: str, **attrs: Any) -> None: + return + + def add_cluster( + self, + uid: str, + label: str | None = None, + cluster_id: str | None = None, + **attrs: Any, + ) -> None: + return + + @property + def nodes(self) -> dict[str, dict[str, Any]]: + return {} + + @property + def edges(self) -> list[dict[str, Any]]: + return [] + + @property + def clusters(self) -> dict[str, dict[str, Any]]: + return {} + + def to_file(self, output_filename: str) -> None: + return + + def to_string(self) -> str: + return "test" + + dag_builder = ConcreteDAGBuilder() + # pylint: disable = assignment-from-none + node = dag_builder.add_node("0", "node0") + edge = dag_builder.add_edge("0", "1") + cluster = dag_builder.add_cluster("0") + nodes = dag_builder.nodes + edges = dag_builder.edges + clusters = dag_builder.clusters + render = dag_builder.to_file("test.png") + string = dag_builder.to_string() + + assert node is None + assert nodes == {} + assert edge is None + assert edges == [] + assert cluster is None + assert clusters == {} + assert render is None + assert string == "test" + + +def test_abc_cannot_be_instantiated(): + """Tests that the DAGBuilder ABC cannot be instantiated.""" + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + DAGBuilder() + + +def test_incomplete_subclass(): + """Tests that an incomplete subclass will fail""" + + # pylint: disable=too-few-public-methods + class IncompleteDAGBuilder(DAGBuilder): + def add_node(self, *args, **kwargs): + pass + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + IncompleteDAGBuilder() diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py new file mode 100644 index 0000000000..e7255bfbd6 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -0,0 +1,414 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the PyDotDAGBuilder subclass.""" + +from unittest.mock import MagicMock + +import pytest + +pydot = pytest.importorskip("pydot") + +pytestmark = pytest.mark.xdsl +xdsl = pytest.importorskip("xdsl") +# pylint: disable=wrong-import-position +from catalyst.python_interface.visualization.pydot_dag_builder import PyDotDAGBuilder + + +@pytest.mark.unit +def test_initialization_defaults(): + """Tests the default graph attributes are as expected.""" + + dag_builder = PyDotDAGBuilder() + + assert isinstance(dag_builder.graph, pydot.Dot) + # Ensure it's a directed graph + assert dag_builder.graph.get_graph_type() == "digraph" + # Ensure that it flows top to bottom + assert dag_builder.graph.get_rankdir() == "TB" + # Ensure edges can be connected directly to clusters / subgraphs + assert dag_builder.graph.get_compound() == "true" + # Ensure duplicated edges cannot be added + assert dag_builder.graph.obj_dict["strict"] is True + # Ensure edges are orthogonal + assert dag_builder.graph.obj_dict["attributes"]["splines"]== "ortho" + + +class TestExceptions: + """Tests the various exceptions defined in the class.""" + + def test_duplicate_node_ids(self): + """Tests that a ValueError is raised for duplicate nodes.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Node ID 0 already present in graph."): + dag_builder.add_node("0", "node1") + + def test_edge_duplicate_source_destination(self): + """Tests that a ValueError is raised when an edge is created with the + same source and destination""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Edges must connect two unique IDs."): + dag_builder.add_edge("0", "0") + + def test_edge_missing_ids(self): + """Tests that an error is raised if IDs are missing.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Destination is not found in the graph."): + dag_builder.add_edge("0", "1") + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("1", "node1") + with pytest.raises(ValueError, match="Source is not found in the graph."): + dag_builder.add_edge("0", "1") + + def test_duplicate_cluster_id(self): + """Tests that an exception is raised if an ID is already present.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_cluster("0") + with pytest.raises(ValueError, match="Cluster ID 0 already present in graph."): + dag_builder.add_cluster("0") + + +class TestAddMethods: + """Test that elements can be added to the graph.""" + + @pytest.mark.unit + def test_add_node(self): + """Unit test the `add_node` method.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + node_list = dag_builder.graph.get_node_list() + assert len(node_list) == 1 + assert node_list[0].get_label() == "node0" + + @pytest.mark.unit + def test_add_edge(self): + """Unit test the `add_edge` method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + + assert len(dag_builder.graph.get_edges()) == 1 + edge = dag_builder.graph.get_edges()[0] + assert edge.get_source() == "0" + assert edge.get_destination() == "1" + + @pytest.mark.unit + def test_add_cluster(self): + """Unit test the 'add_cluster' method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0") + + assert len(dag_builder.graph.get_subgraphs()) == 1 + assert dag_builder.graph.get_subgraphs()[0].get_name() == "cluster_0" + + @pytest.mark.unit + def test_add_node_to_parent_graph(self): + """Tests that you can add a node to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Create node + dag_builder.add_node("0", "node0") + + # Create cluster + dag_builder.add_cluster("c0") + + # Create node inside cluster + dag_builder.add_node("1", "node1", cluster_uid="c0") + + # Verify graph structure + root_graph = dag_builder.graph + + # Make sure the base graph has node0 + assert root_graph.get_node("0"), "Node 0 not found in root graph" + + # Get the cluster and verify it has node1 and not node0 + cluster_list = root_graph.get_subgraph("cluster_c0") + assert cluster_list, "Subgraph 'cluster_c0' not found" + cluster_graph = cluster_list[0] # Get the actual subgraph object + + assert cluster_graph.get_node("1"), "Node 1 not found in cluster 'c0'" + assert not cluster_graph.get_node("0"), ( + "Node 0 was incorrectly added to cluster" + ) + + assert not root_graph.get_node("1"), "Node 1 was incorrectly added to root" + + @pytest.mark.unit + def test_add_cluster_to_parent_graph(self): + """Test that you can add a cluster to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Level 0 (Root): Adds cluster on top of base graph + dag_builder.add_node("n_root", "node_root") + + # Level 1 (c0): Add node on outer cluster + dag_builder.add_cluster("c0") + dag_builder.add_node("n_outer", "node_outer", cluster_uid="c0") + + # Level 2 (c1): Add node on inner cluster + dag_builder.add_cluster("c1", cluster_uid="c0") + dag_builder.add_node("n_inner", "node_inner", cluster_uid="c1") + + root_graph = dag_builder.graph + + outer_cluster_list = root_graph.get_subgraph("cluster_c0") + assert outer_cluster_list, "Outer cluster 'c0' not found in root" + c0 = outer_cluster_list[0] + + inner_cluster_list = c0.get_subgraph("cluster_c1") + assert inner_cluster_list, "Inner cluster 'c1' not found in 'c0'" + c1 = inner_cluster_list[0] + + # Check Level 0 (Root) + assert root_graph.get_node("n_root"), "n_root not found in root" + assert root_graph.get_subgraph("cluster_c0"), "c0 not found in root" + assert not root_graph.get_node("n_outer"), "n_outer incorrectly found in root" + assert not root_graph.get_node("n_inner"), "n_inner incorrectly found in root" + assert not root_graph.get_subgraph("cluster_c1"), "c1 incorrectly found in root" + + # Check Level 1 (c0) + assert c0.get_node("n_outer"), "n_outer not found in c0" + assert c0.get_subgraph("cluster_c1"), "c1 not found in c0" + assert not c0.get_node("n_root"), "n_root incorrectly found in c0" + assert not c0.get_node("n_inner"), "n_inner incorrectly found in c0" + + # Check Level 2 (c1) + assert c1.get_node("n_inner"), "n_inner not found in c1" + assert not c1.get_node("n_root"), "n_root incorrectly found in c1" + assert not c1.get_node("n_outer"), "n_outer incorrectly found in c1" + + +class TestAttributes: + """Tests that the attributes for elements in the graph are overridden correctly.""" + + @pytest.mark.unit + def test_default_graph_attrs(self): + """Test that default graph attributes can be set.""" + + dag_builder = PyDotDAGBuilder(attrs={"fontname": "Times"}) + + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fontname") == "Times" + + dag_builder.add_cluster("1") + cluster = dag_builder.graph.get_subgraphs()[0] + assert cluster.get("fontname") == "Times" + + @pytest.mark.unit + def test_add_node_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder(attrs={"fillcolor": "lightblue", "penwidth": 3}) + + # Defaults + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fillcolor") == "lightblue" + assert node0.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_node("1", "node1", fillcolor="red", penwidth=4) + node1 = dag_builder.graph.get_node("1")[0] + assert node1.get("fillcolor") == "red" + assert node1.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_edge_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder(attrs={"color": "lightblue4", "penwidth": 3}) + + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + edge = dag_builder.graph.get_edges()[0] + # Defaults defined earlier + assert edge.get("color") == "lightblue4" + assert edge.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_edge("0", "1", color="red", penwidth=4) + edge = dag_builder.graph.get_edges()[1] + assert edge.get("color") == "red" + assert edge.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_cluster_with_attrs(self): + """Tests that default cluster attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder( + attrs={ + "style": "solid", + "fillcolor": None, + "penwidth": 2, + "fontname": "Helvetica", + } + ) + + dag_builder.add_cluster("0") + cluster1 = dag_builder.graph.get_subgraph("cluster_0")[0] + + # Defaults + assert cluster1.get("style") == "solid" + assert cluster1.get("fillcolor") is None + assert cluster1.get("penwidth") == 2 + assert cluster1.get("fontname") == "Helvetica" + + dag_builder.add_cluster("1", style="filled", penwidth=10, fillcolor="red") + cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] + + # Make sure we can override + assert cluster2.get("style") == "filled" + assert cluster2.get("penwidth") == 10 + assert cluster2.get("fillcolor") == "red" + + # Check that other defaults are still present + assert cluster2.get("fontname") == "Helvetica" + + +class TestProperties: + """Tests the properties.""" + + def test_nodes(self): + """Tests that nodes works.""" + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0", fillcolor="red") + dag_builder.add_cluster("c0") + dag_builder.add_node("1", "node1", cluster_uid="c0") + + nodes = dag_builder.nodes + + assert len(nodes) == 2 + assert len(nodes["0"]) == 4 + + assert nodes["0"]["uid"] == "0" + assert nodes["0"]["label"] == "node0" + assert nodes["0"]["cluster_uid"] == None + assert nodes["0"]["attrs"]["fillcolor"] == "red" + + assert nodes["1"]["uid"] == "1" + assert nodes["1"]["label"] == "node1" + assert nodes["1"]["cluster_uid"] == "c0" + + def test_edges(self): + """Tests that edges works.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1", penwidth=10) + + edges = dag_builder.edges + + assert len(edges) == 1 + + assert edges[0]["from_uid"] == "0" + assert edges[0]["to_uid"] == "1" + assert edges[0]["attrs"]["penwidth"] == 10 + + def test_clusters(self): + """Tests that clusters property works.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0", "my_cluster", penwidth=10) + + clusters = dag_builder.clusters + + dag_builder.add_cluster( + "1", "my_nested_cluster", cluster_uid="0", + ) + clusters = dag_builder.clusters + assert len(clusters) == 2 + + assert len(clusters["0"]) == 4 + assert clusters["0"]["uid"] == "0" + assert clusters["0"]["label"] == "my_cluster" + assert clusters["0"]["cluster_uid"] == None + assert clusters["0"]["attrs"]["penwidth"] == 10 + + assert len(clusters["1"]) == 4 + assert clusters["1"]["uid"] == "1" + assert clusters["1"]["label"] == "my_nested_cluster" + assert clusters["1"]["cluster_uid"] == "0" + + +class TestOutput: + """Test that the graph can be outputted correctly.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "filename, format", + [("my_graph", None), ("my_graph", "png"), ("prototype.trial1", "png")], + ) + def test_to_file(self, monkeypatch, filename, format): + """Tests that the `to_file` method works correctly.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(filename + "." + (format or "png")) + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with( + filename + "." + (format or "png"), format=format or "png" + ) + + @pytest.mark.unit + @pytest.mark.parametrize("format", ["pdf", "svg", "jpeg"]) + def test_other_supported_formats(self, monkeypatch, format): + """Tests that the `to_file` method works with other formats.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(f"my_graph.{format}") + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with(f"my_graph.{format}", format=format) + + @pytest.mark.unit + def test_to_string(self): + """Tests that the `to_string` method works correclty.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("n0", "node0") + dag_builder.add_node("n1", "node1") + dag_builder.add_edge("n0", "n1") + + string = dag_builder.to_string() + assert isinstance(string, str) + + # make sure important things show up in the string + assert "digraph" in string + assert "n0" in string + assert "n1" in string + assert "n0 -> n1" in string diff --git a/requirements.txt b/requirements.txt index 2a2db94191..3fde869394 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,3 +34,4 @@ nbmake # optional rt/test dependencies pennylane-lightning-kokkos amazon-braket-pennylane-plugin>1.27.1 +pydot