Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689467667
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed Oct 24, 2024
1 parent 5272a3d commit 5050554
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
8 changes: 4 additions & 4 deletions oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,14 +1211,14 @@ def _calc_extra_inps(num_consts, params):
def _reap_pjit_rule(trace, *tracers, **params):
"""Reap pjit rule."""
if params['in_shardings'] and not any(
sharding_impls.is_unspecified(i) for i in params['in_shardings']
isinstance(i, pxla) for i in params['in_shardings']
):
raise ValueError(
'oryx only supports pjit which has no in_axis_resources '
f'specified. Got {params["in_shardings"]}'
)
if params['out_shardings'] and not any(
sharding_impls.is_unspecified(o) for o in params['out_shardings']
isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']
):
raise ValueError(
'oryx only supports pjit which has no out_axis_resources '
Expand Down Expand Up @@ -1648,14 +1648,14 @@ def _plant_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse,
def _plant_pjit_rule(trace, *tracers, **params):
"""Plant pjit rule."""
if params['in_shardings'] and not any(
sharding_impls.is_unspecified(i) for i in params['in_shardings']
isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']
):
raise ValueError(
'oryx only supports pjit which has no in_axis_resources '
f'specified. Got {params["in_shardings"]}'
)
if params['out_shardings'] and not any(
sharding_impls.is_unspecified(o) for o in params['out_shardings']
isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']
):
raise ValueError(
'oryx only supports pjit which has no out_axis_resources '
Expand Down
5 changes: 2 additions & 3 deletions oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from jax._src import sharding_impls
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla

from oryx.core import pytree
from oryx.core import trace_util
Expand Down Expand Up @@ -367,10 +366,10 @@ def _pjit_propagate_rule(incells, outcells, **params):
"""Propagate rule for pjit primitive."""
# TODO(https://github.com/jax-ml/oryx/issues/29): Fix this rule so that it # pylint: disable=g-bad-todo
# works correct for in_sharding, out_shardings and donated_invars.
if not any(pxla._is_unspecified(i) for i in params['in_shardings']): # pylint: disable=protected-access
if not any(isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']): # pylint: disable=protected-access
raise ValueError('oryx only supports pjit which has no in_axis_resources '
'specified.')
if not any(pxla._is_unspecified(o) for o in params['out_shardings']): # pylint: disable=protected-access
if not any(isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']): # pylint: disable=protected-access
raise ValueError('oryx only supports pjit which has no out_axis_resources '
'specified.')

Expand Down

0 comments on commit 5050554

Please sign in to comment.