From 8cdf3b797ea026e0fe1f2d63e06b409286442ac3 Mon Sep 17 00:00:00 2001 From: Adam Cogdell Date: Fri, 28 Mar 2025 13:58:53 -0700 Subject: [PATCH] Add `tree_trim` function for filtering a PyTree by another's structure. PiperOrigin-RevId: 741639251 --- checkpoint/CHANGELOG.md | 1 + checkpoint/orbax/checkpoint/_src/tree/BUILD | 1 + .../orbax/checkpoint/_src/tree/utils.py | 220 +++++++++++++++- .../orbax/checkpoint/_src/tree/utils_test.py | 241 ++++++++++++++++++ 4 files changed, 460 insertions(+), 3 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 6f2ae899d..d4cb4cf1c 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - #v1 Add metadata free functions. +- `tree_trim` function for filtering a PyTree by another's structure. ## [0.11.10] - 2025-03-20 diff --git a/checkpoint/orbax/checkpoint/_src/tree/BUILD b/checkpoint/orbax/checkpoint/_src/tree/BUILD index ebec4049a..7788b6697 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/BUILD +++ b/checkpoint/orbax/checkpoint/_src/tree/BUILD @@ -22,6 +22,7 @@ py_test( deps = [ ":utils", "//checkpoint/orbax/checkpoint:test_utils", + "//checkpoint/orbax/checkpoint/_src/serialization:type_handlers", "//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils", ], ) diff --git a/checkpoint/orbax/checkpoint/_src/tree/utils.py b/checkpoint/orbax/checkpoint/_src/tree/utils.py index 7fc967413..d4dc08d6c 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/utils.py +++ b/checkpoint/orbax/checkpoint/_src/tree/utils.py @@ -14,14 +14,19 @@ """Tree utilities.""" -from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple, Union +from collections import abc +from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Tuple, TypeVar, Union import jax from orbax.checkpoint._src.arrays import abstract_arrays -from orbax.checkpoint._src.tree import types as tree_types -PyTree = tree_types.PyTree +T = TypeVar('T') + +PyTree = Any +# This won't help the type checker but at least allows us to use types to +# document things like `PyTreeOf[ArrayDesc]`. +PyTreeOf = PyTree | T PyTreeKey = ( jax.tree_util.SequenceKey | jax.tree_util.DictKey @@ -33,6 +38,23 @@ to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct +def is_jax_internal_node(x: Any) -> bool: + return not jax.tree_util.all_leaves([x]) + + +def is_jax_internal_leaf(x: Any) -> bool: + return jax.tree_util.all_leaves([x]) + + +def is_jax_internal_leaf_or_none(t: Any) -> bool: + return is_jax_internal_leaf(t) or t is None + + +def _internal_node_as_dict(x: Any) -> Mapping[str, Any]: + keys_and_children, _ = tree_flatten_with_path_one_level(x) + return {jax.tree_util.keystr(k): v for k, v in keys_and_children} + + def isinstance_of_namedtuple(value: Any) -> bool: """Determines if the `value` is a NamedTuple.""" return isinstance(value, tuple) and hasattr(value, '_fields') @@ -458,3 +480,195 @@ def tree_difference( ) return None + + +class TrimmedStructureCallback(Protocol, Generic[T]): + + def __call__( + self, + path: tuple[str | int, ...], + structure: PyTreeOf[T], + ) -> None: + ... + + +# TODO(b/407092826): Substitute for full PartsOf version later. +def tree_trim( + template: PyTreeOf[Any], + structure: PyTreeOf[T], + *, + trimmed_structure_callback: TrimmedStructureCallback[T] | None = None, + strict: bool = True, +) -> PyTreeOf[T]: + """Removes nodes in `structure` so that its shape matches that of `template`. + + Only dictionary entries are trimmed; sequences are unchanged and the length + of a sequence node in `structure` must match that of the corresponding node + in `template`. + + If `not strict`, any subtree of a mapping or named tuple node of `template` + that is missing from the corresponding node of `structure` will be replaced + with an appropriately-shaped subtree full of `...` placeholders (Ellipsis) + instead of causing an error. In this mode, the tree structure of the result + is guaranteed to match the tree structure of `template`. + + Args: + template: The tree whose shape is to be matched. + structure: The tree to be trimmed. + trimmed_structure_callback: If present, will be called with the path to, and + value of, any node that is removed from `structure`. + strict: Require every element of `template` to be matched by an element of + `structure`. + + Returns: + A subset of `structure` that has the same shape as `template`. + + Raises: + TypeError: If the type of a node in `structure` does not match the + type of the corresponding node in `template`. + ValueError: If keys in a dictionary node in `template` are not present + in the corresponding node in `structure`, or if the length of a sequence + node in `structure` does not match the length of the corresponding + sequence node in `template`, or if an internal node that isn't a + sequence or dictionary is encountered. + """ + result = _tree_trim_impl( + template, + structure, + trimmed_structure_callback=trimmed_structure_callback, + strict=strict, + ) + return result + + +def _tree_trim_impl( + template: PyTreeOf[Any], + structure: PyTreeOf[T], + *, + trimmed_structure_callback: TrimmedStructureCallback[T] | None = None, + strict: bool = True, +) -> PyTreeOf[T]: + """Implementation of `tree_trim()` that always returns a `PartsOf`.""" + + # This is nested so as to capture `trimmed_structure_callback`. + def _tree_trim( + path: tuple[str | int, ...], + template: PyTreeOf[Any], + structure: PyTreeOf[T], + ) -> PyTreeOf[T]: + match template: + # This wants to be `case abc.Mapping()` but http://b/283787842. + case mapping if isinstance(mapping, abc.Mapping): + if isinstance_of_namedtuple(structure): + structure_dict = structure._asdict() # pytype:disable=attribute-error + elif isinstance(structure, abc.Mapping): + structure_dict = structure + elif structure is None: + structure_dict = {} + else: + raise TypeError( + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' + ) + + keep_items = [] + drop_items = [] + placeholder_items = [] + + if missing := [k for k in template if k not in structure_dict]: + if strict: + raise ValueError( + f'{path}: missing {len(missing)} ' + f'keys, including: {missing[:10]}' + ) + else: + # Fill the result with placeholders + placeholder_items.extend( + (k, jax.tree.map(lambda x: ..., template[k])) for k in missing + ) + + for k, n in structure_dict.items(): + (keep_items if k in template else drop_items).append((k, n)) + + if trimmed_structure_callback: + for k, n in drop_items: + trimmed_structure_callback((*path, k), n) + + keep_dict = { + k: _tree_trim((*path, k), template[k], v) for k, v in keep_items + } + return type(template)((*keep_dict.items(), *placeholder_items)) # pytype:disable=wrong-arg-count + case named_tuple if isinstance_of_namedtuple(named_tuple): + if structure is None: + structure = () + if isinstance(structure, abc.Mapping): + children_dict = _tree_trim(path, named_tuple._asdict(), structure) + return type(template)(**children_dict) + elif isinstance(structure, abc.Sequence): + children_sequence = _tree_trim(path, tuple(named_tuple), structure) + return type(template)(*children_sequence) + else: + raise TypeError( + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' + ) + # This wants to be `case abc.Sequence()` but http://b/283787842. + case sequence if isinstance(sequence, abc.Sequence): + if structure is None: + structure = () + elif not isinstance(structure, abc.Sequence): + raise TypeError( + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' + ) + if len(structure) != len(template): + raise ValueError( + f'{path}: length mismatch: {len(template)} vs {len(structure)}.' + ) + elements = ( + _tree_trim((*path, i), t, s) + for i, (t, s) in enumerate(zip(template, structure)) + ) + return type(template)(elements) # pytype:disable=wrong-arg-count + case n if n is not None and is_jax_internal_node(n): + s_children_dict = _internal_node_as_dict(structure) + + t_keys_and_children, t_tree_def = tree_flatten_with_path_one_level( + template + ) + t_children_dict = { + jax.tree_util.keystr(k): v for k, v in t_keys_and_children + } + + # Note: unlike other cases, this does not treat the children + # individually. Instead we have effectively cast the structure and + # the template to mappings and will deal with them in their entirety + # by reusing the mapping case. + children_dict = _tree_trim(path, t_children_dict, s_children_dict) + # Now cast back to the result type. + children = [ + children_dict[jax.tree_util.keystr(k)] + for k, _ in t_keys_and_children + ] + return jax.tree_util.tree_unflatten(t_tree_def, children) + case None: + # None is special: it's the only type of template tree node that can + # match both leaves and internal nodes of the structure to be trimmed. + if is_jax_internal_leaf(structure): + if trimmed_structure_callback: + trimmed_structure_callback(path, structure) + return None + else: + # Make sure any callback is called appropriately on all elements of + # `structure`. + _tree_trim(path, {}, structure) + return None + case v if is_jax_internal_leaf(v): + if not is_jax_internal_leaf_or_none(structure): + raise TypeError( + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' + ) + return structure + case _: + raise TypeError( + f'{path}: Unknown internal node type {type(structure)}.' + ) + + return _tree_trim((), template, structure) diff --git a/checkpoint/orbax/checkpoint/_src/tree/utils_test.py b/checkpoint/orbax/checkpoint/_src/tree/utils_test.py index c1c39250a..42ee9a485 100644 --- a/checkpoint/orbax/checkpoint/_src/tree/utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/tree/utils_test.py @@ -15,6 +15,7 @@ """Test for utils module.""" from typing import Any, Mapping, NamedTuple, Sequence +from unittest import mock from absl.testing import absltest from absl.testing import parameterized @@ -24,6 +25,7 @@ import numpy as np import optax from orbax.checkpoint import test_utils +from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.testing import test_tree_utils from orbax.checkpoint._src.tree import utils as tree_utils @@ -466,5 +468,244 @@ def custom_is_leaf(x: tree_utils.PyTree) -> bool: ) +@flax.struct.dataclass +class FlaxRecord: + alpha: Any + beta: Any + + +@flax.struct.dataclass +class FlaxWiderRecord: + alpha: Any + beta: Any + gamma: Any + + +class TreeTrimTest(parameterized.TestCase): + + def test_recursively_trims_structure_to_match_template(self): + structure = { + 'a': 1, + 'b': [2, {'c': [3, 3.25, 3.5, 3.75], 'd': 4}], + 'e': (5, 6), + } + template = { + # drop ('a',) + 'b': [4, {'d': 16}], # drop ('b', 1, 'c') + 'e': (25, 36), + } + + dropped_subtree_callback = mock.Mock() + trimmed_structure = tree_utils.tree_trim( + template, structure, trimmed_structure_callback=dropped_subtree_callback + ) + + jax.tree.map( + lambda x, xx: self.assertEqual(x * x, xx), trimmed_structure, template + ) + + self.assertCountEqual( + [mock.call(('a',), 1), mock.call(('b', 1, 'c'), [3, 3.25, 3.5, 3.75])], + dropped_subtree_callback.call_args_list, + ) + + def test_does_not_copy_leaves(self): + structure = { + 'a': np.asarray(1), + 'b': [ + np.asarray(2), + { + 'c': [ + np.asarray(3), + np.asarray(3.25), + np.asarray(3.5), + np.asarray(3.75), + ], + 'd': np.asarray(4), + }, + ], + 'e': (np.asarray(5), np.asarray(6)), + } + template = { + # drop ('a',) + 'b': [4, {'d': 16}], # drop ('b', 1, 'c') + 'e': (25, 36), + } + + trimmed_structure = tree_utils.tree_trim(template, structure) + self.assertIs(structure['b'][0], trimmed_structure['b'][0]) + self.assertIs(structure['b'][1]['d'], trimmed_structure['b'][1]['d']) + self.assertIs(structure['e'][0], trimmed_structure['e'][0]) + self.assertIs(structure['e'][1], trimmed_structure['e'][1]) + + def test_preserves_mapping_type_specified_in_template(self): + + class MyDict(dict): + pass + + structure = [dict(a=1, b=2, c=3)] + template = [MyDict(a=1, b=4)] # drop (0, 'c') + + dropped_subtree_callback = mock.Mock() + trimmed_structure = tree_utils.tree_trim( + template, structure, trimmed_structure_callback=dropped_subtree_callback + ) + + self.assertIsInstance(trimmed_structure[0], MyDict) + + self.assertCountEqual( + [mock.call((0, 'c'), 3)], + dropped_subtree_callback.call_args_list, + ) + + def test_preserves_named_tuple_type_specified_in_template(self): + class MyNamedTuple(NamedTuple): + a: int + + structure = {'a': 1, 'b': (2,), 'c': 3} + template = {'b': MyNamedTuple(a=4), 'c': 9} # drop ('a',) + + dropped_subtree_callback = mock.Mock() + trimmed_structure = tree_utils.tree_trim( + template, structure, trimmed_structure_callback=dropped_subtree_callback + ) + + self.assertIsInstance(trimmed_structure['b'], MyNamedTuple) + + self.assertCountEqual( + [mock.call(('a',), 1)], + dropped_subtree_callback.call_args_list, + ) + + def test_can_trim_dict_structure_with_named_tuple_template(self): + # `tree_proto` encodes `NamedTuple` instances as dictionaries: + # + # + # Consequently, this case is important for safely loading state + # (particularly optimiser state) that contains `NamedTuple` nodes when + # `want_rich_internal_node_types=False` has been passed to + # `checkpoint.load_index()` (which is encouraged). + class MyNamedTuple(NamedTuple): + x: int + y: float + + structure = {'a': 1, 'b': {'x': 2, 'y': 5.0, 'q': 'dropme'}, 'c': 3} + template = { + 'b': MyNamedTuple(x=4, y=3.0), + 'c': 9, + } # drop ('a',), ('b', 'q') + + dropped_subtree_callback = mock.Mock() + trimmed_structure = tree_utils.tree_trim( + template, structure, trimmed_structure_callback=dropped_subtree_callback + ) + + self.assertIsInstance(trimmed_structure['b'], MyNamedTuple) + + self.assertCountEqual( + [mock.call(('a',), 1), mock.call(('b', 'q'), 'dropme')], + dropped_subtree_callback.call_args_list, + ) + + with self.subTest('requires_all_named_tuple_fields'): + del structure['b']['x'] + with self.assertRaisesRegex( + ValueError, r'\(\'b\',\): missing 1 keys, including: \[\'x\'\]' + ): + tree_utils.tree_trim(template, structure) + + def test_can_trim_flax_struct_with_other_flax_struct(self): + structure = FlaxWiderRecord( + alpha=[1, 2], + beta={'three': 3, 'four': 4}, + gamma=[5, 6], + ) + template = FlaxRecord( + alpha=[9, 9], + beta={'three': 9}, # drop 'four' + # Also drop 'gamma' + ) + + dropped_subtree_callback = mock.Mock() + trimmed_structure = tree_utils.tree_trim( + template, structure, trimmed_structure_callback=dropped_subtree_callback + ) + + self.assertIsInstance(trimmed_structure, FlaxRecord) + self.assertSameStructure( + FlaxRecord( + alpha=[1, 2], + beta={'three': 3}, + ), + trimmed_structure, + ) + + self.assertCountEqual( + [mock.call(('.gamma',), [5, 6]), mock.call(('.beta', 'four'), 4)], + dropped_subtree_callback.call_args_list, + ) + + @parameterized.parameters([ + ({}, 1), + (1, {}), + ([], 1), + (1, []), + ([], {}), + ({}, []), + ]) + def test_raises_if_nodes_have_mismatched_types(self, template, structure): + with self.assertRaisesRegex(TypeError, r'\(\'a\',\): type mismatch'): + tree_utils.tree_trim({'a': template}, {'a': structure}) + + def test_raises_if_structure_is_missing_keys(self): + with self.assertRaisesRegex( + ValueError, r'\(\'a\',\): missing 1 keys, including: \[\'b\'\]' + ): + tree_utils.tree_trim({'a': {'b': 1}}, {'a': {}}) + + def test_non_strict_inserts_placeholders_if_structure_is_missing_keys(self): + ph = type_handlers.PLACEHOLDER + + self.assertSameStructure( + {'a': {'b': (ph, ph)}}, + tree_utils.tree_trim({'a': {'b': (1, 2)}}, {'a': {}}, strict=False), + ) + self.assertSameStructure( + Record((ph, ph), [ph, ph]), + tree_utils.tree_trim(Record((1, 2), [3, 4]), {}, strict=False), + ) + + def test_raises_if_sequence_lengths_do_not_match(self): + with self.assertRaisesRegex( + ValueError, r'\(\'a\',\): length mismatch: 2 vs 1' + ): + tree_utils.tree_trim({'a': [1, 2]}, {'a': [1]}) + + def test_handles_none_leaves_in_template(self): + dropped_subtree_callback = mock.Mock() + trimmed = tree_utils.tree_trim( + template={'a': None, 'b': None, 'c': None}, + structure={'a': None, 'b': {'x': 2, 'y': 3}, 'c': 12}, + strict=False, + trimmed_structure_callback=dropped_subtree_callback, + ) + self.assertSameStructure( + {'a': None, 'b': None, 'c': None}, + trimmed, + ) + self.assertSameStructure( + {'a': None, 'b': None, 'c': None}, + trimmed, + ) + self.assertCountEqual( + [ + mock.call(('b', 'x'), 2), + mock.call(('b', 'y'), 3), + mock.call(('c',), 12), + ], + dropped_subtree_callback.call_args_list, + ) + + if __name__ == '__main__': absltest.main()