Skip to content

Commit 2a914d3

Browse files
fhoushmandGoogle-ML-Automation
authored andcommitted
breaking internal tests
Reverts ab600c3 PiperOrigin-RevId: 747865790
1 parent 3b359ba commit 2a914d3

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

jax/_src/tree_util.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from functools import partial
2222
import operator as op
2323
import textwrap
24-
from typing import Any, TypeVar, overload
24+
from typing import Any, NamedTuple, TypeVar, overload
2525

2626
from jax._src import traceback_util
2727
from jax._src.lib import pytree
@@ -762,6 +762,42 @@ def _simple_entrystr(key: KeyEntry) -> str:
762762
return str(key)
763763

764764

765+
# TODO(ivyzheng): remove this after another jaxlib release.
766+
class _RegistryWithKeypathsEntry(NamedTuple):
767+
flatten_with_keys: Callable[..., Any]
768+
unflatten_func: Callable[..., Any]
769+
770+
771+
def _register_keypaths(
772+
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
773+
) -> None:
774+
def flatten_with_keys(xs):
775+
children, treedef = _registry[ty].to_iter(xs)
776+
return list(zip(handler(xs), children)), treedef
777+
if ty in _registry:
778+
_registry_with_keypaths[ty] = _RegistryWithKeypathsEntry(
779+
flatten_with_keys, _registry[ty].from_iter
780+
)
781+
782+
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
783+
784+
_register_keypaths(
785+
tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
786+
)
787+
_register_keypaths(
788+
list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
789+
)
790+
_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs)))
791+
792+
_register_keypaths(
793+
collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys())
794+
)
795+
796+
_register_keypaths(
797+
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
798+
)
799+
800+
765801
@export
766802
def register_pytree_with_keys(
767803
nodetype: type[T],
@@ -831,6 +867,9 @@ def flatten_func_impl(tree):
831867
register_pytree_node(
832868
nodetype, flatten_func, unflatten_func, flatten_with_keys
833869
)
870+
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
871+
flatten_with_keys, unflatten_func
872+
)
834873

835874

836875
@export
@@ -1023,6 +1062,11 @@ def register_dataclass(
10231062
msg += f" Unexpected fields: {unexpected}."
10241063
raise ValueError(msg)
10251064

1065+
def flatten_with_keys(x):
1066+
meta = tuple(getattr(x, name) for name in meta_fields)
1067+
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
1068+
return data, meta
1069+
10261070
def unflatten_func(meta, data):
10271071
meta_args = tuple(zip(meta_fields, meta))
10281072
data_args = tuple(zip(data_fields, data))
@@ -1038,6 +1082,9 @@ def flatten_func(x):
10381082
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
10391083
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
10401084
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
1085+
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
1086+
flatten_with_keys, unflatten_func
1087+
)
10411088
return nodetype
10421089

10431090

0 commit comments

Comments
 (0)