Skip to content

Commit 2a1554f

Browse files
committed
Design 2->4
1 parent 93cc035 commit 2a1554f

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

Diff for: src/array_api_extra/_lib/_funcs.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ def setdiff1d(
552552
/,
553553
*,
554554
assume_unique: bool = False,
555+
size: int | None = None,
555556
fill_value: object | None = None,
556557
xp: ModuleType | None = None,
557558
) -> Array:
@@ -569,11 +570,16 @@ def setdiff1d(
569570
assume_unique : bool
570571
If ``True``, the input arrays are both assumed to be unique, which
571572
can speed up the calculation. Default is ``False``.
572-
fill_value : object, optional
573-
Pad the output array with this value.
573+
size : int, optional
574+
The size of the output array. This is exclusively used inside the JAX JIT, and
575+
only for as long as JAX does not support arrays of unknown size inside it. In
576+
all other cases, it is disregarded.
577+
Returned elements will be clipped if they are more than size, and padded with
578+
`fill_value` if they are less. Default: raise if inside ``jax.jit``.
574579
575-
This is exclusively used for JAX arrays when running inside ``jax.jit``,
576-
where all array shapes need to be known in advance.
580+
fill_value : object, optional
581+
Pad the output array with this value. This is exclusively used for JAX arrays
582+
when running inside ``jax.jit``. Default: 0.
577583
xp : array_namespace, optional
578584
The standard-compatible namespace for `x1` and `x2`. Default: infer.
579585
@@ -639,7 +645,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
639645
return x1 if assume_unique else xp.unique_values(x1)
640646

641647
def _jax_jit_impl(
642-
x1: Array, x2: Array, fill_value: object | None
648+
x1: Array, x2: Array, size: int | None, fill_value: object | None
643649
) -> Array: # numpydoc ignore=PR01,RT01
644650
"""
645651
JAX implementation inside jax.jit.
@@ -648,9 +654,9 @@ def _jax_jit_impl(
648654
and not being able to filter by a boolean mask.
649655
Returns array the same size as x1, padded with fill_value.
650656
"""
651-
# unique_values inside jax.jit is not supported unless it's got a fixed size
652-
mask = _x1_not_in_x2(x1, x2)
653-
657+
if size is None:
658+
msg = "`size` is mandatory when running inside `jax.jit`."
659+
raise ValueError(msg)
654660
if fill_value is None:
655661
fill_value = xp.zeros((), dtype=x1.dtype)
656662
else:
@@ -659,9 +665,13 @@ def _jax_jit_impl(
659665
msg = "`fill_value` must be a scalar."
660666
raise ValueError(msg)
661667

668+
# unique_values inside jax.jit is not supported unless it's got a fixed size
669+
mask = _x1_not_in_x2(x1, x2)
662670
x1 = xp.where(mask, x1, fill_value)
663-
# Note: jnp.unique_values sorts
664-
return xp.unique_values(x1, size=x1.size, fill_value=fill_value)
671+
# Move fill_value to the right
672+
x1 = xp.take(x1, xp.argsort(~mask, stable=True))
673+
x1 = x1[:size]
674+
x1 = xp.unique_values(x1, size=size, fill_value=fill_value)
665675

666676
if is_dask_namespace(xp):
667677
return _dask_impl(x1, x2)
@@ -675,7 +685,7 @@ def _jax_jit_impl(
675685
jax.errors.ConcretizationTypeError,
676686
jax.errors.NonConcreteBooleanIndexError,
677687
):
678-
return _jax_jit_impl(x1, x2, fill_value) # inside jax.jit
688+
return _jax_jit_impl(x1, x2, size, fill_value) # inside jax.jit
679689

680690
return _generic_impl(x1, x2)
681691

0 commit comments

Comments
 (0)