From d0bbc88319bdb4fcaf69c27ccf7ffcfa7f8c32fd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 18 Dec 2024 10:13:17 +0000 Subject: [PATCH] Reuse alternate index syntax --- src/array_api_extra/_funcs.py | 3 +-- tests/test_at.py | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_funcs.py b/src/array_api_extra/_funcs.py index 092d0ea..887bebe 100644 --- a/src/array_api_extra/_funcs.py +++ b/src/array_api_extra/_funcs.py @@ -672,8 +672,7 @@ def __getitem__(self, idx: Index, /) -> "at": # numpydoc ignore=PR01,RT01 if self._idx is not _undef: msg = "Index has already been set" raise ValueError(msg) - self._idx = idx - return self + return at(self._x, idx) def _update_common( self, diff --git a/tests/test_at.py b/tests/test_at.py index 87d7555..c5ce3f4 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -127,6 +127,11 @@ def test_alternate_index_syntax(): a = np.asarray([1, 2, 3]) assert_array_equal(at(a, 0).set(4), [4, 2, 3]) assert_array_equal(at(a)[0].set(4), [4, 2, 3]) + + a_at = at(a) + assert_array_equal(a_at[0].add(1), [2, 2, 3]) + assert_array_equal(a_at[1].add(2), [1, 4, 3]) + with pytest.raises(ValueError, match="Index"): at(a).set(4) with pytest.raises(ValueError, match="Index"):