Skip to content

Commit 1bed620

Browse files
committed
C deserialization should accept f2 numpy half type
1 parent 002c373 commit 1bed620

6 files changed

Lines changed: 53 additions & 29 deletions

File tree

c/src/neighbors/brute_force.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
/*
3-
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
3+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
44
* SPDX-License-Identifier: Apache-2.0
55
*/
66

@@ -245,10 +245,13 @@ extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res,
245245
index->dtype.bits = dtype.itemsize * 8;
246246
if (dtype.kind == 'f' && dtype.itemsize == 4) {
247247
index->dtype.code = kDLFloat;
248-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
249-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
248+
index->addr =
249+
reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
250+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
251+
dtype.itemsize == 2) {
250252
index->dtype.code = kDLFloat;
251-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<half>(res, filename));
253+
index->addr =
254+
reinterpret_cast<uintptr_t>(_deserialize<half>(res, filename));
252255
} else {
253256
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
254257
}

c/src/neighbors/cagra.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,10 +879,13 @@ extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
879879

880880
index->dtype.bits = dtype.itemsize * 8;
881881
if (dtype.kind == 'f' && dtype.itemsize == 4) {
882-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
882+
index->addr =
883+
reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
883884
index->dtype.code = kDLFloat;
884-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
885-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<half>(res, filename));
885+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
886+
dtype.itemsize == 2) {
887+
index->addr =
888+
reinterpret_cast<uintptr_t>(_deserialize<half>(res, filename));
886889
index->dtype.code = kDLFloat;
887890
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
888891
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename));

c/src/neighbors/ivf_flat.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
/*
3-
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
3+
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
44
* SPDX-License-Identifier: Apache-2.0
55
*/
66

@@ -305,10 +305,13 @@ extern "C" cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res,
305305

306306
index->dtype.bits = dtype.itemsize * 8;
307307
if (dtype.kind == 'f' && dtype.itemsize == 4) {
308-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float, int64_t>(res, filename));
308+
index->addr = reinterpret_cast<uintptr_t>(
309+
_deserialize<float, int64_t>(res, filename));
309310
index->dtype.code = kDLFloat;
310-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
311-
index->addr = reinterpret_cast<uintptr_t>(_deserialize<half, int64_t>(res, filename));
311+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
312+
dtype.itemsize == 2) {
313+
index->addr = reinterpret_cast<uintptr_t>(
314+
_deserialize<half, int64_t>(res, filename));
312315
index->dtype.code = kDLFloat;
313316
index->dtype.bits = 16;
314317
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {

c/src/neighbors/mg_cagra.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -407,10 +407,13 @@ extern "C" cuvsError_t cuvsMultiGpuCagraDeserialize(cuvsResources_t res,
407407
index->dtype.bits = dtype.itemsize * 8;
408408
if (dtype.kind == 'f' && dtype.itemsize == 4) {
409409
index->dtype.code = kDLFloat;
410-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
411-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
410+
index->addr =
411+
reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
412+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
413+
dtype.itemsize == 2) {
412414
index->dtype.code = kDLFloat;
413-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
415+
index->addr =
416+
reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
414417
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
415418
index->dtype.code = kDLInt;
416419
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<int8_t>(res, filename));
@@ -438,10 +441,13 @@ extern "C" cuvsError_t cuvsMultiGpuCagraDistribute(cuvsResources_t res,
438441
index->dtype.bits = dtype.itemsize * 8;
439442
if (dtype.kind == 'f' && dtype.itemsize == 4) {
440443
index->dtype.code = kDLFloat;
441-
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename));
442-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
444+
index->addr =
445+
reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename));
446+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
447+
dtype.itemsize == 2) {
443448
index->dtype.code = kDLFloat;
444-
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename));
449+
index->addr =
450+
reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename));
445451
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
446452
index->dtype.code = kDLInt;
447453
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<int8_t>(res, filename));

c/src/neighbors/mg_ivf_flat.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -404,10 +404,13 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDeserialize(cuvsResources_t res,
404404
index->dtype.bits = dtype.itemsize * 8;
405405
if (dtype.kind == 'f' && dtype.itemsize == 4) {
406406
index->dtype.code = kDLFloat;
407-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
408-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
407+
index->addr =
408+
reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
409+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
410+
dtype.itemsize == 2) {
409411
index->dtype.code = kDLFloat;
410-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
412+
index->addr =
413+
reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
411414
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
412415
index->dtype.code = kDLInt;
413416
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<int8_t>(res, filename));
@@ -435,10 +438,13 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDistribute(cuvsResources_t res,
435438
index->dtype.bits = dtype.itemsize * 8;
436439
if (dtype.kind == 'f' && dtype.itemsize == 4) {
437440
index->dtype.code = kDLFloat;
438-
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename));
439-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
441+
index->addr =
442+
reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename));
443+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
444+
dtype.itemsize == 2) {
440445
index->dtype.code = kDLFloat;
441-
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename));
446+
index->addr =
447+
reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename));
442448
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
443449
index->dtype.code = kDLInt;
444450
index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<int8_t>(res, filename));

c/src/neighbors/mg_ivf_pq.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -396,10 +396,13 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqDeserialize(cuvsResources_t res,
396396
index->dtype.bits = dtype.itemsize * 8;
397397
if (dtype.kind == 'f' && dtype.itemsize == 4) {
398398
index->dtype.code = kDLFloat;
399-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
400-
} else if (dtype.kind == 'e' && dtype.itemsize == 2) {
399+
index->addr =
400+
reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename));
401+
} else if ((dtype.kind == 'f' || dtype.kind == 'e') &&
402+
dtype.itemsize == 2) {
401403
index->dtype.code = kDLFloat;
402-
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
404+
index->addr =
405+
reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename));
403406
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
404407
index->dtype.code = kDLInt;
405408
index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<int8_t>(res, filename));

0 commit comments

Comments
 (0)