Skip to content

Commit 4435be0

Browse files
committed
Pure python implementation of NPY_DT_setitem/NPY_DT_getitem
1 parent 520ea6a commit 4435be0

File tree

3 files changed

+30
-83
lines changed

3 files changed

+30
-83
lines changed

sample_dtypes/scalar.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Scalar types needed by the dtype machinery."""
22

3+
import ctypes
34
import numpy as np
45

56

@@ -26,29 +27,35 @@ def alignment(self) -> int:
2627
return self._ndarr.dtype.alignment
2728

2829
def is_compatible(self, other: 'SampleScalar') -> bool:
29-
print(f'TODO: scalar.py: SampleScalar.is_compatible({other})')
30-
# This implies `self.elsize == other.elsize and self.alignment == other.alignment`
31-
return (
32-
self._ndarr.shape == other._ndarr.shape
33-
and self._ndarr.dtype == other._ndarr.dtype
34-
)
30+
print(f'scalar.py: SampleScalar.is_compatible({other})')
31+
return self._ndarr.shape == other._ndarr.shape
3532

3633
def __repr__(self) -> str:
3734
return f'{type(self).__name__}(id={hex(id(self))}, shape={self._ndarr.shape}, dtype={self._ndarr.dtype})'
3835

39-
@property
40-
def ndarr(self) -> np.ndarray:
41-
return self._ndarr
42-
43-
def setitem(self, target: np.ndarray) -> None:
44-
"""Get raw-data"""
45-
print(
46-
f'TODO: setitem, taget: {type(target)}, {target.shape}, {target.dtype}'
36+
def _get_np_view(self, dataptr: int) -> np.ndarray:
37+
"""Create numpy array that uses `dataptr` memory block"""
38+
ct_array = ctypes.cast(
39+
dataptr, ctypes.POINTER(np.ctypeslib.as_ctypes_type(self._ndarr.dtype))
4740
)
48-
target[...] = self._ndarr
41+
return np.ctypeslib.as_array(ct_array, shape=self._ndarr.shape)
42+
43+
def setitem(self, src: 'SampleScalar', dataptr: int) -> None:
44+
"""Python `NPY_DT_setitem` implementation
45+
46+
NOTE:
47+
`dataptr` is a buffer address, valid during execution of the function only
48+
"""
49+
if not self.is_compatible(src):
50+
raise ValueError('Incompatible item value')
51+
52+
self._get_np_view(dataptr)[...] = src._ndarr
53+
54+
def getitem(self, dataptr: int) -> 'SampleScalar':
55+
"""Python `NPY_DT_getitem` implementation
4956
50-
def getitem(self, source: np.ndarray) -> 'SampleScalar':
51-
"""Store raw-data into a new object"""
57+
NOTE: See setitem()
58+
"""
5259
new = self.copy()
53-
new._ndarr.flat[:] = source
60+
new._ndarr[...] = self._get_np_view(dataptr)
5461
return new

src/dtype.c

Lines changed: 4 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -103,37 +103,6 @@ sampledtype_ensure_canonical(SampleDTypeObject *self) {
103103
return self;
104104
}
105105

106-
static PyObject *create_element_view(SampleDTypeObject *descr, char *dataptr,
107-
npy_bool readonly) {
108-
PyObject *ndarr_obj = PyObject_GetAttrString(descr->sample_scalar, "ndarr");
109-
if (ndarr_obj == NULL) {
110-
return NULL;
111-
}
112-
if (!PyArray_Check(ndarr_obj)) {
113-
Py_DECREF(ndarr_obj);
114-
PyErr_SetString(PyExc_TypeError,
115-
"SampleScalar.ndarr() must return NumPy array");
116-
return NULL;
117-
}
118-
119-
PyArrayObject *ndarr_arrobj = (PyArrayObject *)ndarr_obj;
120-
PyArray_Descr *ndarr_descr = PyArray_DESCR(ndarr_arrobj);
121-
122-
PyObject *view = PyArray_NewFromDescr(
123-
&PyArray_Type, ndarr_descr, PyArray_NDIM(ndarr_arrobj),
124-
PyArray_DIMS(ndarr_arrobj), PyArray_STRIDES(ndarr_arrobj),
125-
dataptr, // data pointer for the view
126-
readonly ? 0 : NPY_ARRAY_WRITEABLE, NULL);
127-
if (view != NULL) {
128-
// Now 'view' owns the descriptor, it will be dereferenced when view is
129-
// deallocated
130-
Py_INCREF(ndarr_descr);
131-
}
132-
133-
Py_DECREF(ndarr_obj);
134-
return view;
135-
}
136-
137106
static int sampledtype_setitem(SampleDTypeObject *descr, PyObject *obj,
138107
char *dataptr) {
139108
printf("%s, target elsise %lld, type_num %d\n", __func__, descr->base.elsize,
@@ -145,26 +114,8 @@ static int sampledtype_setitem(SampleDTypeObject *descr, PyObject *obj,
145114
return -1;
146115
}
147116

148-
PyObject *res =
149-
PyObject_CallMethod(descr->sample_scalar, "is_compatible", "O", obj);
150-
if (res == NULL) {
151-
return -1;
152-
}
153-
int is_compatible = PyObject_IsTrue(res);
154-
Py_DECREF(res);
155-
156-
if (!is_compatible) {
157-
PyErr_Format(PyExc_ValueError, "Incompatible item value");
158-
return -1;
159-
}
160-
161-
PyObject *view = create_element_view(descr, dataptr, NPY_FALSE);
162-
if (view == NULL) {
163-
return -1;
164-
}
165-
166-
res = PyObject_CallMethod(obj, "setitem", "O", view);
167-
Py_DECREF(view);
117+
PyObject *res = PyObject_CallMethod(descr->sample_scalar, "setitem", "On",
118+
obj, (Py_ssize_t)dataptr);
168119
if (res == NULL) {
169120
return -1;
170121
}
@@ -177,19 +128,8 @@ static PyObject *sampledtype_getitem(SampleDTypeObject *descr, char *dataptr) {
177128
printf("%s, source elsize %lld, type_num %d\n", __func__, descr->base.elsize,
178129
descr->base.type_num);
179130

180-
PyObject *view = create_element_view(descr, dataptr, NPY_TRUE);
181-
if (view == NULL) {
182-
return NULL;
183-
}
184-
185-
PyObject *res =
186-
PyObject_CallMethod(descr->sample_scalar, "getitem", "O", view);
187-
Py_DECREF(view);
188-
if (res == NULL) {
189-
return NULL;
190-
}
191-
192-
return res;
131+
return PyObject_CallMethod(descr->sample_scalar, "getitem", "n",
132+
(Py_ssize_t)dataptr);
193133
}
194134

195135
static PyType_Slot SampleDType_Slots[] = {

tests/test_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_setitem():
2828
import sample_dtypes
2929

3030
print('* create empty array')
31-
scalar = sample_dtypes.SampleScalar(dtype=int)
31+
scalar = sample_dtypes.SampleScalar(shape=(3,), dtype=int)
3232

3333
print(' * set compatible value')
3434
arr: np.ndarray = np.empty(3, sample_dtypes.SampleDType(scalar))

0 commit comments

Comments
 (0)