Skip to content

Commit 520ea6a

Browse files
committed
Redirect NPY_DT_setitem/NPY_DT_getitem to python
1 parent cdc206b commit 520ea6a

File tree

4 files changed

+84
-36
lines changed

4 files changed

+84
-36
lines changed

sample_dtypes/scalar.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,18 @@ def __repr__(self) -> str:
3737
return f'{type(self).__name__}(id={hex(id(self))}, shape={self._ndarr.shape}, dtype={self._ndarr.dtype})'
3838

3939
@property
40-
def data_buffer(self) -> bytes:
40+
def ndarr(self) -> np.ndarray:
41+
return self._ndarr
42+
43+
def setitem(self, target: np.ndarray) -> None:
4144
"""Get raw-data"""
42-
return self._ndarr.tobytes(order=BYTES_BUF_ORDER)
45+
print(
46+
f'TODO: setitem, taget: {type(target)}, {target.shape}, {target.dtype}'
47+
)
48+
target[...] = self._ndarr
4349

44-
def from_data_buffer(self, buf: bytes) -> 'SampleScalar':
45-
"""Create new object using raw-data"""
46-
assert (
47-
len(buf) == self._ndarr.nbytes
48-
), f'Buffer size must be {self._ndarr.nbytes}'
50+
def getitem(self, source: np.ndarray) -> 'SampleScalar':
51+
"""Store raw-data into a new object"""
4952
new = self.copy()
50-
new._ndarr.flat[:] = np.frombuffer(buf, dtype=self._ndarr.dtype)
53+
new._ndarr.flat[:] = source
5154
return new

src/dtype.c

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,37 @@ sampledtype_ensure_canonical(SampleDTypeObject *self) {
103103
return self;
104104
}
105105

106+
static PyObject *create_element_view(SampleDTypeObject *descr, char *dataptr,
107+
npy_bool readonly) {
108+
PyObject *ndarr_obj = PyObject_GetAttrString(descr->sample_scalar, "ndarr");
109+
if (ndarr_obj == NULL) {
110+
return NULL;
111+
}
112+
if (!PyArray_Check(ndarr_obj)) {
113+
Py_DECREF(ndarr_obj);
114+
PyErr_SetString(PyExc_TypeError,
115+
"SampleScalar.ndarr() must return NumPy array");
116+
return NULL;
117+
}
118+
119+
PyArrayObject *ndarr_arrobj = (PyArrayObject *)ndarr_obj;
120+
PyArray_Descr *ndarr_descr = PyArray_DESCR(ndarr_arrobj);
121+
122+
PyObject *view = PyArray_NewFromDescr(
123+
&PyArray_Type, ndarr_descr, PyArray_NDIM(ndarr_arrobj),
124+
PyArray_DIMS(ndarr_arrobj), PyArray_STRIDES(ndarr_arrobj),
125+
dataptr, // data pointer for the view
126+
readonly ? 0 : NPY_ARRAY_WRITEABLE, NULL);
127+
if (view != NULL) {
128+
// Now 'view' owns the descriptor, it will be dereferenced when view is
129+
// deallocated
130+
Py_INCREF(ndarr_descr);
131+
}
132+
133+
Py_DECREF(ndarr_obj);
134+
return view;
135+
}
136+
106137
static int sampledtype_setitem(SampleDTypeObject *descr, PyObject *obj,
107138
char *dataptr) {
108139
printf("%s, target elsise %lld, type_num %d\n", __func__, descr->base.elsize,
@@ -127,29 +158,17 @@ static int sampledtype_setitem(SampleDTypeObject *descr, PyObject *obj,
127158
return -1;
128159
}
129160

130-
// Copy data-bytes from 'obj' into 'dataptr'
131-
PyObject *ndarr_bytes = PyObject_GetAttrString(obj, "data_buffer");
132-
if (ndarr_bytes == NULL) {
161+
PyObject *view = create_element_view(descr, dataptr, NPY_FALSE);
162+
if (view == NULL) {
133163
return -1;
134164
}
135165

136-
const char *src_data = NULL;
137-
Py_ssize_t data_size = 0;
138-
if (PyBytes_AsStringAndSize(ndarr_bytes, &src_data, &data_size) != 0) {
139-
Py_DECREF(ndarr_bytes);
140-
return -1;
141-
}
142-
143-
if (data_size != descr->base.elsize) {
144-
Py_DECREF(ndarr_bytes);
145-
PyErr_Format(PyExc_ValueError,
146-
"Unexpected item data buffer size: %zd / %zd", data_size,
147-
descr->base.elsize);
166+
res = PyObject_CallMethod(obj, "setitem", "O", view);
167+
Py_DECREF(view);
168+
if (res == NULL) {
148169
return -1;
149170
}
150-
151-
memcpy(dataptr, src_data, data_size);
152-
Py_DECREF(ndarr_bytes);
171+
Py_DECREF(res);
153172

154173
return 0;
155174
}
@@ -158,16 +177,17 @@ static PyObject *sampledtype_getitem(SampleDTypeObject *descr, char *dataptr) {
158177
printf("%s, source elsize %lld, type_num %d\n", __func__, descr->base.elsize,
159178
descr->base.type_num);
160179

161-
// Pass data-bytes from 'dataptr' to 'obj'
162-
PyObject *ndarr_bytes =
163-
PyBytes_FromStringAndSize(dataptr, descr->base.elsize);
164-
if (ndarr_bytes == NULL) {
180+
PyObject *view = create_element_view(descr, dataptr, NPY_TRUE);
181+
if (view == NULL) {
165182
return NULL;
166183
}
167184

168-
PyObject *res = PyObject_CallMethod(descr->sample_scalar, "from_data_buffer",
169-
"O", ndarr_bytes);
170-
Py_DECREF(ndarr_bytes);
185+
PyObject *res =
186+
PyObject_CallMethod(descr->sample_scalar, "getitem", "O", view);
187+
Py_DECREF(view);
188+
if (res == NULL) {
189+
return NULL;
190+
}
171191

172192
return res;
173193
}

tests/test_create.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test basic module import"""
1+
"""Test DType creation/instantiation"""
22

33
import numpy as np
44
import pytest
@@ -31,12 +31,12 @@ def test_setitem():
3131
scalar = sample_dtypes.SampleScalar(dtype=int)
3232

3333
print(' * set compatible value')
34-
arr = np.empty(3, sample_dtypes.SampleDType(scalar))
34+
arr: np.ndarray = np.empty(3, sample_dtypes.SampleDType(scalar))
3535
scalar._ndarr.flat[...] = np.arange(scalar._ndarr.size) + 10
3636
arr[1] = scalar
3737
np.testing.assert_equal(arr[1]._ndarr, scalar._ndarr)
3838

3939
print(' * set incompatible value')
40-
arr: np.ndarray = np.empty(3, sample_dtypes.SampleDType)
40+
arr = np.empty(3, sample_dtypes.SampleDType)
4141
with pytest.raises(ValueError):
4242
arr[0] = scalar

tests/test_refcount.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Test internal object reference counting"""
2+
3+
import sys
4+
import numpy as np
5+
6+
7+
def test_dtype_refcnt():
8+
"""Test SampleScalar internal dtype reference counts
9+
10+
NOTE: This fails when debugging as of extra ownership of local object
11+
"""
12+
import sample_dtypes
13+
14+
# Use non-singleton (non-immortal) dtypes to have valid ref-counts
15+
dtype = np.dtype([('a', np.int16)])
16+
assert (
17+
sys.getrefcount(dtype) == 2
18+
), 'Unexpected initial ref-count for non-immortal object'
19+
20+
scalar = sample_dtypes.SampleScalar(dtype=dtype)
21+
arr: np.ndarray = np.empty(3, sample_dtypes.SampleDType(scalar))
22+
arr[1] = scalar
23+
24+
del arr, scalar
25+
assert sys.getrefcount(dtype) == 2, 'Unexpected dtype ref-count'

0 commit comments

Comments
 (0)