Skip to content

Commit

Permalink
[REF] Minor polish
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 6, 2024
1 parent 5a2e2b4 commit 843bec7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions einconv/expressions/convNd_kfac_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def einsum_expression(
Einsum equation
Einsum operands in order un-grouped input, patterns, un-grouped input, \
patterns, normalization scaling
Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\
in_channels //groups * tot_kernel_sizes]``
Output shape: ``[groups, in_channels // groups * tot_kernel_sizes,\
in_channels // groups * tot_kernel_sizes]``
"""
N = x.dim() - 2

Expand Down
12 changes: 6 additions & 6 deletions einconv/expressions/conv_transposeNd_kfac_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def einsum_expression(
Einsum equation
Einsum operands in order un-grouped input, patterns, un-grouped input, \
patterns, normalization scaling
Output shape: `[groups, out_channels //groups * tot_kernel_sizes,\
out_channels //groups * tot_kernel_sizes]`
Output shape: `[groups, out_channels // groups * tot_kernel_sizes,\
out_channels // groups * tot_kernel_sizes]`
"""
N = x.dim() - 2

Expand Down Expand Up @@ -116,8 +116,8 @@ def einsum_expression(
if output_size is not None:
t_output_size = _tuple(output_size, N)
t_output_padding = tuple(
output_size - get_conv_input_size(conv_out_size, K, S, P, 0, D)
for output_size, conv_out_size, K, S, P, D in zip(
output_size - get_conv_input_size(I, K, S, P, 0, D)
for output_size, I, K, S, P, D in zip(
t_output_size,
conv_output_size,
t_kernel_size,
Expand All @@ -130,8 +130,8 @@ def einsum_expression(
t_output_padding = _tuple(output_padding, N)

conv_input_size = tuple(
get_conv_input_size(output_size, K, S, P, output_padding, D)
for output_size, K, S, P, output_padding, D in zip(
get_conv_input_size(O, K, S, P, output_padding, D)
for O, K, S, P, output_padding, D in zip(
conv_output_size,
t_kernel_size,
t_stride,
Expand Down
2 changes: 0 additions & 2 deletions test/functionals/test_unfold_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@ def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool):
inputs = input_fn().to(dev)

result = unfoldNd.unfold_transposeNd(inputs, kernel_size, **kwargs)

einconv_result = unfoldNd_transpose(
inputs, kernel_size, **kwargs, simplify=simplify
)

report_nonclose(result, einconv_result)


Expand Down

0 comments on commit 843bec7

Please sign in to comment.