Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap()#3
Open
oikosohn wants to merge 1 commit intogordicaleksa:mainfrom
Open
Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap()#3oikosohn wants to merge 1 commit intogordicaleksa:mainfrom
oikosohn wants to merge 1 commit intogordicaleksa:mainfrom