Skip to content

Commit 77d4406

Browse files
authored
Merge pull request #145 from juntyr/astype
Implement cast support for ubyte and half
2 parents 1d6e8eb + 4f6fbe6 commit 77d4406

File tree

3 files changed

+128
-17
lines changed

3 files changed

+128
-17
lines changed

quaddtype/meson.build

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ if is_windows
1313
add_project_arguments('-DWIN32', '-D_WINDOWS', language : ['c', 'cpp'])
1414
endif
1515

16+
qblas_dep = dependency('qblas', fallback: ['qblas', 'qblas_dep'])
17+
1618
sleef_subproj = subproject('sleef', required: true)
1719
sleef_dep = sleef_subproj.get_variable('sleef_dep')
1820
sleefquad_dep = sleef_subproj.get_variable('sleefquad_dep')
@@ -22,10 +24,13 @@ incdir_numpy = run_command(py,
2224
check : true
2325
).stdout().strip()
2426

25-
# OpenMP dependency (optional, for threading)
27+
npymath_path = incdir_numpy / '..' / 'lib'
28+
npymath_lib = c.find_library('npymath', dirs: npymath_path)
29+
30+
dependencies = [py_dep, qblas_dep, sleef_dep, sleefquad_dep, npymath_lib]
31+
32+
# Add OpenMP dependency (optional, for threading)
2633
openmp_dep = dependency('openmp', required: false)
27-
qblas_dep = dependency('qblas', fallback: ['qblas', 'qblas_dep'])
28-
dependencies = [py_dep, qblas_dep, sleef_dep, sleefquad_dep]
2934
if openmp_dep.found()
3035
dependencies += openmp_dep
3136
endif

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extern "C" {
99
#include <Python.h>
1010

1111
#include "numpy/arrayobject.h"
12+
#include "numpy/halffloat.h"
1213
#include "numpy/ndarraytypes.h"
1314
#include "numpy/dtype_api.h"
1415
}
@@ -20,7 +21,7 @@ extern "C" {
2021
#include "casts.h"
2122
#include "dtype.h"
2223

23-
#define NUM_CASTS 29 // 14 to_casts + 14 from_casts + 1 quad_to_quad
24+
#define NUM_CASTS 33 // 16 to_casts + 16 from_casts + 1 quad_to_quad
2425

2526
static NPY_CASTING
2627
quad_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self),
@@ -150,15 +151,27 @@ quad_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const da
150151
return 0;
151152
}
152153

154+
// Tag dispatching to ensure npy_bool/npy_ubyte and npy_half/npy_ushort do not alias in templates
155+
// see e.g. https://stackoverflow.com/q/32522279
156+
struct spec_npy_bool {};
157+
struct spec_npy_half {};
158+
159+
template<typename T>
160+
struct NpyType { typedef T TYPE; };
161+
template<>
162+
struct NpyType<spec_npy_bool>{ typedef npy_bool TYPE; };
163+
template<>
164+
struct NpyType<spec_npy_half>{ typedef npy_half TYPE; };
165+
153166
// Casting from other types to QuadDType
154167

155168
template <typename T>
156169
static inline quad_value
157-
to_quad(T x, QuadBackendType backend);
170+
to_quad(typename NpyType<T>::TYPE x, QuadBackendType backend);
158171

159172
template <>
160173
inline quad_value
161-
to_quad<npy_bool>(npy_bool x, QuadBackendType backend)
174+
to_quad<spec_npy_bool>(npy_bool x, QuadBackendType backend)
162175
{
163176
quad_value result;
164177
if (backend == BACKEND_SLEEF) {
@@ -184,6 +197,20 @@ to_quad<npy_byte>(npy_byte x, QuadBackendType backend)
184197
return result;
185198
}
186199

200+
template <>
201+
inline quad_value
202+
to_quad<npy_ubyte>(npy_ubyte x, QuadBackendType backend)
203+
{
204+
quad_value result;
205+
if (backend == BACKEND_SLEEF) {
206+
result.sleef_value = Sleef_cast_from_uint64q1(x);
207+
}
208+
else {
209+
result.longdouble_value = (long double)x;
210+
}
211+
return result;
212+
}
213+
187214
template <>
188215
inline quad_value
189216
to_quad<npy_short>(npy_short x, QuadBackendType backend)
@@ -295,6 +322,21 @@ to_quad<npy_ulonglong>(npy_ulonglong x, QuadBackendType backend)
295322
}
296323
return result;
297324
}
325+
326+
template <>
327+
inline quad_value
328+
to_quad<spec_npy_half>(npy_half x, QuadBackendType backend)
329+
{
330+
quad_value result;
331+
if (backend == BACKEND_SLEEF) {
332+
result.sleef_value = Sleef_cast_from_doubleq1(npy_half_to_double(x));
333+
}
334+
else {
335+
result.longdouble_value = (long double)npy_half_to_double(x);
336+
}
337+
return result;
338+
}
339+
298340
template <>
299341
inline quad_value
300342
to_quad<float>(float x, QuadBackendType backend)
@@ -374,10 +416,10 @@ numpy_to_quad_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
374416
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
375417

376418
while (N--) {
377-
T in_val;
419+
typename NpyType<T>::TYPE in_val;
378420
quad_value out_val;
379421

380-
memcpy(&in_val, in_ptr, sizeof(T));
422+
memcpy(&in_val, in_ptr, sizeof(typename NpyType<T>::TYPE));
381423
out_val = to_quad<T>(in_val, backend);
382424
memcpy(out_ptr, &out_val, elem_size);
383425

@@ -401,7 +443,7 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
401443
QuadBackendType backend = descr_out->backend;
402444

403445
while (N--) {
404-
T in_val = *(T *)in_ptr;
446+
typename NpyType<T>::TYPE in_val = *(typename NpyType<T>::TYPE *)in_ptr;
405447
quad_value out_val = to_quad<T>(in_val, backend);
406448

407449
if (backend == BACKEND_SLEEF) {
@@ -420,12 +462,12 @@ numpy_to_quad_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
420462
// Casting from QuadDType to other types
421463

422464
template <typename T>
423-
static inline T
465+
static inline typename NpyType<T>::TYPE
424466
from_quad(quad_value x, QuadBackendType backend);
425467

426468
template <>
427469
inline npy_bool
428-
from_quad<npy_bool>(quad_value x, QuadBackendType backend)
470+
from_quad<spec_npy_bool>(quad_value x, QuadBackendType backend)
429471
{
430472
if (backend == BACKEND_SLEEF) {
431473
return Sleef_cast_to_int64q1(x.sleef_value) != 0;
@@ -447,6 +489,18 @@ from_quad<npy_byte>(quad_value x, QuadBackendType backend)
447489
}
448490
}
449491

492+
template <>
493+
inline npy_ubyte
494+
from_quad<npy_ubyte>(quad_value x, QuadBackendType backend)
495+
{
496+
if (backend == BACKEND_SLEEF) {
497+
return (npy_ubyte)Sleef_cast_to_uint64q1(x.sleef_value);
498+
}
499+
else {
500+
return (npy_ubyte)x.longdouble_value;
501+
}
502+
}
503+
450504
template <>
451505
inline npy_short
452506
from_quad<npy_short>(quad_value x, QuadBackendType backend)
@@ -543,6 +597,18 @@ from_quad<npy_ulonglong>(quad_value x, QuadBackendType backend)
543597
}
544598
}
545599

600+
template <>
601+
inline npy_half
602+
from_quad<spec_npy_half>(quad_value x, QuadBackendType backend)
603+
{
604+
if (backend == BACKEND_SLEEF) {
605+
return npy_double_to_half(Sleef_cast_to_doubleq1(x.sleef_value));
606+
}
607+
else {
608+
return npy_double_to_half((double)x.longdouble_value);
609+
}
610+
}
611+
546612
template <>
547613
inline float
548614
from_quad<float>(quad_value x, QuadBackendType backend)
@@ -611,8 +677,8 @@ quad_to_numpy_strided_loop_unaligned(PyArrayMethod_Context *context, char *const
611677
quad_value in_val;
612678
memcpy(&in_val, in_ptr, elem_size);
613679

614-
T out_val = from_quad<T>(in_val, backend);
615-
memcpy(out_ptr, &out_val, sizeof(T));
680+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
681+
memcpy(out_ptr, &out_val, sizeof(typename NpyType<T>::TYPE));
616682

617683
in_ptr += strides[0];
618684
out_ptr += strides[1];
@@ -642,8 +708,8 @@ quad_to_numpy_strided_loop_aligned(PyArrayMethod_Context *context, char *const d
642708
in_val.longdouble_value = *(long double *)in_ptr;
643709
}
644710

645-
T out_val = from_quad<T>(in_val, backend);
646-
*(T *)(out_ptr) = out_val;
711+
typename NpyType<T>::TYPE out_val = from_quad<T>(in_val, backend);
712+
*(typename NpyType<T>::TYPE *)(out_ptr) = out_val;
647713

648714
in_ptr += strides[0];
649715
out_ptr += strides[1];
@@ -739,8 +805,9 @@ init_casts_internal(void)
739805

740806
add_spec(quad2quad_spec);
741807

742-
add_cast_to<npy_bool>(&PyArray_BoolDType);
808+
add_cast_to<spec_npy_bool>(&PyArray_BoolDType);
743809
add_cast_to<npy_byte>(&PyArray_ByteDType);
810+
add_cast_to<npy_ubyte>(&PyArray_UByteDType);
744811
add_cast_to<npy_short>(&PyArray_ShortDType);
745812
add_cast_to<npy_ushort>(&PyArray_UShortDType);
746813
add_cast_to<npy_int>(&PyArray_IntDType);
@@ -749,12 +816,14 @@ init_casts_internal(void)
749816
add_cast_to<npy_ulong>(&PyArray_ULongDType);
750817
add_cast_to<npy_longlong>(&PyArray_LongLongDType);
751818
add_cast_to<npy_ulonglong>(&PyArray_ULongLongDType);
819+
add_cast_to<spec_npy_half>(&PyArray_HalfDType);
752820
add_cast_to<float>(&PyArray_FloatDType);
753821
add_cast_to<double>(&PyArray_DoubleDType);
754822
add_cast_to<long double>(&PyArray_LongDoubleDType);
755823

756-
add_cast_from<npy_bool>(&PyArray_BoolDType);
824+
add_cast_from<spec_npy_bool>(&PyArray_BoolDType);
757825
add_cast_from<npy_byte>(&PyArray_ByteDType);
826+
add_cast_from<npy_ubyte>(&PyArray_UByteDType);
758827
add_cast_from<npy_short>(&PyArray_ShortDType);
759828
add_cast_from<npy_ushort>(&PyArray_UShortDType);
760829
add_cast_from<npy_int>(&PyArray_IntDType);
@@ -763,6 +832,7 @@ init_casts_internal(void)
763832
add_cast_from<npy_ulong>(&PyArray_ULongDType);
764833
add_cast_from<npy_longlong>(&PyArray_LongLongDType);
765834
add_cast_from<npy_ulonglong>(&PyArray_ULongLongDType);
835+
add_cast_from<spec_npy_half>(&PyArray_HalfDType);
766836
add_cast_from<float>(&PyArray_FloatDType);
767837
add_cast_from<double>(&PyArray_DoubleDType);
768838
add_cast_from<long double>(&PyArray_LongDoubleDType);

quaddtype/tests/test_quaddtype.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,42 @@ def test_finfo_int_constant(name, value):
3939
assert getattr(numpy_quaddtype, name) == value
4040

4141

42+
@pytest.mark.parametrize("dtype", [
43+
"bool",
44+
"byte", "int8", "ubyte", "uint8",
45+
"short", "int16", "ushort", "uint16",
46+
"int", "int32", "uint", "uint32",
47+
"long", "ulong",
48+
"longlong", "int64", "ulonglong", "uint64",
49+
"half", "float16",
50+
"float", "float32",
51+
"double", "float64",
52+
"longdouble", "float96", "float128",
53+
])
54+
def test_supported_astype(dtype):
55+
if dtype in ("float96", "float128") and getattr(np, dtype, None) is None:
56+
pytest.skip(f"{dtype} is unsupported on the current platform")
57+
58+
orig = np.array(1, dtype=dtype)
59+
quad = orig.astype(QuadPrecDType, casting="safe")
60+
back = quad.astype(dtype, casting="unsafe")
61+
62+
assert quad == 1
63+
assert back == orig
64+
65+
66+
@pytest.mark.parametrize("dtype", ["S10", "U10", "T", "V10", "datetime64[ms]", "timedelta64[ms]"])
67+
def test_unsupported_astype(dtype):
68+
if dtype == "V10":
69+
pytest.xfail("casts to and from V10 segfault")
70+
71+
with pytest.raises(TypeError, match="cast"):
72+
np.array(1, dtype=dtype).astype(QuadPrecDType, casting="unsafe")
73+
74+
with pytest.raises(TypeError, match="cast"):
75+
np.array(QuadPrecision(1)).astype(dtype, casting="unsafe")
76+
77+
4278
def test_basic_equality():
4379
assert QuadPrecision("12") == QuadPrecision(
4480
"12.0") == QuadPrecision("12.00")

0 commit comments

Comments
 (0)