diff --git a/src/raglite/_search.py b/src/raglite/_search.py index c65bfbc..1d43a2f 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -33,7 +33,7 @@ def search(query: str, *, config: RAGLiteConfig | None = None) -> tuple[list[Chu def vector_search( - query: str | FloatMatrix, *, config: RAGLiteConfig | None = None + query: str | FloatMatrix, *, oversample: int = 8, config: RAGLiteConfig | None = None ) -> tuple[list[ChunkId], list[float]]: """Search chunks using ANN vector search.""" # Read the config. @@ -66,7 +66,7 @@ def vector_search( results = session.exec( select(ChunkEmbedding.chunk_id, distance) .order_by(distance) - .limit(config.num_chunks) + .limit(config.num_chunks * oversample) ) chunk_ids_, distance = zip(*results, strict=True) chunk_ids, similarity = np.asarray(chunk_ids_), 1.0 - np.asarray(distance) @@ -79,7 +79,7 @@ def vector_search( from pynndescent import NNDescent multi_vector_indices, distance = cast(NNDescent, index).query( - query_embedding[np.newaxis, :], k=config.num_chunks + query_embedding[np.newaxis, :], k=config.num_chunks * oversample ) similarity = 1 - distance[0, :] # Transform the multi-vector indices into chunk indices, and then to chunk ids. @@ -175,7 +175,7 @@ def reciprocal_rank_fusion( def hybrid_search( - query: str, *, oversample: int = 4, config: RAGLiteConfig | None = None + query: str, *, oversample: int = 1, config: RAGLiteConfig | None = None ) -> tuple[list[ChunkId], list[float]]: """Search chunks by combining ANN vector search with BM25 keyword search.""" config = config or RAGLiteConfig()