Skip to content

#v1 Partially port the existing checkpoint format page into V1 documentation, but most of it is rewritten and updated for V1. #1881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@
JsonType = tree_types.JsonType


_DATA_FILENAME = 'data.json'


class JsonHandler(CheckpointableHandler[JsonType, None]):
"""An implementation of `CheckpointableHandler` for Json."""

def __init__(self, filename: str = 'metadata'):
self._filename = filename
def __init__(self, filename: str | None = None):
self._filename = filename or _DATA_FILENAME
self._supported_filenames = [self._filename, _DATA_FILENAME, 'metadata']

async def _background_save(
self,
Expand Down Expand Up @@ -62,9 +66,14 @@ async def _background_load(
self,
directory: path_types.Path,
):
path = directory / self._filename
json_str = await asyncio.to_thread(path.read_text)
return json.loads(json_str)
for filename in self._supported_filenames:
path = directory / filename
if await asyncio.to_thread(path.exists):
return json.loads(await asyncio.to_thread(path.read_text))
raise FileNotFoundError(
f'Unable to parse JSON file in {directory}. Recognized filenames are:'
f' {self._supported_filenames}'
)

async def load(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ def test_save_restore(self):
restored = handler.load(self.directory)
self.assertEqual(item, restored)

@parameterized.parameters(
('data.json',),
('metadata',),
('unrecognized.json',),
)
def test_supported_filenames(self, filename):
item = {'a': 1, 'b': 'test'}
handler = JsonHandler()
handler.save(
directory=self.directory,
checkpointable=item,
)
self.assertTrue((self.directory / 'data.json').exists())
src = self.directory / 'data.json'
dst = self.directory / filename
if src != dst:
(self.directory / 'data.json').rename(self.directory / filename)
self.assertTrue((self.directory / filename).exists())

if filename not in handler._handler._supported_filenames:
with self.assertRaises(FileNotFoundError):
handler.load(self.directory)
else:
restored = handler.load(self.directory)
self.assertEqual(item, restored)

@parameterized.parameters(
('{"one": 1, "two": {"three": "3"}, "four": [4]}', True),
({'one': 1, 'two': {'three': '3'}, 'four': [4]}, True),
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from orbax.checkpoint.experimental.v1._src.handlers.proto_handler import (
ProtoHandler,
)
from orbax.checkpoint.experimental.v1._src.handlers.json_handler import (
JsonHandler,
)


from orbax.checkpoint.experimental.v1._src.handlers.registration import (
CheckpointableHandlerRegistry,
Expand Down
Loading
Loading