Skip to content

Commit

Permalink
Add test cases for lax keep unit methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 13, 2025
1 parent bc87b02 commit ed69c29
Show file tree
Hide file tree
Showing 4 changed files with 781 additions and 14 deletions.
13 changes: 8 additions & 5 deletions brainunit/lax/_lax_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import lax
from jax._src.typing import Shape

from .._base import Quantity, maybe_decimal
from .._base import Quantity, maybe_decimal, has_same_unit
from .._misc import set_module_as
from ..math._fun_keep_unit import _fun_keep_unit_unary, _fun_keep_unit_binary

Expand Down Expand Up @@ -289,7 +289,7 @@ def gather(
"""
if isinstance(operand, Quantity) and isinstance(fill_value, Quantity):
return maybe_decimal(
Quantity(lax.gather(operand.value, start_indices, dimension_numbers, slice_sizes,
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value.in_unit(operand.unit)),
unit=operand.unit)
Expand All @@ -298,7 +298,7 @@ def gather(
if fill_value is not None:
raise ValueError('fill_value must be a Quantity if operand is a Quantity')
return maybe_decimal(
Quantity(lax.gather(operand.value, start_indices, dimension_numbers, slice_sizes,
Quantity(lax.gather(operand.mantissa, start_indices, dimension_numbers, slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode), unit=operand.unit)
)
Expand Down Expand Up @@ -738,11 +738,14 @@ def _fun_lax_scatter(
unique_indices,
mode
) -> Union[Quantity, jax.Array]:
if isinstance(operand, Quantity):
if isinstance(operand, Quantity) and isinstance(updates, Quantity):
assert has_same_unit(operand, updates), f'operand(unit:{operand.unit}) and updates(unit:{updates.unit}) do not have same unit'
return maybe_decimal(Quantity(fun(operand.mantissa, scatter_indices, updates.mantissa, dimension_numbers,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode), unit=operand.unit))
elif isinstance(operand, Quantity) or isinstance(updates, Quantity):
raise AssertionError(f'operand and updates should both be `Quantity` or Array, now we got {type(operand)} and {type(updates)}')
else:
return fun(operand, scatter_indices, updates, dimension_numbers,
indices_are_sorted=indices_are_sorted,
Expand Down Expand Up @@ -1271,7 +1274,7 @@ def clamp(
(min, x, max)):
return lax.clamp(min, x, max)
else:
raise ValueError('All inputs must be Quantity or jax.typing.ArrayLike')
raise AssertionError('All inputs must be Quantity or jax.typing.ArrayLike')


# math funcs keep unit (return Quantity and index)
Expand Down
Loading

0 comments on commit ed69c29

Please sign in to comment.