Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 100 additions & 12 deletions ArtExtract_Mingchun/retrival/searching_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,52 @@ def load_index(path: str, metric: Metric, kind: str, dim: int, meta: Optional[di
return IndexBundle(index=index, metric=metric, kind=kind, dim=dim, meta=meta or {})


# ============ Change 1: normalize() + unnormalized warning ============
def normalize(X: np.ndarray) -> np.ndarray:
"""
L2-normalize row vectors to unit length.

Required before adding vectors to an ip (inner-product) index if you want
cosine similarity semantics. Without this, metric='ip' computes raw dot
product — NOT cosine — and silently returns wrong similarity rankings.

Args:
X: shape (N, D) or (D,) — promoted to 2D internally.

Returns:
L2-normalized float32 array, same shape as input.

Example:
X = normalize(X)
bundle = build_index_flat(X, metric="ip") # now correctly cosine
"""
X = np.atleast_2d(X).astype("float32")
norms = np.linalg.norm(X, axis=1, keepdims=True)
return X / (norms + 1e-8)


def _warn_if_unnormalized(X: np.ndarray, metric: Metric, context: str) -> None:
"""Warn if metric='ip' but vectors are not unit-length. Samples up to 64 rows."""
if metric != "ip":
return
sample = X[:min(64, len(X))]
norms = np.linalg.norm(sample, axis=1)
if not np.allclose(norms, 1.0, atol=1e-3):
print(
f"[searching_tool WARNING] {context}: metric='ip' but vectors are not "
f"L2-normalized (mean norm={norms.mean():.4f}). "
f"Call normalize(X) first for correct cosine similarity."
)


# ============ Building Index ============
def build_index_flat(X: np.ndarray, metric: Metric = "ip") -> IndexBundle:
"""
Small scale (<=100K): Flat index.
IndexFlatL2: Euclidean distance.
IndexFlatIP: Inner Product (for cosine similarity, vectors should be normalized first).
"""
_warn_if_unnormalized(X, metric, "build_index_flat")
N, D = X.shape
if metric == "ip":
index = faiss.IndexFlatIP(D)
Expand All @@ -54,13 +93,15 @@ def build_index_flat(X: np.ndarray, metric: Metric = "ip") -> IndexBundle:
return IndexBundle(index=index, metric=metric, kind="flat", dim=D, meta={"ntotal": index.ntotal})


def build_index_ivfpq(X: np.ndarray, metric: Metric = "ip", nlist: Optional[int] = None,m: int = 16, nbits: int = 8,
def build_index_ivfpq(X: np.ndarray, metric: Metric = "ip", nlist: Optional[int] = None, m: int = 16, nbits: int = 8,
train_samples: int = 200_000, nprobe: int = 16) -> IndexBundle:
"""
Moderate scale (100K~10M)IVFPQ
Moderate scale (100K~10M): IVFPQ
nlist: number of Voronoi cells (clusters); default: 4*sqrt(N), at least 64
m: number of sub-vectors (must divide D)
nprobe: cells visited at search time — can be overridden per-query in search()
"""
_warn_if_unnormalized(X, metric, "build_index_ivfpq")
N, D = X.shape
if nlist is None:
nlist = max(64, int(4 * math.sqrt(N)))
Expand All @@ -86,8 +127,9 @@ def build_index_hnsw(X: np.ndarray, metric: Metric = "ip", M: int = 32, efC: int
Large scale (>10M): HNSW
M: number of neighbors per node (higher=M denser graph=better accuracy/slower)
efC: construction parameter (higher=better accuracy/slower indexing)
efS: search parameter (higher=better accuracy/slower searching)
efS: search parameter (higher=better accuracy/slower searching) — can be overridden per-query in search()
"""
_warn_if_unnormalized(X, metric, "build_index_hnsw")
N, D = X.shape
faiss_metric = faiss.METRIC_INNER_PRODUCT if metric == "ip" else faiss.METRIC_L2
index = faiss.IndexHNSWFlat(D, M, faiss_metric)
Expand All @@ -105,14 +147,58 @@ def search(
topk: int = 5,
exclude_self: bool = True,
exclude_indices: Optional[np.ndarray] = None,
nprobe: Optional[int] = None,
efsearch: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray]:
""" Search the index with query vectors Q and return topk results """
"""
Search the index with query vectors Q and return topk results.

Args:
bundle: IndexBundle to search against.
Q: Query vectors, shape (D,) or (n_q, D).
1D vectors are safely promoted to (1, D) internally.
topk: Number of nearest neighbors to return per query.
exclude_self: Exclude the query's own row index from results.
exclude_indices: Per-query indices to exclude, shape (n_q,) or scalar.
nprobe: IVFPQ only — overrides nprobe for this call only.
Higher = better recall, slower. Restored after search.
efsearch: HNSW only — overrides efSearch for this call only.
Higher = better recall, slower. Restored after search.
"""
# Change 4: 1D shape guard — promote to (1, D) so all downstream logic is safe
was_1d = Q.ndim == 1
if was_1d:
Q = Q[None, :]
if Q.dtype != np.float32:
Q = Q.astype("float32")

n_q = Q.shape[0]
need = topk + 1 if (exclude_self or exclude_indices is not None) else topk
D_raw, I_raw = bundle.index.search(Q, need)

# Changes 2 & 3: runtime nprobe / efSearch overrides
# Save originals and restore after search so the bundle config is unchanged.
_orig_nprobe = None
_orig_efsearch = None

if nprobe is not None:
if bundle.kind != "ivfpq":
raise ValueError(f"nprobe override is only valid for ivfpq bundles, got '{bundle.kind}'")
_orig_nprobe = bundle.index.nprobe
bundle.index.nprobe = nprobe

if efsearch is not None:
if bundle.kind != "hnsw":
raise ValueError(f"efsearch override is only valid for hnsw bundles, got '{bundle.kind}'")
_orig_efsearch = bundle.index.hnsw.efSearch
bundle.index.hnsw.efSearch = efsearch

try:
need = topk + 1 if (exclude_self or exclude_indices is not None) else topk
D_raw, I_raw = bundle.index.search(Q, need)
finally:
# always restore — even if search() raises
if _orig_nprobe is not None:
bundle.index.nprobe = _orig_nprobe
if _orig_efsearch is not None:
bundle.index.hnsw.efSearch = _orig_efsearch

if exclude_indices is not None:
exclude_indices = np.atleast_1d(exclude_indices)
Expand All @@ -126,7 +212,7 @@ def search(
mask = I_raw[r] != ex
D_out[r] = D_raw[r][mask][:topk]
I_out[r] = I_raw[r][mask][:topk]
return D_out, I_out
return (D_out[0], I_out[0]) if was_1d else (D_out, I_out)

if exclude_self:
D_out = np.empty((n_q, topk), dtype=D_raw.dtype)
Expand All @@ -135,9 +221,11 @@ def search(
mask = I_raw[r] != r
D_out[r] = D_raw[r][mask][:topk]
I_out[r] = I_raw[r][mask][:topk]
return D_out, I_out
return (D_out[0], I_out[0]) if was_1d else (D_out, I_out)

D_out, I_out = D_raw[:, :topk], I_raw[:, :topk]
return (D_out[0], I_out[0]) if was_1d else (D_out, I_out)

return D_raw[:, :topk], I_raw[:, :topk]

# ============ Re-ranking ============
def rerank(Q: np.ndarray, X_cands: np.ndarray) -> np.ndarray:
Expand All @@ -157,7 +245,7 @@ def rerank(Q: np.ndarray, X_cands: np.ndarray) -> np.ndarray:


# ============ Evaluating ============
def recall_k(flat_bundle: IndexBundle,ann_bundle: IndexBundle, X: np.ndarray, k: int = 10,
def recall_k(flat_bundle: IndexBundle, ann_bundle: IndexBundle, X: np.ndarray, k: int = 10,
nsamp: int = 200, seed: int = 42) -> float:
"""Compute recall@k between a flat index and an ANN index"""
rng = np.random.default_rng(seed)
Expand All @@ -176,4 +264,4 @@ def sweep_nprobe(ivfpq_bundle: IndexBundle, flat_bundle: IndexBundle, X: np.ndar
ivfpq_bundle.index.nprobe = nprobe
rec = recall_k(flat_bundle, ivfpq_bundle, X, k=10, nsamp=min(200, len(X)))
results.append((nprobe, rec))
return results
return results
Loading