Skip to content

Commit

Permalink
[lax] Add more lax test cases (#93)
Browse files Browse the repository at this point in the history
* Add test cases for lax keep unit methods

* Add test cases for lax linalg methods

* Add test cases for lax misc methods

* Update _lax_keep_unit_test.py

* reformat

* Update _lax_keep_unit_test.py

* Update _lax_keep_unit_test.py

* Update _lax_keep_unit_test.py

* Update _lax_linalg.py
  • Loading branch information
Routhleck authored Jan 14, 2025
1 parent b5cc99c commit 4423578
Show file tree
Hide file tree
Showing 11 changed files with 1,095 additions and 70 deletions.
2 changes: 1 addition & 1 deletion brainunit/lax/_lax_accept_unitless_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,4 @@ def test_lax_collapse(self, value):
result = bulax_fun(q1, value2)

with pytest.raises(bu.UnitMismatchError):
result = bulax_fun(q1, value2, unit_to_scale=bu.second)
result = bulax_fun(q1, value2, unit_to_scale=bu.second)
9 changes: 4 additions & 5 deletions brainunit/lax/_lax_array_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@
# ==============================================================================


import jax.numpy as jnp
import jax.lax as lax
import pytest
import jax.numpy as jnp
from absl.testing import parameterized

import brainunit as bu
import brainunit.lax as bulax
from brainunit import meter, second
from brainunit._base import assert_quantity

lax_array_creation_given_array = [
'zeros_like_array',
]
]

lax_array_creation_misc = [
'iota', 'broadcasted_iota',
]


class TestLaxArrayCreation(parameterized.TestCase):
def __init__(self, *args, **kwargs):
super(TestLaxArrayCreation, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -95,4 +94,4 @@ def test_lax_array_creation_broadcasted_iota(self, shape, unit):

result = bulax_fun(float, shape, dimension, unit=unit)
expected = lax_fun(float, shape, dimension)
assert_quantity(result, expected, unit=unit)
assert_quantity(result, expected, unit=unit)
24 changes: 12 additions & 12 deletions brainunit/lax/_lax_change_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ def test_lax_change_unit_conv(self, shape, window_strides, padding):
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))

@parameterized.product(
shapes = [
shapes=[
dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape)
for lhs_shape, rhs_shape in [
((b, 10, i), (k, i, j))
for b, i, j, k in itertools.product(
[2, 3], [2, 3], [2, 3], [3,]
)
]
((b, 10, i), (k, i, j))
for b, i, j, k in itertools.product(
[2, 3], [2, 3], [2, 3], [3, ]
)
]
],
strides=[(1,), (2,)],
padding=["VALID", "SAME"],
Expand Down Expand Up @@ -251,17 +251,17 @@ def test_lax_change_unit_conv_transpose(self, shapes, strides, padding):
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))

@parameterized.product(
shapes = [
shapes=[
dict(
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dimension_numbers=dimension_numbers,
)
for lhs_shape, rhs_shape, dimension_numbers in [
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
]
],
)
def test_lax_change_unit_dot_general(self, shapes):
Expand All @@ -286,4 +286,4 @@ def test_lax_change_unit_dot_general(self, shapes):
q2 = rhs * meter
result = bulax_fun(q1, q2, dimension_numbers=dimension_numbers)
expected = lax_fun(jnp.array(lhs), jnp.array(rhs), dimension_numbers=dimension_numbers)
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))
assert_quantity(result, expected, unit=bulax_fun._unit_change_fun(bu.get_unit(q1), bu.get_unit(q2)))
15 changes: 10 additions & 5 deletions brainunit/lax/_lax_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import lax
from jax._src.typing import Shape

from .._base import Quantity, maybe_decimal
from .._base import Quantity, maybe_decimal, has_same_unit
from .._misc import set_module_as
from ..math._fun_keep_unit import _fun_keep_unit_unary, _fun_keep_unit_binary

Expand Down Expand Up @@ -289,7 +289,7 @@ def gather(
"""
if isinstance(operand, Quantity) and isinstance(fill_value, Quantity):
return maybe_decimal(
Quantity(lax.gather(operand.value, start_indices, dimension_numbers, slice_sizes,
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value.in_unit(operand.unit)),
unit=operand.unit)
Expand All @@ -298,7 +298,7 @@ def gather(
if fill_value is not None:
raise ValueError('fill_value must be a Quantity if operand is a Quantity')
return maybe_decimal(
Quantity(lax.gather(operand.value, start_indices, dimension_numbers, slice_sizes,
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode), unit=operand.unit)
)
Expand Down Expand Up @@ -738,11 +738,16 @@ def _fun_lax_scatter(
unique_indices,
mode
) -> Union[Quantity, jax.Array]:
if isinstance(operand, Quantity):
if isinstance(operand, Quantity) and isinstance(updates, Quantity):
assert has_same_unit(operand,
updates), f'operand(unit:{operand.unit}) and updates(unit:{updates.unit}) do not have same unit'
return maybe_decimal(Quantity(fun(operand.mantissa, scatter_indices, updates.mantissa, dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode), unit=operand.unit))
elif isinstance(operand, Quantity) or isinstance(updates, Quantity):
raise AssertionError(
f'operand and updates should both be `Quantity` or Array, now we got {type(operand)} and {type(updates)}')
else:
return fun(operand, scatter_indices, updates, dimension_numbers,
indices_are_sorted=indices_are_sorted,
Expand Down Expand Up @@ -1271,7 +1276,7 @@ def clamp(
(min, x, max)):
return lax.clamp(min, x, max)
else:
raise ValueError('All inputs must be Quantity or jax.typing.ArrayLike')
raise AssertionError('All inputs must be Quantity or jax.typing.ArrayLike')


# math funcs keep unit (return Quantity and index)
Expand Down
Loading

0 comments on commit 4423578

Please sign in to comment.