diff --git a/.gitignore b/.gitignore index fd9b584f5..ff6cab99c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ htmlcov/ docs/build/ .vscode/ .python-version +.DS_Store diff --git a/docs/source/api.rst b/docs/source/api.rst index 4e010932c..740a5373d 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -40,3 +40,81 @@ HybridCypherRetriever .. autoclass:: neo4j_genai.retrievers.hybrid.HybridCypherRetriever :members: + + +****** +Errors +****** + + +* :class:`neo4j_genai.exceptions.Neo4jGenAiError` + + * :class:`neo4j_genai.exceptions.RetrieverInitializationError` + + * :class:`neo4j_genai.exceptions.SearchValidationError` + + * :class:`neo4j_genai.exceptions.FilterValidationError` + + * :class:`neo4j_genai.exceptions.EmbeddingRequiredError` + + * :class:`neo4j_genai.exceptions.InvalidRetrieverResultError` + + * :class:`neo4j_genai.exceptions.Neo4jIndexError` + + * :class:`neo4j_genai.exceptions.Neo4jVersionError` + + +Neo4jGenAiError +=============== + +.. autoclass:: neo4j_genai.exceptions.Neo4jGenAiError + :show-inheritance: + + +RetrieverInitializationError +============================ + +.. autoclass:: neo4j_genai.exceptions.RetrieverInitializationError + :show-inheritance: + + +SearchValidationError +===================== + +.. autoclass:: neo4j_genai.exceptions.SearchValidationError + :show-inheritance: + + +FilterValidationError +===================== + +.. autoclass:: neo4j_genai.exceptions.FilterValidationError + :show-inheritance: + + +EmbeddingRequiredError +====================== + +.. autoclass:: neo4j_genai.exceptions.EmbeddingRequiredError + :show-inheritance: + + +InvalidRetrieverResultError +=========================== + +.. autoclass:: neo4j_genai.exceptions.InvalidRetrieverResultError + :show-inheritance: + + +Neo4jIndexError +=============== + +.. autoclass:: neo4j_genai.exceptions.Neo4jIndexError + :show-inheritance: + + +Neo4jVersionError +================= + +.. autoclass:: neo4j_genai.exceptions.Neo4jVersionError + :show-inheritance: diff --git a/src/neo4j_genai/exceptions.py b/src/neo4j_genai/exceptions.py new file mode 100644 index 000000000..96373b091 --- /dev/null +++ b/src/neo4j_genai/exceptions.py @@ -0,0 +1,67 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Neo4jGenAiError(Exception): + """Global exception used for the neo4j-genai package.""" + + pass + + +class RetrieverInitializationError(Neo4jGenAiError): + """Exception raised when initialization of a retriever fails.""" + + def __init__(self, errors: str): + super().__init__(f"Initialization failed: {errors}") + self.errors = errors + + +class SearchValidationError(Neo4jGenAiError): + """Exception raised for validation errors during search.""" + + def __init__(self, errors): + super().__init__(f"Search validation failed: {errors}") + self.errors = errors + + +class FilterValidationError(Neo4jGenAiError): + """Exception raised when input validation for metadata filtering fails.""" + + pass + + +class EmbeddingRequiredError(Neo4jGenAiError): + """Exception raised when an embedding method is required but not provided.""" + + pass + + +class InvalidRetrieverResultError(Neo4jGenAiError): + """Exception raised when the Retriever fails to return a result.""" + + pass + + +class Neo4jIndexError(Neo4jGenAiError): + """Exception raised when handling Neo4j index fails.""" + + pass + + +class Neo4jVersionError(Neo4jGenAiError): + """Exception raised when Neo4j version does not meet minimum requirements.""" + + def __init__(self): + super().__init__("This package only supports Neo4j version 5.18.1 or greater") diff --git a/src/neo4j_genai/filters.py b/src/neo4j_genai/filters.py index ec80ac8c6..3fa0a33cc 100644 --- a/src/neo4j_genai/filters.py +++ b/src/neo4j_genai/filters.py @@ -16,6 +16,7 @@ from typing import Any, Type from collections import Counter +from neo4j_genai.exceptions import FilterValidationError DEFAULT_NODE_ALIAS = "node" @@ -244,12 +245,12 @@ def _handle_field_filter( """ # first, perform some sanity checks if not isinstance(field, str): - raise ValueError( + raise FilterValidationError( f"Field should be a string but got: {type(field)} with value: {field}" ) if field.startswith(OPERATOR_PREFIX): - raise ValueError( + raise FilterValidationError( f"Invalid filter condition. Expected a field but got an operator: " f"{field}" ) @@ -257,7 +258,7 @@ def _handle_field_filter( if isinstance(value, dict): # This is a filter specification e.g. {"$gte": 0} if len(value) != 1: - raise ValueError( + raise FilterValidationError( "Invalid filter condition. Expected a value which " "is a dictionary with a single key that corresponds to an operator " f"but got a dictionary with {len(value)} keys. The first few " @@ -267,7 +268,7 @@ def _handle_field_filter( operator = operator.lower() # Verify that that operator is an operator if operator not in SUPPORTED_OPERATORS: - raise ValueError( + raise FilterValidationError( f"Invalid operator: {operator}. " f"Expected one of {SUPPORTED_OPERATORS}" ) @@ -280,7 +281,7 @@ def _handle_field_filter( # two tests (lower_bound <= value <= higher_bound) if operator == OPERATOR_BETWEEN: if len(filter_value) != 2: - raise ValueError( + raise FilterValidationError( f"Expected lower and upper bounds in a list, got {filter_value}" ) low, high = filter_value @@ -312,7 +313,7 @@ def _construct_metadata_filter( """ if not isinstance(filter, dict): - raise ValueError(f"Filter must be a dictionary, got {type(filter)}") + raise FilterValidationError(f"Filter must be a dictionary, got {type(filter)}") # if we have more than one entry, this is an implicit "AND" filter if len(filter) > 1: return _construct_metadata_filter( @@ -329,13 +330,15 @@ def _construct_metadata_filter( # Here we handle the $and and $or operators if not isinstance(value, list): - raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") + raise FilterValidationError( + f"Expected a list, but got {type(value)} for value: {value}" + ) if key.lower() == OPERATOR_AND: cypher_operator = " AND " elif key.lower() == OPERATOR_OR: cypher_operator = " OR " else: - raise ValueError(f"Unsupported operator: {key}") + raise FilterValidationError(f"Unsupported operator: {key}") query = cypher_operator.join( [ f"({ _construct_metadata_filter(el, param_store, node_alias)})" diff --git a/src/neo4j_genai/indexes.py b/src/neo4j_genai/indexes.py index 32ace5784..528983669 100644 --- a/src/neo4j_genai/indexes.py +++ b/src/neo4j_genai/indexes.py @@ -15,6 +15,8 @@ import neo4j from pydantic import ValidationError + +from .exceptions import Neo4jIndexError from .types import VectorIndexModel, FulltextIndexModel import logging @@ -64,7 +66,7 @@ def create_vector_index( } ) except ValidationError as e: - raise ValueError(f"Error for inputs to create_vector_index {str(e)}") + raise Neo4jIndexError(f"Error for inputs to create_vector_index {str(e)}") try: query = ( @@ -77,8 +79,7 @@ def create_vector_index( {"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn}, ) except neo4j.exceptions.ClientError as e: - logger.error(f"Neo4j vector index creation failed {e}") - raise + raise Neo4jIndexError(f"Neo4j vector index creation failed: {e}") def create_fulltext_index( @@ -113,7 +114,7 @@ def create_fulltext_index( } ) except ValidationError as e: - raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}") + raise Neo4jIndexError(f"Error for inputs to create_fulltext_index: {str(e)}") try: query = ( @@ -124,8 +125,7 @@ def create_fulltext_index( logger.info(f"Creating fulltext index named '{name}'") driver.execute_query(query, {"name": name}) except neo4j.exceptions.ClientError as e: - logger.error(f"Neo4j fulltext index creation failed {e}") - raise + raise Neo4jIndexError(f"Neo4j fulltext index creation failed {e}") def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None: @@ -149,5 +149,4 @@ def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None: logger.info(f"Dropping index named '{name}'") driver.execute_query(query, parameters) except neo4j.exceptions.ClientError as e: - logger.error(f"Dropping Neo4j index failed {e}") - raise + raise Neo4jIndexError(f"Dropping Neo4j index failed: {e}") diff --git a/src/neo4j_genai/retrievers/base.py b/src/neo4j_genai/retrievers/base.py index 217fbe980..f1dd713fb 100644 --- a/src/neo4j_genai/retrievers/base.py +++ b/src/neo4j_genai/retrievers/base.py @@ -16,6 +16,8 @@ from typing import Optional, Any import neo4j +from neo4j_genai.exceptions import Neo4jVersionError + class Retriever(ABC): """ @@ -32,7 +34,7 @@ def _verify_version(self) -> None: Queries the Neo4j database to retrieve its version and compares it against a target version (5.18.1) that is known to support vector - indexing. Raises a ValueError if the connected Neo4j version is + indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is not supported. """ records, _, _ = self.driver.execute_query("CALL dbms.components()") @@ -49,9 +51,7 @@ def _verify_version(self) -> None: target_version = (5, 18, 1) if version_tuple < target_version: - raise ValueError( - "This package only supports Neo4j version 5.18.1 or greater" - ) + raise Neo4jVersionError() @abstractmethod def search(self, *args, **kwargs) -> Any: diff --git a/src/neo4j_genai/retrievers/external/weaviate/types.py b/src/neo4j_genai/retrievers/external/weaviate/types.py new file mode 100644 index 000000000..1400c249f --- /dev/null +++ b/src/neo4j_genai/retrievers/external/weaviate/types.py @@ -0,0 +1,77 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from pydantic import ( + field_validator, + BaseModel, + PositiveInt, + model_validator, + ConfigDict, +) +from weaviate.client import WeaviateClient +from weaviate.collections.classes.filters import _Filters + +from neo4j_genai.retrievers.utils import validate_search_query_input +from neo4j_genai.types import Neo4jDriverModel, EmbedderModel + + +class WeaviateModel(BaseModel): + client: WeaviateClient + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("client") + def check_client(cls, value): + if not isinstance(value, WeaviateClient): + raise TypeError( + "Provided client needs to be of type weaviate.client.WeaviateClient" + ) + return value + + +class WeaviateNeo4jRetrieverModel(BaseModel): + driver_model: Neo4jDriverModel + client_model: WeaviateModel + collection: str + id_property_external: str + id_property_neo4j: str + embedder_model: Optional[EmbedderModel] + return_properties: Optional[list[str]] = None + retrieval_query: Optional[str] = None + + +class WeaviateNeo4jSearchModel(BaseModel): + top_k: PositiveInt = 5 + query_vector: Optional[list[float]] = None + query_text: Optional[str] = None + weaviate_filters: Optional[_Filters] = None + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("weaviate_filters") + def check_weaviate_filters(cls, value): + if value and not isinstance(value, _Filters): + raise TypeError( + "Provided filters need to be of type weaviate.collections.classes.filters._Filters" + ) + return value + + @model_validator(mode="before") + def check_query(cls, values): + """ + Validates that one of either query_vector or query_text is provided exclusively. + """ + query_vector, query_text = values.get("query_vector"), values.get("query_text") + validate_search_query_input(query_text, query_vector) + return values diff --git a/src/neo4j_genai/retrievers/external/weaviate/weaviate.py b/src/neo4j_genai/retrievers/external/weaviate/weaviate.py index 3db774827..3682acf12 100644 --- a/src/neo4j_genai/retrievers/external/weaviate/weaviate.py +++ b/src/neo4j_genai/retrievers/external/weaviate/weaviate.py @@ -14,15 +14,24 @@ # limitations under the License. from typing import Optional + +from pydantic import ValidationError + +from neo4j_genai.exceptions import RetrieverInitializationError, SearchValidationError from neo4j_genai.retrievers.base import ExternalRetriever from neo4j_genai.embedder import Embedder -from neo4j_genai.retrievers.utils import validate_search_query_input +from neo4j_genai.retrievers.external.weaviate.types import ( + WeaviateModel, + WeaviateNeo4jRetrieverModel, + WeaviateNeo4jSearchModel, +) import weaviate.classes as wvc from weaviate.client import WeaviateClient from weaviate.collections.classes.filters import _Filters import neo4j import logging from neo4j_genai.neo4j_queries import get_query_tail +from neo4j_genai.types import Neo4jDriverModel, EmbedderModel logger = logging.getLogger(__name__) @@ -39,13 +48,35 @@ def __init__( return_properties: Optional[list[str]] = None, retrieval_query: Optional[str] = None, ): + try: + driver_model = Neo4jDriverModel(driver=driver) + weaviate_model = WeaviateModel(client=client) + embedder_model = EmbedderModel(embedder=embedder) if embedder else None + validated_data = WeaviateNeo4jRetrieverModel( + driver_model=driver_model, + client_model=weaviate_model, + collection=collection, + id_property_external=id_property_external, + id_property_neo4j=id_property_neo4j, + embedder_model=embedder_model, + return_properties=return_properties, + retrieval_query=retrieval_query, + ) + except ValidationError as e: + raise RetrieverInitializationError(e.errors()) + super().__init__(id_property_external, id_property_neo4j) - self.driver = driver - self.client = client - self.search_collection = client.collections.get(collection) - self.embedder = embedder - self.return_properties = return_properties - self.retrieval_query = retrieval_query + self.driver = validated_data.driver_model.driver + self.client = validated_data.client_model.client + collection = validated_data.collection + self.search_collection = self.client.collections.get(collection) + self.embedder = ( + validated_data.embedder_model.embedder + if validated_data.embedder_model + else None + ) + self.return_properties = validated_data.return_properties + self.retrieval_query = validated_data.retrieval_query def search( self, @@ -71,12 +102,23 @@ def search( top_k (int, optional): The number of neighbors to return. Defaults to 5. weaviate_filters (Optional[_Filters], optional): The filters to apply to the search query in Weaviate. Defaults to None. Raises: - ValueError: If validation of the input arguments fail. + SearchValidationError: If validation of the input arguments fail. Returns: list[neo4j.Record]: The results of the search query """ - - validate_search_query_input(query_text=query_text, query_vector=query_vector) + try: + validated_data = WeaviateNeo4jSearchModel( + top_k=top_k, + query_vector=query_vector, + query_text=query_text, + weaviate_filters=weaviate_filters, + ) + query_text = validated_data.query_text + query_vector = validated_data.query_vector + top_k = validated_data.top_k + weaviate_filters = validated_data.weaviate_filters + except ValidationError as e: + raise SearchValidationError(e.errors()) # If we want to use a local embedder, we still want to call the near_vector method # so we want to create the vector as early as possible here diff --git a/src/neo4j_genai/retrievers/hybrid.py b/src/neo4j_genai/retrievers/hybrid.py index 2f7108afd..b735b1f1c 100644 --- a/src/neo4j_genai/retrievers/hybrid.py +++ b/src/neo4j_genai/retrievers/hybrid.py @@ -18,6 +18,11 @@ from pydantic import ValidationError from neo4j_genai.embedder import Embedder +from neo4j_genai.exceptions import ( + RetrieverInitializationError, + SearchValidationError, + EmbeddingRequiredError, +) from neo4j_genai.retrievers.base import Retriever from neo4j_genai.types import ( HybridSearchModel, @@ -81,7 +86,7 @@ def __init__( return_properties=return_properties, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise RetrieverInitializationError(e.errors()) super().__init__(validated_data.driver_model.driver) self.vector_index_name = validated_data.vector_index_name @@ -116,8 +121,8 @@ def search( top_k (int, optional): The number of neighbors to return. Defaults to 5. Raises: - ValueError: If validation of the input arguments fail. - ValueError: If no embedder is provided. + SearchValidationError: If validation of the input arguments fail. + EmbeddingRequiredError: If no embedder is provided. Returns: list[neo4j.Record]: The results of the search query @@ -131,13 +136,15 @@ def search( query_text=query_text, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise SearchValidationError(e.errors()) parameters = validated_data.model_dump(exclude_none=True) if query_text and not query_vector: if not self.embedder: - raise ValueError("Embedding method required for text query.") + raise EmbeddingRequiredError( + "Embedding method required for text query." + ) query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector @@ -199,7 +206,7 @@ def __init__( embedder_model=embedder_model, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise RetrieverInitializationError(e.errors()) super().__init__(validated_data.driver_model.driver) self.vector_index_name = validated_data.vector_index_name @@ -236,8 +243,8 @@ def search( query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None. Raises: - ValueError: If validation of the input arguments fail. - ValueError: If no embedder is provided. + SearchValidationError: If validation of the input arguments fail. + EmbeddingRequiredError: If no embedder is provided. Returns: list[neo4j.Record]: The results of the search query @@ -252,13 +259,15 @@ def search( query_params=query_params, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise SearchValidationError(e.errors()) parameters = validated_data.model_dump(exclude_none=True) if query_text and not query_vector: if not self.embedder: - raise ValueError("Embedding method required for text query.") + raise EmbeddingRequiredError( + "Embedding method required for text query." + ) query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector diff --git a/src/neo4j_genai/retrievers/vector.py b/src/neo4j_genai/retrievers/vector.py index 8770f70fe..b3a3ebac6 100644 --- a/src/neo4j_genai/retrievers/vector.py +++ b/src/neo4j_genai/retrievers/vector.py @@ -15,6 +15,13 @@ from typing import Optional, Any import neo4j + +from neo4j_genai.exceptions import ( + RetrieverInitializationError, + SearchValidationError, + EmbeddingRequiredError, + InvalidRetrieverResultError, +) from neo4j_genai.retrievers.base import Retriever from pydantic import ValidationError @@ -76,7 +83,7 @@ def __init__( return_properties=return_properties, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise RetrieverInitializationError(e.errors()) super().__init__(driver) self.index_name = validated_data.index_name @@ -111,8 +118,8 @@ def search( filters (Optional[dict[str, Any]]): Filters for metadata pre-filtering. Defaults to None. Raises: - ValueError: If validation of the input arguments fail. - ValueError: If no embedder is provided. + SearchValidationError: If validation of the input arguments fail. + EmbeddingRequiredError: If no embedder is provided. Returns: list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. @@ -125,14 +132,15 @@ def search( query_text=query_text, ) except ValidationError as e: - error_details = e.errors() - raise ValueError(f"Validation failed: {error_details}") + raise SearchValidationError(e.errors()) parameters = validated_data.model_dump(exclude_none=True) if query_text: if not self.embedder: - raise ValueError("Embedding method required for text query.") + raise EmbeddingRequiredError( + "Embedding method required for text query." + ) query_vector = self.embedder.embed_query(query_text) parameters["query_vector"] = query_vector del parameters["query_text"] @@ -158,9 +166,8 @@ def search( for record in records ] except ValidationError as e: - error_details = e.errors() - raise ValueError( - f"Validation failed while constructing output: {error_details}" + raise InvalidRetrieverResultError( + f"Failed constructing VectorSearchRecord output: {e.errors()}" ) @@ -209,7 +216,7 @@ def __init__( embedder_model=embedder_model, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise RetrieverInitializationError(e.errors()) super().__init__(driver) self.index_name = validated_data.index_name @@ -246,8 +253,8 @@ def search( filters (Optional[dict[str, Any]]): Filters for metadata pre-filtering. Defaults to None. Raises: - ValueError: If validation of the input arguments fail. - ValueError: If no embedder is provided. + SearchValidationError: If validation of the input arguments fail. + EmbeddingRequiredError: If no embedder is provided. Returns: list[VectorSearchRecord]: The `top_k` neighbors found in vector search with their nodes and scores. @@ -261,13 +268,15 @@ def search( query_params=query_params, ) except ValidationError as e: - raise ValueError(f"Validation failed: {e.errors()}") + raise SearchValidationError(e.errors()) parameters = validated_data.model_dump(exclude_none=True) if query_text: if not self.embedder: - raise ValueError("Embedding method required for text query.") + raise EmbeddingRequiredError( + "Embedding method required for text query." + ) parameters["query_vector"] = self.embedder.embed_query(query_text) del parameters["query_text"] diff --git a/tests/unit/retrievers/external/test_weaviate.py b/tests/unit/retrievers/external/test_weaviate.py index 6d3af46b1..07ccb716b 100644 --- a/tests/unit/retrievers/external/test_weaviate.py +++ b/tests/unit/retrievers/external/test_weaviate.py @@ -15,6 +15,9 @@ from unittest.mock import MagicMock from types import SimpleNamespace + +from weaviate import WeaviateClient + from neo4j_genai.retrievers.external.weaviate import ( WeaviateNeo4jRetriever, get_match_query, @@ -22,7 +25,7 @@ # Weaviate class with fake methods -class WClient: +class WClient(WeaviateClient): def __init__(self, node_id_value=None, node_match_score=None): self.collections = MagicMock() self.collections.get = MagicMock() diff --git a/tests/unit/retrievers/test_base.py b/tests/unit/retrievers/test_base.py index 0386b81b6..3667d922b 100644 --- a/tests/unit/retrievers/test_base.py +++ b/tests/unit/retrievers/test_base.py @@ -14,6 +14,7 @@ # limitations under the License. import pytest +from neo4j_genai.exceptions import Neo4jVersionError from neo4j_genai.retrievers.base import Retriever @@ -21,9 +22,9 @@ "db_version,expected_exception", [ (["5.18-aura"], None), - (["5.3-aura"], ValueError), + (["5.3-aura"], Neo4jVersionError), (["5.19.0"], None), - (["4.3.5"], ValueError), + (["4.3.5"], Neo4jVersionError), ], ) def test_retriever_version_support(driver, db_version, expected_exception): diff --git a/tests/unit/retrievers/test_hybrid.py b/tests/unit/retrievers/test_hybrid.py index 6e0688c3c..aaf69e994 100644 --- a/tests/unit/retrievers/test_hybrid.py +++ b/tests/unit/retrievers/test_hybrid.py @@ -18,6 +18,7 @@ import pytest from neo4j_genai import HybridRetriever, HybridCypherRetriever +from neo4j_genai.exceptions import RetrieverInitializationError, EmbeddingRequiredError from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType @@ -45,7 +46,7 @@ def test_vector_cypher_retriever_initialization(driver): @patch("neo4j_genai.HybridRetriever._verify_version") def test_hybrid_retriever_invalid_fulltext_index_name(_verify_version_mock, driver): - with pytest.raises(ValueError) as exc_info: + with pytest.raises(RetrieverInitializationError) as exc_info: HybridRetriever( driver=driver, vector_index_name="my-index", fulltext_index_name=42 ) @@ -56,7 +57,7 @@ def test_hybrid_retriever_invalid_fulltext_index_name(_verify_version_mock, driv @patch("neo4j_genai.HybridCypherRetriever._verify_version") def test_hybrid_cypher_retriever_invalid_retrieval_query(_verify_version_mock, driver): - with pytest.raises(ValueError) as exc_info: + with pytest.raises(RetrieverInitializationError) as exc_info: HybridCypherRetriever( driver=driver, vector_index_name="my-index", @@ -143,7 +144,9 @@ def test_error_when_hybrid_search_only_text_no_embedder(hybrid_retriever): query_text = "may thy knife chip and shatter" top_k = 5 - with pytest.raises(ValueError, match="Embedding method required for text query."): + with pytest.raises( + EmbeddingRequiredError, match="Embedding method required for text query." + ): hybrid_retriever.search( query_text=query_text, top_k=top_k, @@ -156,7 +159,9 @@ def test_hybrid_search_retriever_search_missing_embedder_for_text( query_text = "may thy knife chip and shatter" top_k = 5 - with pytest.raises(ValueError, match="Embedding method required for text query"): + with pytest.raises( + EmbeddingRequiredError, match="Embedding method required for text query" + ): hybrid_retriever.search( query_text=query_text, top_k=top_k, diff --git a/tests/unit/retrievers/test_vector.py b/tests/unit/retrievers/test_vector.py index 3755d8a60..ca3dbe634 100644 --- a/tests/unit/retrievers/test_vector.py +++ b/tests/unit/retrievers/test_vector.py @@ -18,6 +18,12 @@ from neo4j.exceptions import CypherSyntaxError from neo4j_genai import VectorRetriever, VectorCypherRetriever +from neo4j_genai.exceptions import ( + RetrieverInitializationError, + EmbeddingRequiredError, + SearchValidationError, + InvalidRetrieverResultError, +) from neo4j_genai.neo4j_queries import get_search_query from neo4j_genai.types import SearchType, VectorSearchRecord @@ -30,7 +36,7 @@ def test_vector_retriever_initialization(driver): @patch("neo4j_genai.VectorRetriever._verify_version") def test_vector_retriever_invalid_index_name(_verify_version_mock, driver): - with pytest.raises(ValueError) as exc_info: + with pytest.raises(RetrieverInitializationError) as exc_info: VectorRetriever(driver=driver, index_name=42) assert "index_name" in str(exc_info.value) @@ -39,7 +45,7 @@ def test_vector_retriever_invalid_index_name(_verify_version_mock, driver): @patch("neo4j_genai.VectorCypherRetriever._verify_version") def test_vector_cypher_retriever_invalid_retrieval_query(_verify_version_mock, driver): - with pytest.raises(ValueError) as exc_info: + with pytest.raises(RetrieverInitializationError) as exc_info: VectorCypherRetriever(driver=driver, index_name="my-index", retrieval_query=42) assert "retrieval_query" in str(exc_info.value) @@ -156,7 +162,9 @@ def test_vector_retriever_search_missing_embedder_for_text(vector_retriever): query_text = "may thy knife chip and shatter" top_k = 5 - with pytest.raises(ValueError, match="Embedding method required for text query"): + with pytest.raises( + EmbeddingRequiredError, match="Embedding method required for text query" + ): vector_retriever.search(query_text=query_text, top_k=top_k) @@ -166,7 +174,8 @@ def test_vector_retriever_search_both_text_and_vector(vector_retriever): top_k = 5 with pytest.raises( - ValueError, match="You must provide exactly one of query_vector or query_text." + SearchValidationError, + match="You must provide exactly one of query_vector or query_text.", ): vector_retriever.search( query_text=query_text, @@ -181,7 +190,9 @@ def test_vector_cypher_retriever_search_missing_embedder_for_text( query_text = "may thy knife chip and shatter" top_k = 5 - with pytest.raises(ValueError, match="Embedding method required for text query"): + with pytest.raises( + EmbeddingRequiredError, match="Embedding method required for text query" + ): vector_cypher_retriever.search(query_text=query_text, top_k=top_k) @@ -191,7 +202,8 @@ def test_vector_cypher_retriever_search_both_text_and_vector(vector_cypher_retri top_k = 5 with pytest.raises( - ValueError, match="You must provide exactly one of query_vector or query_text." + SearchValidationError, + match="You must provide exactly one of query_vector or query_text.", ): vector_cypher_retriever.search( query_text=query_text, @@ -217,7 +229,7 @@ def test_similarity_search_vector_bad_results( ] search_query, _ = get_search_query(SearchType.VECTOR) - with pytest.raises(ValueError): + with pytest.raises(InvalidRetrieverResultError): retriever.search(query_vector=query_vector, top_k=top_k) retriever.driver.execute_query.assert_called_once_with( diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 66d3d6c8b..3086460c9 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -16,6 +16,7 @@ import pytest +from neo4j_genai.exceptions import FilterValidationError from neo4j_genai.filters import ( get_metadata_filter, _single_condition_cypher, @@ -194,7 +195,7 @@ def test_single_condition_cypher_escaped_field_name(param_store_empty): def test_handle_field_filter_not_a_string(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _handle_field_filter(1, "value", param_store=param_store_empty) assert "Field should be a string but got: <class 'int'> with value: 1" in str( excinfo @@ -202,7 +203,7 @@ def test_handle_field_filter_not_a_string(param_store_empty): def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _handle_field_filter("$field_name", "value", param_store=param_store_empty) assert ( "Invalid filter condition. Expected a field but got an operator: $field_name" @@ -211,7 +212,7 @@ def test_handle_field_filter_field_start_with_dollar_sign(param_store_empty): def test_handle_field_filter_bad_value(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _handle_field_filter( "field", value={"operator1": "value1", "operator2": "value2"}, @@ -221,7 +222,7 @@ def test_handle_field_filter_bad_value(param_store_empty): def test_handle_field_filter_bad_operator_name(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _handle_field_filter( "field", value={"$invalid": "value"}, param_store=param_store_empty ) @@ -237,7 +238,7 @@ def test_handle_field_filter_operator_between(param_store_empty): def test_handle_field_filter_operator_between_not_enough_parameters(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _handle_field_filter( "field", value={ @@ -355,7 +356,7 @@ def test_handle_field_filter_ilike(_single_condition_cypher_mocked, param_store_ def test_construct_metadata_filter_filter_is_not_a_dict( _handle_field_filter_mock, param_store_empty ): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _construct_metadata_filter([], param_store_empty, node_alias="n") assert "Filter must be a dictionary, got <class 'list'>" in str(excinfo) @@ -429,7 +430,7 @@ def test_construct_metadata_filter_or( def test_construct_metadata_filter_invalid_operator(param_store_empty): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(FilterValidationError) as excinfo: _construct_metadata_filter( {"$invalid": [{}, {}]}, param_store_empty, node_alias="n" ) @@ -582,17 +583,17 @@ def test_get_metadata_filter_and_or_combined(): # now testing bad filters def test_get_metadata_filter_field_name_with_dollar_sign(): filters = {"$field": "value"} - with pytest.raises(ValueError): + with pytest.raises(FilterValidationError): get_metadata_filter(filters) def test_get_metadata_filter_and_no_list(): filters = {"$and": {}} - with pytest.raises(ValueError): + with pytest.raises(FilterValidationError): get_metadata_filter(filters) def test_get_metadata_filter_unsupported_operator(): filters = {"field": {"$unsupported": "value"}} - with pytest.raises(ValueError): + with pytest.raises(FilterValidationError): get_metadata_filter(filters) diff --git a/tests/unit/test_indexes.py b/tests/unit/test_indexes.py index c5509da9a..4adaa4190 100644 --- a/tests/unit/test_indexes.py +++ b/tests/unit/test_indexes.py @@ -15,6 +15,7 @@ import neo4j.exceptions import pytest +from neo4j_genai.exceptions import Neo4jIndexError from neo4j_genai.indexes import ( create_vector_index, drop_index_if_exists, @@ -57,25 +58,25 @@ def test_create_vector_index_ensure_escaping(driver): def test_create_vector_index_negative_dimension(driver): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(Neo4jIndexError) as excinfo: create_vector_index(driver, "my-index", "People", "name", -5, "cosine") assert "Error for inputs to create_vector_index" in str(excinfo) def test_create_vector_index_validation_error_dimensions(driver): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(Neo4jIndexError) as excinfo: create_vector_index(driver, "my-index", "People", "name", "no-dim", "cosine") assert "Error for inputs to create_vector_index" in str(excinfo) def test_create_vector_index_raises_error_with_neo4j_client_error(driver): driver.execute_query.side_effect = neo4j.exceptions.ClientError - with pytest.raises(neo4j.exceptions.ClientError): + with pytest.raises(Neo4jIndexError): create_vector_index(driver, "my-index", "People", "name", 2048, "cosine") def test_create_vector_index_validation_error_similarity_fn(driver): - with pytest.raises(ValueError) as excinfo: + with pytest.raises(Neo4jIndexError) as excinfo: create_vector_index(driver, "my-index", "People", "name", 1536, "algebra") assert "Error for inputs to create_vector_index" in str(excinfo) @@ -124,7 +125,7 @@ def test_create_fulltext_index_raises_error_with_neo4j_client_error(driver): text_node_properties = ["property-1", "property-2"] driver.execute_query.side_effect = neo4j.exceptions.ClientError - with pytest.raises(neo4j.exceptions.ClientError): + with pytest.raises(Neo4jIndexError): create_fulltext_index(driver, "my-index", label, text_node_properties) @@ -132,7 +133,7 @@ def test_create_fulltext_index_empty_node_properties(driver): label = "node-label" node_properties = [] - with pytest.raises(ValueError) as excinfo: + with pytest.raises(Neo4jIndexError) as excinfo: create_fulltext_index(driver, "my-index", label, node_properties) assert "Error for inputs to create_fulltext_index" in str(excinfo)