Skip to content

Commit 42e5c9d

Browse files
committed
Implement @shoyer's stricter equivalent() function
- Rewrite equivalent() to reject non-boolean comparison results - Accept numpy bool scalars (np.bool_) but reject other non-bool types - Simplify equivalent_attrs() since equivalent() now handles non-bool cases - Update tests to reflect stricter behavior with non-standard __eq__ methods This makes comparisons more predictable by rejecting ambiguous cases like Dataset comparisons, custom objects with weird __eq__, etc. The tradeoff is being less permissive than Python's standard 'if a == b:' behavior.
1 parent e273125 commit 42e5c9d

File tree

3 files changed

+37
-59
lines changed

3 files changed

+37
-59
lines changed

xarray/core/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,18 +239,34 @@ def equivalent(first: T, second: T) -> bool:
239239
"""Compare two objects for equivalence (identity or equality), using
240240
array_equiv if either object is an ndarray. If both objects are lists,
241241
equivalent is sequentially called on all the elements.
242+
243+
Returns False for any comparison that doesn't return a boolean,
244+
making this function safer to use with objects that have non-standard
245+
__eq__ implementations.
242246
"""
243247
# TODO: refactor to avoid circular import
244248
from xarray.core import duck_array_ops
245249

246250
if first is second:
247251
return True
252+
248253
if isinstance(first, np.ndarray) or isinstance(second, np.ndarray):
249254
return duck_array_ops.array_equiv(first, second)
255+
250256
if isinstance(first, list) or isinstance(second, list):
251257
return list_equiv(first, second) # type: ignore[arg-type]
252258

253-
return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload]
259+
# For non-array/list types, use == but require boolean result
260+
result = first == second
261+
if not isinstance(result, bool):
262+
# Accept numpy bool scalars as well
263+
if isinstance(result, np.bool_):
264+
return bool(result)
265+
# Reject any other non-boolean type (Dataset, Series, custom objects, etc.)
266+
return False
267+
268+
# Check for NaN equivalence
269+
return result or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload]
254270

255271

256272
def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool:

xarray/structure/merge.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -614,31 +614,12 @@ def equivalent_attrs(a: Any, b: Any) -> bool:
614614
This handles cases like numpy arrays with ambiguous truth values
615615
and xarray Datasets which can't be directly converted to numpy arrays.
616616
617-
For non-boolean results, we use truthiness (consistent with `if a == b`).
618-
This is an imperfect but pragmatic choice:
619-
620-
Pros of truthiness:
621-
- Consistent with Python's normal `if a == b:` behavior
622-
- Preserves numpy scalars (np.bool_(True)) and similar types
623-
- More permissive for common use cases
624-
625-
Cons of truthiness:
626-
- Keeps attrs when __eq__ returns truthy non-bool (e.g., "error")
627-
- Drops attrs when __eq__ returns falsy non-bool (e.g., 0, [])
628-
629-
The alternative (strict bool checking) would be safer but would drop
630-
many legitimate comparisons. We choose consistency with Python's
631-
standard behavior, accepting edge cases with pathological __eq__ methods.
632-
633-
TODO: Revisit this behavior in the future - consider strict type checking
634-
or a more sophisticated approach to handling non-boolean comparisons.
617+
Since equivalent() now handles non-boolean returns by returning False,
618+
this wrapper mainly catches exceptions from comparisons that can't be
619+
evaluated at all.
635620
"""
636621
try:
637-
result = equivalent(a, b)
638-
# Use truthiness, consistent with `if a == b:` behavior
639-
# Note: This means non-boolean returns are interpreted by truthiness,
640-
# which can lead to false positives/negatives but is more permissive
641-
return bool(result)
622+
return equivalent(a, b)
642623
except (ValueError, TypeError):
643624
# These exceptions indicate the comparison is truly ambiguous
644625
# (e.g., numpy arrays that would raise "ambiguous truth value")

xarray/tests/test_merge.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,7 @@ def test_merge_attrs_drop_conflicts(self):
236236
assert_identical(actual, expected)
237237

238238
def test_merge_attrs_drop_conflicts_non_bool_eq(self):
239-
"""Test drop_conflicts behavior when __eq__ returns non-bool values.
240-
241-
When comparing attribute values, the _equivalent_drop_conflicts() function
242-
uses == which can return non-bool values. The new behavior treats ambiguous
243-
or falsy equality results as non-equivalent, dropping the attribute rather
244-
than raising an error.
245-
"""
239+
"""Test drop_conflicts behavior when __eq__ returns non-bool values."""
246240
import warnings
247241

248242
import numpy as np
@@ -295,14 +289,14 @@ def __repr__(self):
295289
with warnings.catch_warnings():
296290
warnings.filterwarnings("ignore", category=DeprecationWarning)
297291

298-
# With truthiness: objects returning [True] are kept (truthy)
292+
# Objects returning arrays are dropped (non-boolean return)
299293
actual = xr.merge([ds4, ds5], combine_attrs="drop_conflicts")
300-
assert "custom" in actual.attrs # Kept - [True] is truthy
294+
assert "custom" not in actual.attrs # Dropped - returns array, not bool
301295
assert actual.attrs["x"] == 1
302296

303-
# Objects with different values: equivalent returns False (bool), dropped
297+
# Different values also dropped (returns array, not bool)
304298
actual = xr.merge([ds4, ds6], combine_attrs="drop_conflicts")
305-
assert "custom" not in actual.attrs # Dropped - different values
299+
assert "custom" not in actual.attrs # Dropped - returns non-boolean
306300
assert actual.attrs["x"] == 1
307301
assert actual.attrs["y"] == 2
308302

@@ -426,10 +420,9 @@ def test_merge_attrs_drop_conflicts_pathological_cases(self):
426420
assert "dataset_attr" not in actual.attrs # Dropped due to TypeError
427421
assert actual.attrs["scalar"] == 42
428422

429-
# With truthiness: identical datasets are kept
430-
# The comparison returns a truthy Dataset, so they're treated as equal
423+
# Identical datasets are also dropped (comparison returns Dataset, not bool)
431424
actual = xr.merge([ds4, ds6], combine_attrs="drop_conflicts")
432-
assert "dataset_attr" in actual.attrs # Kept with truthiness approach
425+
assert "dataset_attr" not in actual.attrs # Dropped - returns Dataset, not bool
433426
assert actual.attrs["other"] == 99
434427

435428
# Test 3: Pandas Series (raises ValueError due to ambiguous truth value)
@@ -457,22 +450,16 @@ def test_merge_attrs_drop_conflicts_pathological_cases(self):
457450
assert "series" not in actual.attrs # Dropped due to ValueError
458451
assert actual.attrs["value"] == "a"
459452

460-
def test_merge_attrs_drop_conflicts_truthiness_edge_cases(self):
461-
"""Test edge cases demonstrating the truthiness tradeoff.
462-
463-
We deliberately use truthiness for consistency with Python's `if a == b:`
464-
behavior. This test documents the implications of this design choice
465-
with objects that have non-standard __eq__ methods.
466-
"""
453+
def test_merge_attrs_drop_conflicts_non_boolean_eq_returns(self):
454+
"""Test objects with non-boolean __eq__ returns are dropped."""
467455

468-
# Case 1: Objects whose __eq__ returns truthy non-booleans
469-
# These are kept because we respect truthiness
456+
# Case 1: Objects whose __eq__ returns non-boolean strings
470457
class ReturnsString:
471458
def __init__(self, value):
472459
self.value = value
473460

474461
def __eq__(self, other):
475-
# Always returns a string (truthy if non-empty)
462+
# Always returns a string (non-boolean)
476463
return "comparison result"
477464

478465
obj1 = ReturnsString("A")
@@ -483,18 +470,16 @@ def __eq__(self, other):
483470

484471
actual = xr.merge([ds1, ds2], combine_attrs="drop_conflicts")
485472

486-
# Truthiness behavior: keeps attribute because "comparison result" is truthy
487-
# This is the expected behavior when respecting truthiness
488-
assert "obj" in actual.attrs
473+
# Strict behavior: drops attribute because __eq__ returns non-boolean
474+
assert "obj" not in actual.attrs
489475

490-
# Case 2: Objects whose __eq__ returns falsy non-booleans
491-
# These are dropped because we respect truthiness
476+
# Case 2: Objects whose __eq__ returns numbers
492477
class ReturnsZero:
493478
def __init__(self, value):
494479
self.value = value
495480

496481
def __eq__(self, other):
497-
# Always returns 0 (falsy) even if values match
482+
# Always returns 0 (non-boolean)
498483
return 0
499484

500485
obj3 = ReturnsZero("same")
@@ -505,13 +490,9 @@ def __eq__(self, other):
505490

506491
actual = xr.merge([ds3, ds4], combine_attrs="drop_conflicts")
507492

508-
# Truthiness behavior: drops attribute because 0 is falsy
509-
# This is the expected behavior when respecting truthiness
493+
# Strict behavior: drops attribute because __eq__ returns non-boolean
510494
assert "zero" not in actual.attrs
511495

512-
# Note: These edge cases demonstrate the tradeoff of using truthiness.
513-
# Well-behaved __eq__ methods return booleans and work correctly.
514-
# We accept these edge cases for consistency with Python's standard behavior.
515496

516497
def test_merge_attrs_no_conflicts_compat_minimal(self):
517498
"""make sure compat="minimal" does not silence errors"""

0 commit comments

Comments
 (0)