Skip to content

Add tree_trim function for filtering a PyTree by another's structure. #1764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- #v1 Add metadata free functions.
- `tree_trim` function for filtering a PyTree by another's structure.

## [0.11.10] - 2025-03-20

Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/_src/tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ py_test(
deps = [
":utils",
"//checkpoint/orbax/checkpoint:test_utils",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
],
)
220 changes: 217 additions & 3 deletions checkpoint/orbax/checkpoint/_src/tree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

"""Tree utilities."""

from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple, Union
from collections import abc
from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Tuple, TypeVar, Union

import jax
from orbax.checkpoint._src.arrays import abstract_arrays
from orbax.checkpoint._src.tree import types as tree_types


PyTree = tree_types.PyTree
T = TypeVar('T')

PyTree = Any
# This won't help the type checker but at least allows us to use types to
# document things like `PyTreeOf[ArrayDesc]`.
PyTreeOf = PyTree | T
PyTreeKey = (
jax.tree_util.SequenceKey
| jax.tree_util.DictKey
Expand All @@ -33,6 +38,23 @@
to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct


def is_jax_internal_node(x: Any) -> bool:
return not jax.tree_util.all_leaves([x])


def is_jax_internal_leaf(x: Any) -> bool:
return jax.tree_util.all_leaves([x])


def is_jax_internal_leaf_or_none(t: Any) -> bool:
return is_jax_internal_leaf(t) or t is None


def _internal_node_as_dict(x: Any) -> Mapping[str, Any]:
keys_and_children, _ = tree_flatten_with_path_one_level(x)
return {jax.tree_util.keystr(k): v for k, v in keys_and_children}


def isinstance_of_namedtuple(value: Any) -> bool:
"""Determines if the `value` is a NamedTuple."""
return isinstance(value, tuple) and hasattr(value, '_fields')
Expand Down Expand Up @@ -458,3 +480,195 @@ def tree_difference(
)

return None


class TrimmedStructureCallback(Protocol, Generic[T]):

def __call__(
self,
path: tuple[str | int, ...],
structure: PyTreeOf[T],
) -> None:
...


# TODO(b/407092826): Substitute for full PartsOf version later.
def tree_trim(
template: PyTreeOf[Any],
structure: PyTreeOf[T],
*,
trimmed_structure_callback: TrimmedStructureCallback[T] | None = None,
strict: bool = True,
) -> PyTreeOf[T]:
"""Removes nodes in `structure` so that its shape matches that of `template`.

Only dictionary entries are trimmed; sequences are unchanged and the length
of a sequence node in `structure` must match that of the corresponding node
in `template`.

If `not strict`, any subtree of a mapping or named tuple node of `template`
that is missing from the corresponding node of `structure` will be replaced
with an appropriately-shaped subtree full of `...` placeholders (Ellipsis)
instead of causing an error. In this mode, the tree structure of the result
is guaranteed to match the tree structure of `template`.

Args:
template: The tree whose shape is to be matched.
structure: The tree to be trimmed.
trimmed_structure_callback: If present, will be called with the path to, and
value of, any node that is removed from `structure`.
strict: Require every element of `template` to be matched by an element of
`structure`.

Returns:
A subset of `structure` that has the same shape as `template`.

Raises:
TypeError: If the type of a node in `structure` does not match the
type of the corresponding node in `template`.
ValueError: If keys in a dictionary node in `template` are not present
in the corresponding node in `structure`, or if the length of a sequence
node in `structure` does not match the length of the corresponding
sequence node in `template`, or if an internal node that isn't a
sequence or dictionary is encountered.
"""
result = _tree_trim_impl(
template,
structure,
trimmed_structure_callback=trimmed_structure_callback,
strict=strict,
)
return result


def _tree_trim_impl(
template: PyTreeOf[Any],
structure: PyTreeOf[T],
*,
trimmed_structure_callback: TrimmedStructureCallback[T] | None = None,
strict: bool = True,
) -> PyTreeOf[T]:
"""Implementation of `tree_trim()` that always returns a `PartsOf`."""

# This is nested so as to capture `trimmed_structure_callback`.
def _tree_trim(
path: tuple[str | int, ...],
template: PyTreeOf[Any],
structure: PyTreeOf[T],
) -> PyTreeOf[T]:
match template:
# This wants to be `case abc.Mapping()` but http://b/283787842.
case mapping if isinstance(mapping, abc.Mapping):
if isinstance_of_namedtuple(structure):
structure_dict = structure._asdict() # pytype:disable=attribute-error
elif isinstance(structure, abc.Mapping):
structure_dict = structure
elif structure is None:
structure_dict = {}
else:
raise TypeError(
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
)

keep_items = []
drop_items = []
placeholder_items = []

if missing := [k for k in template if k not in structure_dict]:
if strict:
raise ValueError(
f'{path}: missing {len(missing)} '
f'keys, including: {missing[:10]}'
)
else:
# Fill the result with placeholders
placeholder_items.extend(
(k, jax.tree.map(lambda x: ..., template[k])) for k in missing
)

for k, n in structure_dict.items():
(keep_items if k in template else drop_items).append((k, n))

if trimmed_structure_callback:
for k, n in drop_items:
trimmed_structure_callback((*path, k), n)

keep_dict = {
k: _tree_trim((*path, k), template[k], v) for k, v in keep_items
}
return type(template)((*keep_dict.items(), *placeholder_items)) # pytype:disable=wrong-arg-count
case named_tuple if isinstance_of_namedtuple(named_tuple):
if structure is None:
structure = ()
if isinstance(structure, abc.Mapping):
children_dict = _tree_trim(path, named_tuple._asdict(), structure)
return type(template)(**children_dict)
elif isinstance(structure, abc.Sequence):
children_sequence = _tree_trim(path, tuple(named_tuple), structure)
return type(template)(*children_sequence)
else:
raise TypeError(
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
)
# This wants to be `case abc.Sequence()` but http://b/283787842.
case sequence if isinstance(sequence, abc.Sequence):
if structure is None:
structure = ()
elif not isinstance(structure, abc.Sequence):
raise TypeError(
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
)
if len(structure) != len(template):
raise ValueError(
f'{path}: length mismatch: {len(template)} vs {len(structure)}.'
)
elements = (
_tree_trim((*path, i), t, s)
for i, (t, s) in enumerate(zip(template, structure))
)
return type(template)(elements) # pytype:disable=wrong-arg-count
case n if n is not None and is_jax_internal_node(n):
s_children_dict = _internal_node_as_dict(structure)

t_keys_and_children, t_tree_def = tree_flatten_with_path_one_level(
template
)
t_children_dict = {
jax.tree_util.keystr(k): v for k, v in t_keys_and_children
}

# Note: unlike other cases, this does not treat the children
# individually. Instead we have effectively cast the structure and
# the template to mappings and will deal with them in their entirety
# by reusing the mapping case.
children_dict = _tree_trim(path, t_children_dict, s_children_dict)
# Now cast back to the result type.
children = [
children_dict[jax.tree_util.keystr(k)]
for k, _ in t_keys_and_children
]
return jax.tree_util.tree_unflatten(t_tree_def, children)
case None:
# None is special: it's the only type of template tree node that can
# match both leaves and internal nodes of the structure to be trimmed.
if is_jax_internal_leaf(structure):
if trimmed_structure_callback:
trimmed_structure_callback(path, structure)
return None
else:
# Make sure any callback is called appropriately on all elements of
# `structure`.
_tree_trim(path, {}, structure)
return None
case v if is_jax_internal_leaf(v):
if not is_jax_internal_leaf_or_none(structure):
raise TypeError(
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
)
return structure
case _:
raise TypeError(
f'{path}: Unknown internal node type {type(structure)}.'
)

return _tree_trim((), template, structure)
Loading