Skip to content

Commit

Permalink
Merge branch 'routhleck-patch-250114' into lax-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 14, 2025
2 parents 0423d65 + 87ffd83 commit 7074057
Showing 1 changed file with 125 additions and 121 deletions.
246 changes: 125 additions & 121 deletions brainunit/lax/_lax_keep_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,123 +135,61 @@ def test_dynamic_update_slice(self, shape, indices, update_shape):
assert_quantity(result_q, expected, u.second)


# @parameterized.product(
# [dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
# for shape, idxs, dnums, slice_sizes in [
# ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
# offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
# (1,)),
# ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
# offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
# (2,)),
# ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
# offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
# (1, 3)),
# ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
# offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
# (1, 3)),
# ((2, 5), np.array([[[0], [2]], [[1], [1]]]),
# lax.GatherDimensionNumbers(
# offset_dims=(), collapsed_slice_dims=(1,),
# start_index_map=(1,), operand_batching_dims=(0,),
# start_indices_batching_dims=(0,)),
# (1, 1)),
# ((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
# lax.GatherDimensionNumbers(
# offset_dims=(2,), collapsed_slice_dims=(),
# start_index_map=(2,), operand_batching_dims=(0, 1),
# start_indices_batching_dims=(1, 0)),
# (1, 1, 3))
# ]],
# )
# @unittest.skipIf(sys.version_info < (3, 10), "JAX now do not support the python version below 3.10")
def test_gather(self):
if sys.version_info < (3, 10):
return
test_cases = [
dict(
shape=(5,),
idxs=np.array([[0], [2]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(0,),
start_index_map=(0,)
),
slice_sizes=(1,)
),
dict(
shape=(10,),
idxs=np.array([[0], [0], [0]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(),
start_index_map=(0,)
),
slice_sizes=(2,)
),
dict(
shape=(10, 5),
idxs=np.array([[0], [2], [1]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(0,),
start_index_map=(0,)
),
slice_sizes=(1, 3)
),
dict(
shape=(10, 5),
idxs=np.array([[0, 2], [1, 0]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(1,),
collapsed_slice_dims=(0,),
start_index_map=(0, 1)
),
slice_sizes=(1, 3)
),
dict(
shape=(2, 5),
idxs=np.array([[[0], [2]], [[1], [1]]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(),
collapsed_slice_dims=(1,),
start_index_map=(1,),
operand_batching_dims=(0,),
start_indices_batching_dims=(0,)
),
slice_sizes=(1, 1)
),
dict(
shape=(2, 3, 10),
idxs=np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
dnums=lax.GatherDimensionNumbers(
offset_dims=(2,),
collapsed_slice_dims=(),
start_index_map=(2,),
operand_batching_dims=(0, 1),
start_indices_batching_dims=(1, 0)
),
slice_sizes=(1, 1, 3)
)
]

for case in test_cases:
with self.subTest(**case):
shape = case["shape"]
idxs = case["idxs"]
dnums = case["dnums"]
slice_sizes = case["slice_sizes"]

rand_idxs = bst.random.randint(0., high=max(shape), size=idxs.shape)
array = bst.random.random(shape)
@parameterized.product(
[dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
for shape, idxs, dnums, slice_sizes in [
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
((2, 5), np.array([[[0], [2]], [[1], [1]]]),
lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(1,),
start_index_map=(1,), operand_batching_dims=(0,),
start_indices_batching_dims=(0,)),
(1, 1)),
((2, 3, 10), np.array([[[0], [1]], [[2], [3]], [[4], [5]]]),
lax.GatherDimensionNumbers(
offset_dims=(2,), collapsed_slice_dims=(),
start_index_map=(2,), operand_batching_dims=(0, 1),
start_indices_batching_dims=(1, 0)),
(1, 1, 3))
]] if sys.version_info >= (3, 10) else [
dict(shape=shape, idxs=idxs, dnums=dnums, slice_sizes=slice_sizes)
for shape, idxs, dnums, slice_sizes in [
((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers(
offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1,)),
((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
(2,)),
((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
(1, 3)),
((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(
offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1)),
(1, 3)),
]],
)
def test_gather(self, shape, idxs, dnums, slice_sizes):
rand_idxs = bst.random.randint(0., high=max(shape), size=idxs.shape)
array = bst.random.random(shape)

result = ulax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
expected = lax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
self.assertTrue(jnp.all(result == expected))
result = ulax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
expected = lax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
self.assertTrue(jnp.all(result == expected))

array = array * u.second
result_q = ulax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
assert_quantity(result_q, expected, u.second)
array = array * u.second
result_q = ulax.gather(array, rand_idxs, dimension_numbers=dnums, slice_sizes=slice_sizes)
assert_quantity(result_q, expected, u.second)

@parameterized.product(
[dict(shape=shape, idxs=idxs, axes=axes)
Expand Down Expand Up @@ -414,7 +352,21 @@ def test_lax_keep_unit_math_binary(self, value, unit):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -456,9 +408,23 @@ def test_scatter(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]],
mode=["clip", "fill", None],
op=['scatter_add', 'scatter_sub'],
op=['scatter_add', 'scatter_sub'] if sys.version_info >= (3, 10) else ['scatter_add'],
)
def test_scatter_add_sub(self, arg_shape, idxs, update_shape, dnums, mode, op):
ulax_op = getattr(ulax, op)
Expand Down Expand Up @@ -507,7 +473,19 @@ def test_scatter_mul(self):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_min(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -549,7 +527,20 @@ def test_scatter_min(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_max(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down Expand Up @@ -591,7 +582,20 @@ def test_scatter_max(self, arg_shape, idxs, update_shape, dnums):
update_window_dims=(2,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(2,), operand_batching_dims=(0, 1),
scatter_indices_batching_dims=(1, 0)))
]],
]] if sys.version_info >= (3, 10) else [
dict(arg_shape=arg_shape, idxs=idxs, update_shape=update_shape,
dnums=dnums)
for arg_shape, idxs, update_shape, dnums in [
((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
update_window_dims=(), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(),
scatter_dims_to_operand_dims=(0,))),
((10, 5), np.array([[0], [2], [1]], dtype=np.uint64), (3, 3), lax.ScatterDimensionNumbers(
update_window_dims=(1,), inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,))),
]]
)
def test_scatter_apply(self, arg_shape, idxs, update_shape, dnums):
array = bst.random.random(arg_shape)
Expand Down

0 comments on commit 7074057

Please sign in to comment.