|
8 | 8 | #include <Python.h>
|
9 | 9 |
|
10 | 10 | #include "src/common.h"
|
| 11 | +#include "src/some_dtype.h" |
11 | 12 |
|
12 | 13 | namespace sample_dtypes {
|
13 | 14 |
|
| 15 | +static PyObject *RegisterDtype(PyObject *self, PyObject *args) { |
| 16 | + std::cout << "DEBUG: " << __func__ << std::endl; |
| 17 | + |
| 18 | + const char *name = nullptr; |
| 19 | + unsigned int count = 0; |
| 20 | + PyTypeObject *data_type = nullptr; |
| 21 | + if (!PyArg_ParseTuple(args, "sO!I", &name, &PyType_Type, &data_type, |
| 22 | + &count)) { |
| 23 | + std::cout << "DEBUG: " << "PyArg_ParseTuple failed" << std::endl; |
| 24 | + return nullptr; |
| 25 | + } |
| 26 | + |
| 27 | + // Get the size of type |
| 28 | + unsigned data_type_size = 0; |
| 29 | + if (PyType_IsSubtype(data_type, &PyLong_Type)) { |
| 30 | + data_type_size = sizeof(uint64_t); |
| 31 | + std::cout << "DEBUG: data_type(int) specified" << std::endl; |
| 32 | + } else if (PyType_IsSubtype(data_type, &PyFloat_Type)) { |
| 33 | + data_type_size = sizeof(float); |
| 34 | + std::cout << "DEBUG: data_type(float) specified" << std::endl; |
| 35 | + } else { |
| 36 | + // Not a supported python data-type, check numpy types |
| 37 | + // PyObject *type_dict = PyType_GetDict(type); |
| 38 | + PyObject *type_nbytes = PyObject_GetAttrString( |
| 39 | + reinterpret_cast<PyObject *>(data_type), "nbytes"); |
| 40 | + if (type_nbytes == nullptr) { |
| 41 | + std::cout << "DEBUG: PyObject_GetAttrString 'nbytes' failed" << std::endl; |
| 42 | + PyErr_Format(PyExc_TypeError, "Cannot interpret '%R' as a data type", |
| 43 | + data_type); |
| 44 | + return nullptr; |
| 45 | + } |
| 46 | + if (!PyArg_Parse(type_nbytes, "I", &data_type_size)) { |
| 47 | + std::cout << "DEBUG: PyArg_Parse failed" << std::endl; |
| 48 | + return nullptr; |
| 49 | + } |
| 50 | + std::cout << "DEBUG: numpy.dtype specified, nbytes=" << data_type_size |
| 51 | + << std::endl; |
| 52 | + } |
| 53 | + |
| 54 | + PyObject *type_name = PyType_GetQualName(data_type); |
| 55 | + const char *type_name_str = nullptr; |
| 56 | + if (!PyArg_Parse(type_name, "s", &type_name_str)) { |
| 57 | + std::cout << "DEBUG: PyArg_Parse failed" << std::endl; |
| 58 | + return nullptr; |
| 59 | + } |
| 60 | + std::cout << "DEBUG: _register_dtype_ext(name=" << name |
| 61 | + << ", data_type=" << type_name_str << ", count=" << count |
| 62 | + << ") -> data_type_size=" << data_type_size << std::endl; |
| 63 | + |
| 64 | + // Create python type object |
| 65 | + PyType_Slot type_slots[]{ |
| 66 | + {0, nullptr}, |
| 67 | + }; |
| 68 | + typedef struct { |
| 69 | + PyObject_HEAD |
| 70 | + // int value[0]; |
| 71 | + } type_struct; |
| 72 | + PyType_Spec type_spec = { |
| 73 | + .name = name, |
| 74 | + .basicsize = |
| 75 | + static_cast<int>(sizeof(type_struct) + data_type_size * count), |
| 76 | + .itemsize = 0, |
| 77 | + .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, |
| 78 | + .slots = type_slots, |
| 79 | + }; |
| 80 | + PyObject *type = PyType_FromSpecWithBases( |
| 81 | + &type_spec, reinterpret_cast<PyObject *>(&PyGenericArrType_Type)); |
| 82 | + if (!type) { |
| 83 | + return nullptr; |
| 84 | + } |
| 85 | + |
| 86 | + return type; |
| 87 | +} |
| 88 | + |
14 | 89 | /*
|
15 | 90 | *TODO: Implement Dtype casts registration
|
16 | 91 | */
|
@@ -49,13 +124,25 @@ bool Initialize() {
|
49 | 124 | return success;
|
50 | 125 | }
|
51 | 126 |
|
| 127 | +static PyMethodDef module_methods[] = {{"_register_dtype_ext", RegisterDtype, |
| 128 | + METH_VARARGS, |
| 129 | + "Create/register a dtype class"}, |
| 130 | + {nullptr, nullptr, 0}}; |
| 131 | + |
52 | 132 | static PyModuleDef module_def = {
|
53 | 133 | PyModuleDef_HEAD_INIT,
|
54 | 134 | "_sample_dtypes_ext",
|
| 135 | + "sample_dtypes extension module", |
| 136 | + -1, |
| 137 | + module_methods, |
55 | 138 | };
|
56 | 139 |
|
57 | 140 | PyMODINIT_FUNC PyInit__sample_dtypes_ext() {
|
58 |
| - std::cout << "DEBUG: " << __func__ << std::endl; |
| 141 | + std::cout << "DEBUG: " << __DATE__ << " " << __TIME__ << ": " << __func__ |
| 142 | + << std::endl; |
| 143 | +#ifdef _HYBRID |
| 144 | + std::cout << "HYBRID build" << std::endl; |
| 145 | +#endif |
59 | 146 |
|
60 | 147 | Safe_PyObjectPtr m = make_safe(PyModule_Create(&module_def));
|
61 | 148 | if (!m) {
|
|
0 commit comments