|
1 | 1 | /* |
2 | | - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | */ |
5 | 5 |
|
@@ -404,10 +404,13 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDeserialize(cuvsResources_t res, |
404 | 404 | index->dtype.bits = dtype.itemsize * 8; |
405 | 405 | if (dtype.kind == 'f' && dtype.itemsize == 4) { |
406 | 406 | index->dtype.code = kDLFloat; |
407 | | - index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename)); |
408 | | - } else if (dtype.kind == 'e' && dtype.itemsize == 2) { |
| 407 | + index->addr = |
| 408 | + reinterpret_cast<uintptr_t>(_mg_deserialize<float>(res, filename)); |
| 409 | + } else if ((dtype.kind == 'f' || dtype.kind == 'e') && |
| 410 | + dtype.itemsize == 2) { |
409 | 411 | index->dtype.code = kDLFloat; |
410 | | - index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename)); |
| 412 | + index->addr = |
| 413 | + reinterpret_cast<uintptr_t>(_mg_deserialize<half>(res, filename)); |
411 | 414 | } else if (dtype.kind == 'i' && dtype.itemsize == 1) { |
412 | 415 | index->dtype.code = kDLInt; |
413 | 416 | index->addr = reinterpret_cast<uintptr_t>(_mg_deserialize<int8_t>(res, filename)); |
@@ -435,10 +438,13 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDistribute(cuvsResources_t res, |
435 | 438 | index->dtype.bits = dtype.itemsize * 8; |
436 | 439 | if (dtype.kind == 'f' && dtype.itemsize == 4) { |
437 | 440 | index->dtype.code = kDLFloat; |
438 | | - index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename)); |
439 | | - } else if (dtype.kind == 'e' && dtype.itemsize == 2) { |
| 441 | + index->addr = |
| 442 | + reinterpret_cast<uintptr_t>(_mg_distribute<float>(res, filename)); |
| 443 | + } else if ((dtype.kind == 'f' || dtype.kind == 'e') && |
| 444 | + dtype.itemsize == 2) { |
440 | 445 | index->dtype.code = kDLFloat; |
441 | | - index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename)); |
| 446 | + index->addr = |
| 447 | + reinterpret_cast<uintptr_t>(_mg_distribute<half>(res, filename)); |
442 | 448 | } else if (dtype.kind == 'i' && dtype.itemsize == 1) { |
443 | 449 | index->dtype.code = kDLInt; |
444 | 450 | index->addr = reinterpret_cast<uintptr_t>(_mg_distribute<int8_t>(res, filename)); |
|
0 commit comments