Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mask argument to lax.argmax #25623

Open
carlosgmartin opened this issue Dec 20, 2024 · 9 comments
Open

Add mask argument to lax.argmax #25623

carlosgmartin opened this issue Dec 20, 2024 · 9 comments
Labels
enhancement New feature or request

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Dec 20, 2024

Feature request: Add a mask argument to jax.lax.argmax (and jax.lax.argmin) consisting of an array of booleans that indicate which elements to include in the computation.

Here is an example implementation:

Code
from functools import partial

import jax
from jax import lax, numpy as jnp

from .utils import min_value


def tree_select(pred, on_true, on_false):
    return jax.tree.map(partial(lax.select, pred), on_true, on_false)


def argmax(operand, axis, index_dtype, mask=None):

    if mask is None:
        return lax.argmax(operand, axis, index_dtype)

    indices = lax.broadcasted_iota(index_dtype, operand.shape, axis)

    def computation(a, b):
        a_value, a_index, a_valid = a
        b_value, b_index, b_valid = b

        pick_b = (b_valid > a_valid) | (
            (b_valid == a_valid) & (b_value > a_value)
        )

        return tree_select(pick_b, b, a)

    value, index, valid = lax.reduce(
        (operand, indices, mask),
        (min_value(operand.dtype), jnp.array(-1, index_dtype), False),
        computation,
        (axis,),
    )
    # print(value, index, valid)
    return index


def main():
    x = jnp.array([0.0, 1.0, 3.0, 3.0, 3.0, 4.0])
    mask = jnp.array([False, True, False, True, True, False])
    i = argmax(x, 0, int, mask=mask)
    assert i == 3

    x = jnp.array([0.0, 0.0])
    mask = jnp.array([False, True])
    i = argmax(x, 0, int, mask=mask)
    assert i == 1

    x = jnp.array(
        [
            [0.0, 2.0, 1.0, 1.0, 1.0, 3.0],
            [0.0, 4.0, 2.0, 1.0, 3.0, 5.0],
        ]
    )
    mask = jnp.array(
        [
            [False, False, False, True, True, False],
            [False, True, False, True, False, False],
        ]
    )
    i = argmax(x, 1, int, mask=mask)
    assert (i == jnp.array([3, 1])).all()


if __name__ == "__main__":
    main()

I can submit a PR for this.

Related, but for jax.numpy: #20177.

A more general solution that could be re-used in the future for other operations would be to add a mask argument to jax.lax.reduce that controls which elements to include in the reduction. Example:

Code
def reduce(operands, init_values, computation, dimensions, mask):

    if mask is None:
        return lax.reduce(operands, init_values, computation, dimensions)

    def new_computation(a, b):
        a_value, a_valid = a
        b_value, b_valid = b
        return tree_select(
            b_valid,
            tree_select(
                a_valid,
                (computation(a_value, b_value), True),
                b,
            ),
            a,
        )

    value, valid = lax.reduce(
        (operands, mask),
        (init_values, False),
        new_computation,
        dimensions,
    )
    return value
@carlosgmartin carlosgmartin added the enhancement New feature or request label Dec 20, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 20, 2024

Thanks for the request! I don't think lax is the right place to add implementations like this: it's a low-level module designed to be a bare-bones Python API around primitive operations that map to XLA/HLO ops. So for example, lax.argmax lowers directly to argmax_p:

>>> import jax
>>> argmax = partial(jax.lax.argmax, axis=0, index_dtype='int32')
>>> jax.make_jaxpr(argmax)(jax.numpy.arange(10))
{ lambda ; a:i32[10]. let
    b:i32[] = argmax[axes=(0,) index_dtype=int32] a
  in (b,) }

The functions you propose would lower to multiple primitives, so they're not really a great fit for jax.lax.

There's a couple ways we could move forward with an idea like this:

  1. add mask to the argmax/argmin primitives. I wouldn't be in favor of this, because it would require those primitives to lower to a multi-part HLO program rather than to a single HLO reduce (lax primitives are, with a few exceptions, essentially direct Python API wrappers of HLO operations).
  2. add the mask argument to the jax.numpy layer, as you proposed in Add where argument to argmax, argmin, ptp, cumsum, cumprod #20177. Again this isn't obviously in-scope, because jax.numpy is an implementation of the numpy API, and NumPy does not include a mask or where argument in these functions. Still, this would probably be more feasible than adding it to jax.lax.
  3. Create a convenience routine elsewhere in the package. I'm not entirely sure where that would be though.

What do you think?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Dec 20, 2024

@jakevdp If there's no place to add such a function, I guess we'll have to wait for where to be added to numpy.argmax:

It's a bit of a shame that we have to wait for numpy to add useful functionality.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 20, 2024

We don't have to wait for NumPy, but a concern here is that we add some keyword argument to JAX, and in the future NumPy adds a conflicting keyword argument. Then it becomes a somewhat painful deprecation cycle in order to match upstream semantics. We went through this in the past year with the Array API, and I'd like to avoid that if possible.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Dec 20, 2024

That's a valid concern, I agree.

Is there a way to directly "petition" for the inclusion of a mask argument to argmax in the Array API? (And also apparently max, for that matter.)

To be honest, I'm not entirely clear on the orientation of design decisions for JAX, NumPy, and the Array API. NumPy itself is aiming to target the Array API. In the long term, is JAX aiming to target NumPy, or the Array API directly (which NumPy itself is aiming to target)?

My understanding is that the Array API is generated by an informal kind of "consensus" between popular libraries (such as NumPy, JAX, and PyTorch). Doesn't that make the whole thing kind of circular, a catch-22, if we're trying to make design decisions?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 20, 2024

Historically, before the Array API existed, jax.numpy aimed to target the NumPy API. Once the Array API project was created, we made it a goal to also target the Array API. Because NumPy is targeting the array API, we can do both pretty cleanly in one namespace. In some cases we have moved jax.numpy toward the array API slight faster than numpy itself has been able to do.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 20, 2024

I don't think there's anything circular here: if the Array API adds a function or argument to its spec, JAX will adopt it, as will NumPy. If NumPy adds a function or argument to its implementation, JAX will adopt it.

There is some influence back up: for example the Array API maintainers have historically been careful not to introduce things that will conflict with existing symbols, but that's not a hard requirement (for example, JAX historically had an array.device() method that we had to deprecate when the Array API chose to make device an attribute).

@carlosgmartin
Copy link
Contributor Author

Thanks for the clarification. Do you know the answer to this question, by any chance?

Is there a way to directly "petition" for the inclusion of a mask argument to argmax in the Array API? (And also apparently max, for that matter.)

IMO, it makes logical sense that any operation involving reduction(s) should take an optional mask argument, indicating which elements participate in the reduction(s). This includes the following operations:

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 23, 2024

The best way to propose this is probably to open an issue at https://github.com/data-apis/array-api proposing the change.

@carlosgmartin
Copy link
Contributor Author

Done: data-apis/array-api#875.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants