Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tucker and Tensor-Train / MPS factorization templates #339

Merged
merged 9 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ def __init__(
num_folds: The number of channels.

Raises:
NotImplementedError: If the arity is not 2.
ValueError: If the number of input units is not the same as the number of output units.
"""
# TODO: generalize kronecker layer as to support a greater arity
if arity != 2:
raise NotImplementedError("Kronecker only implemented for binary product units.")
super().__init__(
num_input_units,
num_input_units**arity,
Expand All @@ -177,18 +173,25 @@ def config(self) -> Mapping[str, Any]:
}

def forward(self, x: Tensor) -> Tensor:
x0 = x[:, 0].unsqueeze(dim=-1) # shape (F, B, Ki, 1).
x1 = x[:, 1].unsqueeze(dim=-2) # shape (F, B, 1, Ki).
# shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
return self.semiring.mul(x0, x1).flatten(start_dim=-2)
# x: (F, H, B, Ki)
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=-1) # (F, B, K, 1).
y1 = x[:, i].unsqueeze(dim=-2) # (F, B, 1, Ki).
# y0: (F, B, K=K * Ki).
y0 = torch.flatten(self.semiring.mul(y0, y1), start_dim=-2)
# y0: (F, B, Ko=Ki ** arity)
return y0

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
x0 = x[:, 0].unsqueeze(dim=3) # (F, C, Ki, 1, num_samples, D)
x1 = x[:, 1].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
# shape (F, C, Ki, Ki, num_samples, D) -> (F, C, Ko=Ki**2, num_samples, D)
x = x0 + x1
return torch.flatten(x, start_dim=2, end_dim=3), None
y0 = x[:, 0]
for i in range(1, x.shape[1]):
y0 = y0.unsqueeze(dim=3) # (F, C, K, 1, num_samples, D)
y1 = x[:, i].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
y0 = torch.flatten(y0 + y1, start_dim=2, end_dim=3)
# y0: (F, C, Ko=Ki ** arity, num_samples, D)
return y0, None


class TorchSumLayer(TorchInnerLayer):
Expand Down
31 changes: 20 additions & 11 deletions cirkit/backend/torch/layers/optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ def __init__(
Args:
num_input_units: The number of input units.
num_output_units: The number of output units.
arity: The arity of the layer, must be 2. Defaults to 2.
arity: The arity of the layer. Defaults to 2.
weight: The weight parameter, which must have shape $(F, K_o, K_i^2)$,
where $F$ is the number of folds, $K_o$ is the number output units,
and $K_i$ is the number of input units.

Raises:
NotImplementedError: If the arity is not equal to 2. Future versions of cirkit
will support Tucker layers having arity greter than 2.
ValueError: If the arity is less than two.
ValueError: If the number of input and output units are incompatible with the
shape of the weight parameter.
"""
# TODO: Generalize Tucker layer to have any arity greater or equal 2
if arity != 2:
raise NotImplementedError("The Tucker layer is only implemented with arity=2")
if arity < 2:
raise ValueError("The arity should be at least 2")
super().__init__(
num_input_units, num_output_units, arity=arity, semiring=semiring, num_folds=num_folds
)
Expand All @@ -52,6 +50,16 @@ def __init__(
f"{weight.num_folds} and {weight.shape}, respectively"
)
self.weight = weight
# Construct the einsum expression that the Tucker layer computes
# For instance, if arity == 2 then we have that
# self._einsum = ((0, 1, 2), (0, 1, 3), (0, 1, 4, 2, 3), (0, 1, 4))
# Also, if arity == 3 then we have that
# self._einsum = ((0, 1, 2), (0, 1, 3), (0, 1, 4), (0, 5, 2, 3, 4), (0, 1, 5))
self._einsum = (
tuple((0, 1, i + 2) for i in range(arity))
+ ((0, arity + 2, *tuple(i + 2 for i in range(arity))),)
+ ((0, 1, arity + 2),)
)

def _valid_weight_shape(self, w: TorchParameter) -> bool:
if w.num_folds != self.num_folds:
Expand All @@ -60,7 +68,7 @@ def _valid_weight_shape(self, w: TorchParameter) -> bool:

@property
def _weight_shape(self) -> tuple[int, ...]:
return self.num_output_units, self.num_input_units * self.num_input_units
return self.num_output_units, self.num_input_units**self.arity

@property
def config(self) -> Mapping[str, Any]:
Expand All @@ -75,14 +83,15 @@ def params(self) -> Mapping[str, TorchParameter]:
return {"weight": self.weight}

def forward(self, x: Tensor) -> Tensor:
# weight: (F, Ko, Ki * Ki) -> (F, Ko, Ki, Ki)
# x: (F, H, B, Ki)
# weight: (F, Ko, Ki ** arity) -> (F, Ko, Ki, ..., Ki)
weight = self.weight().view(
-1, self.num_output_units, self.num_input_units, self.num_input_units
-1, self.num_output_units, *(self.num_input_units for _ in range(self.arity))
)
return self.semiring.einsum(
"fbi,fbj,foij->fbo",
self._einsum,
inputs=x.unbind(dim=1),
operands=(weight,),
inputs=(x[:, 0], x[:, 1]),
dim=-1,
keepdim=True,
)
Expand Down
40 changes: 31 additions & 9 deletions cirkit/backend/torch/semiring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import itertools
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
from typing import ClassVar, TypeVar, cast
Expand Down Expand Up @@ -153,19 +154,40 @@ def __new__(cls) -> "SemiringImpl":
@classmethod
def einsum(
cls,
equation: str,
equation: str | Sequence[Sequence[int, ...], ...],
*,
inputs: tuple[Tensor, ...],
operands: tuple[Tensor, ...],
inputs: tuple[Tensor, ...] | None = None,
operands: tuple[Tensor, ...] | None = None,
dim: int,
keepdim: bool,
) -> Tensor:
operands = tuple(cls.cast(opd) for opd in operands)

def _einsum_func(*xs: Tensor) -> Tensor:
return torch.einsum(equation, *xs, *operands)

return cls.apply_reduce(_einsum_func, *inputs, dim=dim, keepdim=keepdim)
# TODO (LL): We need to remove this super general yet extremely complicated and hard
# to maintain einsum definition, which depends on the semiring. A future version of the
# compiler in cirkit will be able to emit pytorch code for every layer at compile time
match equation:
case str():

def _einsum_str_func(*xs: Tensor) -> Tensor:
opds = tuple(cls.cast(opd) for opd in operands)
return torch.einsum(equation, *xs, *opds)

einsum_func = _einsum_str_func
case Sequence():

def _einsum_seq_func(*xs: Tensor) -> Tensor:
opds = tuple(cls.cast(opd) for opd in operands)
einsum_args = tuple(
itertools.chain.from_iterable(zip(xs + opds, equation[:-1]))
)
return torch.einsum(*einsum_args, equation[-1])

einsum_func = _einsum_seq_func
case _:
raise ValueError(
"The einsum expression must be either a string or a sequence of int sequences"
)

return cls.apply_reduce(einsum_func, *inputs, dim=dim, keepdim=keepdim)

# NOTE: Subclasses should not touch any of the above final static methods but should implement
# all the following abstract class methods, and subclasses should be @final.
Expand Down
Loading
Loading