Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap()#3
Open
oikosohn wants to merge 1 commit into
Open
Replace jax.tree_util.tree_map() with jax.tree_util.tree_multimap()#3oikosohn wants to merge 1 commit into
oikosohn wants to merge 1 commit into