@@ -1081,44 +1081,12 @@ string_isnan_resolve_descriptors(
1081
1081
* Copied from NumPy, because NumPy doesn't always use it :)
1082
1082
*/
1083
1083
static int
1084
- ufunc_promoter_internal (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1085
- PyArray_DTypeMeta *signature[],
1086
- PyArray_DTypeMeta *new_op_dtypes[],
1087
- PyArray_DTypeMeta *final_dtype)
1084
+ string_inputs_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1085
+ PyArray_DTypeMeta *signature[],
1086
+ PyArray_DTypeMeta *new_op_dtypes[],
1087
+ PyArray_DTypeMeta *final_dtype)
1088
1088
{
1089
- /* If nin < 2 promotion is a no-op, so it should not be registered */
1090
- assert(ufunc->nin > 1);
1091
- if (op_dtypes[0] == NULL) {
1092
- assert(ufunc->nin == 2 && ufunc->nout == 1); /* must be reduction */
1093
- Py_INCREF(op_dtypes[1]);
1094
- new_op_dtypes[0] = op_dtypes[1];
1095
- Py_INCREF(op_dtypes[1]);
1096
- new_op_dtypes[1] = op_dtypes[1];
1097
- Py_INCREF(op_dtypes[1]);
1098
- new_op_dtypes[2] = op_dtypes[1];
1099
- return 0;
1100
- }
1101
- PyArray_DTypeMeta *common = NULL;
1102
- /*
1103
- * If a signature is used and homogeneous in its outputs use that
1104
- * (Could/should likely be rather applied to inputs also, although outs
1105
- * only could have some advantage and input dtypes are rarely enforced.)
1106
- */
1107
- for (int i = ufunc->nin; i < ufunc->nargs; i++) {
1108
- if (signature[i] != NULL) {
1109
- if (common == NULL) {
1110
- Py_INCREF(signature[i]);
1111
- common = signature[i];
1112
- }
1113
- else if (common != signature[i]) {
1114
- Py_CLEAR(common); /* Not homogeneous, unset common */
1115
- break;
1116
- }
1117
- }
1118
- }
1119
- Py_XDECREF(common);
1120
-
1121
- /* Otherwise, set all input operands to final_dtype */
1089
+ /* set all input operands to final_dtype */
1122
1090
for (int i = 0; i < ufunc->nargs; i++) {
1123
1091
PyArray_DTypeMeta *tmp = final_dtype;
1124
1092
if (signature[i]) {
@@ -1127,6 +1095,7 @@ ufunc_promoter_internal(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1127
1095
Py_INCREF(tmp);
1128
1096
new_op_dtypes[i] = tmp;
1129
1097
}
1098
+ /* don't touch output dtypes */
1130
1099
for (int i = ufunc->nin; i < ufunc->nargs; i++) {
1131
1100
Py_XINCREF(op_dtypes[i]);
1132
1101
new_op_dtypes[i] = op_dtypes[i];
@@ -1140,19 +1109,50 @@ string_object_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1140
1109
PyArray_DTypeMeta *signature[],
1141
1110
PyArray_DTypeMeta *new_op_dtypes[])
1142
1111
{
1143
- return ufunc_promoter_internal ((PyUFuncObject *)ufunc, op_dtypes,
1144
- signature, new_op_dtypes,
1145
- (PyArray_DTypeMeta *)&PyArray_ObjectDType);
1112
+ return string_inputs_promoter ((PyUFuncObject *)ufunc, op_dtypes, signature ,
1113
+ new_op_dtypes,
1114
+ (PyArray_DTypeMeta *)&PyArray_ObjectDType);
1146
1115
}
1147
1116
1148
1117
static int
1149
1118
string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1150
1119
PyArray_DTypeMeta *signature[],
1151
1120
PyArray_DTypeMeta *new_op_dtypes[])
1152
1121
{
1153
- return ufunc_promoter_internal((PyUFuncObject *)ufunc, op_dtypes,
1154
- signature, new_op_dtypes,
1155
- (PyArray_DTypeMeta *)&StringDType);
1122
+ return string_inputs_promoter((PyUFuncObject *)ufunc, op_dtypes, signature,
1123
+ new_op_dtypes,
1124
+ (PyArray_DTypeMeta *)&StringDType);
1125
+ }
1126
+
1127
+ static int
1128
+ string_multiply_promoter(PyObject *ufunc_obj, PyArray_DTypeMeta *op_dtypes[],
1129
+ PyArray_DTypeMeta *signature[],
1130
+ PyArray_DTypeMeta *new_op_dtypes[])
1131
+ {
1132
+ PyUFuncObject *ufunc = (PyUFuncObject *)ufunc_obj;
1133
+ for (int i = 0; i < ufunc->nargs; i++) {
1134
+ PyArray_DTypeMeta *tmp = NULL;
1135
+ if (signature[i]) {
1136
+ tmp = signature[i];
1137
+ }
1138
+ else if (op_dtypes[i] == &PyArray_PyIntAbstractDType) {
1139
+ tmp = &PyArray_Int64DType;
1140
+ }
1141
+ else if (op_dtypes[i]) {
1142
+ tmp = op_dtypes[i];
1143
+ }
1144
+ else {
1145
+ tmp = (PyArray_DTypeMeta *)&StringDType;
1146
+ }
1147
+ Py_INCREF(tmp);
1148
+ new_op_dtypes[i] = tmp;
1149
+ }
1150
+ /* don't touch output dtypes */
1151
+ for (int i = ufunc->nin; i < ufunc->nargs; i++) {
1152
+ Py_XINCREF(op_dtypes[i]);
1153
+ new_op_dtypes[i] = op_dtypes[i];
1154
+ }
1155
+ return 0;
1156
1156
}
1157
1157
1158
1158
// Register a ufunc.
@@ -1161,14 +1161,18 @@ string_unicode_promoter(PyObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
1161
1161
int
1162
1162
init_ufunc(PyObject *numpy, const char *ufunc_name, PyArray_DTypeMeta **dtypes,
1163
1163
resolve_descriptors_function *resolve_func,
1164
- PyArrayMethod_StridedLoop *loop_func, const char *loop_name ,
1165
- int nin, int nout, NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
1164
+ PyArrayMethod_StridedLoop *loop_func, int nin, int nout ,
1165
+ NPY_CASTING casting, NPY_ARRAYMETHOD_FLAGS flags)
1166
1166
{
1167
1167
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
1168
1168
if (ufunc == NULL) {
1169
1169
return -1;
1170
1170
}
1171
1171
1172
+ char loop_name[256] = {0};
1173
+
1174
+ snprintf(loop_name, sizeof(loop_name), "string_%s", ufunc_name);
1175
+
1172
1176
PyArrayMethod_Spec spec = {
1173
1177
.name = loop_name,
1174
1178
.nin = nin,
@@ -1208,7 +1212,7 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1208
1212
PyArray_DTypeMeta *ldtype, PyArray_DTypeMeta *rdtype,
1209
1213
PyArray_DTypeMeta *edtype, promoter_function *promoter_impl)
1210
1214
{
1211
- PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
1215
+ PyObject *ufunc = PyObject_GetAttrString((PyObject *) numpy, ufunc_name);
1212
1216
1213
1217
if (ufunc == NULL) {
1214
1218
return -1;
@@ -1251,8 +1255,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1251
1255
\
1252
1256
if (init_ufunc(numpy, "multiply", multiply_right_##shortname##_types, \
1253
1257
&multiply_resolve_descriptors, \
1254
- &multiply_right_##shortname##_strided_loop, \
1255
- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1258
+ &multiply_right_##shortname##_strided_loop, 2, 1, \
1259
+ NPY_NO_CASTING, 0) < 0) { \
1256
1260
goto error; \
1257
1261
} \
1258
1262
\
@@ -1262,8 +1266,8 @@ add_promoter(PyObject *numpy, const char *ufunc_name,
1262
1266
\
1263
1267
if (init_ufunc(numpy, "multiply", multiply_left_##shortname##_types, \
1264
1268
&multiply_resolve_descriptors, \
1265
- &multiply_left_##shortname##_strided_loop, \
1266
- "string_multiply", 2, 1, NPY_NO_CASTING, 0) < 0) { \
1269
+ &multiply_left_##shortname##_strided_loop, 2, 1, \
1270
+ NPY_NO_CASTING, 0) < 0) { \
1267
1271
goto error; \
1268
1272
}
1269
1273
@@ -1279,53 +1283,23 @@ init_ufuncs(void)
1279
1283
"greater", "greater_equal",
1280
1284
"less", "less_equal"};
1281
1285
1286
+ static PyArrayMethod_StridedLoop *strided_loops[6] = {
1287
+ &string_equal_strided_loop, &string_not_equal_strided_loop,
1288
+ &string_greater_strided_loop, &string_greater_equal_strided_loop,
1289
+ &string_less_strided_loop, &string_less_equal_strided_loop,
1290
+ };
1291
+
1282
1292
PyArray_DTypeMeta *comparison_dtypes[] = {
1283
1293
(PyArray_DTypeMeta *)&StringDType,
1284
1294
(PyArray_DTypeMeta *)&StringDType, &PyArray_BoolDType};
1285
1295
1286
- if (init_ufunc(numpy, "equal", comparison_dtypes,
1287
- &string_comparison_resolve_descriptors,
1288
- &string_equal_strided_loop, "string_equal", 2, 1,
1289
- NPY_NO_CASTING, 0) < 0) {
1290
- goto error;
1291
- }
1292
-
1293
- if (init_ufunc(numpy, "not_equal", comparison_dtypes,
1294
- &string_comparison_resolve_descriptors,
1295
- &string_not_equal_strided_loop, "string_not_equal", 2, 1,
1296
- NPY_NO_CASTING, 0) < 0) {
1297
- goto error;
1298
- }
1299
-
1300
- if (init_ufunc(numpy, "greater", comparison_dtypes,
1301
- &string_comparison_resolve_descriptors,
1302
- &string_greater_strided_loop, "string_greater", 2, 1,
1303
- NPY_NO_CASTING, 0) < 0) {
1304
- goto error;
1305
- }
1306
-
1307
- if (init_ufunc(numpy, "greater_equal", comparison_dtypes,
1308
- &string_comparison_resolve_descriptors,
1309
- &string_greater_equal_strided_loop, "string_greater_equal",
1310
- 2, 1, NPY_NO_CASTING, 0) < 0) {
1311
- goto error;
1312
- }
1313
-
1314
- if (init_ufunc(numpy, "less", comparison_dtypes,
1315
- &string_comparison_resolve_descriptors,
1316
- &string_less_strided_loop, "string_less", 2, 1,
1317
- NPY_NO_CASTING, 0) < 0) {
1318
- goto error;
1319
- }
1320
-
1321
- if (init_ufunc(numpy, "less_equal", comparison_dtypes,
1322
- &string_comparison_resolve_descriptors,
1323
- &string_less_equal_strided_loop, "string_less_equal", 2, 1,
1324
- NPY_NO_CASTING, 0) < 0) {
1325
- goto error;
1326
- }
1327
-
1328
1296
for (int i = 0; i < 6; i++) {
1297
+ if (init_ufunc(numpy, comparison_ufunc_names[i], comparison_dtypes,
1298
+ &string_comparison_resolve_descriptors,
1299
+ strided_loops[i], 2, 1, NPY_NO_CASTING, 0) < 0) {
1300
+ goto error;
1301
+ }
1302
+
1329
1303
if (add_promoter(numpy, comparison_ufunc_names[i],
1330
1304
(PyArray_DTypeMeta *)&StringDType,
1331
1305
&PyArray_UnicodeDType, &PyArray_BoolDType,
@@ -1360,8 +1334,7 @@ init_ufuncs(void)
1360
1334
1361
1335
if (init_ufunc(numpy, "isnan", isnan_dtypes,
1362
1336
&string_isnan_resolve_descriptors,
1363
- &string_isnan_strided_loop, "string_isnan", 1, 1,
1364
- NPY_NO_CASTING, 0) < 0) {
1337
+ &string_isnan_strided_loop, 1, 1, NPY_NO_CASTING, 0) < 0) {
1365
1338
goto error;
1366
1339
}
1367
1340
@@ -1372,20 +1345,17 @@ init_ufuncs(void)
1372
1345
};
1373
1346
1374
1347
if (init_ufunc(numpy, "maximum", binary_dtypes, binary_resolve_descriptors,
1375
- &maximum_strided_loop, "string_maximum", 2, 1,
1376
- NPY_NO_CASTING, 0) < 0) {
1348
+ &maximum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
1377
1349
goto error;
1378
1350
}
1379
1351
1380
1352
if (init_ufunc(numpy, "minimum", binary_dtypes, binary_resolve_descriptors,
1381
- &minimum_strided_loop, "string_minimum", 2, 1,
1382
- NPY_NO_CASTING, 0) < 0) {
1353
+ &minimum_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
1383
1354
goto error;
1384
1355
}
1385
1356
1386
1357
if (init_ufunc(numpy, "add", binary_dtypes, binary_resolve_descriptors,
1387
- &add_strided_loop, "string_add", 2, 1, NPY_NO_CASTING,
1388
- 0) < 0) {
1358
+ &add_strided_loop, 2, 1, NPY_NO_CASTING, 0) < 0) {
1389
1359
goto error;
1390
1360
}
1391
1361
@@ -1414,6 +1384,20 @@ init_ufuncs(void)
1414
1384
INIT_MULTIPLY(ULongLong, ulonglong);
1415
1385
#endif
1416
1386
1387
+ if (add_promoter(numpy, "multiply", (PyArray_DTypeMeta *)&StringDType,
1388
+ &PyArray_PyIntAbstractDType,
1389
+ (PyArray_DTypeMeta *)&StringDType,
1390
+ string_multiply_promoter) < 0) {
1391
+ goto error;
1392
+ }
1393
+
1394
+ if (add_promoter(numpy, "multiply", &PyArray_PyIntAbstractDType,
1395
+ (PyArray_DTypeMeta *)&StringDType,
1396
+ (PyArray_DTypeMeta *)&StringDType,
1397
+ string_multiply_promoter) < 0) {
1398
+ goto error;
1399
+ }
1400
+
1417
1401
Py_DECREF(numpy);
1418
1402
return 0;
1419
1403
0 commit comments