Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom exceptions for better error-handling #46

Merged
merged 7 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ htmlcov/
docs/build/
.vscode/
.python-version
.DS_Store
78 changes: 78 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,81 @@ HybridCypherRetriever

.. autoclass:: neo4j_genai.retrievers.hybrid.HybridCypherRetriever
:members:


******
Errors
******


* :class:`neo4j_genai.exceptions.Neo4jGenAiError`

* :class:`neo4j_genai.exceptions.RetrieverInitializationError`

* :class:`neo4j_genai.exceptions.SearchValidationError`

* :class:`neo4j_genai.exceptions.FilterValidationError`

* :class:`neo4j_genai.exceptions.EmbeddingRequiredError`

* :class:`neo4j_genai.exceptions.InvalidRetrieverResultError`

* :class:`neo4j_genai.exceptions.Neo4jIndexError`

* :class:`neo4j_genai.exceptions.Neo4jVersionError`


Neo4jGenAiError
===============

.. autoclass:: neo4j_genai.exceptions.Neo4jGenAiError
:show-inheritance:


RetrieverInitializationError
============================

.. autoclass:: neo4j_genai.exceptions.RetrieverInitializationError
:show-inheritance:


SearchValidationError
=====================

.. autoclass:: neo4j_genai.exceptions.SearchValidationError
:show-inheritance:


FilterValidationError
=====================

.. autoclass:: neo4j_genai.exceptions.FilterValidationError
:show-inheritance:


EmbeddingRequiredError
======================

.. autoclass:: neo4j_genai.exceptions.EmbeddingRequiredError
:show-inheritance:


InvalidRetrieverResultError
===========================

.. autoclass:: neo4j_genai.exceptions.InvalidRetrieverResultError
:show-inheritance:


Neo4jIndexError
===============

.. autoclass:: neo4j_genai.exceptions.Neo4jIndexError
:show-inheritance:


Neo4jVersionError
=================

.. autoclass:: neo4j_genai.exceptions.Neo4jVersionError
:show-inheritance:
67 changes: 67 additions & 0 deletions src/neo4j_genai/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class Neo4jGenAiError(Exception):
"""Global exception used for the neo4j-genai package."""

pass


class RetrieverInitializationError(Neo4jGenAiError):
"""Exception raised when initialization of a retriever fails."""

def __init__(self, errors: str):
super().__init__(f"Initialization failed: {errors}")
self.errors = errors


class SearchValidationError(Neo4jGenAiError):
"""Exception raised for validation errors during search."""

def __init__(self, errors):
super().__init__(f"Search validation failed: {errors}")
self.errors = errors


class FilterValidationError(Neo4jGenAiError):
"""Exception raised when input validation for metadata filtering fails."""

pass


class EmbeddingRequiredError(Neo4jGenAiError):
"""Exception raised when an embedding method is required but not provided."""

pass


class InvalidRetrieverResultError(Neo4jGenAiError):
"""Exception raised when the Retriever fails to return a result."""

pass


class Neo4jIndexError(Neo4jGenAiError):
"""Exception raised when handling Neo4j index fails."""

pass


class Neo4jVersionError(Neo4jGenAiError):
"""Exception raised when Neo4j version does not meet minimum requirements."""

def __init__(self):
super().__init__("This package only supports Neo4j version 5.18.1 or greater")
19 changes: 11 additions & 8 deletions src/neo4j_genai/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Type
from collections import Counter

from neo4j_genai.exceptions import FilterValidationError

DEFAULT_NODE_ALIAS = "node"

Expand Down Expand Up @@ -244,20 +245,20 @@ def _handle_field_filter(
"""
# first, perform some sanity checks
if not isinstance(field, str):
raise ValueError(
raise FilterValidationError(
f"Field should be a string but got: {type(field)} with value: {field}"
)

if field.startswith(OPERATOR_PREFIX):
raise ValueError(
raise FilterValidationError(
f"Invalid filter condition. Expected a field but got an operator: "
f"{field}"
)

if isinstance(value, dict):
# This is a filter specification e.g. {"$gte": 0}
if len(value) != 1:
raise ValueError(
raise FilterValidationError(
"Invalid filter condition. Expected a value which "
"is a dictionary with a single key that corresponds to an operator "
f"but got a dictionary with {len(value)} keys. The first few "
Expand All @@ -267,7 +268,7 @@ def _handle_field_filter(
operator = operator.lower()
# Verify that that operator is an operator
if operator not in SUPPORTED_OPERATORS:
raise ValueError(
raise FilterValidationError(
f"Invalid operator: {operator}. "
f"Expected one of {SUPPORTED_OPERATORS}"
)
Expand All @@ -280,7 +281,7 @@ def _handle_field_filter(
# two tests (lower_bound <= value <= higher_bound)
if operator == OPERATOR_BETWEEN:
if len(filter_value) != 2:
raise ValueError(
raise FilterValidationError(
f"Expected lower and upper bounds in a list, got {filter_value}"
)
low, high = filter_value
Expand Down Expand Up @@ -312,7 +313,7 @@ def _construct_metadata_filter(
"""

if not isinstance(filter, dict):
raise ValueError(f"Filter must be a dictionary, got {type(filter)}")
raise FilterValidationError(f"Filter must be a dictionary, got {type(filter)}")
# if we have more than one entry, this is an implicit "AND" filter
if len(filter) > 1:
return _construct_metadata_filter(
Expand All @@ -329,13 +330,15 @@ def _construct_metadata_filter(

# Here we handle the $and and $or operators
if not isinstance(value, list):
raise ValueError(f"Expected a list, but got {type(value)} for value: {value}")
raise FilterValidationError(
f"Expected a list, but got {type(value)} for value: {value}"
)
if key.lower() == OPERATOR_AND:
cypher_operator = " AND "
elif key.lower() == OPERATOR_OR:
cypher_operator = " OR "
else:
raise ValueError(f"Unsupported operator: {key}")
raise FilterValidationError(f"Unsupported operator: {key}")
query = cypher_operator.join(
[
f"({ _construct_metadata_filter(el, param_store, node_alias)})"
Expand Down
15 changes: 7 additions & 8 deletions src/neo4j_genai/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import neo4j
from pydantic import ValidationError

from .exceptions import Neo4jIndexError
from .types import VectorIndexModel, FulltextIndexModel
import logging

Expand Down Expand Up @@ -64,7 +66,7 @@ def create_vector_index(
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_vector_index {str(e)}")
raise Neo4jIndexError(f"Error for inputs to create_vector_index {str(e)}")

try:
query = (
Expand All @@ -77,8 +79,7 @@ def create_vector_index(
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
)
except neo4j.exceptions.ClientError as e:
logger.error(f"Neo4j vector index creation failed {e}")
raise
raise Neo4jIndexError(f"Neo4j vector index creation failed: {e}")


def create_fulltext_index(
Expand Down Expand Up @@ -113,7 +114,7 @@ def create_fulltext_index(
}
)
except ValidationError as e:
raise ValueError(f"Error for inputs to create_fulltext_index {str(e)}")
raise Neo4jIndexError(f"Error for inputs to create_fulltext_index: {str(e)}")

try:
query = (
Expand All @@ -124,8 +125,7 @@ def create_fulltext_index(
logger.info(f"Creating fulltext index named '{name}'")
driver.execute_query(query, {"name": name})
except neo4j.exceptions.ClientError as e:
logger.error(f"Neo4j fulltext index creation failed {e}")
raise
raise Neo4jIndexError(f"Neo4j fulltext index creation failed {e}")


def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
Expand All @@ -149,5 +149,4 @@ def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
logger.info(f"Dropping index named '{name}'")
driver.execute_query(query, parameters)
except neo4j.exceptions.ClientError as e:
logger.error(f"Dropping Neo4j index failed {e}")
raise
raise Neo4jIndexError(f"Dropping Neo4j index failed: {e}")
8 changes: 4 additions & 4 deletions src/neo4j_genai/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from typing import Optional, Any
import neo4j

from neo4j_genai.exceptions import Neo4jVersionError


class Retriever(ABC):
"""
Expand All @@ -32,7 +34,7 @@ def _verify_version(self) -> None:

Queries the Neo4j database to retrieve its version and compares it
against a target version (5.18.1) that is known to support vector
indexing. Raises a ValueError if the connected Neo4j version is
indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is
not supported.
"""
records, _, _ = self.driver.execute_query("CALL dbms.components()")
Expand All @@ -49,9 +51,7 @@ def _verify_version(self) -> None:
target_version = (5, 18, 1)

if version_tuple < target_version:
raise ValueError(
"This package only supports Neo4j version 5.18.1 or greater"
)
raise Neo4jVersionError()

@abstractmethod
def search(self, *args, **kwargs) -> Any:
Expand Down
77 changes: 77 additions & 0 deletions src/neo4j_genai/retrievers/external/weaviate/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from pydantic import (
field_validator,
BaseModel,
PositiveInt,
model_validator,
ConfigDict,
)
from weaviate.client import WeaviateClient
from weaviate.collections.classes.filters import _Filters

from neo4j_genai.retrievers.utils import validate_search_query_input
from neo4j_genai.types import Neo4jDriverModel, EmbedderModel


class WeaviateModel(BaseModel):
client: WeaviateClient
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("client")
def check_client(cls, value):
if not isinstance(value, WeaviateClient):
raise TypeError(
"Provided client needs to be of type weaviate.client.WeaviateClient"
)
return value


class WeaviateNeo4jRetrieverModel(BaseModel):
driver_model: Neo4jDriverModel
client_model: WeaviateModel
collection: str
id_property_external: str
id_property_neo4j: str
embedder_model: Optional[EmbedderModel]
return_properties: Optional[list[str]] = None
retrieval_query: Optional[str] = None


class WeaviateNeo4jSearchModel(BaseModel):
top_k: PositiveInt = 5
query_vector: Optional[list[float]] = None
query_text: Optional[str] = None
weaviate_filters: Optional[_Filters] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("weaviate_filters")
def check_weaviate_filters(cls, value):
if value and not isinstance(value, _Filters):
raise TypeError(
"Provided filters need to be of type weaviate.collections.classes.filters._Filters"
)
return value

@model_validator(mode="before")
def check_query(cls, values):
"""
Validates that one of either query_vector or query_text is provided exclusively.
"""
query_vector, query_text = values.get("query_vector"), values.get("query_text")
validate_search_query_input(query_text, query_vector)
return values
Loading