|
14 | 14 |
|
15 | 15 | """Tree utilities."""
|
16 | 16 |
|
17 |
| -from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple, Union |
| 17 | +from collections import abc |
| 18 | +from typing import Any, Callable, Generic, Mapping, NamedTuple, Optional, Protocol, Tuple, TypeVar, Union |
18 | 19 |
|
19 | 20 | import jax
|
20 | 21 | from orbax.checkpoint._src.arrays import abstract_arrays
|
21 |
| -from orbax.checkpoint._src.tree import types as tree_types |
22 | 22 |
|
23 | 23 |
|
24 |
| -PyTree = tree_types.PyTree |
| 24 | +T = TypeVar('T') |
| 25 | + |
| 26 | +PyTree = Any |
| 27 | +# This won't help the type checker but at least allows us to use types to |
| 28 | +# document things like `PyTreeOf[ArrayDesc]`. |
| 29 | +PyTreeOf = PyTree | T |
25 | 30 | PyTreeKey = (
|
26 | 31 | jax.tree_util.SequenceKey
|
27 | 32 | | jax.tree_util.DictKey
|
|
33 | 38 | to_shape_dtype_struct = abstract_arrays.to_shape_dtype_struct
|
34 | 39 |
|
35 | 40 |
|
| 41 | +def is_jax_internal_node(x: Any) -> bool: |
| 42 | + return not jax.tree_util.all_leaves([x]) |
| 43 | + |
| 44 | + |
| 45 | +def is_jax_internal_leaf(x: Any) -> bool: |
| 46 | + return jax.tree_util.all_leaves([x]) |
| 47 | + |
| 48 | + |
| 49 | +def is_jax_internal_leaf_or_none(t: Any) -> bool: |
| 50 | + return is_jax_internal_leaf(t) or t is None |
| 51 | + |
| 52 | + |
| 53 | +def _internal_node_as_dict(x: Any) -> Mapping[str, Any]: |
| 54 | + keys_and_children, _ = tree_flatten_with_path_one_level(x) |
| 55 | + return {jax.tree_util.keystr(k): v for k, v in keys_and_children} |
| 56 | + |
| 57 | + |
36 | 58 | def isinstance_of_namedtuple(value: Any) -> bool:
|
37 | 59 | """Determines if the `value` is a NamedTuple."""
|
38 | 60 | return isinstance(value, tuple) and hasattr(value, '_fields')
|
@@ -458,3 +480,195 @@ def tree_difference(
|
458 | 480 | )
|
459 | 481 |
|
460 | 482 | return None
|
| 483 | + |
| 484 | + |
| 485 | +class TrimmedStructureCallback(Protocol, Generic[T]): |
| 486 | + |
| 487 | + def __call__( |
| 488 | + self, |
| 489 | + path: tuple[str | int, ...], |
| 490 | + structure: PyTreeOf[T], |
| 491 | + ) -> None: |
| 492 | + ... |
| 493 | + |
| 494 | + |
| 495 | +# TODO(b/407092826): Substitute for full PartsOf version later. |
| 496 | +def tree_trim( |
| 497 | + template: PyTreeOf[Any], |
| 498 | + structure: PyTreeOf[T], |
| 499 | + *, |
| 500 | + trimmed_structure_callback: TrimmedStructureCallback[T] | None = None, |
| 501 | + strict: bool = True, |
| 502 | +) -> PyTreeOf[T]: |
| 503 | + """Removes nodes in `structure` so that its shape matches that of `template`. |
| 504 | +
|
| 505 | + Only dictionary entries are trimmed; sequences are unchanged and the length |
| 506 | + of a sequence node in `structure` must match that of the corresponding node |
| 507 | + in `template`. |
| 508 | +
|
| 509 | + If `not strict`, any subtree of a mapping or named tuple node of `template` |
| 510 | + that is missing from the corresponding node of `structure` will be replaced |
| 511 | + with an appropriately-shaped subtree full of `...` placeholders (Ellipsis) |
| 512 | + instead of causing an error. In this mode, the tree structure of the result |
| 513 | + is guaranteed to match the tree structure of `template`. |
| 514 | +
|
| 515 | + Args: |
| 516 | + template: The tree whose shape is to be matched. |
| 517 | + structure: The tree to be trimmed. |
| 518 | + trimmed_structure_callback: If present, will be called with the path to, and |
| 519 | + value of, any node that is removed from `structure`. |
| 520 | + strict: Require every element of `template` to be matched by an element of |
| 521 | + `structure`. |
| 522 | +
|
| 523 | + Returns: |
| 524 | + A subset of `structure` that has the same shape as `template`. |
| 525 | +
|
| 526 | + Raises: |
| 527 | + TypeError: If the type of a node in `structure` does not match the |
| 528 | + type of the corresponding node in `template`. |
| 529 | + ValueError: If keys in a dictionary node in `template` are not present |
| 530 | + in the corresponding node in `structure`, or if the length of a sequence |
| 531 | + node in `structure` does not match the length of the corresponding |
| 532 | + sequence node in `template`, or if an internal node that isn't a |
| 533 | + sequence or dictionary is encountered. |
| 534 | + """ |
| 535 | + result = _tree_trim_impl( |
| 536 | + template, |
| 537 | + structure, |
| 538 | + trimmed_structure_callback=trimmed_structure_callback, |
| 539 | + strict=strict, |
| 540 | + ) |
| 541 | + return result |
| 542 | + |
| 543 | + |
| 544 | +def _tree_trim_impl( |
| 545 | + template: PyTreeOf[Any], |
| 546 | + structure: PyTreeOf[T], |
| 547 | + *, |
| 548 | + trimmed_structure_callback: TrimmedStructureCallback[T] | None = None, |
| 549 | + strict: bool = True, |
| 550 | +) -> PyTreeOf[T]: |
| 551 | + """Implementation of `tree_trim()` that always returns a `PartsOf`.""" |
| 552 | + |
| 553 | + # This is nested so as to capture `trimmed_structure_callback`. |
| 554 | + def _tree_trim( |
| 555 | + path: tuple[str | int, ...], |
| 556 | + template: PyTreeOf[Any], |
| 557 | + structure: PyTreeOf[T], |
| 558 | + ) -> PyTreeOf[T]: |
| 559 | + match template: |
| 560 | + # This wants to be `case abc.Mapping()` but http://b/283787842. |
| 561 | + case mapping if isinstance(mapping, abc.Mapping): |
| 562 | + if isinstance_of_namedtuple(structure): |
| 563 | + structure_dict = structure._asdict() # pytype:disable=attribute-error |
| 564 | + elif isinstance(structure, abc.Mapping): |
| 565 | + structure_dict = structure |
| 566 | + elif structure is None: |
| 567 | + structure_dict = {} |
| 568 | + else: |
| 569 | + raise TypeError( |
| 570 | + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' |
| 571 | + ) |
| 572 | + |
| 573 | + keep_items = [] |
| 574 | + drop_items = [] |
| 575 | + placeholder_items = [] |
| 576 | + |
| 577 | + if missing := [k for k in template if k not in structure_dict]: |
| 578 | + if strict: |
| 579 | + raise ValueError( |
| 580 | + f'{path}: missing {len(missing)} ' |
| 581 | + f'keys, including: {missing[:10]}' |
| 582 | + ) |
| 583 | + else: |
| 584 | + # Fill the result with placeholders |
| 585 | + placeholder_items.extend( |
| 586 | + (k, jax.tree.map(lambda x: ..., template[k])) for k in missing |
| 587 | + ) |
| 588 | + |
| 589 | + for k, n in structure_dict.items(): |
| 590 | + (keep_items if k in template else drop_items).append((k, n)) |
| 591 | + |
| 592 | + if trimmed_structure_callback: |
| 593 | + for k, n in drop_items: |
| 594 | + trimmed_structure_callback((*path, k), n) |
| 595 | + |
| 596 | + keep_dict = { |
| 597 | + k: _tree_trim((*path, k), template[k], v) for k, v in keep_items |
| 598 | + } |
| 599 | + return type(template)((*keep_dict.items(), *placeholder_items)) # pytype:disable=wrong-arg-count |
| 600 | + case named_tuple if isinstance_of_namedtuple(named_tuple): |
| 601 | + if structure is None: |
| 602 | + structure = () |
| 603 | + if isinstance(structure, abc.Mapping): |
| 604 | + children_dict = _tree_trim(path, named_tuple._asdict(), structure) |
| 605 | + return type(template)(**children_dict) |
| 606 | + elif isinstance(structure, abc.Sequence): |
| 607 | + children_sequence = _tree_trim(path, tuple(named_tuple), structure) |
| 608 | + return type(template)(*children_sequence) |
| 609 | + else: |
| 610 | + raise TypeError( |
| 611 | + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' |
| 612 | + ) |
| 613 | + # This wants to be `case abc.Sequence()` but http://b/283787842. |
| 614 | + case sequence if isinstance(sequence, abc.Sequence): |
| 615 | + if structure is None: |
| 616 | + structure = () |
| 617 | + elif not isinstance(structure, abc.Sequence): |
| 618 | + raise TypeError( |
| 619 | + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' |
| 620 | + ) |
| 621 | + if len(structure) != len(template): |
| 622 | + raise ValueError( |
| 623 | + f'{path}: length mismatch: {len(template)} vs {len(structure)}.' |
| 624 | + ) |
| 625 | + elements = ( |
| 626 | + _tree_trim((*path, i), t, s) |
| 627 | + for i, (t, s) in enumerate(zip(template, structure)) |
| 628 | + ) |
| 629 | + return type(template)(elements) # pytype:disable=wrong-arg-count |
| 630 | + case n if n is not None and is_jax_internal_node(n): |
| 631 | + s_children_dict = _internal_node_as_dict(structure) |
| 632 | + |
| 633 | + t_keys_and_children, t_tree_def = tree_flatten_with_path_one_level( |
| 634 | + template |
| 635 | + ) |
| 636 | + t_children_dict = { |
| 637 | + jax.tree_util.keystr(k): v for k, v in t_keys_and_children |
| 638 | + } |
| 639 | + |
| 640 | + # Note: unlike other cases, this does not treat the children |
| 641 | + # individually. Instead we have effectively cast the structure and |
| 642 | + # the template to mappings and will deal with them in their entirety |
| 643 | + # by reusing the mapping case. |
| 644 | + children_dict = _tree_trim(path, t_children_dict, s_children_dict) |
| 645 | + # Now cast back to the result type. |
| 646 | + children = [ |
| 647 | + children_dict[jax.tree_util.keystr(k)] |
| 648 | + for k, _ in t_keys_and_children |
| 649 | + ] |
| 650 | + return jax.tree_util.tree_unflatten(t_tree_def, children) |
| 651 | + case None: |
| 652 | + # None is special: it's the only type of template tree node that can |
| 653 | + # match both leaves and internal nodes of the structure to be trimmed. |
| 654 | + if is_jax_internal_leaf(structure): |
| 655 | + if trimmed_structure_callback: |
| 656 | + trimmed_structure_callback(path, structure) |
| 657 | + return None |
| 658 | + else: |
| 659 | + # Make sure any callback is called appropriately on all elements of |
| 660 | + # `structure`. |
| 661 | + _tree_trim(path, {}, structure) |
| 662 | + return None |
| 663 | + case v if is_jax_internal_leaf(v): |
| 664 | + if not is_jax_internal_leaf_or_none(structure): |
| 665 | + raise TypeError( |
| 666 | + f'{path}: type mismatch: {type(template)} vs {type(structure)}.' |
| 667 | + ) |
| 668 | + return structure |
| 669 | + case _: |
| 670 | + raise TypeError( |
| 671 | + f'{path}: Unknown internal node type {type(structure)}.' |
| 672 | + ) |
| 673 | + |
| 674 | + return _tree_trim((), template, structure) |
0 commit comments