21
21
from functools import partial
22
22
import operator as op
23
23
import textwrap
24
- from typing import Any , TypeVar , overload
24
+ from typing import Any , NamedTuple , TypeVar , overload
25
25
26
26
from jax ._src import traceback_util
27
27
from jax ._src .lib import pytree
@@ -762,6 +762,42 @@ def _simple_entrystr(key: KeyEntry) -> str:
762
762
return str (key )
763
763
764
764
765
+ # TODO(ivyzheng): remove this after another jaxlib release.
766
+ class _RegistryWithKeypathsEntry (NamedTuple ):
767
+ flatten_with_keys : Callable [..., Any ]
768
+ unflatten_func : Callable [..., Any ]
769
+
770
+
771
+ def _register_keypaths (
772
+ ty : type [T ], handler : Callable [[T ], tuple [KeyEntry , ...]]
773
+ ) -> None :
774
+ def flatten_with_keys (xs ):
775
+ children , treedef = _registry [ty ].to_iter (xs )
776
+ return list (zip (handler (xs ), children )), treedef
777
+ if ty in _registry :
778
+ _registry_with_keypaths [ty ] = _RegistryWithKeypathsEntry (
779
+ flatten_with_keys , _registry [ty ].from_iter
780
+ )
781
+
782
+ _registry_with_keypaths : dict [type [Any ], _RegistryWithKeypathsEntry ] = {}
783
+
784
+ _register_keypaths (
785
+ tuple , lambda xs : tuple (SequenceKey (i ) for i in range (len (xs )))
786
+ )
787
+ _register_keypaths (
788
+ list , lambda xs : tuple (SequenceKey (i ) for i in range (len (xs )))
789
+ )
790
+ _register_keypaths (dict , lambda xs : tuple (DictKey (k ) for k in sorted (xs )))
791
+
792
+ _register_keypaths (
793
+ collections .defaultdict , lambda x : tuple (DictKey (k ) for k in x .keys ())
794
+ )
795
+
796
+ _register_keypaths (
797
+ collections .OrderedDict , lambda x : tuple (DictKey (k ) for k in x .keys ())
798
+ )
799
+
800
+
765
801
@export
766
802
def register_pytree_with_keys (
767
803
nodetype : type [T ],
@@ -831,6 +867,9 @@ def flatten_func_impl(tree):
831
867
register_pytree_node (
832
868
nodetype , flatten_func , unflatten_func , flatten_with_keys
833
869
)
870
+ _registry_with_keypaths [nodetype ] = _RegistryWithKeypathsEntry (
871
+ flatten_with_keys , unflatten_func
872
+ )
834
873
835
874
836
875
@export
@@ -1023,6 +1062,11 @@ def register_dataclass(
1023
1062
msg += f" Unexpected fields: { unexpected } ."
1024
1063
raise ValueError (msg )
1025
1064
1065
+ def flatten_with_keys (x ):
1066
+ meta = tuple (getattr (x , name ) for name in meta_fields )
1067
+ data = tuple ((GetAttrKey (name ), getattr (x , name )) for name in data_fields )
1068
+ return data , meta
1069
+
1026
1070
def unflatten_func (meta , data ):
1027
1071
meta_args = tuple (zip (meta_fields , meta ))
1028
1072
data_args = tuple (zip (data_fields , data ))
@@ -1038,6 +1082,9 @@ def flatten_func(x):
1038
1082
none_leaf_registry .register_dataclass_node (nodetype , list (data_fields ), list (meta_fields ))
1039
1083
dispatch_registry .register_dataclass_node (nodetype , list (data_fields ), list (meta_fields ))
1040
1084
_registry [nodetype ] = _RegistryEntry (flatten_func , unflatten_func )
1085
+ _registry_with_keypaths [nodetype ] = _RegistryWithKeypathsEntry (
1086
+ flatten_with_keys , unflatten_func
1087
+ )
1041
1088
return nodetype
1042
1089
1043
1090
0 commit comments