diff --git a/ml_dtypes/_src/common.h b/ml_dtypes/_src/common.h index 0f237f85..97efed26 100644 --- a/ml_dtypes/_src/common.h +++ b/ml_dtypes/_src/common.h @@ -157,6 +157,14 @@ struct TypeDescriptor> { static int Dtype() { return NPY_CLONGDOUBLE; } }; +template +struct is_complex : std::false_type {}; +template +struct is_complex> : std::true_type {}; + +template +inline constexpr bool is_complex_v = is_complex::value; + } // namespace ml_dtypes #endif // ML_DTYPES_COMMON_H_ diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index 14232dba..c3afc74a 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -599,12 +599,11 @@ int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, template float CastToFloat(T value) { - return static_cast(value); -} - -template -float CastToFloat(std::complex value) { - return CastToFloat(value.real()); + if constexpr (is_complex_v) { + return CastToFloat(value.real()); + } else { + return static_cast(value); + } } // Performs a NumPy array cast from type 'From' to 'To'. diff --git a/ml_dtypes/_src/int4_numpy.h b/ml_dtypes/_src/int4_numpy.h index 2b065628..0d3f954a 100644 --- a/ml_dtypes/_src/int4_numpy.h +++ b/ml_dtypes/_src/int4_numpy.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef ML_DTYPES_INT4_NUMPY_H_ #define ML_DTYPES_INT4_NUMPY_H_ +#include #include // Must be included first @@ -55,7 +56,7 @@ int Int4TypeDescriptor::npy_type = NPY_NOTYPE; template PyObject* Int4TypeDescriptor::type_ptr = nullptr; -// Representation of a Python custom float object. +// Representation of a Python custom integer object. template struct PyInt4 { PyObject_HEAD; // Python object header @@ -96,7 +97,7 @@ Safe_PyObjectPtr PyInt4_FromValue(T x) { return ref; } -// Converts a Python object to a reduced float value. Returns true on success, +// Converts a Python object to a reduced integer value. Returns true on success, // returns false and reports a Python error on failure. template bool CastToInt4(PyObject* arg, T* output) { @@ -143,8 +144,8 @@ bool CastToInt4(PyObject* arg, T* output) { *output = T(v); return true; } - if (PyArray_IsScalar(arg, Float)) { - float f; + auto floating_conversion = [&](auto type) -> bool { + decltype(type) f; PyArray_ScalarAsCtype(arg, &f); if (!(std::numeric_limits::min() <= f && f <= std::numeric_limits::max())) { @@ -153,17 +154,18 @@ bool CastToInt4(PyObject* arg, T* output) { } *output = T(static_cast<::int8_t>(f)); return true; + }; + if (PyArray_IsScalar(arg, Half)) { + return floating_conversion(Eigen::half{}); + } + if (PyArray_IsScalar(arg, Float)) { + return floating_conversion(float{}); } if (PyArray_IsScalar(arg, Double)) { - double d; - PyArray_ScalarAsCtype(arg, &d); - if (!(std::numeric_limits::min() <= d && - d <= std::numeric_limits::max())) { - PyErr_SetString(PyExc_OverflowError, kOutOfRange); - return false; - } - *output = T(static_cast<::int8_t>(d)); - return true; + return floating_conversion(double{}); + } + if (PyArray_IsScalar(arg, LongDouble)) { + return floating_conversion((long double){}); } return false; } @@ -216,14 +218,13 @@ PyObject* PyInt4_tp_new(PyTypeObject* type, PyObject* args, PyObject* kwds) { template PyObject* PyInt4_nb_float(PyObject* self) { T x = PyInt4_Value_Unchecked(self); - return PyFloat_FromDouble(static_cast(static_cast(x))); + return PyFloat_FromDouble(static_cast(x)); } template PyObject* PyInt4_nb_int(PyObject* self) { T x = PyInt4_Value_Unchecked(self); - long y = static_cast(static_cast(x)); // NOLINT - return PyLong_FromLong(y); + return PyLong_FromLong(static_cast(x)); // NOLINT } template @@ -538,12 +539,11 @@ int NPyInt4_CompareFunc(const void* v1, const void* v2, void* arr) { template int NPyInt4_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) { const T* bdata = reinterpret_cast(data); - // Start with a max_val of NaN, this results in the first iteration preferring - // bdata[0]. - int max_val = std::numeric_limits::max(); + // Start with a max_val of INT_MIN, this results in the first iteration + // preferring bdata[0]. + int max_val = std::numeric_limits::lowest(); for (npy_intp i = 0; i < n; ++i) { - // This condition is chosen so that NaNs are always considered "max". - if (!(static_cast(bdata[i]) <= max_val)) { + if (static_cast(bdata[i]) > max_val) { max_val = static_cast(bdata[i]); *max_ind = i; } @@ -554,12 +554,11 @@ int NPyInt4_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind, void* arr) { template int NPyInt4_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) { const T* bdata = reinterpret_cast(data); - int min_val = std::numeric_limits::lowest(); - // Start with a min_val of NaN, this results in the first iteration preferring - // bdata[0]. + int min_val = std::numeric_limits::max(); + // Start with a min_val of INT_MAX, this results in the first iteration + // preferring bdata[0]. for (npy_intp i = 0; i < n; ++i) { - // This condition is chosen so that NaNs are always considered "min". - if (!(static_cast(bdata[i]) >= min_val)) { + if (static_cast(bdata[i]) < min_val) { min_val = static_cast(bdata[i]); *min_ind = i; } @@ -567,30 +566,21 @@ int NPyInt4_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind, void* arr) { return 0; } -template ::value || - std::is_same::value), - bool> = true> +template int CastToInt(T value) { - if (std::isnan(value) || std::isinf(value) || - value < std::numeric_limits::lowest() || - value > std::numeric_limits::max()) { - return 0; + if constexpr (is_complex_v) { + return CastToInt(value.real()); + } else { + static_assert(std::numeric_limits::is_specialized); + if constexpr (!std::numeric_limits::is_integer) { + if (std::isnan(value) || std::isinf(value) || + value < std::numeric_limits::lowest() || + value > std::numeric_limits::max()) { + return 0; + } + } + return static_cast(value); } - return static_cast(value); -} - -template ::value, bool> = true> -int CastToInt(T value) { - return static_cast(value); -} - -int CastToInt(int4 value) { return static_cast(value); } - -int CastToInt(uint4 value) { return static_cast(value); } - -template -int CastToInt(std::complex value) { - return CastToInt(value.real()); } // Performs a NumPy array cast from type 'From' to 'To'. diff --git a/ml_dtypes/tests/int4_test.py b/ml_dtypes/tests/int4_test.py index d62bf9f0..0bb8dcfa 100644 --- a/ml_dtypes/tests/int4_test.py +++ b/ml_dtypes/tests/int4_test.py @@ -61,7 +61,10 @@ def testPickleable(self, scalar_type): self.assertEqual(x_out.dtype, x.dtype) np.testing.assert_array_equal(x_out.astype(int), x.astype(int)) - @parameterized.product(scalar_type=INT4_TYPES, python_scalar=[int, float]) + @parameterized.product( + scalar_type=INT4_TYPES, + python_scalar=[int, float, np.float16, np.longdouble], + ) def testRoundTripToPythonScalar(self, scalar_type, python_scalar): for v in VALUES[scalar_type]: self.assertEqual(v, scalar_type(v)) @@ -241,12 +244,16 @@ def testArray(self, scalar_type): @parameterized.product( scalar_type=INT4_TYPES, - ufunc=[np.nonzero, np.logical_not], + ufunc=[np.nonzero, np.logical_not, np.argmax, np.argmin], ) def testUnaryPredicateUfunc(self, scalar_type, ufunc): x = np.array(VALUES[scalar_type]) y = np.array(VALUES[scalar_type], dtype=scalar_type) - np.testing.assert_array_equal(ufunc(x), ufunc(y)) + # Compute `ufunc(y)` first so we don't get lucky by reusing memory + # initialized by `ufunc(x)`. + y_result = ufunc(y) + x_result = ufunc(x) + np.testing.assert_array_equal(x_result, y_result) @parameterized.product( scalar_type=INT4_TYPES,