Skip to content

Commit 6ecb59d

Browse files
committed
Experiment: add register_dtype and some_dtype.h
1 parent fe99978 commit 6ecb59d

File tree

5 files changed

+129
-3
lines changed

5 files changed

+129
-3
lines changed

sample_dtypes/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,17 @@
33
__version__ = "0.0.0" # Keep in sync with pyproject.toml:version
44
__all__ = [
55
"__version__",
6+
"register_dtype",
67
]
8+
9+
import sample_dtypes
10+
from sample_dtypes._sample_dtypes_ext import _register_dtype_ext
11+
12+
13+
def register_dtype(name: str, data_type: type, count: int):
14+
dtype = _register_dtype_ext(name, data_type, count)
15+
setattr(sample_dtypes, name, dtype)
16+
setattr(dtype, "__module__", sample_dtypes)
17+
# Support dtype(type_name)
18+
# setattr(dtype, 'dtype', npy_descr)
19+
return dtype

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
if platform.system() == "Windows":
1010
COMPILE_ARGS = [
11-
"/std:c++17",
11+
"/std:c++20",
1212
]
1313
else:
1414
COMPILE_ARGS = [
15-
"-std=c++17",
15+
"-std=c++20",
1616
]
1717

1818

src/dtypes.cc

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,84 @@
88
#include <Python.h>
99

1010
#include "src/common.h"
11+
#include "src/some_dtype.h"
1112

1213
namespace sample_dtypes {
1314

15+
static PyObject *RegisterDtype(PyObject *self, PyObject *args) {
16+
std::cout << "DEBUG: " << __func__ << std::endl;
17+
18+
const char *name = nullptr;
19+
unsigned int count = 0;
20+
PyTypeObject *data_type = nullptr;
21+
if (!PyArg_ParseTuple(args, "sO!I", &name, &PyType_Type, &data_type,
22+
&count)) {
23+
std::cout << "DEBUG: " << "PyArg_ParseTuple failed" << std::endl;
24+
return nullptr;
25+
}
26+
27+
// Get the size of type
28+
unsigned data_type_size = 0;
29+
if (PyType_IsSubtype(data_type, &PyLong_Type)) {
30+
data_type_size = sizeof(uint64_t);
31+
std::cout << "DEBUG: data_type(int) specified" << std::endl;
32+
} else if (PyType_IsSubtype(data_type, &PyFloat_Type)) {
33+
data_type_size = sizeof(float);
34+
std::cout << "DEBUG: data_type(float) specified" << std::endl;
35+
} else {
36+
// Not a supported python data-type, check numpy types
37+
// PyObject *type_dict = PyType_GetDict(type);
38+
PyObject *type_nbytes = PyObject_GetAttrString(
39+
reinterpret_cast<PyObject *>(data_type), "nbytes");
40+
if (type_nbytes == nullptr) {
41+
std::cout << "DEBUG: PyObject_GetAttrString 'nbytes' failed" << std::endl;
42+
PyErr_Format(PyExc_TypeError, "Cannot interpret '%R' as a data type",
43+
data_type);
44+
return nullptr;
45+
}
46+
if (!PyArg_Parse(type_nbytes, "I", &data_type_size)) {
47+
std::cout << "DEBUG: PyArg_Parse failed" << std::endl;
48+
return nullptr;
49+
}
50+
std::cout << "DEBUG: numpy.dtype specified, nbytes=" << data_type_size
51+
<< std::endl;
52+
}
53+
54+
PyObject *type_name = PyType_GetQualName(data_type);
55+
const char *type_name_str = nullptr;
56+
if (!PyArg_Parse(type_name, "s", &type_name_str)) {
57+
std::cout << "DEBUG: PyArg_Parse failed" << std::endl;
58+
return nullptr;
59+
}
60+
std::cout << "DEBUG: _register_dtype_ext(name=" << name
61+
<< ", data_type=" << type_name_str << ", count=" << count
62+
<< ") -> data_type_size=" << data_type_size << std::endl;
63+
64+
// Create python type object
65+
PyType_Slot type_slots[]{
66+
{0, nullptr},
67+
};
68+
typedef struct {
69+
PyObject_HEAD
70+
// int value[0];
71+
} type_struct;
72+
PyType_Spec type_spec = {
73+
.name = name,
74+
.basicsize =
75+
static_cast<int>(sizeof(type_struct) + data_type_size * count),
76+
.itemsize = 0,
77+
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
78+
.slots = type_slots,
79+
};
80+
PyObject *type = PyType_FromSpecWithBases(
81+
&type_spec, reinterpret_cast<PyObject *>(&PyGenericArrType_Type));
82+
if (!type) {
83+
return nullptr;
84+
}
85+
86+
return type;
87+
}
88+
1489
/*
1590
*TODO: Implement Dtype casts registration
1691
*/
@@ -49,13 +124,25 @@ bool Initialize() {
49124
return success;
50125
}
51126

127+
static PyMethodDef module_methods[] = {{"_register_dtype_ext", RegisterDtype,
128+
METH_VARARGS,
129+
"Create/register a dtype class"},
130+
{nullptr, nullptr, 0}};
131+
52132
static PyModuleDef module_def = {
53133
PyModuleDef_HEAD_INIT,
54134
"_sample_dtypes_ext",
135+
"sample_dtypes extension module",
136+
-1,
137+
module_methods,
55138
};
56139

57140
PyMODINIT_FUNC PyInit__sample_dtypes_ext() {
58-
std::cout << "DEBUG: " << __func__ << std::endl;
141+
std::cout << "DEBUG: " << __DATE__ << " " << __TIME__ << ": " << __func__
142+
<< std::endl;
143+
#ifdef _HYBRID
144+
std::cout << "HYBRID build" << std::endl;
145+
#endif
59146

60147
Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
61148
if (!m) {

src/some_dtype.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef SAMPLE_DTYPES_SOME_DTYPE_H_
2+
#define SAMPLE_DTYPES_SOME_DTYPE_H_
3+
4+
namespace sample_dtypes {
5+
6+
template <int N, typename T> struct some_dtype {
7+
T v[N];
8+
};
9+
10+
} // namespace sample_dtypes
11+
12+
#endif // SAMPLE_DTYPES_SOME_DTYPE_H_

tests/test_register.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Test register_dtype()"""
2+
3+
4+
def test_register_dtype():
5+
from sample_dtypes import _sample_dtypes_ext as pyd
6+
7+
print(
8+
f'_register_dtype_ext: {type(pyd._register_dtype_ext)}, {type(pyd.__doc__)}'
9+
)
10+
11+
import sample_dtypes
12+
13+
dtype = sample_dtypes.register_dtype('type_name', int, 4)
14+
assert dtype.__module__ is sample_dtypes

0 commit comments

Comments
 (0)