Skip to content

Commit

Permalink
fix introduced bug and add corresponding test
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 22, 2025
1 parent 0df8fe0 commit f569e9f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
37 changes: 28 additions & 9 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __eq__(self, other: Any) -> bool:
return isinstance(other, InBuffer) and int(self) == int(other)

def __hash__(self) -> int:
return hash(("in", int(self)))


class OutBuffer(Buffer):
pass
def __eq__(self, other: Any) -> bool:
return isinstance(other, OutBuffer) and int(self) == int(other)

def __hash__(self) -> int:
return hash(("out", int(self)))


T = TypeVar("T")
Expand All @@ -44,9 +52,20 @@ def __init__(self, buffers: Sequence[Buffer]):
assert all(isinstance(b, Buffer) for b in self.buffers), self.buffers
assert sum(isinstance(b, OutBuffer) for b in self.buffers) == 1, self.buffers

def __eq__(self, other: Any) -> bool:
return isinstance(other, Computation) and self.buffers == other.buffers

def __hash__(self) -> int:
return hash(self.buffers)

def __repr__(self):
IVARS = "abcdefghijklmnopqrstuvwxyz"
OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

return " ".join(
IVARS[b] if isinstance(b, InBuffer) else OVARS[b] for b in self.buffers
)

@property
def num_operands(self) -> int:
return len(self.buffers)
Expand Down Expand Up @@ -103,6 +122,12 @@ class TensorProductExecution:
def __init__(self, computations: tuple[Computation, ...]):
self.computations = tuple(Computation(c) for c in computations)

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, TensorProductExecution)
and self.computations == other.computations
)

def __hash__(self) -> int:
return hash(self.computations)

Expand All @@ -117,13 +142,7 @@ def __repr__(self):
)
]
for comp in self.computations:
text += [
" "
+ " ".join(
IVARS[b] if isinstance(b, InBuffer) else OVARS[b]
for b in comp.buffers
)
]
text += [f" {comp}"]
return "\n".join(text)

@property
Expand Down
40 changes: 40 additions & 0 deletions cuequivariance/tests/tensor_product_execution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer, OutBuffer


def test_group_by_symmetries():
# x^3
exe = cue.TensorProductExecution(
[(InBuffer(0), InBuffer(0), InBuffer(0), OutBuffer(0))]
)
mul, exe = next(
exe.jvp([True]).group_by_symmetries(
[
(0, 1, 2, 3),
(0, 2, 1, 3),
(1, 0, 2, 3),
(1, 2, 0, 3),
(2, 0, 1, 3),
(2, 1, 0, 3),
]
)
)
# d/dx (x^3) = 3x^2
assert mul == 3
assert exe == cue.TensorProductExecution(
[(InBuffer(1), InBuffer(0), InBuffer(0), OutBuffer(0))]
)

0 comments on commit f569e9f

Please sign in to comment.