Skip to content

Commit

Permalink
redefine rgnode and layer order
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Dec 11, 2023
1 parent 69c3fdd commit 744c008
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 22 deletions.
7 changes: 4 additions & 3 deletions cirkit/new/model/tensorized_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def __init__(self, symb_circuit: SymbolicTensorizedCircuit, *, num_channels: int
self.num_channels = num_channels
self.num_classes = symb_circuit.num_classes

self.layers = nn.ModuleList() # Automatic layer registry, also publically available.
# Automatic nn.Module registry, also in publicly available children names.
self.layers = nn.ModuleList()

# TODO: or do we store edges in Layer?
# The actual internal container for forward.
# The actual internal container for forward, preserves insertion order.
self._symb_to_layers: Dict[SymbolicLayer, Optional[Layer]] = {}

# Both containers with have a consistent layer order by this loop.
for symb_layer in symb_circuit.layers:
layer: Optional[Layer]
# Ignore: all SymbolicLayer contains Any.
Expand Down
10 changes: 8 additions & 2 deletions cirkit/new/region_graph/region_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def add_edge(self, tail: RGNode, head: RGNode) -> None:
# add_node will check for _is_frozen.
self.add_node(tail)
self.add_node(head)
# TODO: this insertion order may be different from add_node order
assert tail.outputs.append(head), "The edges in RG should not be repeated."
head.inputs.append(tail) # Only need to check duplicate in one direction.

Expand Down Expand Up @@ -116,10 +115,18 @@ def freeze(self) -> Self:

def _sort_nodes(self) -> None:
"""Sort the OrderedSet of RGNode for node list and edge tables."""
# Now rg nodes have no sort_key. With a stable sort, equal nodes keep insertion order.
self._nodes.sort()
# Now the nodes are in an order determined solely by the construction algorithm.
for i, node in enumerate(self._nodes):
# Ignore: Unavoidable for Dict[str, Any].
# Disable: It is designed to be accessed.
node._metadata["sort_key"] = i # type: ignore[misc] # pylint: disable=protected-access
# Now the nodes have total ordering based on the original order.
for node in self._nodes:
node.inputs.sort()
node.outputs.sort()
# Now all containers are consistently sorted by the order decided by sort_key.

# TODO: do we need these return? or just assert?
def _validate(self) -> str: # pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -410,7 +417,6 @@ def load(filename: str) -> "RegionGraph":
regions_in = [idx_region[idx_in] for idx_in in (partition["l"], partition["r"])]
region_out = idx_region[partition["p"]]

# TODO: is the order of edge table saved in nodes preserved?
graph.add_partitioning(region_out, regions_in)

return graph.freeze()
34 changes: 20 additions & 14 deletions cirkit/new/region_graph/rg_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, cast

from cirkit.new.utils import OrderedSet, Scope

Expand All @@ -23,7 +23,7 @@ def __init__(self, scope: Iterable[int]) -> None:
self.inputs: OrderedSet[RGNode] = OrderedSet()
self.outputs: OrderedSet[RGNode] = OrderedSet()

# TODO: we might want to save something, but this is not used yet.
# TODO: refine typing to what's actually used?
self._metadata: Dict[str, Any] = {} # type: ignore[misc]

def __repr__(self) -> str:
Expand All @@ -41,21 +41,18 @@ def __repr__(self) -> str:
def __lt__(self, other: "RGNode") -> bool:
"""Compare the node with another node, for < operator implicitly used in sorting.
TODO: the following is currently NOT correct because the sorting rule is not complete.
It is guaranteed that exactly one of a == b, a < b, a > b is True. Can be used for \
sorting and order is guaranteed to be always stable.
TODO: alternative if we ignore the above todo:
Note that a != b does not imply a < b or b < a, as the order within the the same type of \
node with the same scope is not defined, in which case a == b, a < b, b < a are all false. \
Yet although there's no total ordering, sorting can still be performed.
The comparison between two RGNode is:
- If they have different scopes, the one with a smaller scope is smaller;
- With the same scope, PartitionNode is smaller than RegionNode;
- For same type of node and same scope, __lt__ is always False, indicating "equality for \
the purpose of sorting".
- For same type of node and same scope, an extra sort_key can be provided in \
self._metadata to define the order;
- If the above cannot compare, __lt__ is always False, indicating "equality for the \
purpose of sorting".
With the extra sorting key provided, it is guaranteed to have total ordering, i.e., \
exactly one of a == b, a < b, a > b is True, and will lead to a deterministic sorted order.
This comparison guarantees the topological order in a (smooth and decomposable) RG:
This comparison also guarantees the topological order in a (smooth and decomposable) RG:
- For a RegionNode->PartitionNode edge, Region.scope < Partition.scope;
- For a PartitionNode->RegionNode edge, they have the same scope and Partition < Region.
Expand All @@ -67,7 +64,16 @@ def __lt__(self, other: "RGNode") -> bool:
"""
# A trick to compare classes: if the class name is equal, then the class is the same;
# otherwise "P" < "R" and PartitionNode < RegionNode, so comparison of class names works.
return (self.scope, self.__class__.__name__) < (other.scope, other.__class__.__name__)
# Ignore: Unavoidable for Dict[str, Any].
return (
self.scope,
self.__class__.__name__,
cast(int, self._metadata.get("sort_key", -1)), # type: ignore[misc]
) < (
other.scope,
other.__class__.__name__,
cast(int, other._metadata.get("sort_key", -1)), # type: ignore[misc]
)


# Disable: It's intended for RegionNode to only have few methods.
Expand Down
5 changes: 3 additions & 2 deletions cirkit/new/symbolic/symbolic_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.

self._layers: OrderedSet[SymbolicLayer] = OrderedSet()
# The RGNode and SymbolicLayer does not map 1-to-1 but 1-to-many. This still leads to a
# deterministic order: SymbolicLayer of the same RGNode are adjcent, and ordered based on
# the order of edges in the RG.
# deterministic order: SymbolicLayer of different RGNode will be naturally sorted by the
# RGNode order; SymbolicLayer of the same RGNode are adjcent, and ordered based on the order
# of edges in the RGNode.

node_to_layer: Dict[RGNode, SymbolicLayer] = {} # Map RGNode to its "output" SymbolicLayer.

Expand Down
3 changes: 2 additions & 1 deletion cirkit/new/symbolic/symbolic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self.scope = scope

# self.inputs is filled using layers_in, while self.outputs is empty until self appears in
# another layer's layers_in. No need to de-duplicate, so prefer list over OrderedSet.
# another layer's layers_in. No need to de-duplicate, so prefer list over OrderedSet. Both
# lists automatically gain a consistent ordering with RGNode edge tables by design.
self.inputs: List[SymbolicLayer] = []
self.outputs: List[SymbolicLayer] = []
for layer_in in layers_in:
Expand Down

0 comments on commit 744c008

Please sign in to comment.