Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 71 additions & 48 deletions checkpoint/orbax/checkpoint/_src/arrays/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,66 @@ def offset_by(
out_idx[:, :2] += np.expand_dims(delta, axis=1)
return type(self)(np_index=out_idx, value=self.value)

def intersect(
self,
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
) -> _GenericFragment[A] | None:
"""Intersects this fragment with the given NpIndex.

The result is in this fragment's coordinate space. For example,
intersecting a fragment with its own index gives an identical fragment.

Args:
np_index: The NpIndex to intersect with.

Returns:
A new fragment representing the intersection, or None if there is no
overlap.
"""
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
raise NotImplementedError('index steps other than 1 are not supported.')

out_np_index = np_index.copy()
start = out_np_index[:, 0] = np.maximum(out_np_index[:, 0], self.start)
stop = out_np_index[:, 1] = np.minimum(out_np_index[:, 1], self.stop)
if not (start < stop).all():
return None
return type(self)(
np_index=out_np_index, value=self.slice_of_value(out_np_index)
)

def slice(
self,
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
) -> _GenericFragment[A] | None: # Use typing.Self once 3.11 is minimum.
"""Slices this fragment by the given NpIndex.

The result is in the slice's coordinate space. For example, slicing a
fragment by its own index gives a fragment whose start is zero.

Args:
np_index: The NpIndex to slice by.

Returns:
A new fragment representing the slice, or None if there is no overlap.
"""
intersection = self.intersect(np_index)
return intersection.offset_by(-np_index[:, 0]) if intersection else None

def slice_of_value(self, np_index: NpIndex) -> A:
"""Takes a slice of the value of this fragment.

It is required that `np_index` has already been clamped to the fragment's
bounds; otherwise a ValueError will result.

Args:
np_index: The NpIndex to slice by.

Returns:
A slice of the fragment's value.
"""
raise NotImplementedError()


@dataclasses.dataclass(frozen=True, init=False, eq=False, repr=False)
class AbstractFragment(_GenericFragment[type(None)]):
Expand Down Expand Up @@ -178,21 +238,9 @@ def offset_by(
out_idx[:, :2] += np.expand_dims(delta, axis=1)
return type(self)(np_index=out_idx)

def slice(
self,
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
) -> AbstractFragment | None: # Use typing.Self once 3.11 is minimum.
"""Slices this fragment to find the part that overlaps the given NpIndex."""
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
raise NotImplementedError('Coming ... soon?')

slice_shape = np_index[:, 1] - np_index[:, 0]
out = self.offset_by(-np_index[:, 0])
start = out.start[:] = np.maximum(out.start, 0)
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
if not (start < stop).all():
return None
return out
def slice_of_value(self, np_index: NpIndex) -> None:
del np_index
return None


@dataclasses.dataclass(frozen=True, init=False)
Expand Down Expand Up @@ -230,39 +278,14 @@ def __array__(self) -> np.ndarray:
def nbytes(self) -> int:
return self.value.nbytes

def slice(
self,
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
) -> _ConcreteFragment | None: # Use typing.Self once 3.11 is minimum.
"""Slices this fragment to find the part that overlaps the given NpIndex."""
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
raise NotImplementedError('Coming ... soon?')

slice_shape = np_index[:, 1] - np_index[:, 0]
out = self.offset_by(-np_index[:, 0])
start = out.start[:] = np.maximum(out.start, 0)
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
if not (start < stop).all():
return None
return type(self)(
np_index=out.np_index, value=self.slice_of_value(np_index)
)

def slice_of_value(
self,
new_np_idx: NpIndex,
) -> A:
"""Returns a slice of `value`."""
start = self.start
stop = self.stop
def slice_of_value(self, np_index: NpIndex) -> Aconcrete:
# This is just a convenient way to construct the required tuple of slices.
f = AbstractFragment(
np_index=np.stack([
np.maximum(start, new_np_idx[:, 0]),
np.minimum(stop, new_np_idx[:, 1]),
new_np_idx[:, 2],
], axis=1)
).offset_by(-start)
f = AbstractFragment(np_index=np_index).offset_by(-self.start)
if (f.start < 0).any() or (f.stop > self.value.shape).any():
raise ValueError(
f'Attempt to slice fragment value of shape {self.shape} with'
f' out-of-bounds index {f}'
)
return self.value[f.index or ...]


Expand Down Expand Up @@ -353,7 +376,7 @@ def __array__(self) -> np.ndarray:
def slice(
self,
index: NpIndex | Index, # shape=[{rank}, 3], dtype=int
) -> '_GenericFragments[F]': # Use typing.Self once 3.11 is minimum.
) -> '_GenericFragments[_GenericFragment[A]]': # Use typing.Self once >=3.11.
"""Returns a slice of this object."""
if not isinstance(index, np.ndarray):
index = np_utils.resolve_slice(index, self.shape)
Expand Down
94 changes: 94 additions & 0 deletions checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,66 @@ def test_nbytes_astype_of_abstract_fragment_uses_given_dtype(self):
).nbytes_astype(np.dtype(jax.numpy.bfloat16)),
)

@parameterized.named_parameters(
('np_fragment', NpFragment),
('jax_fragment', JaxFragment),
)
def test_intersect(
self,
fragment_t: ConcreteFragmentT,
):
np_api = fragment_t.NP_API
full_value = np_api.arange(8 * 9).reshape((8, 9))
fragment_index = np.s_[4:8:1, 3:9:1]

f = fragment_t(index=fragment_index, value=full_value[fragment_index])

with self.subTest('fully_within_fragment_index'):
bounds = np.s_[5:7:1, 4:8:1]
s = f.intersect(array_fragments._ndarray_from_index(bounds))
self.assertEqual(
fragment_t(index=np.s_[5:7:1, 4:8:1], value=full_value[bounds]),
s,
)

with self.subTest('fully_enclosing_fragment_index'):
bounds = np.s_[2:10:1, 1:11:1]
s = f.intersect(array_fragments._ndarray_from_index(bounds))
self.assertEqual(fragment_t(index=np.s_[4:8:1, 3:9:1], value=f.value), s)

with self.subTest('spanning_fragment_start'):
bounds = np.s_[2:6:1, 2:4:1]
s = f.intersect(array_fragments._ndarray_from_index(bounds))
self.assertEqual(
fragment_t(index=np.s_[4:6:1, 3:4:1], value=f.value[:2, :1]), s
)

with self.subTest('spanning_fragment_stop'):
bounds = np.s_[6:10:1, 6:10:1]
s = f.intersect(array_fragments._ndarray_from_index(bounds))
self.assertEqual(
fragment_t(index=np.s_[6:8:1, 6:9:1], value=f.value[2:, 3:]), s
)

with self.subTest('with_no_overlap'):
self.assertIsNone(
f.intersect(
array_fragments._ndarray_from_index(np.s_[10:12:1, 10:12:1])
)
)
# This is within the bounds of the fragment but spans no elements.
self.assertIsNone(
f.intersect(array_fragments._ndarray_from_index(np.s_[6:6:1, 3:9:1]))
)

with self.subTest('rank_0'):
s = fragment_t(index=(), value=np_api.ones([])).intersect(
np.zeros([0, 3], dtype=int)
)
self.assertIsNotNone(s)
self.assertEqual((), s.index)
self.assertIsInstance(s.value, np_api.ndarray)

@parameterized.named_parameters(
('np_fragment', NpFragment),
('jax_fragment', JaxFragment),
Expand Down Expand Up @@ -272,6 +332,40 @@ def test_slice(
self.assertEqual((), s.index)
self.assertIsInstance(s.value, np_api.ndarray)

@parameterized.named_parameters(
('np_fragment', NpFragment),
('jax_fragment', JaxFragment),
)
def test_slice_of_value(
self,
fragment_t: ConcreteFragmentT,
):
np_api = fragment_t.NP_API
full_value = np_api.arange(8 * 9).reshape((8, 9))
fragment_index = np.s_[4:8:1, 3:9:1]
fragment = fragment_t(
index=fragment_index, value=full_value[fragment_index]
)

with self.subTest('returns_slice_of_value'):
np.testing.assert_array_equal(
full_value[np.s_[5:7:1, 4:8:1]],
fragment.slice_of_value(
array_fragments._ndarray_from_index(np.s_[5:7:1, 4:8:1])
),
)

with self.subTest('raises_if_slice_is_out_of_bounds'):
with self.assertRaises(ValueError):
fragment.slice_of_value(
array_fragments._ndarray_from_index(np.s_[2:6:1, 3:9:1])
)

with self.assertRaises(ValueError):
fragment.slice_of_value(
array_fragments._ndarray_from_index(np.s_[4:8:1, 8:12:1])
)


@parameterized.named_parameters(
('abstract_fragments', AbstractFragments),
Expand Down
Loading