From 2a6637a2d5406dfb9d67c16cb907b02c54f4d2ed Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Wed, 3 Dec 2025 12:36:32 -0800 Subject: [PATCH] Internal change. PiperOrigin-RevId: 839869322 --- .../orbax/checkpoint/_src/arrays/fragments.py | 119 +++++++++++------- .../checkpoint/_src/arrays/fragments_test.py | 94 ++++++++++++++ 2 files changed, 165 insertions(+), 48 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py index 573cd63f7..92df0b14c 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments.py @@ -142,6 +142,66 @@ def offset_by( out_idx[:, :2] += np.expand_dims(delta, axis=1) return type(self)(np_index=out_idx, value=self.value) + def intersect( + self, + np_index: NpIndex, # shape=[{rank}, 3], dtype=int + ) -> _GenericFragment[A] | None: + """Intersects this fragment with the given NpIndex. + + The result is in this fragment's coordinate space. For example, + intersecting a fragment with its own index gives an identical fragment. + + Args: + np_index: The NpIndex to intersect with. + + Returns: + A new fragment representing the intersection, or None if there is no + overlap. + """ + if (self.step != 1).any() or (np_index[:, 2] != 1).any(): + raise NotImplementedError('index steps other than 1 are not supported.') + + out_np_index = np_index.copy() + start = out_np_index[:, 0] = np.maximum(out_np_index[:, 0], self.start) + stop = out_np_index[:, 1] = np.minimum(out_np_index[:, 1], self.stop) + if not (start < stop).all(): + return None + return type(self)( + np_index=out_np_index, value=self.slice_of_value(out_np_index) + ) + + def slice( + self, + np_index: NpIndex, # shape=[{rank}, 3], dtype=int + ) -> _GenericFragment[A] | None: # Use typing.Self once 3.11 is minimum. + """Slices this fragment by the given NpIndex. + + The result is in the slice's coordinate space. For example, slicing a + fragment by its own index gives a fragment whose start is zero. + + Args: + np_index: The NpIndex to slice by. + + Returns: + A new fragment representing the slice, or None if there is no overlap. + """ + intersection = self.intersect(np_index) + return intersection.offset_by(-np_index[:, 0]) if intersection else None + + def slice_of_value(self, np_index: NpIndex) -> A: + """Takes a slice of the value of this fragment. + + It is required that `np_index` has already been clamped to the fragment's + bounds; otherwise a ValueError will result. + + Args: + np_index: The NpIndex to slice by. + + Returns: + A slice of the fragment's value. + """ + raise NotImplementedError() + @dataclasses.dataclass(frozen=True, init=False, eq=False, repr=False) class AbstractFragment(_GenericFragment[type(None)]): @@ -178,21 +238,9 @@ def offset_by( out_idx[:, :2] += np.expand_dims(delta, axis=1) return type(self)(np_index=out_idx) - def slice( - self, - np_index: NpIndex, # shape=[{rank}, 3], dtype=int - ) -> AbstractFragment | None: # Use typing.Self once 3.11 is minimum. - """Slices this fragment to find the part that overlaps the given NpIndex.""" - if (self.step != 1).any() or (np_index[:, 2] != 1).any(): - raise NotImplementedError('Coming ... soon?') - - slice_shape = np_index[:, 1] - np_index[:, 0] - out = self.offset_by(-np_index[:, 0]) - start = out.start[:] = np.maximum(out.start, 0) - stop = out.stop[:] = np.minimum(out.stop, slice_shape) - if not (start < stop).all(): - return None - return out + def slice_of_value(self, np_index: NpIndex) -> None: + del np_index + return None @dataclasses.dataclass(frozen=True, init=False) @@ -230,39 +278,14 @@ def __array__(self) -> np.ndarray: def nbytes(self) -> int: return self.value.nbytes - def slice( - self, - np_index: NpIndex, # shape=[{rank}, 3], dtype=int - ) -> _ConcreteFragment | None: # Use typing.Self once 3.11 is minimum. - """Slices this fragment to find the part that overlaps the given NpIndex.""" - if (self.step != 1).any() or (np_index[:, 2] != 1).any(): - raise NotImplementedError('Coming ... soon?') - - slice_shape = np_index[:, 1] - np_index[:, 0] - out = self.offset_by(-np_index[:, 0]) - start = out.start[:] = np.maximum(out.start, 0) - stop = out.stop[:] = np.minimum(out.stop, slice_shape) - if not (start < stop).all(): - return None - return type(self)( - np_index=out.np_index, value=self.slice_of_value(np_index) - ) - - def slice_of_value( - self, - new_np_idx: NpIndex, - ) -> A: - """Returns a slice of `value`.""" - start = self.start - stop = self.stop + def slice_of_value(self, np_index: NpIndex) -> Aconcrete: # This is just a convenient way to construct the required tuple of slices. - f = AbstractFragment( - np_index=np.stack([ - np.maximum(start, new_np_idx[:, 0]), - np.minimum(stop, new_np_idx[:, 1]), - new_np_idx[:, 2], - ], axis=1) - ).offset_by(-start) + f = AbstractFragment(np_index=np_index).offset_by(-self.start) + if (f.start < 0).any() or (f.stop > self.value.shape).any(): + raise ValueError( + f'Attempt to slice fragment value of shape {self.shape} with' + f' out-of-bounds index {f}' + ) return self.value[f.index or ...] @@ -353,7 +376,7 @@ def __array__(self) -> np.ndarray: def slice( self, index: NpIndex | Index, # shape=[{rank}, 3], dtype=int - ) -> '_GenericFragments[F]': # Use typing.Self once 3.11 is minimum. + ) -> '_GenericFragments[_GenericFragment[A]]': # Use typing.Self once >=3.11. """Returns a slice of this object.""" if not isinstance(index, np.ndarray): index = np_utils.resolve_slice(index, self.shape) diff --git a/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py b/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py index 040840724..f4ef7ce51 100644 --- a/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py +++ b/checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py @@ -214,6 +214,66 @@ def test_nbytes_astype_of_abstract_fragment_uses_given_dtype(self): ).nbytes_astype(np.dtype(jax.numpy.bfloat16)), ) + @parameterized.named_parameters( + ('np_fragment', NpFragment), + ('jax_fragment', JaxFragment), + ) + def test_intersect( + self, + fragment_t: ConcreteFragmentT, + ): + np_api = fragment_t.NP_API + full_value = np_api.arange(8 * 9).reshape((8, 9)) + fragment_index = np.s_[4:8:1, 3:9:1] + + f = fragment_t(index=fragment_index, value=full_value[fragment_index]) + + with self.subTest('fully_within_fragment_index'): + bounds = np.s_[5:7:1, 4:8:1] + s = f.intersect(array_fragments._ndarray_from_index(bounds)) + self.assertEqual( + fragment_t(index=np.s_[5:7:1, 4:8:1], value=full_value[bounds]), + s, + ) + + with self.subTest('fully_enclosing_fragment_index'): + bounds = np.s_[2:10:1, 1:11:1] + s = f.intersect(array_fragments._ndarray_from_index(bounds)) + self.assertEqual(fragment_t(index=np.s_[4:8:1, 3:9:1], value=f.value), s) + + with self.subTest('spanning_fragment_start'): + bounds = np.s_[2:6:1, 2:4:1] + s = f.intersect(array_fragments._ndarray_from_index(bounds)) + self.assertEqual( + fragment_t(index=np.s_[4:6:1, 3:4:1], value=f.value[:2, :1]), s + ) + + with self.subTest('spanning_fragment_stop'): + bounds = np.s_[6:10:1, 6:10:1] + s = f.intersect(array_fragments._ndarray_from_index(bounds)) + self.assertEqual( + fragment_t(index=np.s_[6:8:1, 6:9:1], value=f.value[2:, 3:]), s + ) + + with self.subTest('with_no_overlap'): + self.assertIsNone( + f.intersect( + array_fragments._ndarray_from_index(np.s_[10:12:1, 10:12:1]) + ) + ) + # This is within the bounds of the fragment but spans no elements. + self.assertIsNone( + f.intersect(array_fragments._ndarray_from_index(np.s_[6:6:1, 3:9:1])) + ) + + with self.subTest('rank_0'): + s = fragment_t(index=(), value=np_api.ones([])).intersect( + np.zeros([0, 3], dtype=int) + ) + self.assertIsNotNone(s) + self.assertEqual((), s.index) + self.assertIsInstance(s.value, np_api.ndarray) + @parameterized.named_parameters( ('np_fragment', NpFragment), ('jax_fragment', JaxFragment), @@ -272,6 +332,40 @@ def test_slice( self.assertEqual((), s.index) self.assertIsInstance(s.value, np_api.ndarray) + @parameterized.named_parameters( + ('np_fragment', NpFragment), + ('jax_fragment', JaxFragment), + ) + def test_slice_of_value( + self, + fragment_t: ConcreteFragmentT, + ): + np_api = fragment_t.NP_API + full_value = np_api.arange(8 * 9).reshape((8, 9)) + fragment_index = np.s_[4:8:1, 3:9:1] + fragment = fragment_t( + index=fragment_index, value=full_value[fragment_index] + ) + + with self.subTest('returns_slice_of_value'): + np.testing.assert_array_equal( + full_value[np.s_[5:7:1, 4:8:1]], + fragment.slice_of_value( + array_fragments._ndarray_from_index(np.s_[5:7:1, 4:8:1]) + ), + ) + + with self.subTest('raises_if_slice_is_out_of_bounds'): + with self.assertRaises(ValueError): + fragment.slice_of_value( + array_fragments._ndarray_from_index(np.s_[2:6:1, 3:9:1]) + ) + + with self.assertRaises(ValueError): + fragment.slice_of_value( + array_fragments._ndarray_from_index(np.s_[4:8:1, 8:12:1]) + ) + @parameterized.named_parameters( ('abstract_fragments', AbstractFragments),