Skip to content

Commit df38ce8

Browse files
BlaziusMaximusOrbax Authors
authored and
Orbax Authors
committed
Add tree_trim function for filtering a PyTree by another's structure.
PiperOrigin-RevId: 741639251
1 parent 67800d4 commit df38ce8

File tree

4 files changed

+460
-3
lines changed

4 files changed

+460
-3
lines changed

checkpoint/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010
### Added
1111

1212
- #v1 Add metadata free functions.
13+
- `tree_trim` function for filtering a PyTree by another's structure.
1314

1415
## [0.11.10] - 2025-03-20
1516

checkpoint/orbax/checkpoint/_src/tree/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ py_test(
2222
deps = [
2323
":utils",
2424
"//checkpoint/orbax/checkpoint:test_utils",
25+
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
2526
"//checkpoint/orbax/checkpoint/_src/testing:test_tree_utils",
2627
],
2728
)

checkpoint/orbax/checkpoint/_src/tree/utils.py

+217-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414

1515
"""Tree utilities."""
1616

17-
from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple, Union
17+
from collections import abc
18+
from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Tuple, TypeVar, Union
1819

1920
import jax
2021
from orbax.checkpoint._src.arrays import abstract_arrays
21-
from orbax.checkpoint._src.tree import types as tree_types
2222

2323

24-
PyTree = tree_types.PyTree
24+
T = TypeVar('T')
25+
26+
PyTree = Any
27+
# This won't help the type checker but at least allows us to use types to
28+
# document things like `PyTreeOf[ArrayDesc]`.
29+
PyTreeOf = PyTree | T
2530
PyTreeKey = (
2631
jax.tree_util.SequenceKey
2732
| jax.tree_util.DictKey
@@ -33,6 +38,23 @@
3338
to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct
3439

3540

41+
def is_jax_internal_node(x: Any) -> bool:
42+
return not jax.tree_util.all_leaves([x])
43+
44+
45+
def is_jax_internal_leaf(x: Any) -> bool:
46+
return jax.tree_util.all_leaves([x])
47+
48+
49+
def is_jax_internal_leaf_or_none(t: Any) -> bool:
50+
return is_jax_internal_leaf(t) or t is None
51+
52+
53+
def _internal_node_as_dict(x: Any) -> Mapping[str, Any]:
54+
keys_and_children, _ = tree_flatten_with_path_one_level(x)
55+
return {jax.tree_util.keystr(k): v for k, v in keys_and_children}
56+
57+
3658
def isinstance_of_namedtuple(value: Any) -> bool:
3759
"""Determines if the `value` is a NamedTuple."""
3860
return isinstance(value, tuple) and hasattr(value, '_fields')
@@ -458,3 +480,195 @@ def tree_difference(
458480
)
459481

460482
return None
483+
484+
485+
class TrimmedStructureCallback(Protocol, Generic[T]):
486+
487+
def __call__(
488+
self,
489+
path: tuple[str | int, ...],
490+
structure: PyTreeOf[T],
491+
) -> None:
492+
...
493+
494+
495+
# TODO(b/407092826): Substitute for full PartsOf version later.
496+
def tree_trim(
497+
template: PyTreeOf[Any],
498+
structure: PyTreeOf[T],
499+
*,
500+
trimmed_structure_callback: TrimmedStructureCallback[T] | None = None,
501+
strict: bool = True,
502+
) -> PyTreeOf[T]:
503+
"""Removes nodes in `structure` so that its shape matches that of `template`.
504+
505+
Only dictionary entries are trimmed; sequences are unchanged and the length
506+
of a sequence node in `structure` must match that of the corresponding node
507+
in `template`.
508+
509+
If `not strict`, any subtree of a mapping or named tuple node of `template`
510+
that is missing from the corresponding node of `structure` will be replaced
511+
with an appropriately-shaped subtree full of `...` placeholders (Ellipsis)
512+
instead of causing an error. In this mode, the tree structure of the result
513+
is guaranteed to match the tree structure of `template`.
514+
515+
Args:
516+
template: The tree whose shape is to be matched.
517+
structure: The tree to be trimmed.
518+
trimmed_structure_callback: If present, will be called with the path to, and
519+
value of, any node that is removed from `structure`.
520+
strict: Require every element of `template` to be matched by an element of
521+
`structure`.
522+
523+
Returns:
524+
A subset of `structure` that has the same shape as `template`.
525+
526+
Raises:
527+
TypeError: If the type of a node in `structure` does not match the
528+
type of the corresponding node in `template`.
529+
ValueError: If keys in a dictionary node in `template` are not present
530+
in the corresponding node in `structure`, or if the length of a sequence
531+
node in `structure` does not match the length of the corresponding
532+
sequence node in `template`, or if an internal node that isn't a
533+
sequence or dictionary is encountered.
534+
"""
535+
result = _tree_trim_impl(
536+
template,
537+
structure,
538+
trimmed_structure_callback=trimmed_structure_callback,
539+
strict=strict,
540+
)
541+
return result
542+
543+
544+
def _tree_trim_impl(
545+
template: PyTreeOf[Any],
546+
structure: PyTreeOf[T],
547+
*,
548+
trimmed_structure_callback: TrimmedStructureCallback[T] | None = None,
549+
strict: bool = True,
550+
) -> PyTreeOf[T]:
551+
"""Implementation of `tree_trim()` that always returns a `PartsOf`."""
552+
553+
# This is nested so as to capture `trimmed_structure_callback`.
554+
def _tree_trim(
555+
path: tuple[str | int, ...],
556+
template: PyTreeOf[Any],
557+
structure: PyTreeOf[T],
558+
) -> PyTreeOf[T]:
559+
match template:
560+
# This wants to be `case abc.Mapping()` but http://b/283787842.
561+
case mapping if isinstance(mapping, abc.Mapping):
562+
if isinstance_of_namedtuple(structure):
563+
structure_dict = structure._asdict() # pytype:disable=attribute-error
564+
elif isinstance(structure, abc.Mapping):
565+
structure_dict = structure
566+
elif structure is None:
567+
structure_dict = {}
568+
else:
569+
raise TypeError(
570+
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
571+
)
572+
573+
keep_items = []
574+
drop_items = []
575+
placeholder_items = []
576+
577+
if missing := [k for k in template if k not in structure_dict]:
578+
if strict:
579+
raise ValueError(
580+
f'{path}: missing {len(missing)} '
581+
f'keys, including: {missing[:10]}'
582+
)
583+
else:
584+
# Fill the result with placeholders
585+
placeholder_items.extend(
586+
(k, jax.tree.map(lambda x: ..., template[k])) for k in missing
587+
)
588+
589+
for k, n in structure_dict.items():
590+
(keep_items if k in template else drop_items).append((k, n))
591+
592+
if trimmed_structure_callback:
593+
for k, n in drop_items:
594+
trimmed_structure_callback((*path, k), n)
595+
596+
keep_dict = {
597+
k: _tree_trim((*path, k), template[k], v) for k, v in keep_items
598+
}
599+
return type(template)((*keep_dict.items(), *placeholder_items)) # pytype:disable=wrong-arg-count
600+
case named_tuple if isinstance_of_namedtuple(named_tuple):
601+
if structure is None:
602+
structure = ()
603+
if isinstance(structure, abc.Mapping):
604+
children_dict = _tree_trim(path, named_tuple._asdict(), structure)
605+
return type(template)(**children_dict)
606+
elif isinstance(structure, abc.Sequence):
607+
children_sequence = _tree_trim(path, tuple(named_tuple), structure)
608+
return type(template)(*children_sequence)
609+
else:
610+
raise TypeError(
611+
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
612+
)
613+
# This wants to be `case abc.Sequence()` but http://b/283787842.
614+
case sequence if isinstance(sequence, abc.Sequence):
615+
if structure is None:
616+
structure = ()
617+
elif not isinstance(structure, abc.Sequence):
618+
raise TypeError(
619+
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
620+
)
621+
if len(structure) != len(template):
622+
raise ValueError(
623+
f'{path}: length mismatch: {len(template)} vs {len(structure)}.'
624+
)
625+
elements = (
626+
_tree_trim((*path, i), t, s)
627+
for i, (t, s) in enumerate(zip(template, structure))
628+
)
629+
return type(template)(elements) # pytype:disable=wrong-arg-count
630+
case n if n is not None and is_jax_internal_node(n):
631+
s_children_dict = _internal_node_as_dict(structure)
632+
633+
t_keys_and_children, t_tree_def = tree_flatten_with_path_one_level(
634+
template
635+
)
636+
t_children_dict = {
637+
jax.tree_util.keystr(k): v for k, v in t_keys_and_children
638+
}
639+
640+
# Note: unlike other cases, this does not treat the children
641+
# individually. Instead we have effectively cast the structure and
642+
# the template to mappings and will deal with them in their entirety
643+
# by reusing the mapping case.
644+
children_dict = _tree_trim(path, t_children_dict, s_children_dict)
645+
# Now cast back to the result type.
646+
children = [
647+
children_dict[jax.tree_util.keystr(k)]
648+
for k, _ in t_keys_and_children
649+
]
650+
return jax.tree_util.tree_unflatten(t_tree_def, children)
651+
case None:
652+
# None is special: it's the only type of template tree node that can
653+
# match both leaves and internal nodes of the structure to be trimmed.
654+
if is_jax_internal_leaf(structure):
655+
if trimmed_structure_callback:
656+
trimmed_structure_callback(path, structure)
657+
return None
658+
else:
659+
# Make sure any callback is called appropriately on all elements of
660+
# `structure`.
661+
_tree_trim(path, {}, structure)
662+
return None
663+
case v if is_jax_internal_leaf(v):
664+
if not is_jax_internal_leaf_or_none(structure):
665+
raise TypeError(
666+
f'{path}: type mismatch: {type(template)} vs {type(structure)}.'
667+
)
668+
return structure
669+
case _:
670+
raise TypeError(
671+
f'{path}: Unknown internal node type {type(structure)}.'
672+
)
673+
674+
return _tree_trim((), template, structure)

0 commit comments

Comments
 (0)