Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
63b6e58
Test for a new disco kernel using vectorized operations on channel-last
mauro-bis Aug 22, 2025
933cfb8
there is more work to do
azrael417 Aug 26, 2025
2f673c2
dbg
azrael417 Aug 28, 2025
5b4b035
compiling code
azrael417 Aug 28, 2025
7be3eb5
almost working refactor
azrael417 Aug 28, 2025
5110dd0
fixed a lot of issues
azrael417 Aug 29, 2025
6f1cf6d
fixed fwd pass
azrael417 Aug 29, 2025
fdb14dd
re-implementing transpose conv
azrael417 Aug 29, 2025
8280510
working CPU branch
azrael417 Sep 1, 2025
f36abb8
Added preliminary version of channel-last BWD kernel.
mauro-bis Sep 3, 2025
e22d50a
commenting out tests
azrael417 Sep 3, 2025
92ae176
adding reshapes
azrael417 Sep 3, 2025
88e8c68
adding contiguous call before passing to kernel
azrael417 Sep 3, 2025
621ffa3
dbg
azrael417 Sep 8, 2025
89abe84
fixing some selection criterium
azrael417 Sep 8, 2025
21ff1d5
removing some debug prints
azrael417 Sep 8, 2025
c95f35e
small fix
azrael417 Sep 8, 2025
5eba6f9
cleaning up attention
azrael417 Sep 9, 2025
3656a84
fixing raw meta kernels
azrael417 Sep 9, 2025
9e8a038
better comments
azrael417 Sep 9, 2025
58ad395
snapshot
azrael417 Oct 6, 2025
9810143
Fixed absolutie difference errors w.r.t. CPU and torch DISCO
mauro-bis Oct 8, 2025
f20d5be
adding permute with contig call
azrael417 Oct 9, 2025
ade5461
dbg
azrael417 Oct 13, 2025
c2ec904
after rebase
azrael417 Oct 20, 2025
61eed2b
small fixes after rebase
azrael417 Oct 20, 2025
abc0537
adding theta cutoff to test
azrael417 Oct 27, 2025
381fb6a
cleanups
azrael417 Oct 28, 2025
4b89056
removing prints
azrael417 Oct 28, 2025
bfa4eb1
further cleanup
azrael417 Oct 29, 2025
0cb7a49
changing max array length
azrael417 Oct 30, 2025
0263f19
Increased max no. of element per thread to 20 to both fwd and bwd and
mauro-bis Oct 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,77 @@ def get_helpers_compile_args():
def get_ext_modules():
"""Get list of extension modules to compile."""

# define setup dir
setup_dir = os.path.abspath(os.path.dirname(__file__))

ext_modules = []
cmdclass = {}

print(f"Compiling helper routines for torch-harmonics.")

# Utility helpers
ext_modules.append(
CppExtension(
"utility_helpers",
[
"torch_harmonics/utils/csrc/utils_helpers.cpp",
],
extra_compile_args=get_helpers_compile_args(),
)
)

# DISCO helpers
ext_modules.append(
CppExtension(
"disco_helpers",
[
"torch_harmonics/disco/csrc/disco_helpers.cpp",
],
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_helpers_compile_args(),
)
)

# Attention helpers
ext_modules.append(
CppExtension(
"attention_helpers",
[
"torch_harmonics/attention/csrc/attention_helpers.cpp",
],
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_helpers_compile_args(),
)
)

if BUILD_CPP:
# HELPERS
utility_sources = [
"torch_harmonics/utils/csrc/utils_interface.cpp",
"torch_harmonics/utils/csrc/permute_cpu.cpp",
]

if BUILD_CUDA:
print(f"Compiling custom CUDA kernels for torch-harmonics.")
utility_sources.extend([
"torch_harmonics/utils/csrc/permute_cuda.cu",
])
ext_modules.append(
CUDAExtension(
"torch_harmonics.utils._C",
utility_sources,
extra_compile_args=get_compile_args("utils")
)
)
else:
ext_modules.append(
CppExtension(
"torch_harmonics.utils._C",
utility_sources,
extra_compile_args=get_compile_args("utils")
)
)

# DISCO
# Create a single extension that includes both CPU and CUDA code
disco_sources = [
Expand All @@ -128,13 +174,15 @@ def get_ext_modules():
if BUILD_CUDA:
print(f"Compiling custom CUDA kernels for torch-harmonics.")
disco_sources.extend([
"torch_harmonics/utils/csrc/csr_cuda.cu",
"torch_harmonics/disco/csrc/disco_cuda_fwd.cu",
"torch_harmonics/disco/csrc/disco_cuda_bwd.cu",
])
ext_modules.append(
CUDAExtension(
"torch_harmonics.disco._C",
disco_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("disco")
)
)
Expand All @@ -143,10 +191,10 @@ def get_ext_modules():
CppExtension(
"torch_harmonics.disco._C",
disco_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("disco")
)
)
cmdclass["build_ext"] = BuildExtension

# ATTENTION
# Create a single extension that includes both CPU and CUDA code
Expand All @@ -167,6 +215,7 @@ def get_ext_modules():
CUDAExtension(
"torch_harmonics.attention._C",
attention_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("attention")
)
)
Expand All @@ -175,9 +224,12 @@ def get_ext_modules():
CppExtension(
"torch_harmonics.attention._C",
attention_sources,
include_dirs=[os.path.join(setup_dir, "torch_harmonics/utils/csrc")],
extra_compile_args=get_compile_args("attention")
)
)

# set cmdclass
cmdclass["build_ext"] = BuildExtension

return ext_modules, cmdclass
Expand Down
11 changes: 6 additions & 5 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
# CPU results normalized to 16 OpenMP threads,
# GPU results normalized to V100 16 GB GPU
# this is just to detect performance regressions, not for absolute performance
_perf_test_thresholds = {"cpu": {"fwd_ms": 1000, "bwd_ms": 8000},
"cuda": {"fwd_ms": 50, "bwd_ms": 150}}
_perf_test_thresholds = {"cpu": {"fwd_ms": 800, "bwd_ms": 6000},
"cuda": {"fwd_ms": 10, "bwd_ms": 30}}
_run_perf_tests = (os.getenv("TORCH_HARMONICS_RUN_PERF_TESTS", "0") == "1")


Expand Down Expand Up @@ -326,12 +326,12 @@ def test_optimized_pt2_compatibility(self, batch_size, channels, heads, in_shape
@parameterized.expand(
[
# self attention
[1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", 1e-5, 1e-5],
[1, 256, 1, (91, 180), (91, 180), "equiangular", "equiangular", None],
],
skip_on_empty=True,
)
@unittest.skipUnless(optimized_kernels_is_available() and _run_perf_tests, "skipping performance test because optimized kernels are not available or perf tests are disabled")
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol, verbose=False):
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, theta_cutoff, verbose=False):

if (self.device.type == "cuda") and (not cuda_kernels_is_available()):
raise unittest.SkipTest("skipping test because CUDA kernels are not available")
Expand All @@ -350,7 +350,8 @@ def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, g

att_optimized = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True,
grid_in=grid_in, grid_out=grid_out,
theta_cutoff=theta_cutoff, bias=True,
optimized_kernel=True).to(self.device)

# random weights
Expand Down
41 changes: 40 additions & 1 deletion tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@
import math
import torch

from torch_harmonics.cache import lru_cache
from testutils import compare_tensors

class TestCacheConsistency(unittest.TestCase):

def test_consistency(self, verbose=False):
if verbose:
print("Testing that cache values does not get modified externally")
Expand All @@ -47,7 +50,43 @@ def test_consistency(self, verbose=False):
# perform in-place modification of leg1
leg1 *= -1.0
leg2 = _precompute_legpoly(10, 10, cost)
self.assertFalse(torch.allclose(leg1, leg2))
self.assertFalse(compare_tensors("legendre", leg2, leg1, verbose=verbose))


def test_pytorch_tensors(self, verbose=False):
if verbose:
print("Testing that PyTorch tensors are cached")

@lru_cache(typed=True, copy=True)
def test_func(tens1, tens2):
return tens1, tens2

# initial tensors
tens1 = torch.randn(4, 4, dtype=torch.float32)
tens2 = torch.randn(4, 4, dtype=torch.float32)

# retrieve from cache
tens1c, tens2c = test_func(tens1, tens2)

# modify copies
tens1c *= -1.0
tens2c *= -1.0

# retrieve from cache again
tens1cc, tens2cc = test_func(tens1, tens2)

if verbose:
print("first tensor", tens1)
print("first tensor after modification", tens1c)
print("first tensor cached", tens1cc)
print("second tensor", tens2)
print("second tensor after modification", tens2c)
print("second tensor cached", tens2cc)

self.assertFalse(compare_tensors("first cached", tens1cc, tens1c, verbose=verbose))
self.assertFalse(compare_tensors("second cached", tens2cc, tens2c, verbose=verbose))
self.assertTrue(compare_tensors("first raw", tens1, tens1cc, verbose=verbose))
self.assertTrue(compare_tensors("second raw", tens2, tens2cc, verbose=verbose))


if __name__ == "__main__":
Expand Down
Loading