Skip to content

Commit

Permalink
Merge pull request #167 from april-tools/tensorized_circuit
Browse files Browse the repository at this point in the history
Build TensorizedCircuit
  • Loading branch information
lkct authored Dec 13, 2023
2 parents bc9d592 + c8aa97f commit a108536
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 61 deletions.
7 changes: 6 additions & 1 deletion cirkit/new/layers/inner/sum/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@


class MixingLayer(SumLayer):
"""The sum layer for mixture among layers."""
"""The sum layer for mixture among layers.
It can also be used as a sparse sum within a layer when arity=1.
"""

# TODO: do we use another name for another purpose?

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions cirkit/new/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tensorized_circuit import TensorizedCircuit as TensorizedCircuit
142 changes: 142 additions & 0 deletions cirkit/new/model/tensorized_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Dict, Optional

import torch
from torch import Tensor, nn

from cirkit.new.layers import InputLayer, Layer, SumProductLayer
from cirkit.new.symbolic import (
SymbolicLayer,
SymbolicProductLayer,
SymbolicSumLayer,
SymbolicTensorizedCircuit,
)


class TensorizedCircuit(nn.Module):
"""The tensorized circuit with concrete computational graph in PyTorch.
This class is aimed for computation, and therefore does not include excessive strutural \
properties. If those are really needed, use the properties of TensorizedCircuit.symb_circuit.
"""

# TODO: do we also move num_channels to SymbolicTensorizedCircuit?
def __init__(self, symb_circuit: SymbolicTensorizedCircuit, *, num_channels: int) -> None:
"""Init class.
All the other config other than num_channels should be provided to the symbolic form.
Args:
symb_circuit (SymbolicTensorizedCircuit): The symbolic version of the circuit.
num_channels (int): The number of channels in the input.
"""
super().__init__()
self.symb_circuit = symb_circuit
self.scope = symb_circuit.scope
self.num_vars = symb_circuit.num_vars

self.layers = nn.ModuleList() # Automatic layer registry, also publically available.

# TODO: or do we store edges in Layer?
# The actual internal container for forward.
self._symb_to_layers: Dict[SymbolicLayer, Optional[Layer]] = {}

for symb_layer in symb_circuit.layers:
layer: Optional[Layer]
# Ignore: all SymbolicLayer contains Any.
# Ignore: Unavoidable for kwargs.
if issubclass(symb_layer.layer_cls, SumProductLayer) and isinstance(
symb_layer, SymbolicProductLayer # type: ignore[misc]
): # Sum-product fusion at prod: build the actual layer with arity of prod.
# len(symb_layer.outputs) == 1 should be guaranteed by PartitionNode.
next_layer = symb_layer.outputs[0] # There should be exactly one SymbSum output.
assert (
isinstance(next_layer, SymbolicSumLayer) # type: ignore[misc]
and next_layer.layer_cls == symb_layer.layer_cls
), "Sum-product fusion inconsistent."
layer = symb_layer.layer_cls(
# TODO: is it good to use only [0]?
num_input_units=symb_layer.inputs[0].num_units,
num_output_units=next_layer.num_units,
arity=symb_layer.arity,
reparam=next_layer.reparam,
**next_layer.layer_kwargs, # type: ignore[misc]
)
elif issubclass(symb_layer.layer_cls, SumProductLayer) and isinstance(
symb_layer, SymbolicSumLayer # type: ignore[misc]
): # Sum-product fusion at sum: just run checks and fill a placeholder.
prev_layer = symb_layer.inputs[0] # There should be at exactly SymbProd input.
assert (
len(symb_layer.inputs) == 1 # I.e., symb_layer.arity == 1.
and isinstance(prev_layer, SymbolicProductLayer) # type: ignore[misc]
and prev_layer.layer_cls == symb_layer.layer_cls
), "Sum-product fusion inconsistent."
layer = None
elif not issubclass(symb_layer.layer_cls, SumProductLayer): # Normal layers.
layer = symb_layer.layer_cls(
# TODO: is it good to use only [0]?
num_input_units=( # num_channels for InputLayers or num_units of prev layer.
symb_layer.inputs[0].num_units if symb_layer.inputs else num_channels
),
num_output_units=symb_layer.num_units,
arity=symb_layer.arity,
reparam=symb_layer.reparam,
**symb_layer.layer_kwargs, # type: ignore[misc]
)
else:
# NOTE: In the above if/elif, we made all conditions explicit to make it more
# readable and also easier for static analysis inside the blocks. Yet the
# completeness cannot be inferred and is only guaranteed by larger picture.
# Also, should anything really go wrong, we will hit this guard statement
# instead of going into a wrong branch.
assert False, "This should not happen."
if layer is not None: # Only register actual layers.
self.layers.append(layer)
self._symb_to_layers[symb_layer] = layer # But keep a complete mapping.

def __call__(self, x: Tensor) -> Tensor:
"""Invoke the forward function.
Args:
x (Tensor): The input of the circuit, shape (*B, D, C).
Returns:
Tensor: The output of the circuit, shape (*B, num_out, num_cls).
""" # TODO: single letter name?
# Ignore: Idiom for nn.Module.__call__.
return super().__call__(x) # type: ignore[no-any-return,misc]

# TODO: do we accept each variable separately?
def forward(self, x: Tensor) -> Tensor:
"""Invoke the forward function.
Args:
x (Tensor): The input of the circuit, shape (*B, D, C).
Returns:
Tensor: The output of the circuit, shape (*B, num_out, num_cls).
"""
layer_outputs: Dict[SymbolicLayer, Tensor] = {} # shape (*B, K).

for symb_layer, layer in self._symb_to_layers.items():
if layer is None:
assert (
len(symb_layer.inputs) == 1
), "Only symbolic layers with arity=1 can be implemented by a place-holder."
layer_outputs[symb_layer] = layer_outputs[symb_layer.inputs[0]]
continue

# Disable: Ternary will be too long for readability.
if isinstance(layer, InputLayer): # pylint: disable=consider-ternary-expression
# TODO: mypy bug? tuple(symb_layer.scope) is inferred to Any
layer_input = x[..., tuple(symb_layer.scope), :].movedim( # type: ignore[misc]
-2, 0
) # shape (*B, D, C) -> (H=D, *B, K=C).
else:
layer_input = torch.stack(
[layer_outputs[layer_in] for layer_in in symb_layer.inputs], dim=0
) # shape H * (*B, K) -> (H, *B, K).
layer_outputs[symb_layer] = layer(layer_input)

return torch.stack(
[layer_outputs[layer_out] for layer_out in self.symb_circuit.output_layers], dim=-2
) # shape num_out * (*B, K) -> (*B, num_out, num_cls=K).
11 changes: 8 additions & 3 deletions cirkit/new/region_graph/region_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def add_edge(self, tail: RGNode, head: RGNode) -> None:
# add_node will check for _is_frozen.
self.add_node(tail)
self.add_node(head)
tail.outputs.append(head) # TODO: this insertion order may be different from add_node order
head.inputs.append(tail)
# TODO: this insertion order may be different from add_node order
assert tail.outputs.append(head), "The edges in RG should not be repeated."
head.inputs.append(tail) # Only need to check duplicate in one direction.

def add_partitioning(self, region: RegionNode, sub_regions: Iterable[RegionNode]) -> None:
"""Add a partitioning structure to the graph, with a PartitionNode constructed internally.
Expand Down Expand Up @@ -120,7 +121,8 @@ def _sort_nodes(self) -> None:
node.inputs.sort()
node.outputs.sort()

def _validate(self) -> str:
# TODO: do we need these return? or just assert?
def _validate(self) -> str: # pylint: disable=too-many-return-statements
"""Validate the RG structure to make sure it's a legal computational graph.
Returns:
Expand All @@ -132,6 +134,9 @@ def _validate(self) -> str:
if next(self.output_nodes, None) is None:
return "RG must have at least one output node"

# Also guarantees the input/output nodes are all regions.
if not all(partition.inputs for partition in self.partition_nodes):
return "PartitionNode must have at least one input"
if any(len(partition.outputs) != 1 for partition in self.partition_nodes):
return "PartitionNode can only have one output RegionNode"

Expand Down
2 changes: 1 addition & 1 deletion cirkit/new/symbolic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .symbolic_circuit import SymbolicCircuit as SymbolicCircuit
from .symbolic_circuit import SymbolicTensorizedCircuit as SymbolicTensorizedCircuit
from .symbolic_layer import SymbolicInputLayer as SymbolicInputLayer
from .symbolic_layer import SymbolicLayer as SymbolicLayer
from .symbolic_layer import SymbolicProductLayer as SymbolicProductLayer
Expand Down
114 changes: 84 additions & 30 deletions cirkit/new/symbolic/symbolic_circuit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Iterable, Iterator, Optional, Type
from typing import Any, Dict, Iterable, Iterator, Optional, Type, Union

from cirkit.new.layers import InnerLayer, InputLayer
from cirkit.new.layers import InputLayer, MixingLayer, ProductLayer, SumLayer, SumProductLayer
from cirkit.new.region_graph import PartitionNode, RegionGraph, RegionNode, RGNode
from cirkit.new.reparams import Reparameterization
from cirkit.new.symbolic.symbolic_layer import (
Expand All @@ -9,13 +9,13 @@
SymbolicProductLayer,
SymbolicSumLayer,
)
from cirkit.new.utils import Scope
from cirkit.new.utils import OrderedSet, Scope

# TODO: double check docs and __repr__


# Disable: It's designed to have these many attributes.
class SymbolicCircuit: # pylint: disable=too-many-instance-attributes
class SymbolicTensorizedCircuit: # pylint: disable=too-many-instance-attributes
"""The symbolic representation of a tensorized circuit."""

# TODO: how to design interface? require kwargs only?
Expand All @@ -31,31 +31,37 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
input_layer_cls: Type[InputLayer],
input_layer_kwargs: Optional[Dict[str, Any]] = None,
input_reparam: Optional[Reparameterization] = None,
sum_layer_cls: Type[InnerLayer], # TODO: more specific?
sum_layer_cls: Type[Union[SumLayer, SumProductLayer]],
sum_layer_kwargs: Optional[Dict[str, Any]] = None,
sum_reparam: Reparameterization,
prod_layer_cls: Type[InnerLayer], # TODO: more specific?
prod_layer_cls: Type[Union[ProductLayer, SumProductLayer]],
prod_layer_kwargs: Optional[Dict[str, Any]] = None,
):
"""Construct symbolic circuit from a region graph.
Args:
region_graph (RegionGraph): The region graph to convert.
num_input_units (int): _description_
num_sum_units (int): _description_
num_classes (int, optional): _description_. Defaults to 1.
num_input_units (int): The number of units in the input layer.
num_sum_units (int): The number of units in the sum layer. Will also be used to infer \
the number of product units.
num_classes (int, optional): The number of classes of the circuit output, i.e., the \
number of units in the output layer. Defaults to 1.
input_layer_cls (Type[InputLayer]): The layer class for input layers.
input_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \
input layer class. Defaults to None.
input_reparam (Optional[Reparameterization], optional): The reparameterization for \
input layer parameters, can be None if it has no params. Defaults to None.
sum_layer_cls (Type[InnerLayer]): The layer class for sum layers.
sum_layer_cls (Type[Union[SumLayer, SumProductLayer]]): The layer class for sum \
layers, can be either just a class of SumLayer, or a class of SumProductLayer to \
indicate layer fusion..
sum_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for sum \
layer class. Defaults to None.
sum_reparam (Reparameterization): The reparameterization for sum layer parameters.
prod_layer_cls (Type[InnerLayer]): The layer class for product layers.
prod_layer_cls (Type[Union[ProductLayer, SumProductLayer]]): The layer class for \
product layers, can be either just a class of ProductLayer, or a class of \
SumProductLayer to indicate layer fusion.
prod_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \
product layer class. Defaults to None.
product layer class, will be ignored if SumProductLayer is used. Defaults to None.
"""
self.region_graph = region_graph
self.scope = region_graph.scope
Expand All @@ -65,32 +71,72 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
self.is_structured_decomposable = region_graph.is_structured_decomposable
self.is_omni_compatible = region_graph.is_omni_compatible

node_layer: Dict[RGNode, SymbolicLayer] = {}
self._layers: OrderedSet[SymbolicLayer] = OrderedSet()
# The RGNode and SymbolicLayer does not map 1-to-1 but 1-to-many. This still leads to a
# deterministic order: SymbolicLayer of the same RGNode are adjcent, and ordered based on
# the order of edges in the RG.

node_to_layer: Dict[RGNode, SymbolicLayer] = {} # Map RGNode to its "output" SymbolicLayer.

for rg_node in region_graph.nodes:
layers_in = (node_layer[node_in] for node_in in rg_node.inputs)
layer: SymbolicLayer
# Cannot use a generator as layers_in, because it's used twice.
layers_in = [node_to_layer[node_in] for node_in in rg_node.inputs]
layer_out: SymbolicLayer
# Ignore: Unavoidable for kwargs.
if isinstance(rg_node, RegionNode) and not rg_node.inputs: # Input node.
layer = SymbolicInputLayer(
if isinstance(rg_node, RegionNode) and not rg_node.inputs: # Input region.
layers_in = [
SymbolicInputLayer(
rg_node,
(), # Old layers_in should be empty.
num_units=num_input_units,
layer_cls=input_layer_cls,
layer_kwargs=input_layer_kwargs, # type: ignore[misc]
reparam=input_reparam,
)
]
# This also works when the input is also output, in which case num_classes is used.
layer_out = SymbolicSumLayer(
rg_node,
layers_in,
num_units=num_input_units,
layer_cls=input_layer_cls,
layer_kwargs=input_layer_kwargs, # type: ignore[misc]
reparam=input_reparam,
num_units=num_sum_units if rg_node.outputs else num_classes,
layer_cls=sum_layer_cls,
layer_kwargs=sum_layer_kwargs, # type: ignore[misc]
reparam=sum_reparam,
)
elif isinstance(rg_node, RegionNode) and rg_node.inputs: # Inner region node.
layer = SymbolicSumLayer(
elif isinstance(rg_node, RegionNode) and len(rg_node.inputs) == 1: # Simple inner.
# layers_in keeps the same.
layer_out = SymbolicSumLayer(
rg_node,
layers_in,
num_units=num_sum_units if rg_node.outputs else num_classes,
layer_cls=sum_layer_cls,
layer_kwargs=sum_layer_kwargs, # type: ignore[misc]
reparam=sum_reparam,
)
elif isinstance(rg_node, PartitionNode): # Partition node.
layer = SymbolicProductLayer(
elif isinstance(rg_node, RegionNode) and len(rg_node.inputs) > 1: # Inner with mixture.
# MixingLayer cannot change number of units, so must project early.
layers_in = [
SymbolicSumLayer(
rg_node,
(layer_in,),
num_units=num_sum_units if rg_node.outputs else num_classes,
layer_cls=sum_layer_cls,
layer_kwargs=sum_layer_kwargs, # type: ignore[misc]
reparam=sum_reparam,
)
for layer_in in layers_in
]
layer_out = SymbolicSumLayer(
rg_node,
layers_in,
num_units=num_sum_units if rg_node.outputs else num_classes,
layer_cls=MixingLayer,
layer_kwargs={}, # type: ignore[misc]
reparam=sum_reparam, # TODO: use a constant reparam here?
)
elif isinstance(rg_node, PartitionNode):
# layers_in keeps the same.
layer_out = SymbolicProductLayer(
rg_node,
layers_in,
num_units=prod_layer_cls._infer_num_prod_units(
Expand All @@ -101,10 +147,18 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
reparam=None,
)
else:
# NOTE: In the above if/elif, we made all conditions explicit to make it more
# readable and also easier for static analysis inside the blocks. Yet the
# completeness cannot be inferred and is only guaranteed by larger picture.
# Also, should anything really go wrong, we will hit this guard statement
# instead of going into a wrong branch.
assert False, "This should not happen."
node_layer[rg_node] = layer

self._node_layer = node_layer # Insertion order is preserved by [email protected]+.
# layers_in may be existing layers (from node_layer) which will be de-duplicated by
# OrderedSet, or newly constructed layers to be added.
self._layers.extend(layers_in)
# layer_out is what will be connected to the output of rg_node.
self._layers.append(layer_out)
node_to_layer[rg_node] = layer_out

####################################### Properties #######################################
# Here are the basic properties and some structural properties of the SymbC. Some of them are
Expand All @@ -130,7 +184,7 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
"""Whether the SymbC is omni-compatible, i.e., compatible to all circuits of the same scope."""

def is_compatible(
self, other: "SymbolicCircuit", *, scope: Optional[Iterable[int]] = None
self, other: "SymbolicTensorizedCircuit", *, scope: Optional[Iterable[int]] = None
) -> bool:
"""Test compatibility with another symbolic circuit over the given scope.
Expand All @@ -153,7 +207,7 @@ def is_compatible(
@property
def layers(self) -> Iterator[SymbolicLayer]:
"""All layers in the circuit."""
return iter(self._node_layer.values())
return iter(self._layers)

@property
def sum_layers(self) -> Iterator[SymbolicSumLayer]:
Expand Down
Loading

0 comments on commit a108536

Please sign in to comment.