diff --git a/cirkit/new/symbolic/__init__.py b/cirkit/new/symbolic/__init__.py new file mode 100644 index 00000000..29504a1b --- /dev/null +++ b/cirkit/new/symbolic/__init__.py @@ -0,0 +1,9 @@ +# type: ignore +# pylint: skip-file +from .symbolic_circuit import SymbolicCircuit +from .symbolic_layer import ( + SymbolicInputLayer, + SymbolicLayer, + SymbolicProductLayer, + SymbolicSumLayer, +) diff --git a/cirkit/new/symbolic/symbolic_circuit.py b/cirkit/new/symbolic/symbolic_circuit.py new file mode 100644 index 00000000..e2e36efa --- /dev/null +++ b/cirkit/new/symbolic/symbolic_circuit.py @@ -0,0 +1,241 @@ +# type: ignore +# pylint: skip-file +from functools import cached_property +from typing import Any, Dict, FrozenSet, Iterable, Optional, Set, Type + +from cirkit.layers.input.exp_family import ExpFamilyLayer +from cirkit.layers.sum_product import SumProductLayer +from cirkit.new.symbolic.symbolic_layer import ( + SymbolicInputLayer, + SymbolicLayer, + SymbolicProductLayer, + SymbolicSumLayer, +) +from cirkit.region_graph import RegionGraph +from cirkit.region_graph.rg_node import RGNode +from cirkit.reparams.leaf import ReparamIdentity +from cirkit.utils.type_aliases import ReparamFactory + + +class SymbolicCircuit: + """The Symbolic Circuit, similar to cirkit.region_graph.RegionGraph.""" + + def __init__( + self, + region_graph: RegionGraph, + layer_cls: Type[SumProductLayer], + efamily_cls: Type[ExpFamilyLayer], + layer_kwargs: Optional[Dict[str, Any]] = None, + efamily_kwargs: Optional[Dict[str, Any]] = None, + reparam: ReparamFactory = ReparamIdentity, + num_inner_units: int = 2, + num_input_units: int = 2, + num_channels: int = 1, + num_classes: int = 1, + ): + """Construct symbolic circuit from a region graph. + + Args: + region_graph (RegionGraph): The region graph to convert. + layer_cls (Type[SumProductLayer]): The layer class for inner layers. + efamily_cls (Type[ExpFamilyLayer]): The layer class for input layers. + layer_kwargs (Optional[Dict[str, Any]]): The parameters for inner layer class. + efamily_kwargs (Optional[Dict[str, Any]]): The parameters for input layer class. + reparam (ReparamFactory): The reparametrization function. + num_inner_units (int): Number of units for inner layers. + num_input_units (int): Number of units for input layers. + num_channels (int): Number of channels (e.g., 3 for RGB pixel) for input layers. + num_classes (int): Number of classes for the PC. + + """ + self.region_graph = region_graph + + self._layers: Set[SymbolicLayer] = set() + + existing_symbolic_layers: Dict[RGNode, SymbolicLayer] = {} + + for input_node in region_graph.input_nodes: + rg_node_stack = [(input_node, None)] + + while rg_node_stack: + rg_node, prev_symbolic_layer = rg_node_stack.pop() + if rg_node in existing_symbolic_layers: + symbolic_layer = existing_symbolic_layers[rg_node] + else: + # Construct a symbolic layer from the region node + symbolic_layer = self._from_region_node( + prev_symbolic_layer, + rg_node, + region_graph, + layer_cls, + efamily_cls, + layer_kwargs, + efamily_kwargs, + reparam, + num_inner_units, + num_input_units, + num_channels, + num_classes, + ) + existing_symbolic_layers[rg_node] = symbolic_layer + + # Connect previous symbolic layer to the current one + if prev_symbolic_layer: + self._add_edge(prev_symbolic_layer, symbolic_layer) + + # Handle multiple source nodes + for output_rg_node in rg_node.outputs: + rg_node_stack.append((output_rg_node, symbolic_layer)) + + def _from_region_node( + self, + prev_symbolic_layer: SymbolicLayer, + rg_node: RGNode, + region_graph: RegionGraph, + layer_cls: Type[SumProductLayer], + efamily_cls: Type[ExpFamilyLayer], + layer_kwargs: Optional[Dict[str, Any]], + efamily_kwargs: Optional[Dict[str, Any]], + reparam: ReparamFactory, + num_inner_units: int, + num_input_units: int, + num_channels: int, + num_classes: int, + ) -> SymbolicLayer: + """Create a symbolic layer based on the given region node. + + Args: + prev_symbolic_layer (SymbolicLayer): The parent symbolic layer + (starting from input layer) that the current layer grown from. + rg_node (RGNode): The current region graph node to convert to symbolic layer. + region_graph (RegionGraph): The region graph. + layer_cls (Type[SumProductLayer]): The layer class for inner layers. + efamily_cls (Type[ExpFamilyLayer]): The layer class for input layers. + layer_kwargs (Optional[Dict[str, Any]]): The parameters for inner layer class. + efamily_kwargs (Optional[Dict[str, Any]]): The parameters for input layer class. + reparam (ReparamFactory): The reparametrization function. + num_inner_units (int): Number of units for inner layers. + num_input_units (int): Number of units for input layers. + num_channels (int): Number of channels (e.g., 3 for RGB pixel) for input layers. + num_classes (int): Number of classes for the PC. + + Returns: + SymbolicLayer: The constructed symbolic layer. + + Raises: + ValueError: If the region node is not valid. + """ + scope = rg_node.scope + inputs = rg_node.inputs + outputs = rg_node.outputs + + if rg_node in region_graph.inner_region_nodes: + assert len(inputs) == 1, "Inner region nodes should have exactly one input." + + output_units = num_classes if rg_node in region_graph.output_nodes else num_inner_units + input_units = ( + num_input_units + if any( + isinstance(layer, SymbolicInputLayer) for layer in prev_symbolic_layer.inputs + ) + else num_inner_units + ) + + symbolic_layer = SymbolicSumLayer(scope, output_units, layer_cls, layer_kwargs) + symbolic_layer.set_placeholder_params(input_units, output_units, reparam) + + elif rg_node in region_graph.partition_nodes: + assert len(inputs) == 2, "Partition nodes should have exactly two inputs." + assert len(outputs) > 0, "Partition nodes should have at least one output." + + left_input_units = num_inner_units if inputs[0].inputs else num_input_units + right_input_units = num_inner_units if inputs[1].inputs else num_input_units + + assert ( + left_input_units == right_input_units + ), "Input units for partition nodes should match." + + symbolic_layer = SymbolicProductLayer(scope, left_input_units, layer_cls) + + elif rg_node in region_graph.input_nodes: + num_replicas = region_graph.num_replicas + + symbolic_layer = SymbolicInputLayer(scope, num_input_units, efamily_cls, efamily_kwargs) + symbolic_layer.set_placeholder_params(num_channels, num_replicas, reparam) + + else: + raise ValueError("Region node not valid.") + + return symbolic_layer + + def _add_edge(self, tail: SymbolicLayer, head: SymbolicLayer): + """Add edge and layer. + + Args: + tail (SymbolicLayer): The layer the edge originates from. + head (SymbolicLayer): The layer the edge points to. + """ + self._layers.add(tail) + self._layers.add(head) + tail.outputs.add(head) + head.inputs.add(tail) + + ########################## Properties ######################### + + @property + def scope(self) -> FrozenSet[int]: + """Get the total scope the circuit.""" + scopes = [layer.scope for layer in self.output_layers] + return frozenset(set().union(*scopes)) + + @property + def layers(self) -> Iterable[SymbolicLayer]: + """Get all the layers in the circuit.""" + return iter(self._layers) + + @property + def input_layers(self) -> Iterable[SymbolicLayer]: + """Get input layers of the circuiit.""" + return (layer for layer in self.layers if isinstance(layer, SymbolicInputLayer)) + + @property + def output_layers(self) -> Iterable[SymbolicLayer]: + """Get output layer of the circuit.""" + return (layer for layer in self.layers if not layer.outputs) + + @property + def sum_layers(self) -> Iterable[SymbolicLayer]: + """Get inner sum layers of the circuit.""" + return (layer for layer in self.layers if isinstance(layer, SymbolicSumLayer)) + + @property + def product_layers(self) -> Iterable[SymbolicLayer]: + """Get inner product layers of the circuit.""" + return (layer for layer in self.layers if isinstance(layer, SymbolicProductLayer)) + + ########################## Structural Properties ######################### + + @cached_property + def is_smooth(self) -> bool: + """Test smoothness in symbolic circuit.""" + return self.region_graph.is_smooth + + @cached_property + def is_decomposable(self) -> bool: + """Test decomposability in symbolic circuit.""" + return self.region_graph.is_decomposable + + @cached_property + def is_structured_decomposable(self) -> bool: + """Test structural decomposability in symbolic circuit.""" + return self.region_graph.is_structured_decomposable + + def is_compatible(self, other, x_scope) -> bool: + """Test compatibility, if self and other are compatible w.r.t x_scope. + + Args: + other (SymbolicCircuit): Another symbolic circuit to test compatibility. + x_scope (Iterable[int]): The compatible scope. + + """ + return self.region_graph.is_compatible(other.region_graph, x_scope) diff --git a/cirkit/new/symbolic/symbolic_layer.py b/cirkit/new/symbolic/symbolic_layer.py new file mode 100644 index 00000000..5044dc09 --- /dev/null +++ b/cirkit/new/symbolic/symbolic_layer.py @@ -0,0 +1,252 @@ +# type: ignore +# pylint: skip-file +from abc import ABC +from typing import Any, Dict, Iterable, Optional, Set, Type + +from cirkit.layers.input.exp_family import ( + BinomialLayer, + CategoricalLayer, + ExpFamilyLayer, + NormalLayer, +) +from cirkit.layers.sum_product import ( + CollapsedCPLayer, + SharedCPLayer, + SumProductLayer, + TuckerLayer, + UncollapsedCPLayer, +) +from cirkit.reparams.leaf import ReparamIdentity +from cirkit.utils.type_aliases import ReparamFactory + + +class SymbolicLayer(ABC): + # pylint: disable=too-few-public-methods + """Base class for symbolic nodes in symmbolic circuit.""" + + def __init__(self, scope: Iterable[int]) -> None: + """Construct the Symbolic Node. + + Args: + scope (Iterable[int]): The scope of this node. + """ + self.scope = frozenset(scope) + assert self.scope, "The scope of a node must be non-empty" + + self.inputs: Set[Any] = set() + self.outputs: Set[Any] = set() + + def __repr__(self) -> str: + """Generate the `repr` string of the node.""" + class_name = self.__class__.__name__ + scope = repr(set(self.scope)) + return f"{class_name}:\nScope: {scope}\n" + + +class SymbolicSumLayer(SymbolicLayer): + """Class representing sum nodes in the symbolic circuit.""" + + def __init__( + self, + scope: Iterable[int], + num_units: int, + layer_cls: Type[SumProductLayer], + layer_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """Construct the Symbolic Sum Node. + + Args: + scope (Iterable[int]): The scope of this node. + num_units (int): Number of output units in this node. + layer_cls (Type[SumProductLayer]): The inner (sum) layer class. + layer_kwargs (Optional[Dict[str, Any]]): The parameters for the inner layer class. + + Raises: + NotImplementedError: If the shared uncollapsed CP is not implemented. + """ + super().__init__(scope) + self.num_units = num_units + self.layer_kwargs = layer_kwargs + self.params = None + self.params_in = None + self.params_out = None + + if layer_cls == TuckerLayer: + self.layer_cls = layer_cls + else: # CP layer + collapsed = ( + self.layer_kwargs["collapsed"] if ("collapsed" in self.layer_kwargs) else True + ) + shared = self.layer_kwargs["shared"] if ("shared" in self.layer_kwargs) else False + + if not shared and collapsed: + self.layer_cls = CollapsedCPLayer + elif not shared and not collapsed: + self.layer_cls = UncollapsedCPLayer + elif shared and collapsed: + self.layer_cls = SharedCPLayer + else: + raise NotImplementedError("The shared uncollapsed CP is not implemented.") + + def set_placeholder_params( + self, + num_input_units: int, + num_units: int, + reparam: ReparamFactory = ReparamIdentity, + ) -> None: + """Set un-initialized parameter placeholders for the symbolic sum node. + + Args: + num_input_units (int): Number of input units. + num_units (int): Number of output units. + reparam (ReparamFactory): Reparameterization function. + + Raises: + NotImplementedError: If the shared uncollapsed CP is not implemented. + """ + assert self.num_units == num_units + + # Handling different layer types + if self.layer_cls == TuckerLayer: + # number of fold = 1 + self.params = reparam((1, num_input_units, num_input_units, num_units), dim=(1, 2)) + else: # CP layer + arity = self.layer_kwargs["arity"] if ("arity" in self.layer_kwargs) else 2 + assert ( + "fold_mask" not in self.layer_kwargs or self.layer_kwargs["A"] is None + ), "Do not support fold_mask yet" + + if self.layer_cls == CollapsedCPLayer: + self.params_in = reparam((1, arity, num_input_units, num_units), dim=-2, mask=None) + elif self.layer_cls == UncollapsedCPLayer: + self.params_in = reparam((1, arity, num_input_units, 1), dim=-2, mask=None) + self.params_out = reparam((1, 1, num_units), dim=-2, mask=None) + elif self.layer_cls == SharedCPLayer: + self.params_in = reparam((arity, num_input_units, num_units), dim=-2, mask=None) + else: + raise NotImplementedError("The shared uncollapsed CP is not implemented.") + + def __repr__(self) -> str: + """Generate the `repr` string of the node.""" + class_name = self.__class__.__name__ + layer_cls_name = self.layer_cls.__name__ if self.layer_cls else "None" + params_shape = getattr(self.params, "shape", None) if hasattr(self, "params") else None + + params_in_shape = ( + getattr(self.params_in, "shape", None) if hasattr(self, "params_in") else None + ) + params_out_shape = ( + getattr(self.params_out, "shape", None) if hasattr(self, "params_out") else None + ) + + return ( + f"{class_name}:\n" + f"Scope: {repr(self.scope)}\n" + f"Layer Class: {layer_cls_name}\n" + f"Layer KWArgs: {repr(self.layer_kwargs)}\n" + f"Number of Units: {repr(self.num_units)}\n" + f"Parameter Shape: {repr(params_shape)}\n" + f"CP Layer Parameter in Shape: {repr(params_in_shape)}\n" + f"CP Layer Parameter out Shape: {repr(params_out_shape)}\n" + ) + + +class SymbolicProductLayer(SymbolicLayer): + # pylint: disable=too-few-public-methods + """Class representing product nodes in the symbolic graph.""" + + def __init__( + self, scope: Iterable[int], num_units: int, layer_cls: Type[SumProductLayer] + ) -> None: + """Construct the Symbolic Product Node. + + Args: + scope (Iterable[int]): The scope of this node. + num_units (int): Number of input units. + layer_cls (Type[SumProductLayer]): The inner (sum) layer class. + """ + super().__init__(scope) + self.num_units = num_units + self.layer_cls = layer_cls + + def __repr__(self) -> str: + """Generate the `repr` string of the node.""" + class_name = self.__class__.__name__ + layer_cls_name = self.layer_cls.__name__ if self.layer_cls else "None" + + return ( + f"{class_name}:\n" + f"Scope: {repr(self.scope)}\n" + f"Layer Class: {layer_cls_name}\n" + f"Number of Units: {repr(self.num_units)}\n" + ) + + +class SymbolicInputLayer(SymbolicLayer): + """Class representing input nodes in the symbolic graph.""" + + def __init__( + self, + scope: Iterable[int], + num_units: int, + efamily_cls: Type[ExpFamilyLayer], + efamily_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """Construct the Symbolic Input Node. + + Args: + scope (Iterable[int]): The scope of this node. + num_units (int): Number of output units. + efamily_cls (Type[ExpFamilyLayer]): The exponential family class. + efamily_kwargs (Optional[Dict[str, Any]]): The parameters for + the exponential family class. + """ + super().__init__(scope) + self.num_units = num_units + self.efamily_cls = efamily_cls + self.efamily_kwargs = efamily_kwargs + self.params = None + + def set_placeholder_params( + self, + num_channels: int = 1, + num_replicas: int = 1, + reparam: ReparamFactory = ReparamIdentity, + ) -> None: + """Set un-initialized parameter placeholders for the input node. + + Args: + num_channels (int): Number of channels. + num_replicas (int): Number of replicas. + reparam (ReparamFactory): Reparameterization function. + + Raises: + NotImplementedError: Only support Normal, Categorical, and Binomial input layers. + """ + # Handling different exponential family layer types + if self.efamily_cls == NormalLayer: + num_suff_stats = 2 * num_channels + elif self.efamily_cls == CategoricalLayer: + assert "num_categories" in self.efamily_kwargs + num_suff_stats = self.efamily_kwargs["num_categories"] * num_channels + elif self.efamily_cls == BinomialLayer: + num_suff_stats = num_channels + else: + raise NotImplementedError("Only support Normal, Categorical, and Binomial input layers") + + self.params = reparam((1, self.num_units, num_replicas, num_suff_stats), dim=-1) + + def __repr__(self) -> str: + """Generate the `repr` string of the node.""" + class_name = self.__class__.__name__ + efamily_cls_name = self.efamily_cls.__name__ if self.efamily_cls else "None" + params_shape = getattr(self.params, "shape", None) if hasattr(self, "params") else None + + return ( + f"{class_name}:\n" + f"Scope: {repr(self.scope)}\n" + f"Input Exp Family Class: {efamily_cls_name}\n" + f"Layer KWArgs: {repr(self.efamily_kwargs)}\n" + f"Number of Units: {repr(self.num_units)}\n" + f"Parameter Shape: {repr(params_shape)}\n" + ) diff --git a/tests/new/symbolic/__init__.py b/tests/new/symbolic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/new/symbolic/test_symbolic_circuit.py b/tests/new/symbolic/test_symbolic_circuit.py new file mode 100644 index 00000000..bab2eef8 --- /dev/null +++ b/tests/new/symbolic/test_symbolic_circuit.py @@ -0,0 +1,54 @@ +# type: ignore +# pylint: skip-file + +import pytest + +from cirkit.layers.input.exp_family import CategoricalLayer +from cirkit.layers.sum_product import BaseCPLayer +from cirkit.new.symbolic.symbolic_circuit import SymbolicCircuit +from cirkit.new.symbolic.symbolic_layer import ( + SymbolicInputLayer, + SymbolicProductLayer, + SymbolicSumLayer, +) +from cirkit.region_graph import PartitionNode, RegionGraph, RegionNode +from cirkit.region_graph.quad_tree import QuadTree +from cirkit.reparams.leaf import ReparamExp + +efamily_cls = CategoricalLayer +efamily_kwargs = {"num_categories": 256} +layer_cls = BaseCPLayer +layer_kwargs = {"rank": 1} +reparam = ReparamExp + +num_units = 3 + + +def test_symbolic_circuit(): + rg = RegionGraph() + node1 = RegionNode((1,)) + node2 = RegionNode((2,)) + partition = PartitionNode((1, 2)) + region = RegionNode((1, 2)) + rg.add_edge(node1, partition) + rg.add_edge(node2, partition) + rg.add_edge(partition, region) + + circuit = SymbolicCircuit( + rg, layer_cls, efamily_cls, layer_kwargs, efamily_kwargs, reparam, 4, 4, 1, 1 + ) + + assert len(list(circuit.layers)) == 4 + assert any(isinstance(layer, SymbolicInputLayer) for layer in circuit.input_layers) + assert any(isinstance(layer, SymbolicSumLayer) for layer in circuit.output_layers) + + rg_2 = QuadTree(4, 4, struct_decomp=True) + + circuit_2 = SymbolicCircuit( + rg_2, layer_cls, efamily_cls, layer_kwargs, efamily_kwargs, reparam, 4, 4, 1, 1 + ) + + assert len(list(circuit_2.layers)) == 46 + assert len(list(circuit_2.input_layers)) == 16 + assert len(list(circuit_2.output_layers)) == 1 + assert (circuit_2.scope) == frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) diff --git a/tests/new/symbolic/test_symbolic_layer.py b/tests/new/symbolic/test_symbolic_layer.py new file mode 100644 index 00000000..f34e7fbd --- /dev/null +++ b/tests/new/symbolic/test_symbolic_layer.py @@ -0,0 +1,78 @@ +# type: ignore +# pylint: skip-file + +import pytest + +from cirkit.layers.input.exp_family import CategoricalLayer +from cirkit.layers.sum_product import BaseCPLayer, TuckerLayer +from cirkit.new.symbolic.symbolic_layer import ( + SymbolicInputLayer, + SymbolicLayer, + SymbolicProductLayer, + SymbolicSumLayer, +) +from cirkit.reparams.leaf import ReparamExp + + +def test_symbolic_node() -> None: + scope = [1, 2] + node = SymbolicLayer(scope) + assert repr(node) == "SymbolicLayer:\nScope: {1, 2}\n" + + with pytest.raises(AssertionError, match="The scope of a node must be non-empty"): + SymbolicLayer([]) + + +def test_symbolic_sum_node() -> None: + scope = [1, 2] + num_input_units = 2 + num_units = 3 + node = SymbolicSumLayer(scope, num_units, TuckerLayer, {}) + node.set_placeholder_params(num_input_units, num_units, ReparamExp) + assert "SymbolicSumLayer" in repr(node) + assert "Scope: frozenset({1, 2})" in repr(node) + assert "Layer Class: TuckerLayer" in repr(node) + assert "Number of Units: 3" in repr(node) + assert "Parameter Shape: (1, 2, 2, 3)" in repr(node) + assert "CP Layer Parameter in Shape: None" in repr(node) + assert "CP Layer Parameter out Shape: None" in repr(node) + + +def test_symbolic_sum_node_cp() -> None: + scope = [1, 2] + num_input_units = 2 + num_units = 3 + layer_kwargs = {"collapsed": False, "shared": False, "arity": 2} + node = SymbolicSumLayer(scope, num_units, BaseCPLayer, layer_kwargs) + node.set_placeholder_params(num_input_units, num_units, ReparamExp) + assert "SymbolicSumLayer" in repr(node) + assert "Scope: frozenset({1, 2})" in repr(node) + assert "Layer Class: UncollapsedCPLayer" in repr(node) + assert "Number of Units: 3" in repr(node) + assert "Parameter Shape: None" in repr(node) + assert "CP Layer Parameter in Shape: (1, 2, 2, 1)" in repr(node) + assert "CP Layer Parameter out Shape: (1, 1, 3)" in repr(node) + + +def test_symbolic_product_node() -> None: + scope = [1, 2] + num_input_units = 2 + node = SymbolicProductLayer(scope, num_input_units, TuckerLayer) + assert "SymbolicProductLayer" in repr(node) + assert "Scope: frozenset({1, 2})" in repr(node) + assert "Layer Class: TuckerLayer" in repr(node) + assert "Number of Units: 2" in repr(node) + + +def test_symbolic_input_node() -> None: + scope = [1, 2] + num_units = 3 + efamily_kwargs = {"num_categories": 5} + node = SymbolicInputLayer(scope, num_units, CategoricalLayer, efamily_kwargs) + node.set_placeholder_params(1, 1, ReparamExp) + assert "SymbolicInputLayer" in repr(node) + assert "Scope: frozenset({1, 2})" in repr(node) + assert "Input Exp Family Class: CategoricalLayer" in repr(node) + assert "Layer KWArgs: {'num_categories': 5}" in repr(node) + assert "Number of Units: 3" in repr(node) + assert "Parameter Shape: (1, 3, 1, 5)" in repr(node)