Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add scalar support to where #860

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

betatim
Copy link
Member

@betatim betatim commented Nov 25, 2024

Towards #807

This adds wording to the doc string and function signature to allow scalars in addition to arrays for the second and third argument.

Not super happy with the phrasing for the description of condition, maybe someone else has a suggestion for how to explain it without using a lot of words. We can then hopefully reuse that in the description of out.

There are a lot more functions that need updating (#807 (comment)). I think it make sense to get this one done, and then copy&paste for the others (instead of having a giant diff while discussing things).

There is data-apis/array-api-strict#78 which implements this in array-api-strict.

@asmeurer
Copy link
Member

The rule for scalars should match https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars. We should generalize that section to functions, so that we can refer to it in all the function definitions.

@asmeurer
Copy link
Member

In other words, I think trying to word different conditions for each argument based on whether each other argument is an array or scalar is too wordy and confusing. The rule should be that scalar arguments are implicitly converted into arrays (by the rules stated in the updated version of that particular section). Then we can just talk about each argument as if it were an array. We also need to state that the behavior is undefined when all arguments are scalars. The question is really how much of that needs to be repeated in each function and how much of it we can just write once in some section and refer back to (this is not a straightforward question IMO; a lot of things in the standard are repeated for each function, since that makes it easier to read).

@kgryte
Copy link
Contributor

kgryte commented Nov 25, 2024

I've opened a related PR for element-wise functions: #862

@betatim
Copy link
Member Author

betatim commented Nov 26, 2024

I think trying to word different conditions for each argument based on whether each other argument is an array or scalar is too wordy and confusing.

I agree. I like your suggestion of referring to a central explanation. In particular because I think that most people's intuition about how this will work is 95% correct and for the other 5% you need a lot of words to really explain it.

In the case of where I had assumed that the first argument always has to be an array. So the "all arguments are scalars" case can't happen.

Maybe the way to do this is a "Notes" section in the doc string that says something like "For the rules on how to handle scalar arguments see link_to_central_place." That central place could be https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars, maybe it needs a bit of generalising to remove/rewrite the scalar OP array/array OP scalar bit.

def where(
condition: array,
x1: Union[array, int, float, bool],
x2: Union[array, int, float, bool],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also be complex, as below.

@rgommers rgommers added this to the v2024 milestone Dec 11, 2024
@kgryte kgryte changed the title Allow scalar arguments to where() feat: allow scalar support to where Dec 12, 2024
@kgryte kgryte added the API change Changes to existing functions or objects in the API. label Dec 12, 2024
x2: array
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
x1: Union[array, int, float, complex, bool]
first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).

Aligning with #862.

x1: Union[array, int, float, complex, bool]
first input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
x2: Union[array, int, float, complex, bool]
second input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
second input array or scalar. Scalar values are treated like an array filled with this value. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).

Aligning with #862.

@@ -139,21 +139,33 @@ def searchsorted(
"""


def where(condition: array, x1: array, x2: array, /) -> array:
def where(
condition: array,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
condition: array,
condition: Union[array, bool],

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to require that condition is an array ISTM. x1 and x2 can be arrays or scalars, but there should be at least one array, and condition it is.

In [7]: torch.where(True, 3, 4)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 torch.where(True, 3, 4)

TypeError: where() received an invalid combination of arguments - got (bool, int, int), but expected one of:
 * (Tensor condition)
 * (Tensor condition, Tensor input, Tensor other, *, Tensor out = None)
 * (Tensor condition, Number self, Tensor other)
      didn't match because some of the arguments have invalid types: (bool, int, int)
 * (Tensor condition, Tensor input, Number other)
      didn't match because some of the arguments have invalid types: (bool, int, int)
 * (Tensor condition, Number self, Number other)
      didn't match because some of the arguments have invalid types: (bool, int, int)

In [2]: cupy.where(True, 3, 4)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 1
----> 1 cupy.where(True, 3, 4)

...

AttributeError: 'bool' object has no attribute 'astype'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there should be at least one array, but it is not clear to me that condition should be required to be an array. While Torch does not support a scalar condition, NumPy does.

In [8]: np.where(True,np.ones((3,3)),0)
Out[8]:
array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]])

It seems odd to me to be as restrictive as PyTorch when PyTorch supports

In [9]: torch.where(torch.asarray(True),torch.ones((3,3)),0)
Out[9]:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my (non-extensive) checks, the variant "condition is an array, x1 and x2 are arrays or scalars" covers all major array libraries. And it seems to be the most straightforward to explain in words or legalese, and most useful downstream, FWIW.

The In [9] variant above seems to be covered, too? As long as the algorithm is

  • convert x1 and x2 to arrays
  • broadcast all of condition, x1, x2
  • profit!

Copy link
Contributor

@kgryte kgryte Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the point I was trying to convey was that

torch.where(True,torch.ones((3,3)),0)

can be considered sugar for

torch.where(torch.asarray(True),torch.ones((3,3)),0)

Not clear, from a standards perspective, why the former should be prohibited a priori, as in PyTorch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as per #860 (comment) I think this is fine as is


Notes
-----
See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``.
Copy link
Contributor

@kgryte kgryte Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``.
- At least one of ``condition``, ``x1``, and ``x2`` must be an array.
- If both ``x1`` and ``x2`` are scalar values, the returned array must have a data type which is equivalent to separately passing both ``x1`` and ``x2`` to :func:`~array_api.asarray` and computing the resulting data type using :func:`~array_api.result_type` (i.e., both scalars must be converted to arrays whose data types are used to determine the output data type according to :ref:`type-promotion` rules).
- If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: seems simpler to require that condition is always an array. (line 168)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring condition to be an array doesn't simplify L169 and L170, as condition has no bearing on the output data type.

See :ref:`mixing-scalars-and-arrays` on compatibility requirements and handling of scalar arguments for ``x1`` and ``x2``.

.. versionchanged:: 2024.12
``x1`` and ``x2`` may be scalars.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``x1`` and ``x2`` may be scalars.
Added support for scalar arguments.

@kgryte
Copy link
Contributor

kgryte commented Dec 12, 2024

I made a few suggestions. Namely,

  1. similar to feat: add scalar support to element-wise functions #862, we can keep normative language to assume array arguments. Only the type annotations need to be updated to convey scalar argument support.

  2. we can allow condition to be a scalar provided that at least one of condition, x1, or x2 is an array.

  3. in the case where both x1 and x2 are scalar arguments, the resulting data type should be equivalent to

    xp.result_type(xp.asarray(x1), xp.asarray(x2))
  4. when only one of x1 or x2 is an array, then we can refer to the document on mixing arrays and scalars.

Comment on lines +144 to +145
x1: Union[array, int, float, bool],
x2: Union[array, int, float, bool],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x1: Union[array, int, float, bool],
x2: Union[array, int, float, bool],
x1: Union[array, int, float, complex, bool],
x2: Union[array, int, float, complex, bool],

@kgryte kgryte changed the title feat: allow scalar support to where feat: add scalar support to where Dec 12, 2024
@kgryte kgryte added the Needs Changes Pull request which needs changes before being merged. label Dec 12, 2024
@rgommers
Copy link
Member

  • we can allow condition to be a scalar provided that at least one of condition, x1, or x2 is an array.

  • in the case where both x1 and x2 are scalar arguments, the resulting data type should be equivalent to

We discussed these points in the community meeting yesterday, and agreed on the following:

  1. condition should be an array. Rationale: (a) if it's a Python scalar, where is reduced to a simple if-statement so it doesn't seem very useful, (b) at least PyTorch doesn't support it, (c) if a use case turns up, we can add scalar support later, while the reverse isn't possible.
  2. One of x1 or x2 should be an array. Rational: otherwise there is too much flexibility in dtype determination. We require the same from binary_op(x1, x2), and where is basically a binary operation from the perspective of dtype determination.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API change Changes to existing functions or objects in the API. Needs Changes Pull request which needs changes before being merged.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants