Skip to content

Commit

Permalink
Merge pull request #170 from april-tools/new_integrate
Browse files Browse the repository at this point in the history
Add integration and partition func for new circuit
  • Loading branch information
lkct authored Dec 13, 2023
2 parents a27aa9c + a71a2e7 commit 5da2028
Show file tree
Hide file tree
Showing 24 changed files with 406 additions and 38 deletions.
5 changes: 2 additions & 3 deletions cirkit/new/layers/inner/sum/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def __init__(
)

self.params = reparam
self.params.materialize((num_output_units, num_input_units), dim=1)

self.reset_parameters()
if self.params.materialize((num_output_units, num_input_units), dim=1):
self.reset_parameters() # Only reset if newly materialized.

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("oi,...i->...o", self.params(), x) # shape (*B, Ki) -> (*B, Ko).
Expand Down
5 changes: 2 additions & 3 deletions cirkit/new/layers/inner/sum/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def __init__(
)

self.params = reparam
self.params.materialize((num_output_units, arity), dim=1)

self.reset_parameters()
if self.params.materialize((num_output_units, arity), dim=1):
self.reset_parameters() # Only reset if newly materialized.

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("kh,h...k->...k", self.params(), x) # shape (H, *B, K) -> (*B, K).
Expand Down
3 changes: 2 additions & 1 deletion cirkit/new/layers/inner/sum_product/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
return self.sum(self.prod(x)) # shape (H, *B, K) -> (*B, K) -> (*B, K).
# shape (H, *B, K) -> (*B, K) -> (H, *B, K) -> (*B, K).
return self.sum(self.prod(x).unsqueeze(dim=0))


# TODO: Uncollapsed?
7 changes: 4 additions & 3 deletions cirkit/new/layers/inner/sum_product/tucker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def __init__(
)

self.params = reparam
self.params.materialize((num_output_units, num_input_units, num_input_units), dim=(1, 2))

self.reset_parameters()
if self.params.materialize(
(num_output_units, num_input_units, num_input_units), dim=(1, 2)
):
self.reset_parameters() # Only reset if newly materialized.

@classmethod
def _infer_num_prod_units(cls, num_input_units: int, arity: int = 2) -> int:
Expand Down
1 change: 1 addition & 0 deletions cirkit/new/layers/input/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .constant import ConstantLayer as ConstantLayer
from .exp_family import CategoricalLayer as CategoricalLayer
from .exp_family import ExpFamilyLayer as ExpFamilyLayer
from .exp_family import NormalLayer as NormalLayer
Expand Down
79 changes: 79 additions & 0 deletions cirkit/new/layers/input/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Dict, Literal, Optional, Tuple, Type

import torch
from torch import Tensor

from cirkit.new.layers.input.input import InputLayer
from cirkit.new.reparams import Reparameterization


class ConstantLayer(InputLayer):
"""The constant input layer, with no parameters."""

# Disable: This __init__ is designed to have these arguments.
def __init__( # pylint: disable=too-many-arguments
self,
*,
num_input_units: int,
num_output_units: int,
arity: Literal[1] = 1,
reparam: Optional[Reparameterization] = None,
const_value: float,
) -> None:
"""Init class.
Args:
num_input_units (int): The number of input units, i.e. number of channels for variables.
num_output_units (int): The number of output units.
arity (Literal[1], optional): The arity of the layer, must be 1. Defaults to 1.
reparam (Optional[Reparameterization], optional): Ignored. This layer has no params. \
Defaults to None.
const_value (float): The constant value, in linear space.
"""
super().__init__(
num_input_units=num_input_units,
num_output_units=num_output_units,
arity=arity,
reparam=reparam,
)

self.const_value = const_value

def reset_parameters(self) -> None:
"""Do nothing as the product layers do not have parameters."""

def forward(self, x: Tensor) -> Tensor:
"""Run forward pass.
Args:
x (Tensor): The input to this layer, shape (H, *B, K).
Returns:
Tensor: The output of this layer, shape (*B, K).
"""
return (
self.comp_space.from_linear(torch.tensor(self.const_value))
.to(x)
.expand(*x.shape[1:-1], self.num_output_units)
)

# Disable/Ignore: It's define with this signature. # TODO: consider TypedDict?
@classmethod
def get_integral( # type: ignore[override] # pylint: disable=arguments-differ
cls, const_value: float
) -> Tuple[Type[InputLayer], Dict[str, float]]:
"""Get the config to construct the integral of the input layer.
Args:
const_value (float): The const_value in __init__.
Raises:
ValueError: When const_value != 0, in which case the integral is infinity.
Returns:
Tuple[Type[InputLayer], Dict[str, float]]: The class of the integral layer and its \
additional kwargs.
"""
if const_value:
raise ValueError("The integral of ConstantLayer with const_value != 0 is infinity.")
return ConstantLayer, {"const_value": 0.0}
25 changes: 21 additions & 4 deletions cirkit/new/layers/input/exp_family/exp_family.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import functools
from abc import abstractmethod
from typing import Literal, Tuple
from typing import Any, Dict, Literal, Tuple, Type

import torch
from torch import Tensor, nn

from cirkit.new.layers.input.constant import ConstantLayer
from cirkit.new.layers.input.input import InputLayer
from cirkit.new.reparams import Reparameterization

Expand Down Expand Up @@ -56,9 +57,8 @@ def __init__(
)

self.params = reparam
self.params.materialize((arity, num_output_units, *self.suff_stats_shape), dim=-1)

self.reset_parameters()
if self.params.materialize((arity, num_output_units, *self.suff_stats_shape), dim=-1):
self.reset_parameters() # Only reset if newly materialized.

@torch.no_grad()
def reset_parameters(self) -> None:
Expand Down Expand Up @@ -94,6 +94,23 @@ def forward(self, x: Tensor) -> Tensor:
) # shape (*B, H, K) -> (*B, K).
return self.comp_space.from_log(log_p)

@classmethod
def get_integral( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
cls, **kwargs: Any
) -> Tuple[Type[InputLayer], Dict[str, float]]:
"""Get the config to construct the integral of the input layer.
Args:
**kwargs (Any): The additional kwargs for this layer,
Returns:
Tuple[Type[InputLayer], Dict[str, float]]: The class of the integral layer and its \
additional kwargs.
"""
# TODO: for unnormalized EF, should be ParameterizedConstantLayer
# We have already normalized with log_partition in forward().
return ConstantLayer, {"const_value": 1.0}

@abstractmethod
def sufficient_stats(self, x: Tensor) -> Tensor:
"""Calculate sufficient statistics T from input x.
Expand Down
18 changes: 17 additions & 1 deletion cirkit/new/layers/input/input.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, Optional
from abc import abstractmethod
from typing import Any, Dict, Literal, Optional, Tuple, Type

from cirkit.new.layers.layer import Layer
from cirkit.new.reparams import Reparameterization
Expand Down Expand Up @@ -40,3 +41,18 @@ def __init__(
arity=arity,
reparam=reparam,
)

@classmethod
@abstractmethod
def get_integral( # type: ignore[misc] # Ignore: Unavoidable for kwargs.
cls, **kwargs: Any
) -> Tuple[Type["InputLayer"], Dict[str, Any]]:
"""Get the config to construct the integral of the input layer.
Args:
**kwargs (Any): The additional kwargs for this layer,
Returns:
Tuple[Type[InputLayer], Dict[str, Any]]: The class of the integral layer and its \
additional kwargs.
"""
1 change: 1 addition & 0 deletions cirkit/new/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import functional as functional
from .tensorized_circuit import TensorizedCircuit as TensorizedCircuit
20 changes: 20 additions & 0 deletions cirkit/new/model/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import TYPE_CHECKING, Iterable, Optional

if TYPE_CHECKING: # Only imported for static type checking but not runtime, to avoid cyclic import.
from cirkit.new.model.tensorized_circuit import TensorizedCircuit


def integrate(
self: "TensorizedCircuit", *, scope: Optional[Iterable[int]] = None
) -> "TensorizedCircuit":
"""Integrate the circuit over the variables specified by the given scope.
Args:
self (TensorizedCircuit): The circuit to integrate.
scope (Optional[Iterable[int]], optional): The scope over which to integrate, or None for \
the whole scope of the circuit. Defaults to None.
Returns:
TensorizedCircuit: The circuit giving the integral.
"""
return self.__class__(self.symb_circuit.integrate(scope=scope), num_channels=self.num_channels)
19 changes: 19 additions & 0 deletions cirkit/new/model/tensorized_circuit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from functools import cached_property
from typing import Dict, Optional

import torch
from torch import Tensor, nn

from cirkit.new.layers import InputLayer, Layer, SumProductLayer
from cirkit.new.model.functional import integrate
from cirkit.new.symbolic import (
SymbolicLayer,
SymbolicProductLayer,
Expand Down Expand Up @@ -33,6 +35,8 @@ def __init__(self, symb_circuit: SymbolicTensorizedCircuit, *, num_channels: int
self.symb_circuit = symb_circuit
self.scope = symb_circuit.scope
self.num_vars = symb_circuit.num_vars
self.num_channels = num_channels
self.num_classes = symb_circuit.num_classes

self.layers = nn.ModuleList() # Automatic layer registry, also publically available.

Expand Down Expand Up @@ -140,3 +144,18 @@ def forward(self, x: Tensor) -> Tensor:
return torch.stack(
[layer_outputs[layer_out] for layer_out in self.symb_circuit.output_layers], dim=-2
) # shape num_out * (*B, K) -> (*B, num_out, num_cls=K).

integrate = integrate

# Use cached_property to lazily construct the circuit for partition function.
@cached_property
def partition_circuit(self) -> "TensorizedCircuit":
"""The circuit calculating the partition function."""
return self.integrate(scope=self.scope)

@property
def partition_func(self) -> Tensor: # TODO: is this the correct shape?
"""The partition function of the circuit, shape (num_out, num_cls)."""
# For partition_circuit, the input is irrelevant, so just use zeros.
# shape (*B, D, C) -> (*B, num_out, num_cls) where *B = ().
return self.partition_circuit(torch.zeros((self.num_vars, self.num_channels)))
14 changes: 12 additions & 2 deletions cirkit/new/reparams/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,13 @@ def materialize(
dim: Union[int, Sequence[int]],
mask: Optional[Tensor] = None,
log_mask: Optional[Tensor] = None,
) -> None:
) -> bool:
"""Materialize the internal parameter tensors with given shape.
If it is already materialized, False will be returned to indicate no materialization. \
However, a second call to materialize must give the same config, so that the underlying \
params can indeed be reused.
The initial value of the parameter after materialization is not guaranteed, and explicit \
initialization is expected.
Expand All @@ -89,8 +93,13 @@ def materialize(
log_mask (Optional[Tensor], optional): The -inf/0 mask for normalization positions. \
None for no masking. The shape must be broadcastable to shape if not None. \
Defaults to None.
Returns:
bool: Whether the materialization is done.
"""
super().materialize(shape, dim=dim, mask=mask, log_mask=log_mask)
if not super().materialize(shape, dim=dim, mask=mask, log_mask=log_mask):
return False

for reparam in self.reparams:
if not reparam.is_materialized:
# NOTE: Passing shape to all children reparams may not be always wanted. In that
Expand All @@ -99,6 +108,7 @@ def materialize(
reparam.materialize(shape, dim=dim, mask=mask, log_mask=log_mask)

assert self().shape == self.shape, "The actual shape does not match the given one."
return True

def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None:
"""Initialize the internal parameter tensors with the given initializer.
Expand Down
14 changes: 12 additions & 2 deletions cirkit/new/reparams/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,28 @@ def device(self) -> torch.device:
"""The device of the output parameter."""
return self.param.device

def materialize(self, shape: Sequence[int], /, **_kwargs: Unpack[MaterializeKwargs]) -> None:
def materialize(self, shape: Sequence[int], /, **_kwargs: Unpack[MaterializeKwargs]) -> bool:
"""Materialize the internal parameter tensors with given shape.
If it is already materialized, False will be returned to indicate no materialization. \
However, a second call to materialize must give the same config, so that the underlying \
params can indeed be reused.
The initial value of the parameter after materialization is not guaranteed, and explicit \
initialization is expected.
Args:
shape (Sequence[int]): The shape of the output parameter.
**_kwargs (Unpack[MaterializeKwargs]): Unused. See Reparameterization.materialize().
Returns:
bool: Whether the materialization is done.
"""
super().materialize(shape, dim=())
if not super().materialize(shape, dim=()):
return False
# Not materialized before, i.e., self.param is still nn.UninitializedParameter.
self.param.materialize(self.shape)
return True

def initialize(self, initializer_: Callable[[Tensor], Tensor]) -> None:
"""Initialize the internal parameter tensors with the given initializer.
Expand Down
Loading

0 comments on commit 5da2028

Please sign in to comment.