diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 5e8ba0915..8654a2f8c 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -22,6 +22,9 @@ breakages. include an arbitrary `step_prefix` with any character(s) such as underscores. - Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`. - Fix using jax.eval_shape with StandardRestore +- Fix checkpoint restoration logic for integer dict keys. Ensures + these keys, which were being temporarily stored as strings, are properly + restored back to integers. ### Changed diff --git a/checkpoint/orbax/checkpoint/_src/metadata/key_types.py b/checkpoint/orbax/checkpoint/_src/metadata/key_types.py new file mode 100644 index 000000000..535840272 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/metadata/key_types.py @@ -0,0 +1,72 @@ +# Copyright 2025 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines KeyPythonType enum and helper functions.""" + +from __future__ import annotations + +import enum +from typing import Any + +from orbax.checkpoint._src.tree import utils as tree_utils + + +class KeyType(enum.Enum): + """Enum representing PyTree key type.""" + + SEQUENCE = 1 + DICT = 2 + + def to_json(self) -> int: + return self.value + + @classmethod + def from_json(cls, value: int) -> KeyType: + return cls(value) + + @classmethod + def from_jax_tree_key(cls, key: Any) -> KeyType: + """Translates the JAX key class into a proto enum.""" + if tree_utils.is_sequence_key(key): + return cls.SEQUENCE + elif tree_utils.is_dict_key(key): + return cls.DICT + else: + raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"') + + +class KeyPythonType(enum.Enum): + """Enum representing the python type of the key.""" + + INT = 1 + STR = 2 + + def to_json(self) -> int: + return self.value + + @classmethod + def from_json(cls, value: int) -> KeyPythonType: + return cls(value) + + @classmethod + def from_jax_python_type(cls, key_python_type: Any) -> KeyPythonType: + """Translates the JAX key python type into a proto enum.""" + if isinstance(key_python_type, int): + return cls.INT + elif isinstance(key_python_type, str): + return cls.STR + else: + raise ValueError( + f'Unsupported KeyEntry: {type(key_python_type)}: "{key_python_type}"' + ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree.py b/checkpoint/orbax/checkpoint/_src/metadata/tree.py index 8660b9c6d..28abc6e17 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree.py @@ -20,7 +20,6 @@ import collections import copy import dataclasses -import enum import functools import inspect import json @@ -33,6 +32,7 @@ import jax from orbax.checkpoint._src import asyncio_utils from orbax.checkpoint._src.metadata import empty_values +from orbax.checkpoint._src.metadata import key_types from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import tree_rich_types from orbax.checkpoint._src.metadata import value as value_metadata @@ -56,6 +56,7 @@ _KEY_NAME = 'key' _KEY_TYPE = 'key_type' +_KEY_PYTHON_TYPE = 'key_python_type' _TREE_METADATA_KEY = 'tree_metadata' _KEY_METADATA_KEY = 'key_metadata' _VALUE_METADATA_KEY = 'value_metadata' @@ -65,36 +66,30 @@ _CUSTOM_METADATA = 'custom_metadata' -class KeyType(enum.Enum): - """Enum representing PyTree key type.""" - - SEQUENCE = 1 - DICT = 2 - - def to_json(self) -> int: - return self.value - - @classmethod - def from_json(cls, value: int) -> KeyType: - return cls(value) - - -def _get_key_metadata_type(key: Any) -> KeyType: - """Translates the JAX key class into a proto enum.""" - if tree_utils.is_sequence_key(key): - return KeyType.SEQUENCE - elif tree_utils.is_dict_key(key): - return KeyType.DICT - else: - raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"') +def _get_dict_key_name( + key_name: Union[str, int], key_python_type: key_types.KeyPythonType +) -> Union[str, int]: + if key_python_type is key_types.KeyPythonType.INT: + try: + return int(key_name) + except ValueError as e: + raise ValueError( + f"Key '{key_name}' was expected to be an int based on " + 'key_python_type but could not be converted.' + ) from e + return key_name -def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any: +def _keypath_from_key_type( + key_name: Union[str, int], + key_type: key_types.KeyType, + key_python_type: key_types.KeyPythonType, +) -> Any: """Converts from Key in InternalTreeMetadata to JAX keypath class.""" - if key_type == KeyType.SEQUENCE: + if key_type is key_types.KeyType.SEQUENCE: return jax.tree_util.SequenceKey(int(key_name)) - elif key_type == KeyType.DICT: - return jax.tree_util.DictKey(key_name) + elif key_type is key_types.KeyType.DICT: + return jax.tree_util.DictKey(_get_dict_key_name(key_name, key_python_type)) else: raise ValueError(f'Unsupported KeyEntry: {key_type}') @@ -103,22 +98,33 @@ def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any: class NestedKeyMetadataEntry: """Represents a key at a single level of nesting.""" - nested_key_name: str - key_type: KeyType + nested_key_name: Union[str, int] + key_type: key_types.KeyType # Stores whether key is SEQUENCE or DICT + key_python_type: key_types.KeyPythonType # Stores whether key is INT or STR def to_json(self) -> Dict[str, Union[str, int]]: return { _KEY_NAME: self.nested_key_name, _KEY_TYPE: self.key_type.to_json(), + _KEY_PYTHON_TYPE: self.key_python_type.to_json(), } @classmethod def from_json( cls, json_dict: Dict[str, Union[str, int]] ) -> NestedKeyMetadataEntry: + """Creates a NestedKeyMetadataEntry from a JSON dictionary.""" + # Backward compatibility: Default to STR if key_python_type is not found. + key_python_type = ( + key_types.KeyPythonType.from_json(json_dict[_KEY_PYTHON_TYPE]) + if _KEY_PYTHON_TYPE in json_dict + else key_types.KeyPythonType.STR + ) + return NestedKeyMetadataEntry( nested_key_name=json_dict[_KEY_NAME], - key_type=KeyType.from_json(json_dict[_KEY_TYPE]), + key_type=key_types.KeyType.from_json(json_dict[_KEY_TYPE]), + key_python_type=key_python_type, ) @@ -145,7 +151,11 @@ def from_json( def build(cls, keypath: KeyPath) -> KeyMetadataEntry: return KeyMetadataEntry([ NestedKeyMetadataEntry( - str(tree_utils.get_key_name(k)), _get_key_metadata_type(k) + nested_key_name=tree_utils.get_key_name(k), + key_type=key_types.KeyType.from_jax_tree_key(k), + key_python_type=key_types.KeyPythonType.from_jax_python_type( + tree_utils.get_key_name(k) + ), ) for k in keypath ]) @@ -201,7 +211,10 @@ def jax_keypath(self) -> KeyPath: for nested_key_entry in self.key_metadata.nested_key_metadata_entries: nested_key_name = nested_key_entry.nested_key_name key_type = nested_key_entry.key_type - keypath.append(_keypath_from_key_type(nested_key_name, key_type)) + key_python_type = nested_key_entry.key_python_type + keypath.append( + _keypath_from_key_type(nested_key_name, key_type, key_python_type) + ) return tuple(keypath) @@ -305,10 +318,16 @@ def to_json(self) -> Dict[str, Any]: _TREE_METADATA_KEY: { "(top_level_key, lower_level_key)": { _KEY_METADATA_KEY: ( - {_KEY_NAME: "top_level_key", _KEY_TYPE: }, - {_KEY_NAME: "lower_level_key", _KEY_TYPE: }, + { + _KEY_NAME: "top_level_key", + _KEY_TYPE: , + _KEY_PYTHON_TYPE: + }, + { + _KEY_NAME: "lower_level_key", + _KEY_TYPE: , + _KEY_PYTHON_TYPE: + }, ) _VALUE_METADATA_KEY: { _VALUE_TYPE: "jax.Array", @@ -322,47 +341,62 @@ def to_json(self) -> Dict[str, Any]: _CUSTOM_METADATA: ..., _VALUE_METADATA_TREE: '{ "mu_nu": { - "category": "namedtuple", - "module": "orbax.checkpoint._src.testing.test_tree_utils", - "clazz": "MuNu", - "entries": [ - { - "key": "mu", - "value": { - "category": "custom", - "clazz": "ValueMetadataEntry", - "data": { - "value_type": "jax.Array", - "skip_deserialize": false + "key_python_type": 2, + "value": { + "category": "namedtuple", + "module": "orbax.checkpoint._src.testing.test_tree_utils", + "clazz": "MuNu", + "entries": [ + { + "key": "mu", + "value": { + "key_python_type": 2, + "value": { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "jax.Array", + "skip_deserialize": false + } + } } - } - }, - { - "key": "nu", - "value": { - "category": "custom", - "clazz": "ValueMetadataEntry", - "data": { - "value_type": "np.ndarray", - "skip_deserialize": false + }, + { + "key": "nu", + "value": { + "key_python_type": 2, + "value": { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "np.ndarray", + "skip_deserialize": false + } + } } } - } - ] + ] + } }, "my_tuple": { - "category": "custom", - "clazz": "tuple", - "entries": [ - { - "category": "custom", - "clazz": "ValueMetadataEntry", - "data": { - "value_type": "np.ndarray", - "skip_deserialize": false + "key_python_type": 2, + "value": { + "category": "custom", + "clazz": "tuple", + "entries": [ + { + "key_python_type": 2, + "value": { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "np.ndarray", + "skip_deserialize": false + } + } } - } - ] + ] + } } }' } diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree_rich_types.py b/checkpoint/orbax/checkpoint/_src/metadata/tree_rich_types.py index 54f1fe15c..b0947190c 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree_rich_types.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree_rich_types.py @@ -21,6 +21,7 @@ from typing import Any, Iterable, Mapping, Sequence, Type, TypeAlias import jax +from orbax.checkpoint._src.metadata import key_types from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib from orbax.checkpoint._src.metadata import value_metadata_entry from orbax.checkpoint._src.tree import utils as tree_utils @@ -77,6 +78,10 @@ def _module_and_class_name(cls) -> tuple[str, str]: """Returns the module and class name of the given class instance.""" return cls.__module__, cls.__qualname__ +_DICT_META = 'dict_meta' +_KEY_PYTHON_TYPE = 'key_python_type' +_VALUE = 'value' + _VALUE_METADATA_ENTRY_CLAZZ = 'ValueMetadataEntry' _VALUE_METADATA_ENTRY_MODULE_AND_CLASS = _module_and_class_name( @@ -122,7 +127,15 @@ def _value_metadata_tree_for_json_dumps(obj: Any) -> Any: ) if isinstance(obj, Mapping): - return {k: _value_metadata_tree_for_json_dumps(v) for k, v in obj.items()} + return { + str(k): { + _KEY_PYTHON_TYPE: ( + key_types.KeyPythonType.from_jax_python_type(k).to_json() + ), + _VALUE: _value_metadata_tree_for_json_dumps(v), + } + for k, v in obj.items() + } if isinstance(obj, list): return [_value_metadata_tree_for_json_dumps(e) for e in obj] @@ -175,7 +188,29 @@ def _value_metadata_tree_for_json_loads(obj): ], ) - return {k: _value_metadata_tree_for_json_loads(v) for k, v in obj.items()} + # Deserialize dictionary mapping according to orignal key type. + restored_dict = {} + for k, v in obj.items(): + # Check if original key type is present in the metadata. + if isinstance(v, Mapping) and _KEY_PYTHON_TYPE in v: + key_val = k + key_python_type = key_types.KeyPythonType.from_json(v[_KEY_PYTHON_TYPE]) + + try: + if key_python_type is key_types.KeyPythonType.INT: + key_val = int(key_val) + elif key_python_type is key_types.KeyPythonType.STR: + pass + except ValueError as e: + raise ValueError( + f'Failed to restore key {key_val} to type {key_python_type}.' + ) from e + + restored_dict[key_val] = _value_metadata_tree_for_json_loads(v[_VALUE]) + else: + # If original key type is not present, then it is a standard dict entry. + restored_dict[k] = _value_metadata_tree_for_json_loads(v) + return restored_dict def value_metadata_tree_to_json_str(tree: PyTree) -> str: @@ -185,12 +220,64 @@ def value_metadata_tree_to_json_str(tree: PyTree) -> str: ``` '{ "mu_nu": { - "category": "namedtuple", - "module": "orbax.checkpoint._src.testing.test_tree_utils", - "clazz": "MuNu", - "entries": [ - { - "key": "mu", + // enum value 2 -> str + "key_python_type": 2, + "value": { + "category": "namedtuple", + "module": "orbax.checkpoint._src.testing.test_tree_utils", + "clazz": "MuNu", + "entries": [ + { + "key": "mu", + "value": { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "jax.Array", + "skip_deserialize": false + } + } + }, + { + "key": "nu", + "value": { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "np.ndarray", + "skip_deserialize": false + } + } + } + ] + } + }, + "my_tuple": { + // enum value 2 -> str + "key_python_type": 2, + "value": { + "category": "custom", + "clazz": "tuple", + "entries": [ + { + "category": "custom", + "clazz": "ValueMetadataEntry", + "data": { + "value_type": "np.ndarray", + "skip_deserialize": false + } + } + ] + } + }, + "my_dict": { + // enum value 2 -> str + "key_python_type": 2, + "value": { + // This key ("0") will be restored as the integer . + "0": { + // enum value 1 -> int + "key_python_type": 1, "value": { "category": "custom", "clazz": "ValueMetadataEntry", @@ -200,8 +287,9 @@ def value_metadata_tree_to_json_str(tree: PyTree) -> str: } } }, - { - "key": "nu", + "str_key": { + // enum value 2 -> str + "key_python_type": 2, "value": { "category": "custom", "clazz": "ValueMetadataEntry", @@ -211,21 +299,7 @@ def value_metadata_tree_to_json_str(tree: PyTree) -> str: } } } - ] - }, - "my_tuple": { - "category": "custom", - "clazz": "tuple", - "entries": [ - { - "category": "custom", - "clazz": "ValueMetadataEntry", - "data": { - "value_type": "np.ndarray", - "skip_deserialize": false - } - } - ] + } } }' ``` @@ -234,10 +308,7 @@ def value_metadata_tree_to_json_str(tree: PyTree) -> str: tree: A PyTree to be converted to JSON string. """ return simplejson.dumps( - tree, - default=_value_metadata_tree_for_json_dumps, - tuple_as_array=False, # Must be False to preserve tuples. - namedtuple_as_object=False, # Must be False to preserve namedtuples. + _value_metadata_tree_for_json_dumps(tree), ) diff --git a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py index a143fd2a6..d0c4b4766 100644 --- a/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py +++ b/checkpoint/orbax/checkpoint/_src/metadata/tree_test.py @@ -90,6 +90,65 @@ def test_as_nested_tree( restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() chex.assert_trees_all_equal(restored_tree_metadata, expected_tree_metadata) + @parameterized.product( + support_rich_types=(True, False), + ) + def test_integer_dict_keys(self, support_rich_types): + tree = {0: 'a', 1: 'b'} + pytree_metadata_options = tree_metadata_lib.PyTreeMetadataOptions( + support_rich_types=support_rich_types + ) + original_internal_tree_metadata = InternalTreeMetadata.build( + param_infos=_to_param_infos( + tree, pytree_metadata_options=pytree_metadata_options + ), + pytree_metadata_options=pytree_metadata_options, + ) + json_object = original_internal_tree_metadata.to_json() + restored_internal_tree_metadata = InternalTreeMetadata.from_json( + json_object, pytree_metadata_options=pytree_metadata_options + ) + restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() + self.assertSequenceEqual([0, 1], list(restored_tree_metadata.keys())) + + @parameterized.product( + support_rich_types=(True, False), + ) + def test_missing_key_python_type(self, support_rich_types): + # JSON representation of a tree {'1': 1} with missing key_python_type using + # both data-rich and non data-rich format. + json_object = { + 'tree_metadata': { + "('1',)": { + 'key_metadata': ({'key': '1', 'key_type': 2},), + 'value_metadata': { + 'value_type': 'scalar', + 'skip_deserialize': False, + }, + } + }, + 'use_zarr3': False, + 'store_array_data_equal_to_fill_value': False, + 'custom_metadata': None, + } + if support_rich_types: + json_object['value_metadata_tree'] = ( + '{"1": {"value": {"category": "custom",' + ' "clazz": "ValueMetadataEntry", "data": {"value_type": "scalar",' + ' "skip_deserialize": false}}}}' + ) + pytree_metadata_options = tree_metadata_lib.PyTreeMetadataOptions( + support_rich_types=support_rich_types + ) + restored_internal_tree_metadata = InternalTreeMetadata.from_json( + json_object, + pytree_metadata_options=pytree_metadata_options, + ) + restored_tree_metadata = restored_internal_tree_metadata.as_nested_tree() + # Should fall back to string keys when key_python_type is missing. + self.assertIn('1', restored_tree_metadata) + self.assertSequenceEqual(['1'], list(restored_tree_metadata.keys())) + @parameterized.product( test_pytree=test_tree_utils.TEST_PYTREES, pytree_metadata_options_switch=[ diff --git a/checkpoint/orbax/checkpoint/_src/tree/utils.py b/checkpoint/orbax/checkpoint/_src/tree/utils.py index 069d1d6b7..1d2e4a09a 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/utils.py +++ b/checkpoint/orbax/checkpoint/_src/tree/utils.py @@ -86,7 +86,7 @@ def get_key_name(key: Any) -> Union[int, str]: if isinstance(key, jax.tree_util.SequenceKey): return key.idx elif isinstance(key, jax.tree_util.DictKey): - return str(key.key) + return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name elif isinstance(key, jax.tree_util.FlattenedIndexKey):