diff --git a/einconv/expressions/convNd_kfac_reduce.py b/einconv/expressions/convNd_kfac_reduce.py index 3fa55ee..dd6d766 100644 --- a/einconv/expressions/convNd_kfac_reduce.py +++ b/einconv/expressions/convNd_kfac_reduce.py @@ -68,8 +68,8 @@ def einsum_expression( Einsum equation Einsum operands in order un-grouped input, patterns, un-grouped input, \ patterns, normalization scaling - Output shape: ``[groups, in_channels //groups * tot_kernel_sizes,\ - in_channels //groups * tot_kernel_sizes]`` + Output shape: ``[groups, in_channels // groups * tot_kernel_sizes,\ + in_channels // groups * tot_kernel_sizes]`` """ N = x.dim() - 2 diff --git a/einconv/expressions/conv_transposeNd_kfac_reduce.py b/einconv/expressions/conv_transposeNd_kfac_reduce.py index 1bae4a0..450368c 100644 --- a/einconv/expressions/conv_transposeNd_kfac_reduce.py +++ b/einconv/expressions/conv_transposeNd_kfac_reduce.py @@ -85,8 +85,8 @@ def einsum_expression( Einsum equation Einsum operands in order un-grouped input, patterns, un-grouped input, \ patterns, normalization scaling - Output shape: `[groups, out_channels //groups * tot_kernel_sizes,\ - out_channels //groups * tot_kernel_sizes]` + Output shape: `[groups, out_channels // groups * tot_kernel_sizes,\ + out_channels // groups * tot_kernel_sizes]` """ N = x.dim() - 2 @@ -116,8 +116,8 @@ def einsum_expression( if output_size is not None: t_output_size = _tuple(output_size, N) t_output_padding = tuple( - output_size - get_conv_input_size(conv_out_size, K, S, P, 0, D) - for output_size, conv_out_size, K, S, P, D in zip( + output_size - get_conv_input_size(I, K, S, P, 0, D) + for output_size, I, K, S, P, D in zip( t_output_size, conv_output_size, t_kernel_size, @@ -130,8 +130,8 @@ def einsum_expression( t_output_padding = _tuple(output_padding, N) conv_input_size = tuple( - get_conv_input_size(output_size, K, S, P, output_padding, D) - for output_size, K, S, P, output_padding, D in zip( + get_conv_input_size(O, K, S, P, output_padding, D) + for O, K, S, P, output_padding, D in zip( conv_output_size, t_kernel_size, t_stride, diff --git a/test/functionals/test_unfold_transpose.py b/test/functionals/test_unfold_transpose.py index 184a282..9c15325 100644 --- a/test/functionals/test_unfold_transpose.py +++ b/test/functionals/test_unfold_transpose.py @@ -51,11 +51,9 @@ def test_unfoldNd_transpose(case: Dict, dev: device, simplify: bool): inputs = input_fn().to(dev) result = unfoldNd.unfold_transposeNd(inputs, kernel_size, **kwargs) - einconv_result = unfoldNd_transpose( inputs, kernel_size, **kwargs, simplify=simplify ) - report_nonclose(result, einconv_result)