-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #167 from april-tools/tensorized_circuit
Build TensorizedCircuit
- Loading branch information
Showing
10 changed files
with
283 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .tensorized_circuit import TensorizedCircuit as TensorizedCircuit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from cirkit.new.layers import InputLayer, Layer, SumProductLayer | ||
from cirkit.new.symbolic import ( | ||
SymbolicLayer, | ||
SymbolicProductLayer, | ||
SymbolicSumLayer, | ||
SymbolicTensorizedCircuit, | ||
) | ||
|
||
|
||
class TensorizedCircuit(nn.Module): | ||
"""The tensorized circuit with concrete computational graph in PyTorch. | ||
This class is aimed for computation, and therefore does not include excessive strutural \ | ||
properties. If those are really needed, use the properties of TensorizedCircuit.symb_circuit. | ||
""" | ||
|
||
# TODO: do we also move num_channels to SymbolicTensorizedCircuit? | ||
def __init__(self, symb_circuit: SymbolicTensorizedCircuit, *, num_channels: int) -> None: | ||
"""Init class. | ||
All the other config other than num_channels should be provided to the symbolic form. | ||
Args: | ||
symb_circuit (SymbolicTensorizedCircuit): The symbolic version of the circuit. | ||
num_channels (int): The number of channels in the input. | ||
""" | ||
super().__init__() | ||
self.symb_circuit = symb_circuit | ||
self.scope = symb_circuit.scope | ||
self.num_vars = symb_circuit.num_vars | ||
|
||
self.layers = nn.ModuleList() # Automatic layer registry, also publically available. | ||
|
||
# TODO: or do we store edges in Layer? | ||
# The actual internal container for forward. | ||
self._symb_to_layers: Dict[SymbolicLayer, Optional[Layer]] = {} | ||
|
||
for symb_layer in symb_circuit.layers: | ||
layer: Optional[Layer] | ||
# Ignore: all SymbolicLayer contains Any. | ||
# Ignore: Unavoidable for kwargs. | ||
if issubclass(symb_layer.layer_cls, SumProductLayer) and isinstance( | ||
symb_layer, SymbolicProductLayer # type: ignore[misc] | ||
): # Sum-product fusion at prod: build the actual layer with arity of prod. | ||
# len(symb_layer.outputs) == 1 should be guaranteed by PartitionNode. | ||
next_layer = symb_layer.outputs[0] # There should be exactly one SymbSum output. | ||
assert ( | ||
isinstance(next_layer, SymbolicSumLayer) # type: ignore[misc] | ||
and next_layer.layer_cls == symb_layer.layer_cls | ||
), "Sum-product fusion inconsistent." | ||
layer = symb_layer.layer_cls( | ||
# TODO: is it good to use only [0]? | ||
num_input_units=symb_layer.inputs[0].num_units, | ||
num_output_units=next_layer.num_units, | ||
arity=symb_layer.arity, | ||
reparam=next_layer.reparam, | ||
**next_layer.layer_kwargs, # type: ignore[misc] | ||
) | ||
elif issubclass(symb_layer.layer_cls, SumProductLayer) and isinstance( | ||
symb_layer, SymbolicSumLayer # type: ignore[misc] | ||
): # Sum-product fusion at sum: just run checks and fill a placeholder. | ||
prev_layer = symb_layer.inputs[0] # There should be at exactly SymbProd input. | ||
assert ( | ||
len(symb_layer.inputs) == 1 # I.e., symb_layer.arity == 1. | ||
and isinstance(prev_layer, SymbolicProductLayer) # type: ignore[misc] | ||
and prev_layer.layer_cls == symb_layer.layer_cls | ||
), "Sum-product fusion inconsistent." | ||
layer = None | ||
elif not issubclass(symb_layer.layer_cls, SumProductLayer): # Normal layers. | ||
layer = symb_layer.layer_cls( | ||
# TODO: is it good to use only [0]? | ||
num_input_units=( # num_channels for InputLayers or num_units of prev layer. | ||
symb_layer.inputs[0].num_units if symb_layer.inputs else num_channels | ||
), | ||
num_output_units=symb_layer.num_units, | ||
arity=symb_layer.arity, | ||
reparam=symb_layer.reparam, | ||
**symb_layer.layer_kwargs, # type: ignore[misc] | ||
) | ||
else: | ||
# NOTE: In the above if/elif, we made all conditions explicit to make it more | ||
# readable and also easier for static analysis inside the blocks. Yet the | ||
# completeness cannot be inferred and is only guaranteed by larger picture. | ||
# Also, should anything really go wrong, we will hit this guard statement | ||
# instead of going into a wrong branch. | ||
assert False, "This should not happen." | ||
if layer is not None: # Only register actual layers. | ||
self.layers.append(layer) | ||
self._symb_to_layers[symb_layer] = layer # But keep a complete mapping. | ||
|
||
def __call__(self, x: Tensor) -> Tensor: | ||
"""Invoke the forward function. | ||
Args: | ||
x (Tensor): The input of the circuit, shape (*B, D, C). | ||
Returns: | ||
Tensor: The output of the circuit, shape (*B, num_out, num_cls). | ||
""" # TODO: single letter name? | ||
# Ignore: Idiom for nn.Module.__call__. | ||
return super().__call__(x) # type: ignore[no-any-return,misc] | ||
|
||
# TODO: do we accept each variable separately? | ||
def forward(self, x: Tensor) -> Tensor: | ||
"""Invoke the forward function. | ||
Args: | ||
x (Tensor): The input of the circuit, shape (*B, D, C). | ||
Returns: | ||
Tensor: The output of the circuit, shape (*B, num_out, num_cls). | ||
""" | ||
layer_outputs: Dict[SymbolicLayer, Tensor] = {} # shape (*B, K). | ||
|
||
for symb_layer, layer in self._symb_to_layers.items(): | ||
if layer is None: | ||
assert ( | ||
len(symb_layer.inputs) == 1 | ||
), "Only symbolic layers with arity=1 can be implemented by a place-holder." | ||
layer_outputs[symb_layer] = layer_outputs[symb_layer.inputs[0]] | ||
continue | ||
|
||
# Disable: Ternary will be too long for readability. | ||
if isinstance(layer, InputLayer): # pylint: disable=consider-ternary-expression | ||
# TODO: mypy bug? tuple(symb_layer.scope) is inferred to Any | ||
layer_input = x[..., tuple(symb_layer.scope), :].movedim( # type: ignore[misc] | ||
-2, 0 | ||
) # shape (*B, D, C) -> (H=D, *B, K=C). | ||
else: | ||
layer_input = torch.stack( | ||
[layer_outputs[layer_in] for layer_in in symb_layer.inputs], dim=0 | ||
) # shape H * (*B, K) -> (H, *B, K). | ||
layer_outputs[symb_layer] = layer(layer_input) | ||
|
||
return torch.stack( | ||
[layer_outputs[layer_out] for layer_out in self.symb_circuit.output_layers], dim=-2 | ||
) # shape num_out * (*B, K) -> (*B, num_out, num_cls=K). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from typing import Any, Dict, Iterable, Iterator, Optional, Type | ||
from typing import Any, Dict, Iterable, Iterator, Optional, Type, Union | ||
|
||
from cirkit.new.layers import InnerLayer, InputLayer | ||
from cirkit.new.layers import InputLayer, MixingLayer, ProductLayer, SumLayer, SumProductLayer | ||
from cirkit.new.region_graph import PartitionNode, RegionGraph, RegionNode, RGNode | ||
from cirkit.new.reparams import Reparameterization | ||
from cirkit.new.symbolic.symbolic_layer import ( | ||
|
@@ -9,13 +9,13 @@ | |
SymbolicProductLayer, | ||
SymbolicSumLayer, | ||
) | ||
from cirkit.new.utils import Scope | ||
from cirkit.new.utils import OrderedSet, Scope | ||
|
||
# TODO: double check docs and __repr__ | ||
|
||
|
||
# Disable: It's designed to have these many attributes. | ||
class SymbolicCircuit: # pylint: disable=too-many-instance-attributes | ||
class SymbolicTensorizedCircuit: # pylint: disable=too-many-instance-attributes | ||
"""The symbolic representation of a tensorized circuit.""" | ||
|
||
# TODO: how to design interface? require kwargs only? | ||
|
@@ -31,31 +31,37 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. | |
input_layer_cls: Type[InputLayer], | ||
input_layer_kwargs: Optional[Dict[str, Any]] = None, | ||
input_reparam: Optional[Reparameterization] = None, | ||
sum_layer_cls: Type[InnerLayer], # TODO: more specific? | ||
sum_layer_cls: Type[Union[SumLayer, SumProductLayer]], | ||
sum_layer_kwargs: Optional[Dict[str, Any]] = None, | ||
sum_reparam: Reparameterization, | ||
prod_layer_cls: Type[InnerLayer], # TODO: more specific? | ||
prod_layer_cls: Type[Union[ProductLayer, SumProductLayer]], | ||
prod_layer_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
"""Construct symbolic circuit from a region graph. | ||
Args: | ||
region_graph (RegionGraph): The region graph to convert. | ||
num_input_units (int): _description_ | ||
num_sum_units (int): _description_ | ||
num_classes (int, optional): _description_. Defaults to 1. | ||
num_input_units (int): The number of units in the input layer. | ||
num_sum_units (int): The number of units in the sum layer. Will also be used to infer \ | ||
the number of product units. | ||
num_classes (int, optional): The number of classes of the circuit output, i.e., the \ | ||
number of units in the output layer. Defaults to 1. | ||
input_layer_cls (Type[InputLayer]): The layer class for input layers. | ||
input_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \ | ||
input layer class. Defaults to None. | ||
input_reparam (Optional[Reparameterization], optional): The reparameterization for \ | ||
input layer parameters, can be None if it has no params. Defaults to None. | ||
sum_layer_cls (Type[InnerLayer]): The layer class for sum layers. | ||
sum_layer_cls (Type[Union[SumLayer, SumProductLayer]]): The layer class for sum \ | ||
layers, can be either just a class of SumLayer, or a class of SumProductLayer to \ | ||
indicate layer fusion.. | ||
sum_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for sum \ | ||
layer class. Defaults to None. | ||
sum_reparam (Reparameterization): The reparameterization for sum layer parameters. | ||
prod_layer_cls (Type[InnerLayer]): The layer class for product layers. | ||
prod_layer_cls (Type[Union[ProductLayer, SumProductLayer]]): The layer class for \ | ||
product layers, can be either just a class of ProductLayer, or a class of \ | ||
SumProductLayer to indicate layer fusion. | ||
prod_layer_kwargs (Optional[Dict[str, Any]], optional): The additional kwargs for \ | ||
product layer class. Defaults to None. | ||
product layer class, will be ignored if SumProductLayer is used. Defaults to None. | ||
""" | ||
self.region_graph = region_graph | ||
self.scope = region_graph.scope | ||
|
@@ -65,32 +71,72 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. | |
self.is_structured_decomposable = region_graph.is_structured_decomposable | ||
self.is_omni_compatible = region_graph.is_omni_compatible | ||
|
||
node_layer: Dict[RGNode, SymbolicLayer] = {} | ||
self._layers: OrderedSet[SymbolicLayer] = OrderedSet() | ||
# The RGNode and SymbolicLayer does not map 1-to-1 but 1-to-many. This still leads to a | ||
# deterministic order: SymbolicLayer of the same RGNode are adjcent, and ordered based on | ||
# the order of edges in the RG. | ||
|
||
node_to_layer: Dict[RGNode, SymbolicLayer] = {} # Map RGNode to its "output" SymbolicLayer. | ||
|
||
for rg_node in region_graph.nodes: | ||
layers_in = (node_layer[node_in] for node_in in rg_node.inputs) | ||
layer: SymbolicLayer | ||
# Cannot use a generator as layers_in, because it's used twice. | ||
layers_in = [node_to_layer[node_in] for node_in in rg_node.inputs] | ||
layer_out: SymbolicLayer | ||
# Ignore: Unavoidable for kwargs. | ||
if isinstance(rg_node, RegionNode) and not rg_node.inputs: # Input node. | ||
layer = SymbolicInputLayer( | ||
if isinstance(rg_node, RegionNode) and not rg_node.inputs: # Input region. | ||
layers_in = [ | ||
SymbolicInputLayer( | ||
rg_node, | ||
(), # Old layers_in should be empty. | ||
num_units=num_input_units, | ||
layer_cls=input_layer_cls, | ||
layer_kwargs=input_layer_kwargs, # type: ignore[misc] | ||
reparam=input_reparam, | ||
) | ||
] | ||
# This also works when the input is also output, in which case num_classes is used. | ||
layer_out = SymbolicSumLayer( | ||
rg_node, | ||
layers_in, | ||
num_units=num_input_units, | ||
layer_cls=input_layer_cls, | ||
layer_kwargs=input_layer_kwargs, # type: ignore[misc] | ||
reparam=input_reparam, | ||
num_units=num_sum_units if rg_node.outputs else num_classes, | ||
layer_cls=sum_layer_cls, | ||
layer_kwargs=sum_layer_kwargs, # type: ignore[misc] | ||
reparam=sum_reparam, | ||
) | ||
elif isinstance(rg_node, RegionNode) and rg_node.inputs: # Inner region node. | ||
layer = SymbolicSumLayer( | ||
elif isinstance(rg_node, RegionNode) and len(rg_node.inputs) == 1: # Simple inner. | ||
# layers_in keeps the same. | ||
layer_out = SymbolicSumLayer( | ||
rg_node, | ||
layers_in, | ||
num_units=num_sum_units if rg_node.outputs else num_classes, | ||
layer_cls=sum_layer_cls, | ||
layer_kwargs=sum_layer_kwargs, # type: ignore[misc] | ||
reparam=sum_reparam, | ||
) | ||
elif isinstance(rg_node, PartitionNode): # Partition node. | ||
layer = SymbolicProductLayer( | ||
elif isinstance(rg_node, RegionNode) and len(rg_node.inputs) > 1: # Inner with mixture. | ||
# MixingLayer cannot change number of units, so must project early. | ||
layers_in = [ | ||
SymbolicSumLayer( | ||
rg_node, | ||
(layer_in,), | ||
num_units=num_sum_units if rg_node.outputs else num_classes, | ||
layer_cls=sum_layer_cls, | ||
layer_kwargs=sum_layer_kwargs, # type: ignore[misc] | ||
reparam=sum_reparam, | ||
) | ||
for layer_in in layers_in | ||
] | ||
layer_out = SymbolicSumLayer( | ||
rg_node, | ||
layers_in, | ||
num_units=num_sum_units if rg_node.outputs else num_classes, | ||
layer_cls=MixingLayer, | ||
layer_kwargs={}, # type: ignore[misc] | ||
reparam=sum_reparam, # TODO: use a constant reparam here? | ||
) | ||
elif isinstance(rg_node, PartitionNode): | ||
# layers_in keeps the same. | ||
layer_out = SymbolicProductLayer( | ||
rg_node, | ||
layers_in, | ||
num_units=prod_layer_cls._infer_num_prod_units( | ||
|
@@ -101,10 +147,18 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. | |
reparam=None, | ||
) | ||
else: | ||
# NOTE: In the above if/elif, we made all conditions explicit to make it more | ||
# readable and also easier for static analysis inside the blocks. Yet the | ||
# completeness cannot be inferred and is only guaranteed by larger picture. | ||
# Also, should anything really go wrong, we will hit this guard statement | ||
# instead of going into a wrong branch. | ||
assert False, "This should not happen." | ||
node_layer[rg_node] = layer | ||
|
||
self._node_layer = node_layer # Insertion order is preserved by [email protected]+. | ||
# layers_in may be existing layers (from node_layer) which will be de-duplicated by | ||
# OrderedSet, or newly constructed layers to be added. | ||
self._layers.extend(layers_in) | ||
# layer_out is what will be connected to the output of rg_node. | ||
self._layers.append(layer_out) | ||
node_to_layer[rg_node] = layer_out | ||
|
||
####################################### Properties ####################################### | ||
# Here are the basic properties and some structural properties of the SymbC. Some of them are | ||
|
@@ -130,7 +184,7 @@ def __init__( # type: ignore[misc] # Ignore: Unavoidable for kwargs. | |
"""Whether the SymbC is omni-compatible, i.e., compatible to all circuits of the same scope.""" | ||
|
||
def is_compatible( | ||
self, other: "SymbolicCircuit", *, scope: Optional[Iterable[int]] = None | ||
self, other: "SymbolicTensorizedCircuit", *, scope: Optional[Iterable[int]] = None | ||
) -> bool: | ||
"""Test compatibility with another symbolic circuit over the given scope. | ||
|
@@ -153,7 +207,7 @@ def is_compatible( | |
@property | ||
def layers(self) -> Iterator[SymbolicLayer]: | ||
"""All layers in the circuit.""" | ||
return iter(self._node_layer.values()) | ||
return iter(self._layers) | ||
|
||
@property | ||
def sum_layers(self) -> Iterator[SymbolicSumLayer]: | ||
|
Oops, something went wrong.