diff --git a/jax/sharding.py b/jax/sharding.py index 757d9b538cb8..b00ba430d178 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -31,6 +31,7 @@ AbstractMesh as AbstractMesh, AxisType as AxisType, get_abstract_mesh as get_abstract_mesh, + get_concrete_mesh as get_concrete_mesh, use_abstract_mesh as use_abstract_mesh, )