Skip to content

Commit 8fc8427

Browse files
committed
small fixes after rebase
1 parent 455ae90 commit 8fc8427

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

tests/testutils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# coding=utf-8
2+
3+
# SPDX-FileCopyrightText: Copyright (c) 2025 The torch-harmonics Authors. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# Redistribution and use in source and binary forms, with or without
7+
# modification, are permitted provided that the following conditions are met:
8+
#
9+
# 1. Redistributions of source code must retain the above copyright notice, this
10+
# list of conditions and the following disclaimer.
11+
#
12+
# 2. Redistributions in binary form must reproduce the above copyright notice,
13+
# this list of conditions and the following disclaimer in the documentation
14+
# and/or other materials provided with the distribution.
15+
#
16+
# 3. Neither the name of the copyright holder nor the names of its
17+
# contributors may be used to endorse or promote products derived from
18+
# this software without specific prior written permission.
19+
#
20+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30+
#
31+
32+
import torch
33+
34+
def compare_tensors(msg, tensor, tensor_ref, rtol=1e-8, atol=1e-5, verbose=False):
35+
allclose = torch.allclose(tensor, tensor_ref, rtol=rtol, atol=atol)
36+
if (not allclose) and verbose:
37+
diff = torch.abs(tensor - tensor_ref)
38+
print(f"{msg} absolute tensor diff: min = {torch.min(diff)}, mean = {torch.mean(diff)}, max = {torch.max(diff)}.")
39+
reldiff = diff / torch.abs(tensor_ref)
40+
print(f"{msg} relative tensor diff: min = {torch.min(reldiff)}, mean = {torch.mean(reldiff)}, max = {torch.max(reldiff)}.")
41+
# find element with maximum difference
42+
index = torch.argmax(diff)
43+
print(f"{msg} element {index} with maximum difference: value = {tensor.flatten()[index]}, reference value = {tensor_ref.flatten()[index]}, diff = {diff.flatten()[index]}.")
44+
return allclose

torch_harmonics/disco/convolution.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,10 @@ def __init__(
370370
raise ValueError("Error, the number of output channels has to be an integer multiple of the group size")
371371
self.groupsize_in = in_channels // self.groups
372372
self.groupsize_out = out_channels // self.groups
373+
# keep this for backward compatibility
374+
self.groupsize = self.groupsize_in
373375
scale = math.sqrt(1.0 / self.groupsize_in / self.kernel_size)
374-
self.weight = nn.Parameter(scale * torch.randn(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size))
376+
self.weight = nn.Parameter(scale * torch.randn(self.groups * self.groupsize_out, self.groupsize_in, self.kernel_size))
375377

376378
if bias:
377379
self.bias = nn.Parameter(torch.zeros(out_channels))
@@ -496,7 +498,7 @@ def __init__(
496498
self.psi = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out)
497499

498500
def extra_repr(self):
499-
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
501+
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
500502

501503
@property
502504
def psi_idx(self):
@@ -524,7 +526,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
524526
# weight multiplication
525527
B, H, W, _, K = xpc.shape
526528
xpc = xpc.reshape(B, H, W, self.groups, self.groupsize_in, K)
527-
outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight)
529+
outp = torch.einsum("bxygck,gock->bxygo", xpc, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size))
528530
outp = outp.reshape(B, H, W, -1).contiguous()
529531

530532
# permute output
@@ -538,7 +540,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
538540
x = x.reshape(B, self.groups, self.groupsize_in, K, H, W)
539541

540542
# weight multiplication
541-
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight)
543+
out = torch.einsum("bgckxy,gock->bgoxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size))
542544
out = out.reshape(B, -1, H, W).contiguous()
543545

544546
if self.bias is not None:
@@ -657,7 +659,7 @@ def __init__(
657659
self.psi_st = _get_psi(self.kernel_size, self.psi_idx, self.psi_vals, self.nlat_in, self.nlon_in, self.nlat_out, self.nlon_out, semi_transposed=True)
658660

659661
def extra_repr(self):
660-
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
662+
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}, in_chans={self.groupsize_in * self.groups}, out_chans={self.weight.shape[0]}, filter_basis={self.filter_basis}, kernel_shape={self.kernel_shape}, groups={self.groups}"
661663

662664
@property
663665
def psi_idx(self):
@@ -674,7 +676,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
674676

675677
# weight multiplication
676678
xp = xp.reshape(B, H, W, self.groups, self.groupsize_in)
677-
xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight)
679+
xpc = torch.einsum("bxygc,gock->bxygok", xp, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size))
678680
xpc = xpc.reshape(B, H, W, -1, self.kernel_size).contiguous()
679681

680682
# disco contraction
@@ -695,7 +697,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
695697
else:
696698
# weight multiplication
697699
x = x.reshape(B, self.groups, self.groupsize_in, H, W)
698-
xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight)
700+
xc = torch.einsum("bgcxy,gock->bgokxy", x, self.weight.reshape(self.groups, self.groupsize_out, self.groupsize_in, self.kernel_size))
699701
xc = xc.reshape(B, self.groups* self.groupsize_out, -1, H, W).contiguous()
700702

701703
# disco contraction

0 commit comments

Comments
 (0)