Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions checkpoint/orbax/checkpoint/_src/arrays/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import dataclasses
from typing import Optional, Sequence, TypeAlias
from typing import ClassVar, Optional, Sequence, TypeAlias

import jax
import numpy as np
Expand All @@ -43,7 +43,7 @@ def _index_from_ndarray(a: NpIndex) -> Index:


@dataclasses.dataclass(frozen=True, init=False)
class Fragment:
class _Fragment:
"""One of a collection of slices into the same (abstract or concrete) array.

Fields:
Expand Down Expand Up @@ -113,8 +113,8 @@ def shape(self) -> Shape:
def size(self) -> int:
return np.prod(self.shape)

def __eq__(self, other: 'Fragment'):
if not isinstance(other, Fragment):
def __eq__(self, other: '_Fragment'):
if not isinstance(other, _Fragment):
return False
if not np.array_equal(self.np_index, other.np_index):
return False
Expand Down Expand Up @@ -159,15 +159,15 @@ def nbytes_astype(self, dtype: np.dtype) -> int:
def offset_by(
self,
delta: np.ndarray, # shape=[{rank}], dtype=int
) -> 'Fragment':
) -> '_Fragment':
out_idx = self.np_index.copy()
out_idx[:, :2] += np.expand_dims(delta, axis=1)
return Fragment(np_index=out_idx, value=self.value)
return _Fragment(np_index=out_idx, value=self.value)

def slice(
self,
np_index: NpIndex, # shape=[{rank}, 3], dtype=int
) -> Optional['Fragment']:
) -> Optional['_Fragment']:
"""Slices this fragment to find the part that overlaps the given NpIndex."""
if (self.step != 1).any() or (np_index[:, 2] != 1).any():
raise NotImplementedError('Coming ... soon?')
Expand All @@ -190,7 +190,7 @@ def slice_of_value(
start = self.start
stop = self.stop
# This is just a convenient way to construct the required tuple of slices.
f = Fragment(
f = _Fragment(
np_index=np.stack([
np.maximum(start, new_np_idx[:, 0]),
np.minimum(stop, new_np_idx[:, 1]),
Expand All @@ -201,7 +201,7 @@ def slice_of_value(


@dataclasses.dataclass(frozen=True)
class Fragments:
class _Fragments:
"""An abstract or concrete collection of fragments.

A `Fragments` is a lot like a `jax.Array` (or a `jax.ShapeDtypeStruct`) but
Expand All @@ -210,14 +210,20 @@ class Fragments:
of a `jax.Array` (fragments are not required to have the same shape, or to map
to a device mesh).
"""
# Keep printed representation the same as before the leading underscore
# was added. TODO(b/465183318): Remove this once there are separate
# classes for abstract and concrete fragments.
__qualname__ = 'Fragments'

FRAGMENT_T: ClassVar[type[_Fragment]] = _Fragment

shape: Shape
dtype: np.dtype
fragments: Sequence[Fragment]
fragments: Sequence[_Fragment]

def __post_init__(self):
for fragment in self.fragments:
if not isinstance(fragment, Fragment):
if not isinstance(fragment, _Fragment):
raise TypeError(
f'Fragments must contain Fragment, not {type(fragment)}.'
)
Expand Down Expand Up @@ -265,7 +271,7 @@ def __array__(self) -> np.ndarray:
def slice(
self,
index: NpIndex | Index, # shape=[{rank}, 3], dtype=int
) -> 'Fragments':
) -> '_Fragments':
"""Returns a slice of this object."""
if not isinstance(index, np.ndarray):
index = np_utils.resolve_slice(index, self.shape)
Expand All @@ -280,7 +286,7 @@ def slice(
f'with out-of-bounds index {_index_from_ndarray(index)}'
)

return Fragments(
return _Fragments(
tuple(d.item() for d in sliced_shape),
self.dtype,
[
Expand All @@ -291,7 +297,18 @@ def slice(
)


def _is_full(fragments: Fragments) -> bool:
# TODO(b/465188418): Remove these two aliases once all users have been migrated
# to the more specific ones.
Fragment: TypeAlias = _Fragment
Fragments: TypeAlias = _Fragments

AbstractFragment: TypeAlias = _Fragment
AbstractFragments: TypeAlias = _Fragments
ConcreteFragment: TypeAlias = _Fragment
ConcreteFragments: TypeAlias = _Fragments


def _is_full(fragments: _Fragments) -> bool:
"""True iff every array element is covered by some fragment."""
present = np.zeros(fragments.shape, dtype=bool)
for f in fragments.fragments:
Expand Down Expand Up @@ -323,19 +340,31 @@ def addressable_shards(x: jax.Array | jax.ShapeDtypeStruct) -> list[Index]:


def abstract_fragments(
x: jax.Array | jax.ShapeDtypeStruct | Fragments,
) -> Fragments:
x: jax.Array | jax.ShapeDtypeStruct | AbstractFragments | ConcreteFragments,
) -> AbstractFragments:
"""Returns abstract fragments matching the given array."""
if isinstance(x, Fragments):
return x
return Fragments(
if isinstance(x, _Fragments):
# TODO(b/465183318): Replace this condition with an instance check
# once AbstractFragments and ConcreteFragments are separate classes.
if all(f.value is None for f in x.fragments):
return x
else:
return AbstractFragments(
x.shape,
x.dtype,
[AbstractFragment(index=f.index) for f in x.fragments],
)
return AbstractFragments(
x.shape,
x.dtype,
[Fragment(index=index, value=None) for index in addressable_shards(x)],
[
AbstractFragment(index=index, value=None)
for index in addressable_shards(x)
],
)


def validate_fragments_can_be_stacked(fragments: Fragments) -> None:
def validate_fragments_can_be_stacked(fragments: ConcreteFragments) -> None:
"""Validates that the given fragments can be stacked."""
if not fragments.fragments:
raise ValueError('No fragments to stack.')
Expand All @@ -353,7 +382,7 @@ def validate_fragments_can_be_stacked(fragments: Fragments) -> None:
)


def stack_fragments(fragments: Fragments | None) -> np.ndarray | None:
def stack_fragments(fragments: ConcreteFragments | None) -> np.ndarray | None:
"""Stacks the given fragments, which must all have the same shape."""
if fragments is None:
return fragments
Expand Down
Loading