Skip to content

Commit 35d4d93

Browse files
committed
comparison reduction ops support
1 parent 48d2f42 commit 35d4d93

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,8 @@ template <>
502502
inline npy_byte
503503
from_quad<npy_byte>(quad_value x, QuadBackendType backend)
504504
{
505+
// reduction ops often give warning, we can handle the NAN casting
506+
// this behaviour might apply to all casting
505507
if (backend == BACKEND_SLEEF) {
506508
return (npy_byte)Sleef_cast_to_int64q1(x.sleef_value);
507509
}

quaddtype/numpy_quaddtype/src/umath/comparison_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
200200
return -1;
201201
}
202202

203-
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArray_BoolDType);
203+
PyObject *DTypes = PyTuple_Pack(3, Py_None, Py_None, Py_None);
204204
if (DTypes == 0) {
205205
Py_DECREF(promoter_capsule);
206206
return -1;

quaddtype/tests/test_quaddtype.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,70 @@ def test_array_minmax(op, a, b):
385385
assert np.signbit(float_res) == np.signbit(
386386
quad_res), f"Zero sign mismatch for {op}({a}, {b})"
387387

388+
class TestComparisonReductionOps:
389+
"""Test suite for comparison reduction operations on QuadPrecision arrays."""
390+
391+
@pytest.mark.parametrize("op", ["all", "any"])
392+
@pytest.mark.parametrize("input_array", [
393+
(["1.0", "2.0", "3.0"]),
394+
(["1.0", "0.0", "3.0"]),
395+
(["0.0", "0.0", "0.0"]),
396+
# Including negative zero
397+
(["-0.0", "0.0"]),
398+
# Including NaN (should be treated as true)
399+
(["nan", "1.0"]),
400+
(["nan", "0.0"]),
401+
(["nan", "nan"]),
402+
# inf cases
403+
(["inf", "1.0"]),
404+
(["-inf", "0.0"]),
405+
(["inf", "-inf"]),
406+
# Mixed cases
407+
(["1.0", "-0.0", "nan", "inf"]),
408+
(["0.0", "-0.0", "nan", "-inf"]),
409+
])
410+
def test_reduction_ops(self, op, input_array):
411+
"""Test all and any reduction operations."""
412+
quad_array = np.array([QuadPrecision(x) for x in input_array])
413+
float_array = np.array([float(x) for x in input_array])
414+
if op == "all":
415+
result = np.all(quad_array)
416+
expected = np.all(float_array)
417+
else: # op == "any"
418+
result = np.any(quad_array)
419+
expected = np.any(float_array)
420+
421+
assert result == expected, (
422+
f"Reduction op '{op}' failed for input {input_array}: "
423+
f"expected {expected}, got {result}"
424+
)
425+
426+
@pytest.mark.parametrize("val_str", [
427+
"0.0",
428+
"-0.0",
429+
"1.0",
430+
"-1.0",
431+
"nan",
432+
"inf",
433+
"-inf",
434+
])
435+
def test_scalar_reduction_ops(self, val_str):
436+
"""Test reduction operations on scalar QuadPrecision values."""
437+
quad_val = QuadPrecision(val_str)
438+
float_val = float(val_str)
439+
440+
result_all = np.all(quad_val)
441+
expected_all_result = np.all(float_val)
442+
assert result_all == expected_all_result, (
443+
f"Scalar all failed for {val_str}: expected {expected_all_result}, got {result_all}"
444+
)
445+
446+
result_any = np.any(quad_val)
447+
expected_any_result = np.any(float_val)
448+
assert result_any == expected_any_result, (
449+
f"Scalar any failed for {val_str}: expected {expected_any_result}, got {result_any}"
450+
)
451+
388452

389453
# Logical operations tests
390454
@pytest.mark.parametrize("op", ["logical_and", "logical_or", "logical_xor"])

0 commit comments

Comments
 (0)