Skip to content
78 changes: 42 additions & 36 deletions xarray/computation/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import functools
from collections import Counter
from collections.abc import (
Callable,
Hashable,
)
from collections.abc import Callable, Hashable
from typing import TYPE_CHECKING, Any, Literal, cast, overload

import numpy as np
Expand All @@ -23,10 +20,7 @@
from xarray.core.duck_array_ops import datetime_to_numeric
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import (
is_scalar,
parse_dims_as_set,
)
from xarray.core.utils import is_scalar, parse_dims_as_set
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -903,7 +897,9 @@ def _calc_idxminmax(
if not array.ndim:
raise ValueError("This function does not apply for scalars")

if dim is not None:
if dim is Ellipsis:
dim = array.dims
elif dim is not None:
pass # Use the dim if available
elif array.ndim == 1:
# it is okay to guess the dim if there is only 1
Expand All @@ -912,14 +908,19 @@ def _calc_idxminmax(
# The dim is not specified and ambiguous. Don't guess.
raise ValueError("Must supply 'dim' argument for multidimensional arrays")

if dim not in array.dims:
raise KeyError(
f"Dimension {dim!r} not found in array dimensions {array.dims!r}"
)
if dim not in array.coords:
raise KeyError(
f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)
dim_is_str = isinstance(dim, str)
# Standardize to an iterable format
dims = [dim] if dim_is_str else dim

for _dim in dims:
if _dim not in array.dims:
raise KeyError(
f"Dimension {_dim!r} not found in array dimensions {array.dims!r}"
)
if _dim not in array.coords:
raise KeyError(
f"Dimension {_dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
)

# These are dtypes with NaN values argmin and argmax can handle
na_dtypes = "cfO"
Expand All @@ -931,25 +932,30 @@ def _calc_idxminmax(

# This will run argmin or argmax.
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
# Force dictionary format in case of single dim so that we can iterate over it in for loop below
if dim_is_str:
indx = {dim: indx}

res = {}
for _dim, _da_idx in zip(dims, indx.values(), strict=False):
# Handle chunked arrays (e.g. dask).
coord = array[_dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[_dim].data, chunks=((array.sizes[_dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[_dim].data, array.data))

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
else:
coord = coord.copy(data=to_like_array(array[dim].data, array.data))

res = indx._replace(coord[(indx.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
res = res.where(~allna, fill_value)

# Copy attributes from argmin/argmax, if any
res.attrs = indx.attrs
_res = _da_idx._replace(coord[(_da_idx.variable,)]).rename(_dim)
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
_res = _res.where(~allna, fill_value)
_res.attrs = _da_idx.attrs
res[_dim] = _res

if dim_is_str:
return res[dim]
return res
Loading
Loading