Skip to content

Commit 58e3d77

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Improve V1 docstrings
PiperOrigin-RevId: 839968361
1 parent 525781c commit 58e3d77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+549
-309
lines changed

checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class Context(epy.ContextManager):
4747
with ocp.Context(...):
4848
ocp.save_pytree(...)
4949
50-
Creating a new `Context` within an existing `Context` sets all parameters from
51-
scratch; it does not inherit properties from the parent `Context`. To achieve
52-
this, use::
50+
Creating a new :py:class:`.Context` within an existing :py:class:`.Context`
51+
sets all parameters from scratch; it does not inherit properties from the
52+
parent :py:class:`.Context`. To achieve this, use::
5353
5454
with Context(**some_properties) as outer_ctx:
5555
with Context(outer_ctx, **other) as inner_ctx:
@@ -59,7 +59,7 @@ class Context(epy.ContextManager):
5959
properties modified in the `dataclasses.replace` call.
6060
6161
NOTE: The context is not shared across threads. In other words, the whole
62-
context block must be executed in the same thread. Following example will
62+
context block must be executed in the same thread. The following example will
6363
not work as expected::
6464
6565
executor = ThreadPoolExecutor()

checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class AsyncOptions:
4343
post_finalization_callback:
4444
A function that is called after the async save operation is complete.
4545
create_directories_asynchronously:
46-
If true, create directories asynchronously in the background.
46+
If True, creates directories asynchronously in the background.
4747
"""
4848

4949
timeout_secs: int = 600 # 10 minutes.
@@ -67,15 +67,18 @@ class MultiprocessingOptions:
6767
all hosts will be considered as primary. It's useful in the case that all
6868
hosts are only working with local storage.
6969
active_processes:
70-
A set of process indices (corresponding to `multihost.process_index()`) over
71-
which `CheckpointManager` is expected to be called. This makes it possible
72-
to have a `CheckpointManager` instance that runs over a subset of processes,
73-
rather than all processes as it is normally expected to do. If specified,
74-
`primary_host` must belong to `active_processes`.
70+
A set of process indices (corresponding to
71+
:py:func:`~orbax.checkpoint._src.multihost.process_index`) over which
72+
:py:class:`~orbax.checkpoint.CheckpointManager` is expected to be called.
73+
This makes it possible to have a
74+
:py:class:`~orbax.checkpoint.CheckpointManager` instance that runs over a
75+
subset of processes, rather than all processes as it is normally expected to
76+
do. If specified, `primary_host` must belong to `active_processes`.
7577
barrier_sync_key_prefix:
7678
A string to be prepended to the barrier sync key used to synchronize
7779
processes. This is useful to avoid collisions with other barrier syncs if
78-
another CheckpointManager is being used concurrently.
80+
another :py:class:`~orbax.checkpoint.CheckpointManager` is being used
81+
concurrently.
7982
"""
8083

8184
primary_host: int | None = 0
@@ -102,12 +105,14 @@ class FileOptions:
102105
https://github.com/google/etils/blob/main/etils/epath/backend.py if your
103106
path is supported. default=None.
104107
temporary_path_class:
105-
A class that is used to create and finallize temporary paths, and to
106-
ensure atomicity.
108+
A class that is used to create and finalize temporary paths, and to ensure
109+
atomicity.
107110
path_class:
108-
The implementation of `path_types.Path` to use. Defaults to
109-
`etils.epath.Path`, but may be overridden to some other subclass of
110-
`path_types.Path`.
111+
The implementation of
112+
:py:class:`~orbax.checkpoint.experimental.v1._src.path.types.Path` to use.
113+
Defaults to :py:class:`~etils.epath.Path`, but may be overridden to some
114+
other subclass of
115+
:py:class:`~orbax.checkpoint.experimental.v1._src.path.types.Path`.
111116
"""
112117

113118
path_permission_mode: int | None = None
@@ -141,11 +146,12 @@ class Saving:
141146
142147
create_array_storage_options_fn:
143148
A function that is called in order to create
144-
`ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree, when it is
149+
:py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree,
150+
when it is
145151
being saved. It is called similar to:
146152
`jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`.
147153
If provided, it overrides any default settings in
148-
`ArrayOptions.Saving.StorageOptions`.
154+
:py:class:`.ArrayOptions.Saving.StorageOptions`.
149155
pytree_metadata_options: Options for managing PyTree metadata.
150156
"""
151157

@@ -230,19 +236,19 @@ class StorageOptions:
230236
231237
dtype:
232238
If provided, casts the parameter to the given dtype before saving.
233-
Note that the parameter must be compatible with the given type (e.g.
234-
jnp.bfloat16 is not compatible with np.ndarray).
239+
Note that the parameter must be compatible with the given type (e.g.,
240+
`jnp.bfloat16` is not compatible with `np.ndarray`).
235241
chunk_byte_size:
236242
This is an experimental feature that automatically chooses the largest
237-
chunk shape possible, while keeping the chunk byte size less than or
238-
equal to the specified chunk_byte_size. Both the write_chunk_shape and
239-
read_chunk_shape are automatically set to the chosen shape. This uses a
240-
greedy algorithm that prioritizes splitting the largest dimensions
243+
possible chunk shape while keeping the chunk byte size less than or
244+
equal to the specified `chunk_byte_size`. Both `write_chunk_shape` and
245+
`read_chunk_shape` are automatically set to the chosen shape. This uses
246+
a greedy algorithm that prioritizes splitting the largest dimensions
241247
first.
242248
shard_axes:
243-
An optional list of axes that should be prioritized when sharding array
244-
for storage. If empty, storage sharding implementation will prioritize
245-
axes which are already sharded.
249+
An optional list of axes that should be prioritized when sharding an
250+
array for storage. If empty, the storage sharding implementation will
251+
prioritize axes which are already sharded.
246252
"""
247253

248254
dtype: np.typing.DTypeLike | None = None
@@ -322,9 +328,9 @@ class CheckpointablesOptions:
322328
first because it is registered first.
323329
324330
Attributes:
325-
registry: A `CheckpointableHandlerRegistry` that is used to resolve
326-
`CheckpointableHandler` classes for each provided `checkpointable` during
327-
saving and loading.
331+
registry: A :py:class:`.CheckpointableHandlerRegistry` that is used to
332+
resolve :py:class:`.CheckpointableHandler` classes for each provided
333+
`checkpointable` during saving and loading.
328334
"""
329335

330336
registry: registration.CheckpointableHandlerRegistry = dataclasses.field(

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
class _PathAwaitingCreation(path_types.PathAwaitingCreation):
34-
"""Implementation of `PathAwaitingCreation` that awaits contracted signals."""
34+
"""Implementation of :py:class:`~orbax.checkpoint.experimental.v1._src.path.types.PathAwaitingCreation` that awaits contracted signals."""
3535

3636
def __init__(self, path: path_types.Path, operation_id: str):
3737
self._path = path
@@ -56,7 +56,7 @@ def path(self) -> path_types.Path:
5656
class CompatibilityCheckpointHandler(
5757
async_checkpoint_handler.AsyncCheckpointHandler
5858
):
59-
"""A V0 CheckpointHandler that wraps a V1 CheckpointableHandler."""
59+
"""A V0 :py:class:`~orbax.checkpoint._src.handlers.async_checkpoint_handler.AsyncCheckpointHandler` that wraps a V1 :py:class:`~orbax.checkpoint.experimental.v1._src.handlers.types.CheckpointableHandler`."""
6060

6161
def __init__(self, handler: handler_types.CheckpointableHandler):
6262
self._handler = handler

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/composite_handler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Defines `CompositeHandler`, a helper component for saving and loading."""
15+
"""Defines :py:class:`.CompositeHandler`, a helper component for saving and loading."""
1616

1717
from __future__ import annotations
1818

@@ -74,8 +74,10 @@ async def _create_orbax_identifier_file(
7474
class CompositeHandler:
7575
"""CompositeHandler.
7676
77-
This class is a helper component for `save_checkpointables` and
78-
`load_checkpointables`. It performs a few core functions:
77+
This class is a helper component for
78+
:py:func:`~orbax.checkpoint.experimental.v1.save_checkpointables` and
79+
:py:func:`~orbax.checkpoint.experimental.v1.load_checkpointables`. It performs
80+
a few core functions:
7981
- Resolves handlers for saving and loading.
8082
- Saves and loads checkpointables to/from individual subdirectories by
8183
delegating to the resolved handlers.

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/json_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Implementation of `CheckpointableHandler` for PyTrees."""
15+
"""Implementation of :py:class:`.CheckpointableHandler` for PyTrees."""
1616

1717
from __future__ import annotations
1818

@@ -40,7 +40,7 @@ def _get_supported_filenames(filename: str | None = None) -> list[str]:
4040

4141

4242
class JsonHandler(CheckpointableHandler[JsonType, None]):
43-
"""An implementation of `CheckpointableHandler` for Json."""
43+
"""An implementation of :py:class:`.CheckpointableHandler` for Json."""
4444

4545
def __init__(self, filename: str | None = None):
4646
self._supported_filenames = _get_supported_filenames(filename)

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/proto_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self,
3939
filename: str = _DEFAULT_FILENAME,
4040
):
41-
"""Initializes ProtoCheckpointHandler."""
41+
"""Initializes :py:class:`.ProtoHandler`."""
4242
self._filename = filename
4343

4444
async def _background_save(

checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Implementation of :py:class:`.CheckpointableHandler` for PyTrees."""
15+
"""Implementation of :py:class:`~orbax.checkpoint.experimental.v1._src.handlers.types.CheckpointableHandler` for PyTrees."""
1616

1717
from __future__ import annotations
1818

@@ -123,24 +123,24 @@ def create_v0_save_args(
123123
def _restore_type_by_abstract_type(
124124
abstract_checkpointable: Any,
125125
) -> Any:
126-
"""This is to allow users to override the restored type.
126+
"""Allows users to override the restored type.
127127
128-
When users pass in the `value` in the DeserializationParam, the PytreeHandler
129-
will try to restore to the specified type. T. This only supports the standard
128+
When users pass the `value` in the `DeserializationParam`, the `PyTreeHandler`
129+
will try to restore to the specified type `T`. This only supports the standard
130130
types supported by Orbax.
131131
For example:
132-
- jax.ShapeDtype -> jax.Array
133-
- NumpyAbstractType -> jax.Array
134-
- int | float | Type[int] | Type[float] -> int | float | int | float
132+
- `jax.ShapeDtype` -> `jax.Array`
133+
- `NumpyAbstractType` -> `jax.Array`
134+
- `int` | `float` | `Type[int]` | `Type[float]` -> `int` | `float` | `int` |
135+
`float`
135136
136137
Args:
137-
abstract_checkpointable: The abstract checkpointable that passed in by the
138-
user.
138+
abstract_checkpointable: The abstract checkpointable passed in by the user.
139139
140140
Returns:
141-
Return the restore_type parameter for the V0RestoreArgs. This is needed to
142-
determine which LeafHandler will eventually handle this
143-
abstract_checkpointable.
141+
Returns the `restore_type` parameter for `V0RestoreArgs`. This is needed to
142+
determine which `LeafHandler` will eventually handle this
143+
`abstract_checkpointable`.
144144
"""
145145

146146
if abstract_checkpointable is None:
@@ -315,12 +315,13 @@ async def load(
315315
abstract_checkpointable: The abstract checkpointable to load into. If
316316
None, the handler will attempt to load the entire checkpoint using the
317317
recorded metadata. Otherwise, the `abstract_checkpointable` is expected
318-
to be a PyTree of abstract leaves. See :py:class:`.LeafHandler` for more
319-
details. The abstract leaf may be a value of type `AbstractLeaf`,
320-
`Type[AbstractLeaf]`, or `None`. E.g. if the `AbstractLeaf` is
321-
`AbstractFoo`, it is always valid to pass `AbstractFoo()` or
322-
`AbstractFoo` or `None`. Passing the latter two indicates that metadata
323-
should be used to restore the leaf.
318+
to be a PyTree of abstract leaves. See
319+
:py:class:`~orbax.checkpoint.experimental.v1._src.serialization.types.LeafHandler`
320+
for more details. The abstract leaf may be a value of type
321+
`AbstractLeaf`, `Type[AbstractLeaf]`, or `None`. E.g. if the
322+
`AbstractLeaf` is `AbstractFoo`, it is always valid to pass
323+
`AbstractFoo()` or `AbstractFoo` or `None`. Passing the latter two
324+
indicates that metadata should be used to restore the leaf.
324325
325326
Returns:
326327
A awaitable which can be awaited to complete the load operation and

0 commit comments

Comments
 (0)