diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index debad683e8..7b78e8ffc3 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -2,6 +2,16 @@
New features since last release
+* Compiled programs can be visualized.
+ [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213)
+ [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229)
+ [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214)
+ [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246)
+ [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231)
+ [(#2285)](https://github.com/PennyLaneAI/catalyst/pull/2285)
+ [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234)
+ [(#2218)](https://github.com/PennyLaneAI/catalyst/pull/2218)
+
* Catalyst now features a unified compilation framework, which enables users and developers to design
and implement compilation passes in Python in addition to C++, on the same Catalyst IR. The Python
interface relies on the xDSL library to represent and manipulate programs (analogous to the MLIR library
diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py
new file mode 100644
index 0000000000..271c354c44
--- /dev/null
+++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py
@@ -0,0 +1,425 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module."""
+
+from functools import singledispatch, singledispatchmethod
+
+from pennylane.measurements import ExpectationMP, MeasurementProcess, ProbabilityMP, VarianceMP
+from pennylane.operation import Operator
+from xdsl.dialects import builtin, func, scf
+from xdsl.ir import Block, Operation, Region, SSAValue
+
+from catalyst.python_interface.dialects import quantum
+from catalyst.python_interface.inspection.xdsl_conversion import (
+ xdsl_to_qml_measurement,
+ xdsl_to_qml_op,
+)
+from catalyst.python_interface.visualization.dag_builder import DAGBuilder
+
+
+class ConstructCircuitDAG:
+ """Utility tool following the director pattern to build a DAG representation of a compiled quantum program.
+
+ This tool traverses an xDSL module and constructs a Directed Acyclic Graph (DAG)
+ of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module.
+
+ **Example**
+
+ >>> builder = PyDotDAGBuilder()
+ >>> director = ConstructCircuitDAG(builder)
+ >>> director.construct(module)
+ >>> director.dag_builder.to_string()
+ ...
+ """
+
+ def __init__(self, dag_builder: DAGBuilder) -> None:
+ self.dag_builder: DAGBuilder = dag_builder
+
+ # Keep track of nesting clusters using a stack
+ self._cluster_uid_stack: list[str] = []
+
+ # Use counter internally for UID
+ self._node_uid_counter: int = 0
+ self._cluster_uid_counter: int = 0
+
+ def _reset(self) -> None:
+ """Resets the instance."""
+ self._cluster_uid_stack: list[str] = []
+ self._node_uid_counter: int = 0
+ self._cluster_uid_counter: int = 0
+
+ def construct(self, module: builtin.ModuleOp) -> None:
+ """Constructs the DAG from the module.
+
+ Args:
+ module (xdsl.builtin.ModuleOp): The module containing the quantum program to visualize.
+
+ """
+ self._reset()
+ for op in module.ops:
+ self._visit_operation(op)
+
+ # =============
+ # IR TRAVERSAL
+ # =============
+
+ @singledispatchmethod
+ def _visit_operation(self, operation: Operation) -> None:
+ """Visit an xDSL Operation. Default to visiting each region contained in the operation."""
+ for region in operation.regions:
+ self._visit_region(region)
+
+ def _visit_region(self, region: Region) -> None:
+ """Visit an xDSL Region operation."""
+ for block in region.blocks:
+ self._visit_block(block)
+
+ def _visit_block(self, block: Block) -> None:
+ """Visit an xDSL Block operation, dispatching handling for each contained Operation."""
+ for op in block.ops:
+ self._visit_operation(op)
+
+ # ===================
+ # QUANTUM OPERATIONS
+ # ===================
+
+ @_visit_operation.register
+ def _gate_op(
+ self,
+ op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp,
+ ) -> None:
+ """Generic handler for unitary gates."""
+
+ # Create PennyLane instance
+ qml_op = xdsl_to_qml_op(op)
+
+ # Add node to current cluster
+ node_uid = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ uid=node_uid,
+ label=get_label(qml_op),
+ cluster_uid=self._cluster_uid_stack[-1],
+ # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
+ shape="record",
+ )
+ self._node_uid_counter += 1
+
+ @_visit_operation.register
+ def _projective_measure_op(self, op: quantum.MeasureOp) -> None:
+ """Handler for the single-qubit projective measurement operation."""
+
+ # Create PennyLane instance
+ meas = xdsl_to_qml_measurement(op)
+
+ # Add node to current cluster
+ node_uid = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ uid=node_uid,
+ label=get_label(meas),
+ cluster_uid=self._cluster_uid_stack[-1],
+ # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
+ shape="record",
+ )
+ self._node_uid_counter += 1
+
+ # =====================
+ # QUANTUM MEASUREMENTS
+ # =====================
+
+ @_visit_operation.register
+ def _state_op(self, op: quantum.StateOp) -> None:
+ """Handler for the terminal state measurement operation."""
+
+ # Create PennyLane instance
+ meas = xdsl_to_qml_measurement(op)
+
+ # Add node to current cluster
+ node_uid = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ uid=node_uid,
+ label=get_label(meas),
+ cluster_uid=self._cluster_uid_stack[-1],
+ fillcolor="lightpink",
+ color="lightpink3",
+ # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
+ shape="record",
+ )
+ self._node_uid_counter += 1
+
+ @_visit_operation.register
+ def _expval_and_var_ops(
+ self,
+ op: quantum.ExpvalOp | quantum.VarianceOp,
+ ) -> None:
+ """Handler for statistical measurement operations."""
+
+ # Create PennyLane instance
+ obs_op = op.obs.owner
+ meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op))
+
+ # Add node to current cluster
+ node_uid = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ uid=node_uid,
+ label=get_label(meas),
+ cluster_uid=self._cluster_uid_stack[-1],
+ fillcolor="lightpink",
+ color="lightpink3",
+ # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
+ shape="record",
+ )
+ self._node_uid_counter += 1
+
+ @_visit_operation.register
+ def _sample_counts_probs_ops(
+ self,
+ op: quantum.SampleOp | quantum.ProbsOp,
+ ) -> None:
+ """Handler for sample operations."""
+
+ # Create PennyLane instance
+ obs_op = op.obs.owner
+
+ # TODO: This doesn't logically make sense, but quantum.compbasis
+ # is obs_op and function below just pulls out the static wires
+ wires = xdsl_to_qml_measurement(obs_op)
+ meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires)
+
+ # Add node to current cluster
+ node_uid = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ uid=node_uid,
+ label=get_label(meas),
+ cluster_uid=self._cluster_uid_stack[-1],
+ fillcolor="lightpink",
+ color="lightpink3",
+ # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
+ shape="record",
+ )
+ self._node_uid_counter += 1
+
+ # =============
+ # CONTROL FLOW
+ # =============
+
+ @_visit_operation.register
+ def _for_op(self, operation: scf.ForOp) -> None:
+ """Handle an xDSL ForOp operation."""
+
+ uid = f"cluster{self._cluster_uid_counter}"
+ self.dag_builder.add_cluster(
+ uid,
+ label="for loop",
+ labeljust="l",
+ cluster_uid=self._cluster_uid_stack[-1],
+ )
+ self._cluster_uid_stack.append(uid)
+ self._cluster_uid_counter += 1
+
+ self._visit_region(operation.regions[0])
+
+ self._cluster_uid_stack.pop()
+
+ @_visit_operation.register
+ def _while_op(self, operation: scf.WhileOp) -> None:
+ """Handle an xDSL WhileOp operation."""
+ uid = f"cluster{self._cluster_uid_counter}"
+ self.dag_builder.add_cluster(
+ uid,
+ label="while loop",
+ labeljust="l",
+ cluster_uid=self._cluster_uid_stack[-1],
+ )
+ self._cluster_uid_stack.append(uid)
+ self._cluster_uid_counter += 1
+
+ for region in operation.regions:
+ self._visit_region(region)
+
+ self._cluster_uid_stack.pop()
+
+ @_visit_operation.register
+ def _if_op(self, operation: scf.IfOp):
+ """Handles the scf.IfOp operation."""
+ uid = f"cluster{self._cluster_uid_counter}"
+ self.dag_builder.add_cluster(
+ uid,
+ label="conditional",
+ labeljust="l",
+ cluster_uid=self._cluster_uid_stack[-1],
+ )
+ self._cluster_uid_stack.append(uid)
+ self._cluster_uid_counter += 1
+
+ # Loop through each branch and visualize as a cluster
+ flattened_if_op: list[Region] = _flatten_if_op(operation)
+ num_regions = len(flattened_if_op)
+ for i, region in enumerate(flattened_if_op):
+ cluster_label = "elif"
+ if i == 0:
+ cluster_label = "if"
+ elif i == num_regions - 1:
+ cluster_label = "else"
+
+ uid = f"cluster{self._cluster_uid_counter}"
+ self.dag_builder.add_cluster(
+ uid,
+ label=cluster_label,
+ labeljust="l",
+ style="dashed",
+ cluster_uid=self._cluster_uid_stack[-1],
+ )
+ self._cluster_uid_stack.append(uid)
+ self._cluster_uid_counter += 1
+
+ # Go recursively into the branch to process internals
+ self._visit_region(region)
+
+ # Pop branch cluster after processing to ensure
+ # logical branches are treated as 'parallel'
+ self._cluster_uid_stack.pop()
+
+ # Pop IfOp cluster before leaving this handler
+ self._cluster_uid_stack.pop()
+
+ # ============
+ # DEVICE NODE
+ # ============
+
+ @_visit_operation.register
+ def _device_init(self, operation: quantum.DeviceInitOp) -> None:
+ """Handles the initialization of a quantum device."""
+ node_id = f"node{self._node_uid_counter}"
+ self.dag_builder.add_node(
+ node_id,
+ label=operation.device_name.data,
+ cluster_uid=self._cluster_uid_stack[-1],
+ fillcolor="grey",
+ color="black",
+ penwidth=2,
+ shape="rectangle",
+ )
+ self._node_uid_counter += 1
+
+ # =======================
+ # FuncOp NESTING UTILITY
+ # =======================
+
+ @_visit_operation.register
+ def _func_op(self, operation: func.FuncOp) -> None:
+ """Visit a FuncOp Operation."""
+
+ label = operation.sym_name.data
+ if "jit_" in operation.sym_name.data:
+ label = "qjit"
+
+ uid = f"cluster{self._cluster_uid_counter}"
+ parent_cluster_uid = None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1]
+ self.dag_builder.add_cluster(
+ uid,
+ label=label,
+ cluster_uid=parent_cluster_uid,
+ )
+ self._cluster_uid_counter += 1
+ self._cluster_uid_stack.append(uid)
+
+ self._visit_block(operation.regions[0].blocks[0])
+
+ @_visit_operation.register
+ def _func_return(self, operation: func.ReturnOp) -> None:
+ """Handle func.return to exit FuncOp's cluster scope."""
+
+ # NOTE: Skip first cluster as it is the "base" of the graph diagram.
+ # In our case, it is the `qjit` bounding box.
+ if len(self._cluster_uid_stack) > 1:
+ # If we hit a func.return operation we know we are leaving
+ # the FuncOp's scope and so we can pop the ID off the stack.
+ self._cluster_uid_stack.pop()
+
+
+def _flatten_if_op(op: scf.IfOp) -> list[Region]:
+ """Recursively flattens a nested IfOp (if/elif/else chains)."""
+
+ then_region, else_region = op.regions
+
+ flattened_op: list[Region] = [then_region]
+
+ # Check to see if there are any nested quantum operations in the else block
+ else_block: Block = else_region.block
+ has_quantum_ops = False
+ nested_if_op = None
+ for op in else_block.ops:
+ if isinstance(op, scf.IfOp):
+ nested_if_op = op
+ # No need to walk this op as this will be
+ # recursively handled down below
+ continue
+ for internal_op in op.walk():
+ if type(internal_op) in quantum.Quantum.operations:
+ has_quantum_ops = True
+ # No need to check anything else
+ break
+
+ if nested_if_op and not has_quantum_ops:
+ # Recursively flatten any IfOps found in said block
+ nested_flattened_op: list[Region] = _flatten_if_op(nested_if_op)
+ flattened_op.extend(nested_flattened_op)
+ return flattened_op
+
+ # No more nested IfOps, therefore append final region
+ flattened_op.append(else_region)
+ return flattened_op
+
+
+@singledispatch
+def get_label(op: Operator | MeasurementProcess) -> str:
+ """Gets the appropriate label for a PennyLane object."""
+ return str(op)
+
+
+@get_label.register
+def _operator(op: Operator) -> str:
+ """Returns the appropriate label for PennyLane Operator"""
+ wires = list(op.wires.labels)
+ if wires == []:
+ wires_str = "all"
+ else:
+ wires_str = f"[{', '.join(map(str, wires))}]"
+ # Using <...> lets us use ports (https://graphviz.org/doc/info/shapes.html#record)
+ return f" {op.name}| {wires_str}"
+
+
+@get_label.register
+def _meas(meas: MeasurementProcess) -> str:
+ """Returns the appropriate label for a PennyLane MeasurementProcess using match/case."""
+
+ wires_str = list(meas.wires.labels)
+ if not wires_str:
+ wires_str = "all"
+ else:
+ wires_str = f"[{', '.join(map(str, wires_str))}]"
+
+ base_name = meas._shortname
+
+ match meas:
+ case ExpectationMP() | VarianceMP() | ProbabilityMP():
+ if meas.obs is not None:
+ obs_name = meas.obs.name
+ base_name = f"{base_name}({obs_name})"
+
+ case _:
+ pass
+
+ return f" {base_name}| {wires_str}"
diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py
new file mode 100644
index 0000000000..ae8f18e77e
--- /dev/null
+++ b/frontend/catalyst/python_interface/visualization/dag_builder.py
@@ -0,0 +1,141 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""File that defines the DAGBuilder abstract base class."""
+
+from abc import ABC, abstractmethod
+from typing import Any, TypeAlias
+
+ClusterUID: TypeAlias = str
+NodeUID: TypeAlias = str
+
+
+class DAGBuilder(ABC):
+ """An abstract base class for building Directed Acyclic Graphs (DAGs).
+
+ This class provides a simple interface with three core methods (`add_node`, `add_edge` and `add_cluster`).
+ You can override these methods to implement any backend, like `pydot` or `graphviz` or even `matplotlib`.
+
+ Outputting your graph can be done by overriding `to_file` and `to_string`.
+ """
+
+ @abstractmethod
+ def add_node(
+ self,
+ uid: NodeUID,
+ label: str,
+ *,
+ cluster_uid: ClusterUID | None = None,
+ **attrs: Any,
+ ) -> None:
+ """Add a single node to the graph.
+
+ Args:
+ uid (str): Unique node ID to identify this node.
+ label (str): The text to display on the node when rendered.
+ cluster_uid (str | None): Optional unique ID of the cluster this node belongs to. If `None`, this node gets
+ added on the base graph.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_edge(self, from_uid: NodeUID, to_uid: NodeUID, **attrs: Any) -> None:
+ """Add a single directed edge between nodes in the graph.
+
+ Args:
+ from_uid (str): The unique ID of the source node.
+ to_uid (str): The unique ID of the destination node.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def add_cluster(
+ self,
+ uid: ClusterUID,
+ *,
+ label: str | None = None,
+ cluster_uid: ClusterUID | None = None,
+ **attrs: Any,
+ ) -> None:
+ """Add a single cluster to the graph.
+
+ A cluster is a specific type of subgraph where the nodes and edges contained
+ within it are visually and logically grouped.
+
+ Args:
+ uid (str): Unique cluster ID to identify this cluster.
+ label (str | None): Optional text to display as a label on the cluster when rendered.
+ cluster_uid (str | None): Optional unique ID of the cluster this cluster belongs to. If `None`, the cluster will be
+ placed on the base graph.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ """
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def nodes(self) -> dict[NodeUID, dict[str, Any]]:
+ """Retrieve the current set of nodes in the graph.
+
+ Returns:
+ nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information.
+ """
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def edges(self) -> list[dict[str, Any]]:
+ """Retrieve the current set of edges in the graph.
+
+ Returns:
+ edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information.
+ """
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def clusters(self) -> dict[ClusterUID, dict[str, Any]]:
+ """Retrieve the current set of clusters in the graph.
+
+ Returns:
+ clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def to_file(self, output_filename: str) -> None:
+ """Save the graph to a file.
+
+ The implementation should ideally infer the output format
+ (e.g., 'png', 'svg') from this filename's extension.
+
+ Args:
+ output_filename (str): Desired filename for the graph.
+
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def to_string(self) -> str:
+ """Return the graph as a string.
+
+ This is typically used to get the graph's representation in a standard string format like DOT.
+
+ Returns:
+ str: A string representation of the graph.
+ """
+ raise NotImplementedError
diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py
new file mode 100644
index 0000000000..908cc5fc9d
--- /dev/null
+++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py
@@ -0,0 +1,289 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""File that defines the PyDotDAGBuilder subclass of DAGBuilder."""
+
+import pathlib
+from collections import ChainMap
+from typing import Any
+
+from .dag_builder import DAGBuilder
+
+has_pydot = True
+try:
+ import pydot
+ from pydot import Cluster, Dot, Edge, Graph, Node, Subgraph
+except ImportError:
+ has_pydot = False
+
+
+class PyDotDAGBuilder(DAGBuilder):
+ """A Directed Acyclic Graph builder for the PyDot backend.
+
+ Args:
+ attrs (dict | None): User default attributes to be used for all elements (nodes, edges, clusters) in the graph.
+ node_attrs (dict | None): User default attributes for a node.
+ edge_attrs (dict | None): User default attributes for an edge.
+ cluster_attrs (dict | None): User default attributes for a cluster.
+
+ Example:
+ >>> builder = PyDotDAGBuilder()
+ >>> builder.add_node("n0", "node 0")
+ >>> builder.add_cluster("c0")
+ >>> builder.add_node("n1", "node 1", cluster_uid="c0")
+ >>> print(builder.to_string())
+ strict digraph G {
+ rankdir=TB;
+ compound=true;
+ n0 [label="node 0", fontname=Helvetica, penwidth=3, shape=ellipse, style=filled, fillcolor=lightblue, color=lightblue4];
+ subgraph cluster_c0 {
+ fontname=Helvetica;
+ penwidth=2;
+ shape=rectangle;
+ style=solid;
+ }
+
+ n1 [label="node 1", fontname=Helvetica, penwidth=3, shape=ellipse, style=filled, fillcolor=lightblue, color=lightblue4, cluster_uid=c0];
+ }
+
+ """
+
+ def __init__(
+ self,
+ attrs: dict | None = None,
+ node_attrs: dict | None = None,
+ edge_attrs: dict | None = None,
+ cluster_attrs: dict | None = None,
+ ) -> None:
+ # Initialize the pydot graph:
+ # - graph_type="digraph": Create a directed graph (edges have arrows).
+ # - rankdir="TB": Set layout direction from Top to Bottom.
+ # - compound="true": Allow edges to connect directly to clusters/subgraphs.
+ # - strict=True: Prevent duplicate edges (e.g., A -> B added twice).
+ # - splines="ortho": Edges connecting clusters are orthogonal
+ self.graph: Dot = Dot(
+ graph_type="digraph", rankdir="TB", compound="true", strict=True, splines="ortho"
+ )
+
+ # Use internal cache that maps cluster ID to actual pydot (Dot or Cluster) object
+ # NOTE: This is needed so we don't need to traverse the graph to find the relevant
+ # cluster object to modify
+ self._subgraph_cache: dict[str, Graph] = {}
+
+ # Internal state for graph structure
+ self._nodes: dict[str, dict[str, Any]] = {}
+ self._edges: list[dict[str, Any]] = []
+ self._clusters: dict[str, dict[str, Any]] = {}
+
+ _default_attrs: dict = (
+ {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs
+ )
+ self._default_node_attrs: dict = (
+ {
+ **_default_attrs,
+ "shape": "ellipse",
+ "style": "filled",
+ "fillcolor": "lightblue",
+ "color": "lightblue4",
+ "penwidth": 3,
+ }
+ if node_attrs is None
+ else node_attrs
+ )
+ self._default_edge_attrs: dict = (
+ {
+ "color": "lightblue4",
+ "penwidth": 3,
+ }
+ if edge_attrs is None
+ else edge_attrs
+ )
+ self._default_cluster_attrs: dict = (
+ {
+ **_default_attrs,
+ "shape": "rectangle",
+ "style": "solid",
+ }
+ if cluster_attrs is None
+ else cluster_attrs
+ )
+
+ def add_node(
+ self,
+ uid: str,
+ label: str,
+ cluster_uid: str | None = None,
+ **attrs: Any,
+ ) -> None:
+ """Add a single node to the graph.
+
+ Args:
+ uid (str): Unique node ID to identify this node.
+ label (str): The text to display on the node when rendered.
+ cluster_uid (str | None): Optional unique ID of the cluster this node belongs to.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ Raises:
+ ValueError: Node ID is already present in the graph.
+
+ """
+ if uid in self.nodes:
+ raise ValueError(f"Node ID {uid} already present in graph.")
+
+ # Use ChainMap so you don't need to construct a new dictionary
+ node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs)
+ node = Node(uid, label=label, **node_attrs)
+
+ # Add node to cluster
+ if cluster_uid is None:
+ self.graph.add_node(node)
+ else:
+ self._subgraph_cache[cluster_uid].add_node(node)
+
+ self._nodes[uid] = {
+ "uid": uid,
+ "label": label,
+ "cluster_uid": cluster_uid,
+ "attrs": dict(node_attrs),
+ }
+
+ def add_edge(self, from_uid: str, to_uid: str, **attrs: Any) -> None:
+ """Add a single directed edge between nodes in the graph.
+
+ Args:
+ from_uid (str): The unique ID of the source node.
+ to_uid (str): The unique ID of the destination node.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ Raises:
+ ValueError: Source and destination have the same ID
+ ValueError: Source is not found in the graph.
+ ValueError: Destination is not found in the graph.
+
+ """
+ if from_uid == to_uid:
+ raise ValueError("Edges must connect two unique IDs.")
+ if from_uid not in self.nodes:
+ raise ValueError("Source is not found in the graph.")
+ if to_uid not in self.nodes:
+ raise ValueError("Destination is not found in the graph.")
+
+ # Use ChainMap so you don't need to construct a new dictionary
+ edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs)
+ edge = Edge(from_uid, to_uid, **edge_attrs)
+
+ self.graph.add_edge(edge)
+
+ self._edges.append(
+ {"from_uid": from_uid, "to_uid": to_uid, "attrs": dict(edge_attrs)}
+ )
+
+ def add_cluster(
+ self,
+ uid: str,
+ label: str | None = None,
+ cluster_uid: str | None = None,
+ **attrs: Any,
+ ) -> None:
+ """Add a single cluster to the graph.
+
+ A cluster is a specific type of subgraph where the nodes and edges contained
+ within it are visually and logically grouped.
+
+ Args:
+ uid (str): Unique cluster ID to identify this cluster.
+ label (str | None): Optional text to display as a label on the cluster when rendered.
+ cluster_uid (str | None): Optional unique ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph.
+ **attrs (Any): Any additional styling keyword arguments.
+
+ Raises:
+ ValueError: Cluster ID is already present in the graph.
+ """
+ if uid in self.clusters:
+ raise ValueError(f"Cluster ID {uid} already present in graph.")
+
+ # Use ChainMap so you don't need to construct a new dictionary
+ cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs)
+ cluster = Cluster(uid, label=label, **cluster_attrs)
+
+ # Record new cluster
+ self._subgraph_cache[uid] = cluster
+
+ # Add node to cluster
+ if cluster_uid is None:
+ self.graph.add_subgraph(cluster)
+ else:
+ self._subgraph_cache[cluster_uid].add_subgraph(cluster)
+
+ self._clusters[uid] = {
+ "uid": uid,
+ "label": label,
+ "cluster_uid": cluster_uid,
+ "attrs": dict(cluster_attrs),
+ }
+
+ @property
+ def nodes(self) -> dict[str, dict[str, Any]]:
+ """Retrieve the current set of nodes in the graph.
+
+ Returns:
+ nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information.
+ """
+ return self._nodes
+
+ @property
+ def edges(self) -> list[dict[str, Any]]:
+ """Retrieve the current set of edges in the graph.
+
+ Returns:
+ edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information.
+ """
+ return self._edges
+
+ @property
+ def clusters(self) -> dict[str, dict[str, Any]]:
+ """Retrieve the current set of clusters in the graph.
+
+ Returns:
+ clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information.
+ """
+ return self._clusters
+
+ def to_file(self, output_filename: str) -> None:
+ """Save the graph to a file.
+
+ This method will infer the file's format (e.g., 'png', 'svg') from this filename's extension.
+ If no extension is provided, the 'png' format will be the default.
+
+ Args:
+ output_filename (str): Desired filename for the graph. File extension can be included
+ and if no file extension is provided, it will default to a `.png` file.
+
+ """
+ output_filename_path: pathlib.Path = pathlib.Path(output_filename)
+ if not output_filename_path.suffix:
+ output_filename_path = output_filename_path.with_suffix(".png")
+
+ format = output_filename_path.suffix[1:].lower()
+
+ self.graph.write(str(output_filename_path), format=format)
+
+ def to_string(self) -> str:
+ """Return the graph as a string.
+
+ This is typically used to get the graph's representation in a standard string format like DOT.
+
+ Returns:
+ str: A string representation of the graph.
+ """
+ return self.graph.to_string()
diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py
new file mode 100644
index 0000000000..d39bee45f1
--- /dev/null
+++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py
@@ -0,0 +1,974 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for the ConstructCircuitDAG utility."""
+
+from unittest.mock import Mock
+
+import pytest
+
+pytestmark = pytest.mark.xdsl
+xdsl = pytest.importorskip("xdsl")
+
+# pylint: disable=wrong-import-position
+# This import needs to be after pytest in order to prevent ImportErrors
+import pennylane as qml
+from xdsl.dialects import test
+from xdsl.dialects.builtin import ModuleOp
+from xdsl.ir.core import Block, Region
+
+from catalyst import measure
+from catalyst.python_interface.conversion import xdsl_from_qjit
+from catalyst.python_interface.visualization.construct_circuit_dag import (
+ ConstructCircuitDAG,
+ get_label,
+)
+from catalyst.python_interface.visualization.dag_builder import DAGBuilder
+
+
+class FakeDAGBuilder(DAGBuilder):
+ """
+ A concrete implementation of DAGBuilder used ONLY for testing.
+ It stores all graph manipulation calls in data structures
+ for easy assertion of the final graph state.
+ """
+
+ def __init__(self):
+ self._nodes = {}
+ self._edges = []
+ self._clusters = {}
+
+ def add_node(self, uid, label, cluster_uid=None, **attrs) -> None:
+ self._nodes[uid] = {
+ "uid": uid,
+ "label": label,
+ "parent_cluster_uid": cluster_uid,
+ "attrs": attrs,
+ }
+
+ def add_edge(self, from_uid: str, to_uid: str, **attrs) -> None:
+ self._edges.append(
+ {
+ "from_uid": from_uid,
+ "to_uid": to_uid,
+ "attrs": attrs,
+ }
+ )
+
+ def add_cluster(
+ self,
+ uid,
+ label=None,
+ cluster_uid=None,
+ **attrs,
+ ) -> None:
+ self._clusters[uid] = {
+ "uid": uid,
+ "label": label,
+ "parent_cluster_uid": cluster_uid,
+ "attrs": attrs,
+ }
+
+ @property
+ def nodes(self):
+ return self._nodes
+
+ @property
+ def edges(self):
+ return self._edges
+
+ @property
+ def clusters(self):
+ return self._clusters
+
+ def to_file(self, output_filename):
+ pass
+
+ def to_string(self) -> str:
+ return "graph"
+
+
+@pytest.mark.unit
+def test_dependency_injection():
+ """Tests that relevant dependencies are injected."""
+
+ mock_dag_builder = Mock(DAGBuilder)
+ utility = ConstructCircuitDAG(mock_dag_builder)
+ assert utility.dag_builder is mock_dag_builder
+
+
+@pytest.mark.unit
+def test_does_not_mutate_module():
+ """Test that the module is not mutated."""
+
+ # Create module
+ op = test.TestOp()
+ block = Block(ops=[op])
+ region = Region(blocks=[block])
+ container_op = test.TestOp(regions=[region])
+ module_op = ModuleOp(ops=[container_op])
+
+ # Save state before
+ module_op_str_before = str(module_op)
+
+ # Process module
+ mock_dag_builder = Mock(DAGBuilder)
+ utility = ConstructCircuitDAG(mock_dag_builder)
+ utility.construct(module_op)
+
+ # Ensure not mutated
+ assert str(module_op) == module_op_str_before
+
+
+@pytest.mark.unit
+class TestFuncOpVisualization:
+ """Tests the visualization of FuncOps with bounding boxes"""
+
+ def test_standard_qnode(self):
+ """Tests that a standard QJIT'd QNode is visualized correctly"""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ qml.H(0)
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ graph_clusters = utility.dag_builder.clusters
+
+ # Check nesting is correct
+ # graph
+ # └── qjit
+ # └── my_workflow
+
+ # Check qjit is nested under graph
+ assert graph_clusters["cluster0"]["label"] == "qjit"
+ assert graph_clusters["cluster0"]["parent_cluster_uid"] is None
+
+ # Check that my_workflow is under qjit
+ assert graph_clusters["cluster1"]["label"] == "my_workflow"
+ assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0"
+
+ def test_nested_qnodes(self):
+ """Tests that nested QJIT'd QNodes are visualized correctly"""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @qml.qnode(dev)
+ def my_qnode2():
+ qml.X(0)
+
+ @qml.qnode(dev)
+ def my_qnode1():
+ qml.H(0)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ def my_workflow():
+ my_qnode1()
+ my_qnode2()
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ graph_clusters = utility.dag_builder.clusters
+
+ # Check nesting is correct
+ # graph
+ # └── qjit
+ # ├── my_qnode1
+ # └── my_qnode2
+
+ # Check qjit is under graph
+ assert graph_clusters["cluster0"]["label"] == "qjit"
+ assert graph_clusters["cluster0"]["parent_cluster_uid"] is None
+
+ # Check both qnodes are under my_workflow
+ assert graph_clusters["cluster1"]["label"] == "my_qnode1"
+ assert graph_clusters["cluster1"]["parent_cluster_uid"] == "cluster0"
+
+ assert graph_clusters["cluster2"]["label"] == "my_qnode2"
+ assert graph_clusters["cluster2"]["parent_cluster_uid"] == "cluster0"
+
+
+class TestDeviceNode:
+ """Tests that the device node is correctly visualized."""
+
+ def test_standard_qnode(self):
+ """Tests that a standard setup works."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ qml.H(0)
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ graph_nodes = utility.dag_builder.nodes
+ graph_clusters = utility.dag_builder.clusters
+
+ # Check nesting is correct
+ # graph
+ # └── qjit
+ # └── my_workflow: NullQubit
+
+ # Assert device node is inside my_workflow cluster
+ assert graph_clusters["cluster1"]["label"] == "my_workflow"
+ assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1"
+
+ # Assert label is as expected
+ assert graph_nodes["node0"]["label"] == "NullQubit"
+
+ def test_nested_qnodes(self):
+ """Tests that nested QJIT'd QNodes are visualized correctly"""
+
+ dev1 = qml.device("null.qubit", wires=1)
+ dev2 = qml.device("lightning.qubit", wires=1)
+
+ @qml.qnode(dev2)
+ def my_qnode2():
+ qml.X(0)
+
+ @qml.qnode(dev1)
+ def my_qnode1():
+ qml.H(0)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ def my_workflow():
+ my_qnode1()
+ my_qnode2()
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ graph_nodes = utility.dag_builder.nodes
+ graph_clusters = utility.dag_builder.clusters
+
+ # Check nesting is correct
+ # graph
+ # └── qjit
+ # ├── my_qnode1: NullQubit
+ # └── my_qnode2: LightningSimulator
+
+ # Assert lightning.qubit device node is inside my_qnode1 cluster
+ assert graph_clusters["cluster1"]["label"] == "my_qnode1"
+ assert graph_nodes["node0"]["parent_cluster_uid"] == "cluster1"
+
+ # Assert label is as expected
+ assert graph_nodes["node0"]["label"] == "NullQubit"
+
+ # Assert null qubit device node is inside my_qnode2 cluster
+ assert graph_clusters["cluster2"]["label"] == "my_qnode2"
+ # NOTE: node1 is the qml.H(0) in my_qnode1
+ assert graph_nodes["node2"]["parent_cluster_uid"] == "cluster2"
+
+ # Assert label is as expected
+ assert graph_nodes["node2"]["label"] == "LightningSimulator"
+
+
+class TestForOp:
+ """Tests that the for loop control flow can be visualized correctly."""
+
+ @pytest.mark.unit
+ def test_basic_example(self):
+ """Tests that the for loop cluster can be visualized correctly."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ for i in range(3):
+ qml.H(0)
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ assert clusters["cluster2"]["label"] == "for loop"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+
+ @pytest.mark.unit
+ def test_nested_loop(self):
+ """Tests that nested for loops are visualized correctly."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ for i in range(0, 5, 2):
+ for j in range(1, 6, 2):
+ qml.H(0)
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ assert clusters["cluster2"]["label"] == "for loop"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+ assert clusters["cluster3"]["label"] == "for loop"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+
+
+class TestWhileOp:
+ """Tests that the while loop control flow can be visualized correctly."""
+
+ @pytest.mark.unit
+ def test_basic_example(self):
+ """Test that the while loop is visualized correctly."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ counter = 0
+ while counter < 5:
+ qml.H(0)
+ counter += 1
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ assert clusters["cluster2"]["label"] == "while loop"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+
+ @pytest.mark.unit
+ def test_nested_loop(self):
+ """Tests that nested while loops are visualized correctly."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow():
+ outer_counter = 0
+ inner_counter = 0
+ while outer_counter < 5:
+ while inner_counter < 6:
+ qml.H(0)
+ inner_counter += 1
+ outer_counter += 1
+
+ module = my_workflow()
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ assert clusters["cluster2"]["label"] == "while loop"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+ assert clusters["cluster3"]["label"] == "while loop"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+
+
+class TestIfOp:
+ """Tests that the conditional control flow can be visualized correctly."""
+
+ @pytest.mark.unit
+ def test_basic_example(self):
+ """Test that the conditional operation is visualized correctly."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow(x):
+ if x == 2:
+ qml.X(0)
+ else:
+ qml.Y(0)
+
+ args = (1,)
+ module = my_workflow(*args)
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ # Check conditional is a cluster within cluster1 (my_workflow)
+ assert clusters["cluster2"]["label"] == "conditional"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+
+ # Check three clusters live within cluster2 (conditional)
+ assert clusters["cluster3"]["label"] == "if"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster4"]["label"] == "else"
+ assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2"
+
+ @pytest.mark.unit
+ def test_if_elif_else_conditional(self):
+ """Test that the conditional operation is visualized correctly."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow(x):
+ if x == 1:
+ qml.X(0)
+ elif x == 2:
+ qml.Y(0)
+ else:
+ qml.Z(0)
+
+ args = (1,)
+ module = my_workflow(*args)
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ # Check conditional is a cluster within my_workflow
+ assert clusters["cluster2"]["label"] == "conditional"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+
+ # Check three clusters live within conditional
+ assert clusters["cluster3"]["label"] == "if"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster4"]["label"] == "elif"
+ assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster5"]["label"] == "else"
+ assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2"
+
+ @pytest.mark.unit
+ def test_nested_conditionals(self):
+ """Tests that nested conditionals are visualized correctly."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow(x, y):
+ if x == 1:
+ if y == 2:
+ qml.H(0)
+ else:
+ qml.Z(0)
+ qml.X(0)
+ else:
+ qml.Z(0)
+
+ args = (1, 2)
+ module = my_workflow(*args)
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ # cluster2 -> conditional (1)
+ # cluster3 -> if
+ # cluster4 -> conditional ()
+ # cluster5 -> if
+ # cluster6 -> else
+ # cluster7 -> else
+
+ # Check first conditional is a cluster within my_workflow
+ assert clusters["cluster2"]["label"] == "conditional"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+
+ # Check 'if' cluster of first conditional has another conditional
+ assert clusters["cluster3"]["label"] == "if"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+
+ # Second conditional
+ assert clusters["cluster4"]["label"] == "conditional"
+ assert clusters["cluster4"]["parent_cluster_uid"] == "cluster3"
+ # Check 'if' and 'else' in second conditional
+ assert clusters["cluster5"]["label"] == "if"
+ assert clusters["cluster5"]["parent_cluster_uid"] == "cluster4"
+ assert clusters["cluster6"]["label"] == "else"
+ assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4"
+
+ # Check nested if / else is within the first if cluster
+ assert clusters["cluster7"]["label"] == "else"
+ assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2"
+
+ def test_nested_conditionals_with_quantum_ops(self):
+ """Tests that nested conditionals are unflattend if quantum operations
+ are present"""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow(x):
+ if x == 1:
+ qml.X(0)
+ elif x == 2:
+ qml.Y(0)
+ else:
+ qml.Z(0)
+ if x == 3:
+ qml.RX(0, 0)
+ elif x == 4:
+ qml.RY(0, 0)
+ else:
+ qml.RZ(0, 0)
+
+ args = (1,)
+ module = my_workflow(*args)
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ # node0 -> NullQubit
+ # cluster2 -> conditional (1)
+ # cluster3 -> if
+ # node1 -> X(0)
+ # cluster4 -> elif
+ # node2 -> Y(0)
+ # cluster5 -> else
+ # node3 -> Z(0)
+ # cluster6 -> conditional (2)
+ # cluster7 -> if
+ # node4 -> RX(0,0)
+ # cluster8 -> elif
+ # node5 -> RY(0,0)
+ # cluster9 -> else
+ # node6 -> RZ(0,0)
+
+ # check outer conditional (1)
+ assert clusters["cluster2"]["label"] == "conditional"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+ assert clusters["cluster3"]["label"] == "if"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster4"]["label"] == "elif"
+ assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster5"]["label"] == "else"
+ assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2"
+
+ # Nested conditional (2) inside conditional (1)
+ assert clusters["cluster6"]["label"] == "conditional"
+ assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5"
+ assert clusters["cluster7"]["label"] == "if"
+ assert clusters["cluster7"]["parent_cluster_uid"] == "cluster6"
+ assert clusters["cluster8"]["label"] == "elif"
+ assert clusters["cluster8"]["parent_cluster_uid"] == "cluster6"
+ assert clusters["cluster9"]["label"] == "else"
+ assert clusters["cluster9"]["parent_cluster_uid"] == "cluster6"
+
+ def test_nested_conditionals_with_nested_quantum_ops(self):
+ """Tests that nested conditionals are unflattend if quantum operations
+ are present but nested in other operations"""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_workflow(x):
+ if x == 1:
+ qml.X(0)
+ elif x == 2:
+ qml.Y(0)
+ else:
+ for i in range(3):
+ qml.Z(0)
+ if x == 3:
+ qml.RX(0, 0)
+ elif x == 4:
+ qml.RY(0, 0)
+ else:
+ qml.RZ(0, 0)
+
+ args = (1,)
+ module = my_workflow(*args)
+
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ clusters = utility.dag_builder.clusters
+
+ # cluster0 -> qjit
+ # cluster1 -> my_workflow
+ # node0 -> NullQubit
+ # cluster2 -> conditional (1)
+ # cluster3 -> if
+ # node1 -> X(0)
+ # cluster4 -> elif
+ # node2 -> Y(0)
+ # cluster5 -> else
+ # cluster6 -> for loop
+ # node3 -> Z(0)
+ # cluster7 -> conditional (2)
+ # cluster8 -> if
+ # node4 -> RX(0,0)
+ # cluster9 -> elif
+ # node5 -> RY(0,0)
+ # cluster10 -> else
+ # node6 -> RZ(0,0)
+
+ # check outer conditional (1)
+ assert clusters["cluster2"]["label"] == "conditional"
+ assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1"
+ assert clusters["cluster3"]["label"] == "if"
+ assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster4"]["label"] == "elif"
+ assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2"
+ assert clusters["cluster5"]["label"] == "else"
+ assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2"
+
+ # Nested conditional (2) inside conditional (1)
+ assert clusters["cluster6"]["label"] == "for loop"
+ assert clusters["cluster6"]["parent_cluster_uid"] == "cluster5"
+
+ assert clusters["cluster7"]["label"] == "conditional"
+ assert clusters["cluster7"]["parent_cluster_uid"] == "cluster5"
+ assert clusters["cluster8"]["label"] == "if"
+ assert clusters["cluster8"]["parent_cluster_uid"] == "cluster7"
+ assert clusters["cluster9"]["label"] == "elif"
+ assert clusters["cluster9"]["parent_cluster_uid"] == "cluster7"
+ assert clusters["cluster10"]["label"] == "else"
+ assert clusters["cluster10"]["parent_cluster_uid"] == "cluster7"
+
+
+class TestGetLabel:
+ """Tests the get_label utility."""
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "op, label",
+ [
+ (qml.H(0), " Hadamard| [0]"),
+ (
+ qml.QubitUnitary([[0, 1], [1, 0]], 0),
+ " QubitUnitary| [0]",
+ ),
+ (qml.SWAP([0, 1]), " SWAP| [0, 1]"),
+ ],
+ )
+ def test_standard_operator(self, op, label):
+ """Tests against an operator instance."""
+ assert get_label(op) == label
+
+ def test_global_phase_operator(self):
+ """Tests against a GlobalPhase operator instance."""
+ assert get_label(qml.GlobalPhase(0.5)) == f" GlobalPhase| all"
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "meas, label",
+ [
+ (qml.state(), " state| all"),
+ (qml.expval(qml.Z(0)), " expval(PauliZ)| [0]"),
+ (qml.var(qml.Z(0)), " var(PauliZ)| [0]"),
+ (qml.probs(), " probs| all"),
+ (qml.probs(wires=0), " probs| [0]"),
+ (qml.probs(wires=[0, 1]), " probs| [0, 1]"),
+ (qml.sample(), " sample| all"),
+ (qml.sample(wires=0), " sample| [0]"),
+ (qml.sample(wires=[0, 1]), " sample| [0, 1]"),
+ ],
+ )
+ def test_standard_measurement(self, meas, label):
+ """Tests against an operator instance."""
+ assert get_label(meas) == label
+
+
+class TestCreateStaticOperatorNodes:
+ """Tests that operators with static parameters can be created and visualized as nodes."""
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])])
+ def test_custom_op(self, op):
+ """Tests that the CustomOp operation node can be created and visualized."""
+
+ # Build module with only a CustomOp
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ qml.apply(op)
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ # Make sure label has relevant info
+ assert nodes["node1"]["label"] == get_label(op)
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "op",
+ [
+ qml.GlobalPhase(0.5),
+ qml.GlobalPhase(0.5, wires=0),
+ qml.GlobalPhase(0.5, wires=[0, 1]),
+ ],
+ )
+ def test_global_phase_op(self, op):
+ """Test that GlobalPhase can be handled."""
+
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ qml.apply(op)
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ # Compiler throws out the wires and they get converted to wires=[] no matter what
+ assert nodes["node1"]["label"] == get_label(qml.GlobalPhase(0.5))
+
+ @pytest.mark.unit
+ def test_qubit_unitary_op(self):
+ """Test that QubitUnitary operations can be handled."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ qml.QubitUnitary([[0, 1], [1, 0]], wires=0)
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(qml.QubitUnitary([[0, 1], [1, 0]], wires=0))
+
+ @pytest.mark.unit
+ def test_multi_rz_op(self):
+ """Test that MultiRZ operations can be handled."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ qml.MultiRZ(0.5, wires=[0])
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0]))
+
+ @pytest.mark.unit
+ def test_projective_measurement_op(self):
+ """Test that projective measurements can be captured as nodes."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ measure(0)
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == f" MidMeasureMP| [0]"
+
+
+class TestCreateStaticMeasurementNodes:
+ """Tests that measurements with static parameters can be created and visualized as nodes."""
+
+ @pytest.mark.unit
+ def test_state_op(self):
+ """Test that qml.state can be handled."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ return qml.state()
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(qml.state())
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var])
+ def test_expval_var_measurement_op(self, meas_fn):
+ """Test that statistical measurement operators can be captured as nodes."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ return meas_fn(qml.Z(0))
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ # Ensure DAG only has one node
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(meas_fn(qml.Z(0)))
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "op",
+ [
+ qml.probs(),
+ qml.probs(wires=0),
+ qml.probs(wires=[0, 1]),
+ ],
+ )
+ def test_probs_measurement_op(self, op):
+ """Tests that the probs measurement function can be captured as a node."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.qnode(dev)
+ def my_circuit():
+ return op
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(op)
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "op",
+ [
+ qml.sample(),
+ qml.sample(wires=0),
+ qml.sample(wires=[0, 1]),
+ ],
+ )
+ def test_valid_sample_measurement_op(self, op):
+ """Tests that the sample measurement function can be captured as a node."""
+ dev = qml.device("null.qubit", wires=1)
+
+ @xdsl_from_qjit
+ @qml.qjit(autograph=True, target="mlir")
+ @qml.set_shots(10)
+ @qml.qnode(dev)
+ def my_circuit():
+ return op
+
+ module = my_circuit()
+
+ # Construct DAG
+ utility = ConstructCircuitDAG(FakeDAGBuilder())
+ utility.construct(module)
+
+ nodes = utility.dag_builder.nodes
+ assert len(nodes) == 2 # Device node + operator
+
+ assert nodes["node1"]["label"] == get_label(op)
diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py
new file mode 100644
index 0000000000..d6de2d4011
--- /dev/null
+++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py
@@ -0,0 +1,113 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for the DAGBuilder abstract base class."""
+
+from typing import Any
+
+import pytest
+
+pytestmark = pytest.mark.xdsl
+xdsl = pytest.importorskip("xdsl")
+
+# pylint: disable=wrong-import-position
+# This import needs to be after pytest in order to prevent ImportErrors
+from catalyst.python_interface.visualization.dag_builder import DAGBuilder
+
+
+def test_concrete_implementation_works():
+ """Unit test for concrete implementation of abc."""
+
+ # pylint: disable=unused-argument
+ class ConcreteDAGBuilder(DAGBuilder):
+ """Concrete subclass of an ABC for testing purposes."""
+
+ def add_node(
+ self,
+ uid: str,
+ label: str,
+ cluster_id: str | None = None,
+ **attrs: Any,
+ ) -> None:
+ return
+
+ def add_edge(self, from_uid: str, to_uid: str, **attrs: Any) -> None:
+ return
+
+ def add_cluster(
+ self,
+ uid: str,
+ label: str | None = None,
+ cluster_id: str | None = None,
+ **attrs: Any,
+ ) -> None:
+ return
+
+ @property
+ def nodes(self) -> dict[str, dict[str, Any]]:
+ return {}
+
+ @property
+ def edges(self) -> list[dict[str, Any]]:
+ return []
+
+ @property
+ def clusters(self) -> dict[str, dict[str, Any]]:
+ return {}
+
+ def to_file(self, output_filename: str) -> None:
+ return
+
+ def to_string(self) -> str:
+ return "test"
+
+ dag_builder = ConcreteDAGBuilder()
+ # pylint: disable = assignment-from-none
+ node = dag_builder.add_node("0", "node0")
+ edge = dag_builder.add_edge("0", "1")
+ cluster = dag_builder.add_cluster("0")
+ nodes = dag_builder.nodes
+ edges = dag_builder.edges
+ clusters = dag_builder.clusters
+ render = dag_builder.to_file("test.png")
+ string = dag_builder.to_string()
+
+ assert node is None
+ assert nodes == {}
+ assert edge is None
+ assert edges == []
+ assert cluster is None
+ assert clusters == {}
+ assert render is None
+ assert string == "test"
+
+
+def test_abc_cannot_be_instantiated():
+ """Tests that the DAGBuilder ABC cannot be instantiated."""
+
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
+ # pylint: disable=abstract-class-instantiated
+ DAGBuilder()
+
+
+def test_incomplete_subclass():
+ """Tests that an incomplete subclass will fail"""
+
+ # pylint: disable=too-few-public-methods
+ class IncompleteDAGBuilder(DAGBuilder):
+ def add_node(self, *args, **kwargs):
+ pass
+
+ with pytest.raises(TypeError, match="Can't instantiate abstract class"):
+ # pylint: disable=abstract-class-instantiated
+ IncompleteDAGBuilder()
diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py
new file mode 100644
index 0000000000..e7255bfbd6
--- /dev/null
+++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py
@@ -0,0 +1,414 @@
+# Copyright 2025 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Unit tests for the PyDotDAGBuilder subclass."""
+
+from unittest.mock import MagicMock
+
+import pytest
+
+pydot = pytest.importorskip("pydot")
+
+pytestmark = pytest.mark.xdsl
+xdsl = pytest.importorskip("xdsl")
+# pylint: disable=wrong-import-position
+from catalyst.python_interface.visualization.pydot_dag_builder import PyDotDAGBuilder
+
+
+@pytest.mark.unit
+def test_initialization_defaults():
+ """Tests the default graph attributes are as expected."""
+
+ dag_builder = PyDotDAGBuilder()
+
+ assert isinstance(dag_builder.graph, pydot.Dot)
+ # Ensure it's a directed graph
+ assert dag_builder.graph.get_graph_type() == "digraph"
+ # Ensure that it flows top to bottom
+ assert dag_builder.graph.get_rankdir() == "TB"
+ # Ensure edges can be connected directly to clusters / subgraphs
+ assert dag_builder.graph.get_compound() == "true"
+ # Ensure duplicated edges cannot be added
+ assert dag_builder.graph.obj_dict["strict"] is True
+ # Ensure edges are orthogonal
+ assert dag_builder.graph.obj_dict["attributes"]["splines"]== "ortho"
+
+
+class TestExceptions:
+ """Tests the various exceptions defined in the class."""
+
+ def test_duplicate_node_ids(self):
+ """Tests that a ValueError is raised for duplicate nodes."""
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("0", "node0")
+ with pytest.raises(ValueError, match="Node ID 0 already present in graph."):
+ dag_builder.add_node("0", "node1")
+
+ def test_edge_duplicate_source_destination(self):
+ """Tests that a ValueError is raised when an edge is created with the
+ same source and destination"""
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("0", "node0")
+ with pytest.raises(ValueError, match="Edges must connect two unique IDs."):
+ dag_builder.add_edge("0", "0")
+
+ def test_edge_missing_ids(self):
+ """Tests that an error is raised if IDs are missing."""
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("0", "node0")
+ with pytest.raises(ValueError, match="Destination is not found in the graph."):
+ dag_builder.add_edge("0", "1")
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("1", "node1")
+ with pytest.raises(ValueError, match="Source is not found in the graph."):
+ dag_builder.add_edge("0", "1")
+
+ def test_duplicate_cluster_id(self):
+ """Tests that an exception is raised if an ID is already present."""
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_cluster("0")
+ with pytest.raises(ValueError, match="Cluster ID 0 already present in graph."):
+ dag_builder.add_cluster("0")
+
+
+class TestAddMethods:
+ """Test that elements can be added to the graph."""
+
+ @pytest.mark.unit
+ def test_add_node(self):
+ """Unit test the `add_node` method."""
+
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("0", "node0")
+ node_list = dag_builder.graph.get_node_list()
+ assert len(node_list) == 1
+ assert node_list[0].get_label() == "node0"
+
+ @pytest.mark.unit
+ def test_add_edge(self):
+ """Unit test the `add_edge` method."""
+
+ dag_builder = PyDotDAGBuilder()
+ dag_builder.add_node("0", "node0")
+ dag_builder.add_node("1", "node1")
+ dag_builder.add_edge("0", "1")
+
+ assert len(dag_builder.graph.get_edges()) == 1
+ edge = dag_builder.graph.get_edges()[0]
+ assert edge.get_source() == "0"
+ assert edge.get_destination() == "1"
+
+ @pytest.mark.unit
+ def test_add_cluster(self):
+ """Unit test the 'add_cluster' method."""
+
+ dag_builder = PyDotDAGBuilder()
+ dag_builder.add_cluster("0")
+
+ assert len(dag_builder.graph.get_subgraphs()) == 1
+ assert dag_builder.graph.get_subgraphs()[0].get_name() == "cluster_0"
+
+ @pytest.mark.unit
+ def test_add_node_to_parent_graph(self):
+ """Tests that you can add a node to a parent graph."""
+ dag_builder = PyDotDAGBuilder()
+
+ # Create node
+ dag_builder.add_node("0", "node0")
+
+ # Create cluster
+ dag_builder.add_cluster("c0")
+
+ # Create node inside cluster
+ dag_builder.add_node("1", "node1", cluster_uid="c0")
+
+ # Verify graph structure
+ root_graph = dag_builder.graph
+
+ # Make sure the base graph has node0
+ assert root_graph.get_node("0"), "Node 0 not found in root graph"
+
+ # Get the cluster and verify it has node1 and not node0
+ cluster_list = root_graph.get_subgraph("cluster_c0")
+ assert cluster_list, "Subgraph 'cluster_c0' not found"
+ cluster_graph = cluster_list[0] # Get the actual subgraph object
+
+ assert cluster_graph.get_node("1"), "Node 1 not found in cluster 'c0'"
+ assert not cluster_graph.get_node("0"), (
+ "Node 0 was incorrectly added to cluster"
+ )
+
+ assert not root_graph.get_node("1"), "Node 1 was incorrectly added to root"
+
+ @pytest.mark.unit
+ def test_add_cluster_to_parent_graph(self):
+ """Test that you can add a cluster to a parent graph."""
+ dag_builder = PyDotDAGBuilder()
+
+ # Level 0 (Root): Adds cluster on top of base graph
+ dag_builder.add_node("n_root", "node_root")
+
+ # Level 1 (c0): Add node on outer cluster
+ dag_builder.add_cluster("c0")
+ dag_builder.add_node("n_outer", "node_outer", cluster_uid="c0")
+
+ # Level 2 (c1): Add node on inner cluster
+ dag_builder.add_cluster("c1", cluster_uid="c0")
+ dag_builder.add_node("n_inner", "node_inner", cluster_uid="c1")
+
+ root_graph = dag_builder.graph
+
+ outer_cluster_list = root_graph.get_subgraph("cluster_c0")
+ assert outer_cluster_list, "Outer cluster 'c0' not found in root"
+ c0 = outer_cluster_list[0]
+
+ inner_cluster_list = c0.get_subgraph("cluster_c1")
+ assert inner_cluster_list, "Inner cluster 'c1' not found in 'c0'"
+ c1 = inner_cluster_list[0]
+
+ # Check Level 0 (Root)
+ assert root_graph.get_node("n_root"), "n_root not found in root"
+ assert root_graph.get_subgraph("cluster_c0"), "c0 not found in root"
+ assert not root_graph.get_node("n_outer"), "n_outer incorrectly found in root"
+ assert not root_graph.get_node("n_inner"), "n_inner incorrectly found in root"
+ assert not root_graph.get_subgraph("cluster_c1"), "c1 incorrectly found in root"
+
+ # Check Level 1 (c0)
+ assert c0.get_node("n_outer"), "n_outer not found in c0"
+ assert c0.get_subgraph("cluster_c1"), "c1 not found in c0"
+ assert not c0.get_node("n_root"), "n_root incorrectly found in c0"
+ assert not c0.get_node("n_inner"), "n_inner incorrectly found in c0"
+
+ # Check Level 2 (c1)
+ assert c1.get_node("n_inner"), "n_inner not found in c1"
+ assert not c1.get_node("n_root"), "n_root incorrectly found in c1"
+ assert not c1.get_node("n_outer"), "n_outer incorrectly found in c1"
+
+
+class TestAttributes:
+ """Tests that the attributes for elements in the graph are overridden correctly."""
+
+ @pytest.mark.unit
+ def test_default_graph_attrs(self):
+ """Test that default graph attributes can be set."""
+
+ dag_builder = PyDotDAGBuilder(attrs={"fontname": "Times"})
+
+ dag_builder.add_node("0", "node0")
+ node0 = dag_builder.graph.get_node("0")[0]
+ assert node0.get("fontname") == "Times"
+
+ dag_builder.add_cluster("1")
+ cluster = dag_builder.graph.get_subgraphs()[0]
+ assert cluster.get("fontname") == "Times"
+
+ @pytest.mark.unit
+ def test_add_node_with_attrs(self):
+ """Tests that default attributes are applied and can be overridden."""
+ dag_builder = PyDotDAGBuilder(attrs={"fillcolor": "lightblue", "penwidth": 3})
+
+ # Defaults
+ dag_builder.add_node("0", "node0")
+ node0 = dag_builder.graph.get_node("0")[0]
+ assert node0.get("fillcolor") == "lightblue"
+ assert node0.get("penwidth") == 3
+
+ # Make sure we can override
+ dag_builder.add_node("1", "node1", fillcolor="red", penwidth=4)
+ node1 = dag_builder.graph.get_node("1")[0]
+ assert node1.get("fillcolor") == "red"
+ assert node1.get("penwidth") == 4
+
+ @pytest.mark.unit
+ def test_add_edge_with_attrs(self):
+ """Tests that default attributes are applied and can be overridden."""
+ dag_builder = PyDotDAGBuilder(attrs={"color": "lightblue4", "penwidth": 3})
+
+ dag_builder.add_node("0", "node0")
+ dag_builder.add_node("1", "node1")
+ dag_builder.add_edge("0", "1")
+ edge = dag_builder.graph.get_edges()[0]
+ # Defaults defined earlier
+ assert edge.get("color") == "lightblue4"
+ assert edge.get("penwidth") == 3
+
+ # Make sure we can override
+ dag_builder.add_edge("0", "1", color="red", penwidth=4)
+ edge = dag_builder.graph.get_edges()[1]
+ assert edge.get("color") == "red"
+ assert edge.get("penwidth") == 4
+
+ @pytest.mark.unit
+ def test_add_cluster_with_attrs(self):
+ """Tests that default cluster attributes are applied and can be overridden."""
+ dag_builder = PyDotDAGBuilder(
+ attrs={
+ "style": "solid",
+ "fillcolor": None,
+ "penwidth": 2,
+ "fontname": "Helvetica",
+ }
+ )
+
+ dag_builder.add_cluster("0")
+ cluster1 = dag_builder.graph.get_subgraph("cluster_0")[0]
+
+ # Defaults
+ assert cluster1.get("style") == "solid"
+ assert cluster1.get("fillcolor") is None
+ assert cluster1.get("penwidth") == 2
+ assert cluster1.get("fontname") == "Helvetica"
+
+ dag_builder.add_cluster("1", style="filled", penwidth=10, fillcolor="red")
+ cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0]
+
+ # Make sure we can override
+ assert cluster2.get("style") == "filled"
+ assert cluster2.get("penwidth") == 10
+ assert cluster2.get("fillcolor") == "red"
+
+ # Check that other defaults are still present
+ assert cluster2.get("fontname") == "Helvetica"
+
+
+class TestProperties:
+ """Tests the properties."""
+
+ def test_nodes(self):
+ """Tests that nodes works."""
+ dag_builder = PyDotDAGBuilder()
+
+ dag_builder.add_node("0", "node0", fillcolor="red")
+ dag_builder.add_cluster("c0")
+ dag_builder.add_node("1", "node1", cluster_uid="c0")
+
+ nodes = dag_builder.nodes
+
+ assert len(nodes) == 2
+ assert len(nodes["0"]) == 4
+
+ assert nodes["0"]["uid"] == "0"
+ assert nodes["0"]["label"] == "node0"
+ assert nodes["0"]["cluster_uid"] == None
+ assert nodes["0"]["attrs"]["fillcolor"] == "red"
+
+ assert nodes["1"]["uid"] == "1"
+ assert nodes["1"]["label"] == "node1"
+ assert nodes["1"]["cluster_uid"] == "c0"
+
+ def test_edges(self):
+ """Tests that edges works."""
+
+ dag_builder = PyDotDAGBuilder()
+ dag_builder.add_node("0", "node0")
+ dag_builder.add_node("1", "node1")
+ dag_builder.add_edge("0", "1", penwidth=10)
+
+ edges = dag_builder.edges
+
+ assert len(edges) == 1
+
+ assert edges[0]["from_uid"] == "0"
+ assert edges[0]["to_uid"] == "1"
+ assert edges[0]["attrs"]["penwidth"] == 10
+
+ def test_clusters(self):
+ """Tests that clusters property works."""
+
+ dag_builder = PyDotDAGBuilder()
+ dag_builder.add_cluster("0", "my_cluster", penwidth=10)
+
+ clusters = dag_builder.clusters
+
+ dag_builder.add_cluster(
+ "1", "my_nested_cluster", cluster_uid="0",
+ )
+ clusters = dag_builder.clusters
+ assert len(clusters) == 2
+
+ assert len(clusters["0"]) == 4
+ assert clusters["0"]["uid"] == "0"
+ assert clusters["0"]["label"] == "my_cluster"
+ assert clusters["0"]["cluster_uid"] == None
+ assert clusters["0"]["attrs"]["penwidth"] == 10
+
+ assert len(clusters["1"]) == 4
+ assert clusters["1"]["uid"] == "1"
+ assert clusters["1"]["label"] == "my_nested_cluster"
+ assert clusters["1"]["cluster_uid"] == "0"
+
+
+class TestOutput:
+ """Test that the graph can be outputted correctly."""
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize(
+ "filename, format",
+ [("my_graph", None), ("my_graph", "png"), ("prototype.trial1", "png")],
+ )
+ def test_to_file(self, monkeypatch, filename, format):
+ """Tests that the `to_file` method works correctly."""
+ dag_builder = PyDotDAGBuilder()
+
+ # mock out the graph writing functionality
+ mock_write = MagicMock()
+ monkeypatch.setattr(dag_builder.graph, "write", mock_write)
+ dag_builder.to_file(filename + "." + (format or "png"))
+
+ # make sure the function handles extensions correctly
+ mock_write.assert_called_once_with(
+ filename + "." + (format or "png"), format=format or "png"
+ )
+
+ @pytest.mark.unit
+ @pytest.mark.parametrize("format", ["pdf", "svg", "jpeg"])
+ def test_other_supported_formats(self, monkeypatch, format):
+ """Tests that the `to_file` method works with other formats."""
+ dag_builder = PyDotDAGBuilder()
+
+ # mock out the graph writing functionality
+ mock_write = MagicMock()
+ monkeypatch.setattr(dag_builder.graph, "write", mock_write)
+ dag_builder.to_file(f"my_graph.{format}")
+
+ # make sure the function handles extensions correctly
+ mock_write.assert_called_once_with(f"my_graph.{format}", format=format)
+
+ @pytest.mark.unit
+ def test_to_string(self):
+ """Tests that the `to_string` method works correclty."""
+
+ dag_builder = PyDotDAGBuilder()
+ dag_builder.add_node("n0", "node0")
+ dag_builder.add_node("n1", "node1")
+ dag_builder.add_edge("n0", "n1")
+
+ string = dag_builder.to_string()
+ assert isinstance(string, str)
+
+ # make sure important things show up in the string
+ assert "digraph" in string
+ assert "n0" in string
+ assert "n1" in string
+ assert "n0 -> n1" in string
diff --git a/requirements.txt b/requirements.txt
index 2a2db94191..3fde869394 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -34,3 +34,4 @@ nbmake
# optional rt/test dependencies
pennylane-lightning-kokkos
amazon-braket-pennylane-plugin>1.27.1
+pydot