@@ -552,6 +552,7 @@ def setdiff1d(
552
552
/ ,
553
553
* ,
554
554
assume_unique : bool = False ,
555
+ size : int | None = None ,
555
556
fill_value : object | None = None ,
556
557
xp : ModuleType | None = None ,
557
558
) -> Array :
@@ -569,11 +570,16 @@ def setdiff1d(
569
570
assume_unique : bool
570
571
If ``True``, the input arrays are both assumed to be unique, which
571
572
can speed up the calculation. Default is ``False``.
572
- fill_value : object, optional
573
- Pad the output array with this value.
573
+ size : int, optional
574
+ The size of the output array. This is exclusively used inside the JAX JIT, and
575
+ only for as long as JAX does not support arrays of unknown size inside it. In
576
+ all other cases, it is disregarded.
577
+ Returned elements will be clipped if they are more than size, and padded with
578
+ `fill_value` if they are less. Default: raise if inside ``jax.jit``.
574
579
575
- This is exclusively used for JAX arrays when running inside ``jax.jit``,
576
- where all array shapes need to be known in advance.
580
+ fill_value : object, optional
581
+ Pad the output array with this value. This is exclusively used for JAX arrays
582
+ when running inside ``jax.jit``. Default: 0.
577
583
xp : array_namespace, optional
578
584
The standard-compatible namespace for `x1` and `x2`. Default: infer.
579
585
@@ -639,7 +645,7 @@ def _dask_impl(x1: Array, x2: Array) -> Array: # numpydoc ignore=PR01,RT01
639
645
return x1 if assume_unique else xp .unique_values (x1 )
640
646
641
647
def _jax_jit_impl (
642
- x1 : Array , x2 : Array , fill_value : object | None
648
+ x1 : Array , x2 : Array , size : int | None , fill_value : object | None
643
649
) -> Array : # numpydoc ignore=PR01,RT01
644
650
"""
645
651
JAX implementation inside jax.jit.
@@ -648,9 +654,9 @@ def _jax_jit_impl(
648
654
and not being able to filter by a boolean mask.
649
655
Returns array the same size as x1, padded with fill_value.
650
656
"""
651
- # unique_values inside jax.jit is not supported unless it's got a fixed size
652
- mask = _x1_not_in_x2 ( x1 , x2 )
653
-
657
+ if size is None :
658
+ msg = "`size` is mandatory when running inside `jax.jit`."
659
+ raise ValueError ( msg )
654
660
if fill_value is None :
655
661
fill_value = xp .zeros ((), dtype = x1 .dtype )
656
662
else :
@@ -659,9 +665,13 @@ def _jax_jit_impl(
659
665
msg = "`fill_value` must be a scalar."
660
666
raise ValueError (msg )
661
667
668
+ # unique_values inside jax.jit is not supported unless it's got a fixed size
669
+ mask = _x1_not_in_x2 (x1 , x2 )
662
670
x1 = xp .where (mask , x1 , fill_value )
663
- # Note: jnp.unique_values sorts
664
- return xp .unique_values (x1 , size = x1 .size , fill_value = fill_value )
671
+ # Move fill_value to the right
672
+ x1 = xp .take (x1 , xp .argsort (~ mask , stable = True ))
673
+ x1 = x1 [:size ]
674
+ x1 = xp .unique_values (x1 , size = size , fill_value = fill_value )
665
675
666
676
if is_dask_namespace (xp ):
667
677
return _dask_impl (x1 , x2 )
@@ -675,7 +685,7 @@ def _jax_jit_impl(
675
685
jax .errors .ConcretizationTypeError ,
676
686
jax .errors .NonConcreteBooleanIndexError ,
677
687
):
678
- return _jax_jit_impl (x1 , x2 , fill_value ) # inside jax.jit
688
+ return _jax_jit_impl (x1 , x2 , size , fill_value ) # inside jax.jit
679
689
680
690
return _generic_impl (x1 , x2 )
681
691
0 commit comments