From f394cf307ca761f3c9dfd091d59976297af893dd Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 10:15:59 -0400 Subject: [PATCH 01/10] [ADD] Utility function to compute input sizes of a convolution --- einconv/expressions/convNd_unfold.py | 2 +- einconv/functionals/unfold.py | 1 + einconv/modules/unfold.py | 1 + einconv/utils.py | 45 ++++++++++++++++++++++++++++ makefile | 11 +++++++ test/expressions/test_utils.py | 1 - test/simplifications/test_opt.py | 1 - test/test_utils.py | 34 +++++++++++++++++++-- test/utils_cases.py | 31 +++++++++++++++++++ test/utils_jax.py | 8 +++-- 10 files changed, 126 insertions(+), 9 deletions(-) diff --git a/einconv/expressions/convNd_unfold.py b/einconv/expressions/convNd_unfold.py index b401dac..5f531ca 100644 --- a/einconv/expressions/convNd_unfold.py +++ b/einconv/expressions/convNd_unfold.py @@ -34,7 +34,7 @@ def einsum_expression( Returns: Einsum equation Einsum operands in order input, patterns - Output shape: ``[batch_size, in_channels, tot_output_size]`` + Output shape: ``[batch_size, in_channels * tot_kernel_size, tot_output_size]`` """ N = x.dim() - 2 diff --git a/einconv/functionals/unfold.py b/einconv/functionals/unfold.py index 2347c79..63f0edd 100644 --- a/einconv/functionals/unfold.py +++ b/einconv/functionals/unfold.py @@ -1,4 +1,5 @@ """Equivalent of ``torch.functional.unfold`` for arbitrary dimensions.""" + from typing import Tuple, Union from torch import Tensor, einsum diff --git a/einconv/modules/unfold.py b/einconv/modules/unfold.py index 9b91524..bcb0947 100644 --- a/einconv/modules/unfold.py +++ b/einconv/modules/unfold.py @@ -1,4 +1,5 @@ """PyTorch equivalent of ``nn.Unfold`` implemented as einsum.""" + from typing import Tuple, Union from torch import Tensor diff --git a/einconv/utils.py b/einconv/utils.py index 91004e7..5f5bc52 100644 --- a/einconv/utils.py +++ b/einconv/utils.py @@ -158,3 +158,48 @@ def get_conv_output_size( kernel_span = kernel_size + (kernel_size - 1) * (dilation - 1) return 1 + floor((input_size + padding_left + padding_right - kernel_span) / stride) + + +def get_conv_input_size( + output_size: int, + kernel_size: int, + stride: int, + padding: int, + output_padding: int, + dilation: int, +) -> int: + """Compute the input size of a convolution. + + Args: + output_size: Number of output pixels along dimension. + kernel_size: Kernel size of the convolution. + stride: Stride of the convolution. + padding: Padding of the convolution. + output_padding: Number of unused pixels at the end. + dilation: Dilation of the kernel. + + Returns: + Convolution input dimension. + + Raises: + ValueError: If the output size re-computed with the determined input size does + not match. This indicates that `output_padding` was set too large. + """ + input_size = ( + (output_size - 1) * stride + - 2 * padding + + dilation * (kernel_size - 1) + + 1 + + output_padding + ) + + output_size_recomputed = get_conv_output_size( + input_size, kernel_size, stride, padding, dilation + ) + if output_size_recomputed != output_size: + raise ValueError( + f"Output size {output_size} does not match re-computed output size " + f"{output_size_recomputed}." + ) + + return input_size diff --git a/makefile b/makefile index 69a85c0..59061ed 100644 --- a/makefile +++ b/makefile @@ -5,6 +5,8 @@ help: @echo " Install einconv and dependencies" @echo "uninstall" @echo " Unstall einconv" + @echo "lint" + @echo " Run all linting actions" @echo "install-dev" @echo " Install all development tools" @echo "install-test" @@ -102,3 +104,12 @@ pydocstyle-check: conda-env: @conda env create --file .conda_env.yml + +.PHONY: lint + +lint: + make black-check + make isort-check + make flake8 + make darglint-check + make pydocstyle-check diff --git a/test/expressions/test_utils.py b/test/expressions/test_utils.py index c0bbfaa..241590b 100644 --- a/test/expressions/test_utils.py +++ b/test/expressions/test_utils.py @@ -1,6 +1,5 @@ """Tests ``einconv.expressions.utils``.""" - from einconv.expressions.utils import get_letters, translate_to_torch diff --git a/test/simplifications/test_opt.py b/test/simplifications/test_opt.py index 00c9701..59ad418 100644 --- a/test/simplifications/test_opt.py +++ b/test/simplifications/test_opt.py @@ -1,6 +1,5 @@ """Test ``einconv.simplifications.opt``.""" - from torch import allclose, einsum, float32, manual_seed, rand from einconv import index_pattern diff --git a/test/test_utils.py b/test/test_utils.py index 559fd48..50eaf75 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,13 +1,18 @@ """Tests for ``einconv/utils``'s.""" -from test.utils_cases import OUTPUT_SIZE_CASES, OUTPUT_SIZE_IDS +from test.utils_cases import ( + INPUT_SIZE_CASES, + INPUT_SIZE_IDS, + OUTPUT_SIZE_CASES, + OUTPUT_SIZE_IDS, +) from typing import Dict from pytest import mark from torch import zeros -from torch.nn.functional import conv1d +from torch.nn.functional import conv1d, conv_transpose1d -from einconv.utils import get_conv_output_size +from einconv.utils import get_conv_input_size, get_conv_output_size @mark.parametrize("case", OUTPUT_SIZE_CASES, ids=OUTPUT_SIZE_IDS) @@ -30,3 +35,26 @@ def test_get_conv_output_size(case: Dict): output_size = get_conv_output_size(**case) assert output_size_torch == output_size + + +@mark.parametrize("case", INPUT_SIZE_CASES, ids=INPUT_SIZE_IDS) +def test_get_conv_input_size(case: Dict): + """Test input size computation of a convolution. + + Args: + case: Dictionary describing the test case. + """ + input_torch = conv_transpose1d( + zeros(1, 1, case["output_size"]), # [N, C_out, O] + zeros(1, 1, case["kernel_size"]), # [C_out, C_in, K] + bias=None, + stride=case["stride"], + padding=case["padding"], + output_padding=case["output_padding"], + dilation=case["dilation"], + ) + input_size_torch = input_torch.shape[2] + + input_size = get_conv_input_size(**case) + + assert input_size_torch == input_size diff --git a/test/utils_cases.py b/test/utils_cases.py index 921e360..b36c906 100644 --- a/test/utils_cases.py +++ b/test/utils_cases.py @@ -12,3 +12,34 @@ ] OUTPUT_SIZE_IDS = [make_id(case) for case in OUTPUT_SIZE_CASES] + +INPUT_SIZE_CASES = [ + # default hyperparameters + { + "output_size": 10, + "kernel_size": 3, + "stride": 1, + "padding": 1, + "dilation": 1, + "output_padding": 0, + }, + # nontrivial hyperparameters + { + "output_size": 11, + "kernel_size": 3, + "stride": 2, + "padding": 10, + "dilation": 2, + "output_padding": 1, + }, + # nontrivial hyperparameters (non-overlapping patches) + { + "output_size": 11, + "kernel_size": 4, + "stride": 2, + "padding": 10, + "dilation": 5, + "output_padding": 0, + }, +] +INPUT_SIZE_IDS = [make_id(case) for case in INPUT_SIZE_CASES] diff --git a/test/utils_jax.py b/test/utils_jax.py index 8d153b0..0ec25ba 100644 --- a/test/utils_jax.py +++ b/test/utils_jax.py @@ -104,9 +104,11 @@ def ConvNd_jax(input: torch.Tensor) -> torch.Tensor: return jax_convNd( input, einconv_module.weight.data.clone(), - bias=None - if einconv_module.bias is None - else einconv_module.bias.data.clone(), + bias=( + None + if einconv_module.bias is None + else einconv_module.bias.data.clone() + ), stride=einconv_module.stride, padding=einconv_module.padding, dilation=einconv_module.dilation, From 01d90d3c01ea28dac6b8dde46d73aa20aebe73fc Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 11:39:13 -0400 Subject: [PATCH 02/10] [ADD] Einsum expression and functional for transpose input unfolding --- docs/api/expressions.md | 1 + docs/api/functionals.md | 1 + .../expressions/conv_transposeNd_unfold.py | 120 ++++++++++++++++++ einconv/functionals/__init__.py | 2 + einconv/functionals/unfold_transpose.py | 69 ++++++++++ einconv/simplifications/opt.py | 6 +- test/functionals/test_unfold_transpose.py | 72 +++++++++++ test/functionals/transpose_unfold_cases.py | 66 ++++++++++ 8 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 einconv/expressions/conv_transposeNd_unfold.py create mode 100644 einconv/functionals/unfold_transpose.py create mode 100644 test/functionals/test_unfold_transpose.py create mode 100644 test/functionals/transpose_unfold_cases.py diff --git a/docs/api/expressions.md b/docs/api/expressions.md index cff415b..7a753d3 100644 --- a/docs/api/expressions.md +++ b/docs/api/expressions.md @@ -4,3 +4,4 @@ :::einconv.expressions.convNd_unfold :::einconv.expressions.convNd_kfc :::einconv.expressions.convNd_kfac_reduce +:::einconv.expressions.conv_transposeNd_unfold diff --git a/docs/api/functionals.md b/docs/api/functionals.md index 0203384..c00120a 100644 --- a/docs/api/functionals.md +++ b/docs/api/functionals.md @@ -1,2 +1,3 @@ ::: einconv.functionals.convNd ::: einconv.functionals.unfoldNd +::: einconv.functionals.unfoldNd_transpose diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py new file mode 100644 index 0000000..beaf28a --- /dev/null +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -0,0 +1,120 @@ +from typing import List, Tuple, Union + +from torch import Tensor + +import einconv +from einconv.expressions.utils import create_conv_index_patterns, translate_to_torch +from einconv.utils import _tuple, get_conv_input_size + + +def einsum_expression( + x: Tensor, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + simplify: bool = True, +) -> Tuple[str, List[Tensor], Tuple[int, ...]]: + """Generate einsum expression to unfold the input of a transpose convolution. + + The unfolded input for a transpose convolution flattens and concatenates all + elements of the input tensor into a matrix such that the transpose convolution + can be written as matrix multiplication between the unfolded input and the + matricized kernel. + + We will use the associated convolution's hyper-parameters to describe all arguments. + Consider an `N`d convolution which maps an input tensor of shape + `[batch_size, in_channels, *input_sizes]` to an output tensor of shape + `[batch_size, out_channels, *output_sizes]`. The transpose convolution's input + has shape `[batch_size, out_channels, *output_sizes]` and the output has shape + `[batch_size, in_channels, *input_sizes]`. + + Args: + x: Input to the `N`d transpose convolution. Has shape + `[batch_size, in_channels, *input_sizes]` where `len(input_sizes) == N`. + kernel_size: Kernel dimensions. Can be a single integer (shared along all + spatial dimensions), or an `N`-tuple of integers. + stride: Stride of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + padding: Padding of the associated convolution. Can be a single integer (shared + along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. + output_padding: The associated convolution's number of unused pixels at the end + of a spatial dimension. This is required to resolve the ambiguity that a + convolution can produce the same output shape for different input shapes if + it has non-unit stride. Can be a single integer (shared along all spatial + dimensions), or an `N`-tuple of integers. Default: `0`. + dilation: Dilation of the associated convolution. Can be a single integer + (shared along all spatial dimensions), or an `N`-tuple of integers. + Default: `1`. + simplify: Whether to simplify the einsum expression. Default: `True`. + + Returns: + Einsum equation + Einsum operands in order input, patterns + Output shape: `[batch_size, out_channels * tot_kernel_size, tot_input_size]` + """ + N = x.dim() - 2 + + # construct einsum equation + x_str = "n c_out " + " ".join([f"o{i}" for i in range(N)]) + pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)] + lhs = ",".join([x_str, *pattern_strs]) + + rhs = ( + "n c_out " + + " ".join([f"k{i}" for i in range(N)]) + + " " + + " ".join([f"i{i}" for i in range(N)]) + ) + + equation = "->".join([lhs, rhs]) + equation = translate_to_torch(equation) + + # compute input sizes + t_output_size = x.shape[2:] + t_stride = _tuple(stride, N) + t_kernel_size = _tuple(kernel_size, N) + t_padding = _tuple(padding, N) + t_dilation = _tuple(dilation, N) + t_output_padding = _tuple(output_padding, N) + t_input_size = tuple( + get_conv_input_size( + output_size, kernel_size, stride, padding, output_padding, dilation + ) + for output_size, kernel_size, stride, padding, output_padding, dilation in zip( + t_output_size, + t_kernel_size, + t_stride, + t_padding, + t_output_padding, + t_dilation, + ) + ) + + # construct einsum operands + patterns = create_conv_index_patterns( + N, + input_size=t_input_size, + kernel_size=t_kernel_size, + stride=t_stride, + padding=t_padding, + dilation=t_dilation, + device=x.device, + dtype=x.dtype, + ) + operands = [x, *patterns] + + # construct shape + input_tot_size = int(Tensor(t_input_size).int().prod()) + kernel_tot_size = int(Tensor(t_kernel_size).int().prod()) + batch_size, out_channels = x.shape[:2] + shape = (batch_size, out_channels * kernel_tot_size, input_tot_size) + + print(equation) + print([op.shape for op in operands]) + + if simplify: + equation, operands = einconv.simplify(equation, operands) + + return equation, operands, shape diff --git a/einconv/functionals/__init__.py b/einconv/functionals/__init__.py index a02c831..c98abff 100644 --- a/einconv/functionals/__init__.py +++ b/einconv/functionals/__init__.py @@ -2,8 +2,10 @@ from einconv.functionals.conv import convNd from einconv.functionals.unfold import unfoldNd +from einconv.functionals.unfold_transpose import unfoldNd_transpose __all__ = [ "convNd", "unfoldNd", + "unfoldNd_transpose", ] diff --git a/einconv/functionals/unfold_transpose.py b/einconv/functionals/unfold_transpose.py new file mode 100644 index 0000000..f1cc136 --- /dev/null +++ b/einconv/functionals/unfold_transpose.py @@ -0,0 +1,69 @@ +"""Implements a `torch.nn.functional` for input unfolding of a transpose convolution.""" + +from typing import Tuple, Union + +from torch import Tensor, einsum + +from einconv.expressions import conv_transposeNd_unfold + + +def unfoldNd_transpose( + x: Tensor, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + dilation: Union[int, Tuple[int, ...]] = 1, + simplify: bool = True, +) -> Tensor: + """Torch functional for N-dimensional input unfolding of a transpose convolution. + + Extracts elements that overlap with the transpose convolution's kernel at a time + into a matrix. This matrix can then be used to formulate transpose convolution as + matrix multiplication between the unfolded input and the matricized kernel. + + This function uses `einsum` under the hood and does not have a PyTorch equivalent. + + We will use the hyper-parameters of an `N`d convolution which maps an input of shape + `[batch_size, in_channels, *input_sizes]` to an output of shape + `[batch_size, out_channels, *output_sizes]`. The transpose convolution's input has + shape `[batch_size, out_channels, *output_sizes]` and the output has shape + `[batch_size, in_channels, *input_sizes]`. + + Args: + x: Input to the `N`d transpose convolution. Has shape + `[batch_size, in_channels, *input_sizes]` where `len(input_sizes) == N`. + kernel_size: Kernel dimensions. Can be a single integer (shared along all + spatial dimensions), or an `N`-tuple of integers. + stride: Stride of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + padding: Padding of the associated convolution. Can be a single integer (shared + along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. + output_padding: The associated convolution's number of unused pixels at the end + of a spatial dimension. This is required to resolve the ambiguity that a + convolution can produce the same output shape for different input shapes if + it has non-unit stride. Can be a single integer (shared along all spatial + dimensions), or an `N`-tuple of integers. Default: `0`. + dilation: Dilation of the associated convolution. Can be a single integer + (shared along all spatial dimensions), or an `N`-tuple of integers. + Default: `1`. + simplify: Whether to simplify the einsum equation before evaluating it. + Default: `True`. + + Returns: + Unfolded input tensor of shape \ + shape `[batch_size, in_channels * tot_kernel_size, tot_input_size]` where \ + `tot_kernel_size`, `tot_input_size` are the total number of kernel elements and + spatial input elements to the associated convolution. In `einops` notation, the + index structure is `n (c_out k1 k2 ...) (i1 i2 ...)`. + """ + equation, operands, shape = conv_transposeNd_unfold.einsum_expression( + x, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + output_padding=output_padding, + simplify=simplify, + ) + return einsum(equation, *operands).reshape(shape) diff --git a/einconv/simplifications/opt.py b/einconv/simplifications/opt.py index 6d4bae5..1e9f7fe 100644 --- a/einconv/simplifications/opt.py +++ b/einconv/simplifications/opt.py @@ -151,7 +151,11 @@ def squeeze(self) -> None: for pos in range(len(self.input_indices)): for idx in maybe_squeeze: - if idx in self.input_indices[pos] and len(self.input_indices[pos]) > 1: + if ( + idx in self.input_indices[pos] + and len(self.input_indices[pos]) > 1 + and not isinstance(self.operands[pos], Identity) + ): idx_pos = self.input_indices[pos].index(idx) self.input_indices[pos] = self.input_indices[pos].replace(idx, "") diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py new file mode 100644 index 0000000..4148fe9 --- /dev/null +++ b/test/functionals/test_unfold_transpose.py @@ -0,0 +1,72 @@ +"""Contains tests for ``einconv/expressions/conv_transposeNd_unfold.py``.""" + +from test.functionals.transpose_unfold_cases import ( + TRANSPOSE_UNFOLD_1D_CASES, + TRANSPOSE_UNFOLD_1D_IDS, + TRANSPOSE_UNFOLD_2D_CASES, + TRANSPOSE_UNFOLD_2D_IDS, + TRANSPOSE_UNFOLD_3D_CASES, + TRANSPOSE_UNFOLD_3D_IDS, +) +from test.utils import DEVICE_IDS, DEVICES, SIMPLIFIES, SIMPLIFY_IDS, report_nonclose +from typing import Dict + +from einops import einsum, rearrange +from pytest import mark +from torch import ( + conv_transpose1d, + conv_transpose2d, + conv_transpose3d, + device, + manual_seed, + rand, +) + +from einconv.functionals import unfoldNd_transpose +from einconv.utils import _tuple + + +@mark.parametrize("simplify", SIMPLIFIES, ids=SIMPLIFY_IDS) +@mark.parametrize( + "case", + TRANSPOSE_UNFOLD_1D_CASES + TRANSPOSE_UNFOLD_2D_CASES + TRANSPOSE_UNFOLD_3D_CASES, + ids=TRANSPOSE_UNFOLD_1D_IDS + TRANSPOSE_UNFOLD_2D_IDS + TRANSPOSE_UNFOLD_3D_IDS, +) +@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) +def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): + """Compare transpose convolution via matrix-multiplication with built-in one. + + Args: + case: Dictionary describing the test case. + dev: Device to execute the test on. + simplify: Whether to use a simplified einsum expression. + """ + seed = case["seed"] + input_fn = case["input_fn"] + kernel_size = case["kernel_size"] + kwargs = case["kwargs"] + + manual_seed(seed) + inputs = input_fn().to(dev) + N = inputs.ndim - 2 + batch_size, C_out = inputs.shape[:2] + C_in, G = 3, 1 # hard-coded for now + t_kernel_size = _tuple(kernel_size, N) + weight = rand(C_out, C_in // G, *t_kernel_size).to(dev) + + # ground truth: PyTorch's built-in transpose convolution + conv_func = {1: conv_transpose1d, 2: conv_transpose2d, 3: conv_transpose3d}[N] + result = conv_func(inputs, weight, **kwargs) + + # transpose convolution via matrix-multiplication perspective using unfolded input + # and matricized kernel + inputs_unfolded = unfoldNd_transpose( + inputs, kernel_size, **kwargs, simplify=simplify + ) + k_s = " ".join([f"k{i}" for i in range(N)]) + weight_mat = rearrange(weight, f"c_out c_in {k_s} -> c_in (c_out {k_s})") + result_mat = einsum( + inputs_unfolded, weight_mat, "n c_out_k i, c_in c_out_k -> n c_in i" + ) + + report_nonclose(result, result_mat.reshape(batch_size, C_in, *result.shape[2:])) diff --git a/test/functionals/transpose_unfold_cases.py b/test/functionals/transpose_unfold_cases.py new file mode 100644 index 0000000..f065ec6 --- /dev/null +++ b/test/functionals/transpose_unfold_cases.py @@ -0,0 +1,66 @@ +"""Contains cases to test transpose convolution unfolding functional.""" + +from test.utils import make_id + +from torch import rand + +TRANSPOSE_UNFOLD_1D_CASES = [ + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 50), + "kernel_size": 1, + "kwargs": {}, + }, + { + "seed": 0, + "input_fn": lambda: rand(5, 2, 9), + "kernel_size": 3, + "kwargs": {"stride": 2, "padding": 1, "output_padding": 1, "dilation": 2}, + }, +] + +TRANSPOSE_UNFOLD_1D_IDS = [make_id(problem) for problem in TRANSPOSE_UNFOLD_1D_CASES] + +TRANSPOSE_UNFOLD_2D_CASES = [ + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 10, 10), + "kernel_size": 1, + "kwargs": {}, + }, + { + "seed": 0, + "input_fn": lambda: rand(5, 2, 9, 7), + "kernel_size": (3, 2), + "kwargs": { + "stride": (2, 1), + "padding": (1, 0), + "output_padding": (1, 0), + "dilation": (2, 2), + }, + }, +] + +TRANSPOSE_UNFOLD_2D_IDS = [make_id(problem) for problem in TRANSPOSE_UNFOLD_2D_CASES] + +TRANSPOSE_UNFOLD_3D_CASES = [ + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 10, 10, 10), + "kernel_size": 1, + "kwargs": {}, + }, + { + "seed": 0, + "input_fn": lambda: rand(5, 2, 9, 7, 6), + "kernel_size": (3, 2, 4), + "kwargs": { + "stride": (2, 1, 3), + "padding": (1, 0, 2), + "output_padding": (1, 0, 2), + "dilation": (2, 2, 1), + }, + }, +] + +TRANSPOSE_UNFOLD_3D_IDS = [make_id(problem) for problem in TRANSPOSE_UNFOLD_3D_CASES] From 7bc0da3593c5b6bfb1789cd1721577418dfd595f Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 11:40:18 -0400 Subject: [PATCH 03/10] [DEL] Print statements --- einconv/expressions/conv_transposeNd_unfold.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py index beaf28a..dd2a45c 100644 --- a/einconv/expressions/conv_transposeNd_unfold.py +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -111,9 +111,6 @@ def einsum_expression( batch_size, out_channels = x.shape[:2] shape = (batch_size, out_channels * kernel_tot_size, input_tot_size) - print(equation) - print([op.shape for op in operands]) - if simplify: equation, operands = einconv.simplify(equation, operands) From 86870e26f473a7f93079a29ddd6884264693907a Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 11:51:02 -0400 Subject: [PATCH 04/10] [DOC] Polish docstrings --- einconv/expressions/conv_transposeNd_unfold.py | 6 +++--- einconv/functionals/unfold_transpose.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py index dd2a45c..317293f 100644 --- a/einconv/expressions/conv_transposeNd_unfold.py +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -19,9 +19,9 @@ def einsum_expression( """Generate einsum expression to unfold the input of a transpose convolution. The unfolded input for a transpose convolution flattens and concatenates all - elements of the input tensor into a matrix such that the transpose convolution - can be written as matrix multiplication between the unfolded input and the - matricized kernel. + elements of the input tensor that overlap with the kernel for a specific output + location into a matrix such that the transpose convolution can be written as matrix + multiplication between the unfolded input and the matricized kernel. We will use the associated convolution's hyper-parameters to describe all arguments. Consider an `N`d convolution which maps an input tensor of shape diff --git a/einconv/functionals/unfold_transpose.py b/einconv/functionals/unfold_transpose.py index f1cc136..d037faa 100644 --- a/einconv/functionals/unfold_transpose.py +++ b/einconv/functionals/unfold_transpose.py @@ -22,7 +22,8 @@ def unfoldNd_transpose( into a matrix. This matrix can then be used to formulate transpose convolution as matrix multiplication between the unfolded input and the matricized kernel. - This function uses `einsum` under the hood and does not have a PyTorch equivalent. + Note: + This function uses `einsum` under the hood and does not have a PyTorch equivalent. We will use the hyper-parameters of an `N`d convolution which maps an input of shape `[batch_size, in_channels, *input_sizes]` to an output of shape From 0cc64786ac8e44791bb87cd7896ae0ffe9e1e9c2 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 11:59:20 -0400 Subject: [PATCH 05/10] [DOC] Minor polish --- einconv/expressions/conv_transposeNd_unfold.py | 6 ++---- einconv/functionals/unfold_transpose.py | 4 ++-- test/functionals/test_unfold_transpose.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py index 317293f..09d7172 100644 --- a/einconv/expressions/conv_transposeNd_unfold.py +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -38,7 +38,7 @@ def einsum_expression( stride: Stride of the associated convolution. Can be a single integer (shared along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. padding: Padding of the associated convolution. Can be a single integer (shared - along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. + along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. output_padding: The associated convolution's number of unused pixels at the end of a spatial dimension. This is required to resolve the ambiguity that a convolution can produce the same output shape for different input shapes if @@ -54,20 +54,18 @@ def einsum_expression( Einsum operands in order input, patterns Output shape: `[batch_size, out_channels * tot_kernel_size, tot_input_size]` """ - N = x.dim() - 2 + N = x.ndim - 2 # construct einsum equation x_str = "n c_out " + " ".join([f"o{i}" for i in range(N)]) pattern_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)] lhs = ",".join([x_str, *pattern_strs]) - rhs = ( "n c_out " + " ".join([f"k{i}" for i in range(N)]) + " " + " ".join([f"i{i}" for i in range(N)]) ) - equation = "->".join([lhs, rhs]) equation = translate_to_torch(equation) diff --git a/einconv/functionals/unfold_transpose.py b/einconv/functionals/unfold_transpose.py index d037faa..5b19bee 100644 --- a/einconv/functionals/unfold_transpose.py +++ b/einconv/functionals/unfold_transpose.py @@ -39,7 +39,7 @@ def unfoldNd_transpose( stride: Stride of the associated convolution. Can be a single integer (shared along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. padding: Padding of the associated convolution. Can be a single integer (shared - along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. + along all spatial dimensions) or an `N`-tuple of integers, Default: `0`. output_padding: The associated convolution's number of unused pixels at the end of a spatial dimension. This is required to resolve the ambiguity that a convolution can produce the same output shape for different input shapes if @@ -53,7 +53,7 @@ def unfoldNd_transpose( Returns: Unfolded input tensor of shape \ - shape `[batch_size, in_channels * tot_kernel_size, tot_input_size]` where \ + `[batch_size, in_channels * tot_kernel_size, tot_input_size]` where \ `tot_kernel_size`, `tot_input_size` are the total number of kernel elements and spatial input elements to the associated convolution. In `einops` notation, the index structure is `n (c_out k1 k2 ...) (i1 i2 ...)`. diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py index 4148fe9..e2778c2 100644 --- a/test/functionals/test_unfold_transpose.py +++ b/test/functionals/test_unfold_transpose.py @@ -34,7 +34,7 @@ ) @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): - """Compare transpose convolution via matrix-multiplication with built-in one. + """Compare transpose convolution via input unfolding with built-in one. Args: case: Dictionary describing the test case. From a227e7d874384aba77b2acbe85d6a926869525bb Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Wed, 5 Jun 2024 12:00:05 -0400 Subject: [PATCH 06/10] [FIX] Too long lines --- einconv/functionals/unfold_transpose.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/einconv/functionals/unfold_transpose.py b/einconv/functionals/unfold_transpose.py index 5b19bee..0a39b32 100644 --- a/einconv/functionals/unfold_transpose.py +++ b/einconv/functionals/unfold_transpose.py @@ -23,7 +23,7 @@ def unfoldNd_transpose( matrix multiplication between the unfolded input and the matricized kernel. Note: - This function uses `einsum` under the hood and does not have a PyTorch equivalent. + This function uses `einsum` under the hood and has no PyTorch equivalent. We will use the hyper-parameters of an `N`d convolution which maps an input of shape `[batch_size, in_channels, *input_sizes]` to an output of shape From a16b67f2a4277a969849da8977a8d901eed54dc0 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 6 Jun 2024 09:29:32 -0400 Subject: [PATCH 07/10] [ADD] einsum expression for transpose convolution's KFAC-reduce --- docs/api/expressions.md | 1 + einconv/expressions/convNd_kfac_reduce.py | 28 ++- einconv/expressions/convNd_kfc.py | 24 ++- .../conv_transposeNd_kfac_reduce.py | 173 ++++++++++++++++++ .../expressions/conv_transposeNd_unfold.py | 21 +++ mkdocs.yml | 4 +- .../conv_transposeNd_kfac_reduce_cases.py | 117 ++++++++++++ test/expressions/test_convNd_kfac_reduce.py | 2 +- .../test_conv_transposeNd_kfac_reduce.py | 63 +++++++ test/functionals/test_unfold_transpose.py | 33 ++++ 10 files changed, 459 insertions(+), 7 deletions(-) create mode 100644 einconv/expressions/conv_transposeNd_kfac_reduce.py create mode 100644 test/expressions/conv_transposeNd_kfac_reduce_cases.py create mode 100644 test/expressions/test_conv_transposeNd_kfac_reduce.py diff --git a/docs/api/expressions.md b/docs/api/expressions.md index 7a753d3..f177d89 100644 --- a/docs/api/expressions.md +++ b/docs/api/expressions.md @@ -5,3 +5,4 @@ :::einconv.expressions.convNd_kfc :::einconv.expressions.convNd_kfac_reduce :::einconv.expressions.conv_transposeNd_unfold +:::einconv.expressions.conv_transposeNd_kfac_reduce diff --git a/einconv/expressions/convNd_kfac_reduce.py b/einconv/expressions/convNd_kfac_reduce.py index 6ad10db..3fa55ee 100644 --- a/einconv/expressions/convNd_kfac_reduce.py +++ b/einconv/expressions/convNd_kfac_reduce.py @@ -2,8 +2,10 @@ KFAC-reduce was introduced by: -- Eschenhagen, R. (2022). Kronecker-factored approximate curvature for linear - weight-sharing layers, Master thesis. +- [Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., & Hennig, P. + (2023). Kronecker-factored approximate curvature for modern neural network + architectures. In Advances in Neural Information Processing Systems (NeurIPS)]\ +(https://arxiv.org/abs/2311.00636). """ from typing import List, Tuple, Union @@ -27,6 +29,26 @@ def einsum_expression( ) -> Tuple[str, List[Tensor], Tuple[int, ...]]: """Generate einsum expression of input-based KFAC-reduce factor for convolution. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + denote the input of a convolution. The unfolded input $[[\\mathbf{X}]]$ + has dimension $(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times (O_1 \\cdot O_2 + \\cdots)$ where $K_i$ and $O_i$ are the kernel and output sizes of the convolution. + The input-based KFAC-reduce factor is the batch-averaged outer product + of the column-averaged unfolded input, + + $$ + \\hat{\\mathbf{\\Omega}} = + \\frac{1}{B \\cdot (O_1 \\cdot O_2 \\cdots)^2} \\sum_{b=1}^B + ( [[\\mathbf{X}_b]]^\\top \\mathbf{1} ) + ( [[\\mathbf{X}_b]]^\\top \\mathbf{1} )^\\top + \\in \\mathbb{R}^{(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the convolution's input from the + $b$th data point. + Args: x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]`` where ``len(input_sizes) == N``. @@ -83,7 +105,7 @@ def einsum_expression( x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups) output_tot_size = Tensor([p.shape[1] for p in patterns]).int().prod() batch_size = x.shape[0] - scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device).to(x.dtype) + scale = Tensor([1.0 / (batch_size * output_tot_size**2)]).to(x.device, x.dtype) operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale] # construct output shape diff --git a/einconv/expressions/convNd_kfc.py b/einconv/expressions/convNd_kfc.py index bb422e1..20c0071 100644 --- a/einconv/expressions/convNd_kfc.py +++ b/einconv/expressions/convNd_kfc.py @@ -2,8 +2,9 @@ KFC was introduced by: -- Grosse, R., & Martens, J. (2016). A Kronecker-factored approximate Fisher matrix - for convolution layers. International Conference on Machine Learning (ICML). +- [Grosse, R., & Martens, J. (2016). A Kronecker-factored approximate Fisher matrix + for convolution layers. International Conference on Machine Learning (ICML).]\ +(https://arxiv.org/abs/1602.01407) """ from typing import List, Tuple, Union @@ -27,6 +28,25 @@ def einsum_expression( ) -> Tuple[str, List[Tensor], Tuple[int, ...]]: """Generate einsum expression of input-based KFC factor for convolution. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + denote the input of a convolution. The unfolded input $[[\\mathbf{X}]]$ + has dimension $(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times (O_1 \\cdot O_2 + \\cdots)$ where $K_i$ and $O_i$ are the kernel and output sizes of the convolution. + The input-based KFC factor is the batch-averaged outer product of the unfolded + input, + + $$ + \\mathbf{\\Omega} = + \\frac{1}{B} \\sum_{b=1}^B + [[\\mathbf{X}_b]] [[\\mathbf{X}_b]]^\\top + \\in \\mathbb{R}^{(C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{in} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the convolution's input from the + $b$th data point. + Args: x: Convolution input. Has shape ``[batch_size, in_channels, *input_sizes]`` where ``len(input_sizes) == N``. diff --git a/einconv/expressions/conv_transposeNd_kfac_reduce.py b/einconv/expressions/conv_transposeNd_kfac_reduce.py new file mode 100644 index 0000000..1bae4a0 --- /dev/null +++ b/einconv/expressions/conv_transposeNd_kfac_reduce.py @@ -0,0 +1,173 @@ +"""Input-based factor of the KFAC-reduce approximation for transpose convolutions. + +KFAC-reduce was introduced by: + +- [Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., & Hennig, P. + (2023). Kronecker-factored approximate curvature for modern neural network + architectures. In Advances in Neural Information Processing Systems (NeurIPS)]\ +(https://arxiv.org/abs/2311.00636). +""" + +from typing import List, Optional, Tuple, Union + +from einops import rearrange +from torch import Tensor + +import einconv +from einconv.expressions.utils import create_conv_index_patterns, translate_to_torch +from einconv.utils import _tuple, get_conv_input_size + + +def einsum_expression( + x: Tensor, + kernel_size: Union[int, Tuple[int, ...]], + stride: Union[int, Tuple[int, ...]] = 1, + padding: Union[int, Tuple[int, ...]] = 0, + output_padding: Union[int, Tuple[int, ...]] = 0, + output_size: Optional[Union[int, Tuple[int, ...]]] = None, + dilation: Union[int, Tuple[int, ...]] = 1, + groups: int = 1, + simplify: bool = True, +) -> Tuple[str, List[Tensor], Tuple[int, ...]]: + """Generate einsum expr. of input-based KFAC-reduce factor for transp. convolution. + + We describe the `N`d transpose convolution using its associated `N`d convolution + which maps an input of shape `[batch_size, in_channels, *input_sizes]` to an output + of shape `[batch_size, out_channels, *output_sizes]`. The transpose convolution's + input has shape `[batch_size, out_channels, *output_sizes]` and the output has shape + `[batch_size, in_channels, *input_sizes]`. + + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{out}\\times O_1\\times O_2\\times\\dots}$ + denote the input of a transpose convolution. The unfolded input $[[\\mathbf{X} + ]]_\\top$ has dimension $(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (I_1 \\cdot I_2 \\cdots)$ where $K_i$ and $I_i$ are the kernel and input sizes of + the associated convolution. The input-based KFAC-reduce factor is the batch-averaged + outer product of the column-averaged unfolded input, + + $$ + \\hat{\\mathbf{\\Omega}} = + \\frac{1}{B \\cdot (I_1 \\cdot I_2 \\cdots)^2} \\sum_{b=1}^B + ( [[\\mathbf{X}_b]]^\\top_\\top \\mathbf{1} ) + ( [[\\mathbf{X}_b]]^\\top_\\top \\mathbf{1} )^\\top + \\in \\mathbb{R}^{(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times + (C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots)} + \\,, + $$ + + where $B$ is the batch size and $\\mathbf{X}_b$ is the transpose convolution's + input from the $b$th data point. + + Args: + x: Input tensor of shape `[batch_size, out_channels, *output_sizes]`. + kernel_size: Size of the convolutional kernel. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. + stride: Stride of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + padding: Padding of the associated convolution. Can be a single integer (shared + along all spatial dimensions), or an `N`-tuple of integers. Default: `0`. + output_padding: Number of unused pixels at the end of the spatial domain. + This is used to resolve the ambiguity that a convolution can map different + input sizes to the same output size if its stride is different from 1. + Instead of specifying this argument, you can directly specify the output + size of the transpose convolution (i.e. the input size of the associated + convolution via the `output_size` argument). Can be a single integer + (shared along all spatial dimensions), or an `N`-tuple. Default: `0`. + output_size: Size of the output of the transpose convolution (i.e. the input + size of the associated convolution). Specifying this argument will override + the `output_padding` argument. Can be a single integer (shared along all + spatial dimensions), or an `N`-tuple of integers. Default: `None`. + dilation: Dilation of the convolution. Can be a single integer (shared along + all spatial dimensions), or an `N`-tuple of integers. Default: `1`. + groups: In how many groups to split the channels. Default: `1`. + simplify: Whether to simplify the einsum expression. Default: `True`. + + Returns: + Einsum equation + Einsum operands in order un-grouped input, patterns, un-grouped input, \ + patterns, normalization scaling + Output shape: `[groups, out_channels //groups * tot_kernel_sizes,\ + out_channels //groups * tot_kernel_sizes]` + """ + N = x.dim() - 2 + + # construct einsum equation + x1_str = "n g c_out " + " ".join([f"o{i}" for i in range(N)]) + x2_str = "n g c_out_ " + " ".join([f"o{i}_" for i in range(N)]) + pattern1_strs: List[str] = [f"k{i} o{i} i{i}" for i in range(N)] + pattern2_strs: List[str] = [f"k{i}_ o{i}_ i{i}_" for i in range(N)] + scale_str = "s" + lhs = ",".join([x1_str, *pattern1_strs, *pattern2_strs, x2_str, scale_str]) + rhs = ( + "g c_out " + + " ".join([f"k{i}" for i in range(N)]) + + " c_out_ " + + " ".join([f"k{i}_" for i in range(N)]) + ) + equation = "->".join([lhs, rhs]) + equation = translate_to_torch(equation) + + conv_output_size = x.shape[2:] + t_kernel_size = _tuple(kernel_size, N) + t_stride = _tuple(stride, N) + t_padding = _tuple(padding, N) + t_dilation = _tuple(dilation, N) + + # infer output_padding from convolution's input size + if output_size is not None: + t_output_size = _tuple(output_size, N) + t_output_padding = tuple( + output_size - get_conv_input_size(conv_out_size, K, S, P, 0, D) + for output_size, conv_out_size, K, S, P, D in zip( + t_output_size, + conv_output_size, + t_kernel_size, + t_stride, + t_padding, + t_dilation, + ) + ) + else: + t_output_padding = _tuple(output_padding, N) + + conv_input_size = tuple( + get_conv_input_size(output_size, K, S, P, output_padding, D) + for output_size, K, S, P, output_padding, D in zip( + conv_output_size, + t_kernel_size, + t_stride, + t_padding, + t_output_padding, + t_dilation, + ) + ) + + # construct einsum operands + patterns = create_conv_index_patterns( + N, + input_size=conv_input_size, + kernel_size=t_kernel_size, + stride=t_stride, + padding=t_padding, + dilation=dilation, + device=x.device, + dtype=x.dtype, + ) + x_ungrouped = rearrange(x, "n (g c_in) ... -> n g c_in ...", g=groups) + conv_input_tot_size = Tensor(conv_input_size).int().prod() + batch_size, out_channels = x.shape[:2] + scale = Tensor([1.0 / (batch_size * conv_input_tot_size**2)]).to(x.device, x.dtype) + operands = [x_ungrouped, *patterns, *patterns, x_ungrouped, scale] + + # construct output shape + t_kernel_size = _tuple(kernel_size, N) + kernel_tot_size = int(Tensor(t_kernel_size).int().prod()) + shape = ( + groups, + out_channels // groups * kernel_tot_size, + out_channels // groups * kernel_tot_size, + ) + + if simplify: + equation, operands = einconv.simplify(equation, operands) + + return equation, operands, shape diff --git a/einconv/expressions/conv_transposeNd_unfold.py b/einconv/expressions/conv_transposeNd_unfold.py index 09d7172..9c10779 100644 --- a/einconv/expressions/conv_transposeNd_unfold.py +++ b/einconv/expressions/conv_transposeNd_unfold.py @@ -30,6 +30,27 @@ def einsum_expression( has shape `[batch_size, out_channels, *output_sizes]` and the output has shape `[batch_size, in_channels, *input_sizes]`. + Let $\\mathbf{X}\\in\\mathbb{R}^{C_\\text{out}\\times O_1\\times O_2\\times\\dots}$ + denote the input of a transpose convolution, $\\mathbf{W} \\in \\mathbb{R}^{ + C_\\text{out} \\times C_\\text{in} \\times K_1\\times K_2\\times\\dots}$ its kernel + and $\\mathbf{Y}\\in\\mathbb{R}^{C_\\text{in}\\times I_1\\times I_2\\times\\dots}$ + its output. The unfolded input $[[\\mathbf{X}]]_\\top$ has dimension + $(C_\\text{out} \\cdot K_1 \\cdot K_2 \\cdots) \\times (I_1 \\cdot I_2 \\cdots)$ and + can be used to express transpose convolution as matrix multiplication, + + $$ + \\mathrm{mat}(\\mathbf{Y}) + = + \\mathrm{mat}(\\mathbf{W}) + [[\\mathbf{X})]]_\\top + \\,, + $$ + + where $\\mathrm{mat}(\\mathbf{Y}) \\in \\mathbb{R}^{C_\\text{in}\\times (I_1\\cdot + I_2 \\cdots)}$ and $\\mathrm{mat}(\\mathbf{W}) \\in \\mathbb{R}^{C_\\text{in}\\times + (C_\\text{out} \\cdot K_1\\cdot K_2 \\cdots)}$ are matrix views of $\\mathbf{Y}, + \\mathbf{W}$ (note that $\\mathbf{W}$ must also be transposed before matricizing). + Args: x: Input to the `N`d transpose convolution. Has shape `[batch_size, in_channels, *input_sizes]` where `len(input_sizes) == N`. diff --git a/mkdocs.yml b/mkdocs.yml index 1c29bef..a950584 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,8 +1,10 @@ site_name: Einconv -site_url: https://example.com # TODO Fill in the link from the hosting platform +site_url: https://einconv.readthedocs.io repo_url: https://github.com/f-dangel/einconv/ repo_name: f-dangel/einconv site_author: Felix Dangel +watch: + - einconv nav: - Getting Started: index.md - Tutorials: diff --git a/test/expressions/conv_transposeNd_kfac_reduce_cases.py b/test/expressions/conv_transposeNd_kfac_reduce_cases.py new file mode 100644 index 0000000..afa010a --- /dev/null +++ b/test/expressions/conv_transposeNd_kfac_reduce_cases.py @@ -0,0 +1,117 @@ +"""Test cases for einsum expression of input-based TRANSPOSE_KFAC-reduce factor of convolution.""" + +from test.utils import make_id + +from torch import rand + +TRANSPOSE_KFAC_REDUCE_1D_CASES = [ + # no kwargs + { + "seed": 0, + # (batch_size, in_channels, num_pixels) + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default stride + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"stride": 2}, + }, + # non-default stride, groups + { + "seed": 0, + "input_fn": lambda: rand(2, 4, 8), + "kernel_size": 5, + "kwargs": {"stride": 3, "groups": 2}, + }, + # non-default padding + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"padding": 2}, + }, + # non-default output padding + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"output_padding": 1, "stride": 2}, + }, + # non-default dilation + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": 5, + "kwargs": {"dilation": 2}, + }, + # non-default arguments supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8), + "kernel_size": (3,), + "kwargs": { + "padding": (1,), + "stride": (2,), + "dilation": (1,), + "output_padding": (1,), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_1D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_1D_CASES +] + +TRANSPOSE_KFAC_REDUCE_2D_CASES = [ + # no kwargs + { + "seed": 0, + # (batch_size, in_channels, num_pixels_h, num_pixels_w) + "input_fn": lambda: rand(2, 3, 8, 7), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default kwargs supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7), + "kernel_size": (5, 3), + "kwargs": { + "padding": (1, 2), + "stride": (2, 3), + "dilation": (2, 1), + "output_padding": (1, 2), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_2D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_2D_CASES +] + +TRANSPOSE_KFAC_REDUCE_3D_CASES = [ + # no kwargs + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7, 6), + "kernel_size": 5, + "kwargs": {}, + }, + # non-default kwargs supplied as tuple + { + "seed": 0, + "input_fn": lambda: rand(2, 3, 8, 7, 6), + "kernel_size": (5, 3, 2), + "kwargs": { + "padding": (0, 1, 2), + "stride": (3, 2, 1), + "dilation": (1, 2, 3), + "output_padding": (2, 0, 0), + }, + }, +] +TRANSPOSE_KFAC_REDUCE_3D_IDS = [ + make_id(case) for case in TRANSPOSE_KFAC_REDUCE_3D_CASES +] diff --git a/test/expressions/test_convNd_kfac_reduce.py b/test/expressions/test_convNd_kfac_reduce.py index e6c7677..083e142 100644 --- a/test/expressions/test_convNd_kfac_reduce.py +++ b/test/expressions/test_convNd_kfac_reduce.py @@ -50,7 +50,7 @@ def test_einsum_expression(case: Dict, device: torch.device, simplify: bool): avg_unfolded_x = unfolded_x.mean(dim=-1) groups = kwargs.get("groups", 1) avg_unfolded_x = rearrange(avg_unfolded_x, "n (g c_in_k) -> n g c_in_k", g=groups) - kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / (batch_size) + kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / batch_size equation, operands, shape = convNd_kfac_reduce.einsum_expression( x, kernel_size, **kwargs, simplify=simplify diff --git a/test/expressions/test_conv_transposeNd_kfac_reduce.py b/test/expressions/test_conv_transposeNd_kfac_reduce.py new file mode 100644 index 0000000..79fb637 --- /dev/null +++ b/test/expressions/test_conv_transposeNd_kfac_reduce.py @@ -0,0 +1,63 @@ +"""Tests ``einconv.expressions.conv_transposeNd_kfac_reduce``.""" + +from test.expressions.conv_transposeNd_kfac_reduce_cases import ( + TRANSPOSE_KFAC_REDUCE_1D_CASES, + TRANSPOSE_KFAC_REDUCE_1D_IDS, + TRANSPOSE_KFAC_REDUCE_2D_CASES, + TRANSPOSE_KFAC_REDUCE_2D_IDS, + TRANSPOSE_KFAC_REDUCE_3D_CASES, + TRANSPOSE_KFAC_REDUCE_3D_IDS, +) +from test.utils import DEVICE_IDS, DEVICES, SIMPLIFIES, SIMPLIFY_IDS, report_nonclose +from typing import Dict + +import unfoldNd +from einops import rearrange +from pytest import mark +from torch import device, einsum, manual_seed + +from einconv.expressions import conv_transposeNd_kfac_reduce + + +@mark.parametrize("simplify", SIMPLIFIES, ids=SIMPLIFY_IDS) +@mark.parametrize( + "case", + TRANSPOSE_KFAC_REDUCE_1D_CASES + + TRANSPOSE_KFAC_REDUCE_2D_CASES + + TRANSPOSE_KFAC_REDUCE_3D_CASES, + ids=TRANSPOSE_KFAC_REDUCE_1D_IDS + + TRANSPOSE_KFAC_REDUCE_2D_IDS + + TRANSPOSE_KFAC_REDUCE_3D_IDS, +) +@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) +def test_einsum_expression(case: Dict, dev: device, simplify: bool): + """Compare einsum expression of KFAC reduce with implementation via unfolding. + + Args: + case: Dictionary describing the test case. + dev: Device to execute the test on. + simplify: Whether to simplify the einsum expression. + """ + seed = case["seed"] + input_fn = case["input_fn"] + kernel_size = case["kernel_size"] + kwargs = case["kwargs"] + + manual_seed(seed) + x = input_fn().to(dev) + batch_size = x.shape[0] + + # ground truth + unfold_kwargs = {key: value for key, value in kwargs.items() if key != "groups"} + unfolded_x = unfoldNd.unfold_transposeNd(x, kernel_size, **unfold_kwargs) + avg_unfolded_x = unfolded_x.mean(dim=-1) + groups = kwargs.get("groups", 1) + avg_unfolded_x = rearrange(avg_unfolded_x, "n (g c_in_k) -> n g c_in_k", g=groups) + kfac_unfold = einsum("ngi,ngj->gij", avg_unfolded_x, avg_unfolded_x) / batch_size + + equation, operands, shape = conv_transposeNd_kfac_reduce.einsum_expression( + x, kernel_size, **kwargs, simplify=simplify + ) + kfac_einconv = einsum(equation, *operands).reshape(shape) + + report_nonclose(kfac_unfold, kfac_einconv) diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py index e2778c2..184a282 100644 --- a/test/functionals/test_unfold_transpose.py +++ b/test/functionals/test_unfold_transpose.py @@ -11,6 +11,7 @@ from test.utils import DEVICE_IDS, DEVICES, SIMPLIFIES, SIMPLIFY_IDS, report_nonclose from typing import Dict +import unfoldNd from einops import einsum, rearrange from pytest import mark from torch import ( @@ -34,6 +35,38 @@ ) @mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): + """Compare input unfolding for transpose convolution with `unfoldNd` package. + + Args: + case: Dictionary describing the test case. + dev: Device to execute the test on. + simplify: Whether to use a simplified einsum expression. + """ + seed = case["seed"] + input_fn = case["input_fn"] + kernel_size = case["kernel_size"] + kwargs = case["kwargs"] + + manual_seed(seed) + inputs = input_fn().to(dev) + + result = unfoldNd.unfold_transposeNd(inputs, kernel_size, **kwargs) + + einconv_result = unfoldNd_transpose( + inputs, kernel_size, **kwargs, simplify=simplify + ) + + report_nonclose(result, einconv_result) + + +@mark.parametrize("simplify", SIMPLIFIES, ids=SIMPLIFY_IDS) +@mark.parametrize( + "case", + TRANSPOSE_UNFOLD_1D_CASES + TRANSPOSE_UNFOLD_2D_CASES + TRANSPOSE_UNFOLD_3D_CASES, + ids=TRANSPOSE_UNFOLD_1D_IDS + TRANSPOSE_UNFOLD_2D_IDS + TRANSPOSE_UNFOLD_3D_IDS, +) +@mark.parametrize("dev", DEVICES, ids=DEVICE_IDS) +def test_unfoldNd_transpose_via_conv_transpose(case: Dict, dev: device, simplify: bool): """Compare transpose convolution via input unfolding with built-in one. Args: From 5a2e2b4dd178d85f31176aadc6829f99538e6c7b Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 6 Jun 2024 11:22:29 -0400 Subject: [PATCH 08/10] [FIX] Long line --- test/expressions/conv_transposeNd_kfac_reduce_cases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/expressions/conv_transposeNd_kfac_reduce_cases.py b/test/expressions/conv_transposeNd_kfac_reduce_cases.py index afa010a..6ec5bb2 100644 --- a/test/expressions/conv_transposeNd_kfac_reduce_cases.py +++ b/test/expressions/conv_transposeNd_kfac_reduce_cases.py @@ -1,4 +1,4 @@ -"""Test cases for einsum expression of input-based TRANSPOSE_KFAC-reduce factor of convolution.""" +"""Test cases for einsum expression of input-based KFAC-reduce for conv. transpose.""" from test.utils import make_id From 843bec7b916d00a3a227245aeb8ef68684f15836 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 6 Jun 2024 11:33:22 -0400 Subject: [PATCH 09/10] [REF] Minor polish --- einconv/expressions/convNd_kfac_reduce.py | 4 ++-- einconv/expressions/conv_transposeNd_kfac_reduce.py | 12 ++++++------ test/functionals/test_unfold_transpose.py | 2 -- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/einconv/expressions/convNd_kfac_reduce.py b/einconv/expressions/convNd_kfac_reduce.py index 3fa55ee..dd6d766 100644 --- a/einconv/expressions/convNd_kfac_reduce.py +++ b/einconv/expressions/convNd_kfac_reduce.py @@ -68,8 +68,8 @@ def einsum_expression( Einsum equation Einsum operands in order un-grouped input, patterns, un-grouped input, \ patterns, normalization scaling - Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\ - in_channels //groups * tot_kernel_sizes]`` + Output shape: ``[groups, in_channels // groups * tot_kernel_sizes,\ + in_channels // groups * tot_kernel_sizes]`` """ N = x.dim() - 2 diff --git a/einconv/expressions/conv_transposeNd_kfac_reduce.py b/einconv/expressions/conv_transposeNd_kfac_reduce.py index 1bae4a0..450368c 100644 --- a/einconv/expressions/conv_transposeNd_kfac_reduce.py +++ b/einconv/expressions/conv_transposeNd_kfac_reduce.py @@ -85,8 +85,8 @@ def einsum_expression( Einsum equation Einsum operands in order un-grouped input, patterns, un-grouped input, \ patterns, normalization scaling - Output shape: `[groups, out_channels //groups * tot_kernel_sizes,\ - out_channels //groups * tot_kernel_sizes]` + Output shape: `[groups, out_channels // groups * tot_kernel_sizes,\ + out_channels // groups * tot_kernel_sizes]` """ N = x.dim() - 2 @@ -116,8 +116,8 @@ def einsum_expression( if output_size is not None: t_output_size = _tuple(output_size, N) t_output_padding = tuple( - output_size - get_conv_input_size(conv_out_size, K, S, P, 0, D) - for output_size, conv_out_size, K, S, P, D in zip( + output_size - get_conv_input_size(I, K, S, P, 0, D) + for output_size, I, K, S, P, D in zip( t_output_size, conv_output_size, t_kernel_size, @@ -130,8 +130,8 @@ def einsum_expression( t_output_padding = _tuple(output_padding, N) conv_input_size = tuple( - get_conv_input_size(output_size, K, S, P, output_padding, D) - for output_size, K, S, P, output_padding, D in zip( + get_conv_input_size(O, K, S, P, output_padding, D) + for O, K, S, P, output_padding, D in zip( conv_output_size, t_kernel_size, t_stride, diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py index 184a282..9c15325 100644 --- a/test/functionals/test_unfold_transpose.py +++ b/test/functionals/test_unfold_transpose.py @@ -51,11 +51,9 @@ def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): inputs = input_fn().to(dev) result = unfoldNd.unfold_transposeNd(inputs, kernel_size, **kwargs) - einconv_result = unfoldNd_transpose( inputs, kernel_size, **kwargs, simplify=simplify ) - report_nonclose(result, einconv_result) From 897af8dc895c4b5a5ddde1aed1a353ba575bfdf9 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Thu, 6 Jun 2024 11:38:18 -0400 Subject: [PATCH 10/10] [FIX] flake8 --- einconv/expressions/conv_transposeNd_kfac_reduce.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/einconv/expressions/conv_transposeNd_kfac_reduce.py b/einconv/expressions/conv_transposeNd_kfac_reduce.py index 450368c..ba397d0 100644 --- a/einconv/expressions/conv_transposeNd_kfac_reduce.py +++ b/einconv/expressions/conv_transposeNd_kfac_reduce.py @@ -116,8 +116,8 @@ def einsum_expression( if output_size is not None: t_output_size = _tuple(output_size, N) t_output_padding = tuple( - output_size - get_conv_input_size(I, K, S, P, 0, D) - for output_size, I, K, S, P, D in zip( + output_size - get_conv_input_size(out, K, S, P, 0, D) + for output_size, out, K, S, P, D in zip( t_output_size, conv_output_size, t_kernel_size, @@ -130,8 +130,8 @@ def einsum_expression( t_output_padding = _tuple(output_padding, N) conv_input_size = tuple( - get_conv_input_size(O, K, S, P, output_padding, D) - for O, K, S, P, output_padding, D in zip( + get_conv_input_size(out, K, S, P, output_padding, D) + for out, K, S, P, output_padding, D in zip( conv_output_size, t_kernel_size, t_stride,