Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Clean class Computation to simplify potential changes in the future #69

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
97 changes: 66 additions & 31 deletions cuequivariance/cuequivariance/tensor_product_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,71 @@ class Buffer(int):


class InBuffer(Buffer):
pass
def __eq__(self, other: Any) -> bool:
return isinstance(other, InBuffer) and int(self) == int(other)

def __hash__(self) -> int:
return hash(("in", int(self)))


class OutBuffer(Buffer):
pass
def __eq__(self, other: Any) -> bool:
return isinstance(other, OutBuffer) and int(self) == int(other)

def __hash__(self) -> int:
return hash(("out", int(self)))


T = TypeVar("T")


class Computation(tuple):
def __new__(cls, elements):
elements = list(elements)
assert all(isinstance(b, Buffer) for b in elements), elements
assert sum(isinstance(b, OutBuffer) for b in elements) == 1, elements
return super().__new__(cls, elements)
class Computation:
buffers: tuple[Buffer, ...] # one buffer per operand

def __init__(self, buffers: Sequence[Buffer]):
if isinstance(buffers, Computation):
buffers = buffers.buffers
self.buffers = tuple(buffers)
assert all(isinstance(b, Buffer) for b in self.buffers), self.buffers
assert sum(isinstance(b, OutBuffer) for b in self.buffers) == 1, self.buffers

def __eq__(self, other: Any) -> bool:
return isinstance(other, Computation) and self.buffers == other.buffers

def __hash__(self) -> int:
return hash(self.buffers)

def __repr__(self):
IVARS = "abcdefghijklmnopqrstuvwxyz"
OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

return " ".join(
IVARS[b] if isinstance(b, InBuffer) else OVARS[b] for b in self.buffers
)

@property
def num_operands(self) -> int:
return len(self)
return len(self.buffers)

@property
def in_buffers(self) -> tuple[InBuffer, ...]:
return tuple(b for b in self if isinstance(b, InBuffer))
return tuple(b for b in self.buffers if isinstance(b, InBuffer))

@property
def out_buffer(self) -> OutBuffer:
return next(b for b in self if isinstance(b, OutBuffer))
return next(b for b in self.buffers if isinstance(b, OutBuffer))

@property
def in_operands(self) -> tuple[int, ...]:
return tuple(oid for oid, b in enumerate(self) if isinstance(b, InBuffer))
return tuple(
oid for oid, b in enumerate(self.buffers) if isinstance(b, InBuffer)
)

@property
def out_operand(self) -> int:
return next(oid for oid, b in enumerate(self) if isinstance(b, OutBuffer))
return next(
oid for oid, b in enumerate(self.buffers) if isinstance(b, OutBuffer)
)

def map_operands(
self,
Expand All @@ -68,12 +97,14 @@ def map_operands(
) -> list[Optional[T]]:
in_buffers = list(in_buffers)
if out_buffers is None:
return [in_buffers[b] if isinstance(b, InBuffer) else None for b in self]
return [
in_buffers[b] if isinstance(b, InBuffer) else None for b in self.buffers
]
else:
out_buffers = list(out_buffers)
return [
in_buffers[b] if isinstance(b, InBuffer) else out_buffers[b]
for b in self
for b in self.buffers
]

def map_inputs(
Expand All @@ -91,6 +122,12 @@ class TensorProductExecution:
def __init__(self, computations: tuple[Computation, ...]):
self.computations = tuple(Computation(c) for c in computations)

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, TensorProductExecution)
and self.computations == other.computations
)

def __hash__(self) -> int:
return hash(self.computations)

Expand All @@ -105,12 +142,7 @@ def __repr__(self):
)
]
for comp in self.computations:
text += [
" "
+ " ".join(
IVARS[b] if isinstance(b, InBuffer) else OVARS[b] for b in comp
)
]
text += [f" {comp}"]
return "\n".join(text)

@property
Expand All @@ -121,7 +153,7 @@ def is_trivial(self) -> bool:
def num_operands(self) -> int:
assert not self.is_trivial
for c in self.computations:
return len(c)
return c.num_operands

@property
def in_buffers(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -182,7 +214,7 @@ def map_buffers(
if isinstance(b, InBuffer)
else OutBuffer(int(f_out(b)))
)
for b in comp
for b in comp.buffers
)
for comp in self.computations
)
Expand Down Expand Up @@ -215,7 +247,7 @@ def jvp(self, has_tangent: list[bool]) -> "TensorProductExecution":
if bid is None:
continue # the tangent is zero

c = list(computation)
c = list(computation.buffers)
c[oid] = InBuffer(bid)
new_computations.append(Computation(c))

Expand Down Expand Up @@ -257,18 +289,18 @@ def transpose(
continue # cotangent is zero

for oid in comp.in_operands:
if not is_undefined_primal[comp[oid]]:
if not is_undefined_primal[comp.buffers[oid]]:
continue # nothing to transpose

c = [None] * len(comp)
c = [None] * comp.num_operands
# undefined primal -> output
c[oid] = OutBuffer(primals_new_bid[comp[oid]])
c[oid] = OutBuffer(primals_new_bid[comp.buffers[oid]])
# output -> cotangent input
c[comp.out_operand] = InBuffer(cotangents_new_bid[comp.out_buffer])
# rest of inputs
for i in range(comp.num_operands):
if i != oid and i != comp.out_operand:
c[i] = InBuffer(primals_new_bid[comp[i]])
c[i] = InBuffer(primals_new_bid[comp.buffers[i]])

new_computations.append(Computation(c))

Expand All @@ -289,8 +321,11 @@ def group_by_symmetries(
for c in self.computations:
found_bucket = False
for bucket in buckets:
rep = bucket[0]
if any(Computation(rep[p] for p in perm) == c for perm in permutations):
rep: Computation = bucket[0]
if any(
Computation(rep.buffers[p] for p in perm) == c
for perm in permutations
):
bucket.append(c)
found_bucket = True
break
Expand All @@ -311,7 +346,7 @@ def group_by_identical_buffers(

def partition(computation: Computation) -> list[list[int]]:
bid_to_oid = defaultdict(list)
for oid, b in enumerate(computation):
for oid, b in enumerate(computation.buffers):
b = (type(b), b)
bid_to_oid[b].append(oid)
return sorted(map(sorted, bid_to_oid.values()))
Expand Down
40 changes: 40 additions & 0 deletions cuequivariance/tests/tensor_product_execution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cuequivariance as cue
from cuequivariance.tensor_product_execution import InBuffer, OutBuffer


def test_group_by_symmetries():
# x^3
exe = cue.TensorProductExecution(
[(InBuffer(0), InBuffer(0), InBuffer(0), OutBuffer(0))]
)
mul, exe = next(
exe.jvp([True]).group_by_symmetries(
[
(0, 1, 2, 3),
(0, 2, 1, 3),
(1, 0, 2, 3),
(1, 2, 0, 3),
(2, 0, 1, 3),
(2, 1, 0, 3),
]
)
)
# d/dx (x^3) = 3x^2
assert mul == 3
assert exe == cue.TensorProductExecution(
[(InBuffer(1), InBuffer(0), InBuffer(0), OutBuffer(0))]
)
Loading
Loading