Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ breakages.
include an arbitrary `step_prefix` with any character(s) such as underscores.
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
- Fix using jax.eval_shape with StandardRestore
- Fix checkpoint restoration logic for integer dict keys. Ensures
these keys, which were being temporarily stored as strings, are properly
restored back to integers.

### Changed

Expand Down
72 changes: 72 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/key_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines KeyPythonType enum and helper functions."""

from __future__ import annotations

import enum
from typing import Any

from orbax.checkpoint._src.tree import utils as tree_utils


class KeyType(enum.Enum):
"""Enum representing PyTree key type."""

SEQUENCE = 1
DICT = 2

def to_json(self) -> int:
return self.value

@classmethod
def from_json(cls, value: int) -> KeyType:
return cls(value)

@classmethod
def from_jax_tree_key(cls, key: Any) -> KeyType:
"""Translates the JAX key class into a proto enum."""
if tree_utils.is_sequence_key(key):
return cls.SEQUENCE
elif tree_utils.is_dict_key(key):
return cls.DICT
else:
raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"')


class KeyPythonType(enum.Enum):
"""Enum representing the python type of the key."""

INT = 1
STR = 2

def to_json(self) -> int:
return self.value

@classmethod
def from_json(cls, value: int) -> KeyPythonType:
return cls(value)

@classmethod
def from_jax_python_type(cls, key_python_type: Any) -> KeyPythonType:
"""Translates the JAX key python type into a proto enum."""
if isinstance(key_python_type, int):
return cls.INT
elif isinstance(key_python_type, str):
return cls.STR
else:
raise ValueError(
f'Unsupported KeyEntry: {type(key_python_type)}: "{key_python_type}"'
)
176 changes: 105 additions & 71 deletions checkpoint/orbax/checkpoint/_src/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import collections
import copy
import dataclasses
import enum
import functools
import inspect
import json
Expand All @@ -33,6 +32,7 @@
import jax
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import key_types
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
from orbax.checkpoint._src.metadata import tree_rich_types
from orbax.checkpoint._src.metadata import value as value_metadata
Expand All @@ -56,6 +56,7 @@

_KEY_NAME = 'key'
_KEY_TYPE = 'key_type'
_KEY_PYTHON_TYPE = 'key_python_type'
_TREE_METADATA_KEY = 'tree_metadata'
_KEY_METADATA_KEY = 'key_metadata'
_VALUE_METADATA_KEY = 'value_metadata'
Expand All @@ -65,36 +66,30 @@
_CUSTOM_METADATA = 'custom_metadata'


class KeyType(enum.Enum):
"""Enum representing PyTree key type."""

SEQUENCE = 1
DICT = 2

def to_json(self) -> int:
return self.value

@classmethod
def from_json(cls, value: int) -> KeyType:
return cls(value)


def _get_key_metadata_type(key: Any) -> KeyType:
"""Translates the JAX key class into a proto enum."""
if tree_utils.is_sequence_key(key):
return KeyType.SEQUENCE
elif tree_utils.is_dict_key(key):
return KeyType.DICT
else:
raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"')
def _get_dict_key_name(
key_name: Union[str, int], key_python_type: key_types.KeyPythonType
) -> Union[str, int]:
if key_python_type is key_types.KeyPythonType.INT:
try:
return int(key_name)
except ValueError as e:
raise ValueError(
f"Key '{key_name}' was expected to be an int based on "
'key_python_type but could not be converted.'
) from e
return key_name


def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any:
def _keypath_from_key_type(
key_name: Union[str, int],
key_type: key_types.KeyType,
key_python_type: key_types.KeyPythonType,
) -> Any:
"""Converts from Key in InternalTreeMetadata to JAX keypath class."""
if key_type == KeyType.SEQUENCE:
if key_type is key_types.KeyType.SEQUENCE:
return jax.tree_util.SequenceKey(int(key_name))
elif key_type == KeyType.DICT:
return jax.tree_util.DictKey(key_name)
elif key_type is key_types.KeyType.DICT:
return jax.tree_util.DictKey(_get_dict_key_name(key_name, key_python_type))
else:
raise ValueError(f'Unsupported KeyEntry: {key_type}')

Expand All @@ -103,22 +98,33 @@ def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any:
class NestedKeyMetadataEntry:
"""Represents a key at a single level of nesting."""

nested_key_name: str
key_type: KeyType
nested_key_name: Union[str, int]
key_type: key_types.KeyType # Stores whether key is SEQUENCE or DICT
key_python_type: key_types.KeyPythonType # Stores whether key is INT or STR

def to_json(self) -> Dict[str, Union[str, int]]:
return {
_KEY_NAME: self.nested_key_name,
_KEY_TYPE: self.key_type.to_json(),
_KEY_PYTHON_TYPE: self.key_python_type.to_json(),
}

@classmethod
def from_json(
cls, json_dict: Dict[str, Union[str, int]]
) -> NestedKeyMetadataEntry:
"""Creates a NestedKeyMetadataEntry from a JSON dictionary."""
# Backward compatibility: Default to STR if key_python_type is not found.
key_python_type = (
key_types.KeyPythonType.from_json(json_dict[_KEY_PYTHON_TYPE])
if _KEY_PYTHON_TYPE in json_dict
else key_types.KeyPythonType.STR
)

return NestedKeyMetadataEntry(
nested_key_name=json_dict[_KEY_NAME],
key_type=KeyType.from_json(json_dict[_KEY_TYPE]),
key_type=key_types.KeyType.from_json(json_dict[_KEY_TYPE]),
key_python_type=key_python_type,
)


Expand All @@ -145,7 +151,11 @@ def from_json(
def build(cls, keypath: KeyPath) -> KeyMetadataEntry:
return KeyMetadataEntry([
NestedKeyMetadataEntry(
str(tree_utils.get_key_name(k)), _get_key_metadata_type(k)
nested_key_name=tree_utils.get_key_name(k),
key_type=key_types.KeyType.from_jax_tree_key(k),
key_python_type=key_types.KeyPythonType.from_jax_python_type(
tree_utils.get_key_name(k)
),
)
for k in keypath
])
Expand Down Expand Up @@ -201,7 +211,10 @@ def jax_keypath(self) -> KeyPath:
for nested_key_entry in self.key_metadata.nested_key_metadata_entries:
nested_key_name = nested_key_entry.nested_key_name
key_type = nested_key_entry.key_type
keypath.append(_keypath_from_key_type(nested_key_name, key_type))
key_python_type = nested_key_entry.key_python_type
keypath.append(
_keypath_from_key_type(nested_key_name, key_type, key_python_type)
)
return tuple(keypath)


Expand Down Expand Up @@ -305,10 +318,16 @@ def to_json(self) -> Dict[str, Any]:
_TREE_METADATA_KEY: {
"(top_level_key, lower_level_key)": {
_KEY_METADATA_KEY: (
{_KEY_NAME: "top_level_key", _KEY_TYPE: <KeyType
(int)>},
{_KEY_NAME: "lower_level_key", _KEY_TYPE: <KeyType
(int)>},
{
_KEY_NAME: "top_level_key",
_KEY_TYPE: <KeyType (int)>,
_KEY_PYTHON_TYPE: <KeyPythonType (int)>
},
{
_KEY_NAME: "lower_level_key",
_KEY_TYPE: <KeyType (int)>,
_KEY_PYTHON_TYPE: <KeyPythonType (int)>
},
)
_VALUE_METADATA_KEY: {
_VALUE_TYPE: "jax.Array",
Expand All @@ -322,47 +341,62 @@ def to_json(self) -> Dict[str, Any]:
_CUSTOM_METADATA: ...,
_VALUE_METADATA_TREE: '{
"mu_nu": {
"category": "namedtuple",
"module": "orbax.checkpoint._src.testing.test_tree_utils",
"clazz": "MuNu",
"entries": [
{
"key": "mu",
"value": {
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "jax.Array",
"skip_deserialize": false
"key_python_type": 2,
"value": {
"category": "namedtuple",
"module": "orbax.checkpoint._src.testing.test_tree_utils",
"clazz": "MuNu",
"entries": [
{
"key": "mu",
"value": {
"key_python_type": 2,
"value": {
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "jax.Array",
"skip_deserialize": false
}
}
}
}
},
{
"key": "nu",
"value": {
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "np.ndarray",
"skip_deserialize": false
},
{
"key": "nu",
"value": {
"key_python_type": 2,
"value": {
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "np.ndarray",
"skip_deserialize": false
}
}
}
}
}
]
]
}
},
"my_tuple": {
"category": "custom",
"clazz": "tuple",
"entries": [
{
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "np.ndarray",
"skip_deserialize": false
"key_python_type": 2,
"value": {
"category": "custom",
"clazz": "tuple",
"entries": [
{
"key_python_type": 2,
"value": {
"category": "custom",
"clazz": "ValueMetadataEntry",
"data": {
"value_type": "np.ndarray",
"skip_deserialize": false
}
}
}
}
]
]
}
}
}'
}
Expand Down
Loading
Loading