From fdf6743938d6f514c3aa5d31eb12a7d5e640f771 Mon Sep 17 00:00:00 2001 From: blaise-muhirwa Date: Mon, 5 Feb 2024 14:32:04 -0800 Subject: [PATCH] special-casing .npy files --- experiments/run-big-bench.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/experiments/run-big-bench.py b/experiments/run-big-bench.py index b251436..3379095 100644 --- a/experiments/run-big-bench.py +++ b/experiments/run-big-bench.py @@ -26,16 +26,6 @@ } -def load_sift_dataset( - train_dataset_path: str, queries_path: str, gtruth_path: str -) -> Tuple[np.ndarray]: - return ( - np.load(train_dataset_path).astype(np.float32), - np.load(queries_path).astype(np.float32), - np.load(gtruth_path).astype(np.uint32), - ) - - def load_benchmark_dataset( train_dataset_path: str, queries_path: str, @@ -81,6 +71,17 @@ def load_ground_truth(path: str) -> Tuple[np.ndarray, np.ndarray, int, int]: verify_paths_exist([train_dataset_path, queries_path, gtruth_path]) + files_have_npy_extensions = all( + dataset.endswith("npy") + for dataset in [train_dataset_path, queries_path, gtruth_path] + ) + if files_have_npy_extensions: + return ( + np.load(train_dataset_path).astype(np.float32, copy=False), + np.load(queries_path, np.float32, copy=False), + np.load(gtruth_path, np.uint32, copy=False), + ) + train_dtype = np.float32 if train_dataset_path.endswith("fbin") else np.uint8 total_size = os.path.getsize(train_dataset_path) // np.dtype(train_dtype).itemsize