Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 29 additions & 97 deletions brevia/collections.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,29 @@
"""Collections handling functions"""
from langchain_community.vectorstores.pgembedding import CollectionStore
from sqlalchemy.orm import Session
from brevia import connection
from brevia.utilities.uuid import is_valid_uuid


def collections_info(collection: str | None = None):
""" Retrieve collections data """
filter_collection = CollectionStore.name == collection
if collection is None:
filter_collection = CollectionStore.name is not None

with Session(connection.db_connection()) as session:
query = (
session.query(
CollectionStore.uuid,
CollectionStore.name,
CollectionStore.cmetadata,
)
.filter(filter_collection)
.limit(100)
)

result = [u._asdict() for u in query.all()]
session.close()

return result


def collection_name_exists(name: str) -> bool:
""" Check if a collection name exists already """
store = single_collection_by_name(name)

return store is not None


def collection_exists(uuid: str) -> bool:
""" Check if a collection uuid exists """
store = single_collection(uuid)

return store is not None


def single_collection(uuid: str) -> (CollectionStore | None):
""" Get single collection by UUID """
if not is_valid_uuid(uuid):
return None
with Session(connection.db_connection()) as session:
return session.get(CollectionStore, uuid)


def single_collection_by_name(name: str) -> (CollectionStore | None):
""" Get single collection by name"""
with Session(connection.db_connection()) as session:
return CollectionStore.get_by_name(session=session, name=name)


def create_collection(
name: str,
cmetadata: dict,
) -> CollectionStore:
""" Create single collection """
with Session(connection.db_connection()) as session:
collection_store = CollectionStore(
name=name,
cmetadata=cmetadata,
)
session.expire_on_commit = False
session.add(collection_store)
session.commit()

return collection_store


def update_collection(
uuid: str,
name: str,
cmetadata: dict,
):
""" Update single collection """
with Session(connection.db_connection()) as session:
collection = session.get(CollectionStore, uuid)
collection.name = name
collection.cmetadata = cmetadata
session.add(collection)
session.commit()


def delete_collection(
uuid: str,
):
""" Delete single collection """
with Session(connection.db_connection()) as session:
collection = session.get(CollectionStore, uuid)
session.delete(collection)
session.commit()
"""Collections handling functions (DEPRECATED)"""
import warnings
from brevia.collections_tools import (
collections_info,
collection_name_exists,
collection_exists,
create_collection,
update_collection,
delete_collection,
single_collection,
single_collection_by_name,
)

warnings.warn(
"'collections' module is deprecated; use 'collections_tools' instead.",
DeprecationWarning,
stacklevel=2
)

__all__ = [
'collections_info',
'collection_name_exists',
'collection_exists',
'create_collection',
'update_collection',
'delete_collection',
'single_collection',
'single_collection_by_name',
]
97 changes: 97 additions & 0 deletions brevia/collections_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Collections handling functions"""
from langchain_community.vectorstores.pgembedding import CollectionStore
from sqlalchemy.orm import Session
from brevia import connection
from brevia.utilities.uuid import is_valid_uuid


def collections_info(collection: str | None = None):
""" Retrieve collections data """
filter_collection = CollectionStore.name == collection
if collection is None:
filter_collection = CollectionStore.name is not None

with Session(connection.db_connection()) as session:
query = (
session.query(
CollectionStore.uuid,
CollectionStore.name,
CollectionStore.cmetadata,
)
.filter(filter_collection)
.limit(100)
)

result = [u._asdict() for u in query.all()]
session.close()

return result


def collection_name_exists(name: str) -> bool:
""" Check if a collection name exists already """
store = single_collection_by_name(name)

return store is not None


def collection_exists(uuid: str) -> bool:
""" Check if a collection uuid exists """
store = single_collection(uuid)

return store is not None


def single_collection(uuid: str) -> (CollectionStore | None):
""" Get single collection by UUID """
if not is_valid_uuid(uuid):
return None
with Session(connection.db_connection()) as session:
return session.get(CollectionStore, uuid)


def single_collection_by_name(name: str) -> (CollectionStore | None):
""" Get single collection by name"""
with Session(connection.db_connection()) as session:
return CollectionStore.get_by_name(session=session, name=name)


def create_collection(
name: str,
cmetadata: dict,
) -> CollectionStore:
""" Create single collection """
with Session(connection.db_connection()) as session:
collection_store = CollectionStore(
name=name,
cmetadata=cmetadata,
)
session.expire_on_commit = False
session.add(collection_store)
session.commit()

return collection_store


def update_collection(
uuid: str,
name: str,
cmetadata: dict,
):
""" Update single collection """
with Session(connection.db_connection()) as session:
collection = session.get(CollectionStore, uuid)
collection.name = name
collection.cmetadata = cmetadata
session.add(collection)
session.commit()


def delete_collection(
uuid: str,
):
""" Delete single collection """
with Session(connection.db_connection()) as session:
collection = session.get(CollectionStore, uuid)
session.delete(collection)
session.commit()
15 changes: 10 additions & 5 deletions brevia/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from langchain_community.vectorstores.pgembedding import CollectionStore
from fastapi import HTTPException, status, Header, Depends, UploadFile
from fastapi.security import OAuth2PasswordBearer
from brevia import collections, tokens
from brevia.tokens import verify_token
from brevia.collections_tools import (
collection_exists,
single_collection_by_name,
collection_name_exists,
)
from brevia.settings import get_settings

oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') # use token authentication
Expand Down Expand Up @@ -39,7 +44,7 @@ def application_json(content_type: str = Header(...)):
def token_auth(token: str = Depends(oauth2_scheme)):
"""Check authorization header bearer token"""
try:
tokens.verify_token(token)
verify_token(token)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -49,7 +54,7 @@ def token_auth(token: str = Depends(oauth2_scheme)):

def check_collection_name(name: str) -> CollectionStore:
"""Raise a 404 response if a collection name does not exist"""
collection = collections.single_collection_by_name(name)
collection = single_collection_by_name(name)
if not collection:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand All @@ -61,7 +66,7 @@ def check_collection_name(name: str) -> CollectionStore:

def check_collection_uuid(uuid: str):
"""Raise a 404 response if a collection uuid does not exist"""
if not collections.collection_exists(uuid=uuid):
if not collection_exists(uuid=uuid):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f"Collection id '{uuid}' was not found",
Expand All @@ -70,7 +75,7 @@ def check_collection_uuid(uuid: str):

def check_collection_name_absent(name: str):
"""Raise a 409 conflict if a collection name already exists"""
if collections.collection_name_exists(name=name):
if collection_name_exists(name=name):
raise HTTPException(
status.HTTP_409_CONFLICT,
f"Collection '{name}' exists",
Expand Down
2 changes: 1 addition & 1 deletion brevia/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from requests import HTTPError
from sqlalchemy.orm import Session
from brevia import connection, load_file
from brevia.collections import single_collection_by_name
from brevia.collections_tools import single_collection_by_name
from brevia.models import load_embeddings
from brevia.settings import get_settings
from brevia.utilities.json_api import query_data_pagination
Expand Down
2 changes: 1 addition & 1 deletion brevia/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from langchain_core.prompts.loading import load_prompt_from_config
from pydantic import BaseModel
from brevia.connection import connection_string
from brevia.collections import single_collection_by_name
from brevia.collections_tools import single_collection_by_name
from brevia.callback import AsyncLoggingCallbackHandler
from brevia.models import load_chatmodel, load_embeddings
from brevia.settings import get_settings
Expand Down
26 changes: 16 additions & 10 deletions brevia/routers/collections_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
check_collection_name_absent,
check_collection_uuid
)
from brevia import collections
from brevia.collections_tools import (
collections_info,
create_collection,
update_collection,
delete_collection,
single_collection,
)

router = APIRouter()

Expand All @@ -18,7 +24,7 @@
)
async def collections_index(name: str | None = None):
""" GET /collections endpoint, information on available collections """
return collections.collections_info(collection=name)
return collections_info(collection=name)


@router.get(
Expand All @@ -30,7 +36,7 @@ async def read_collection(uuid: str):
""" GET /collections/{uuid} endpoint"""
check_collection_uuid(uuid)

return collections.single_collection(uuid)
return single_collection(uuid)


class CollectionBody(BaseModel):
Expand All @@ -45,11 +51,11 @@ class CollectionBody(BaseModel):
dependencies=get_dependencies(),
tags=['Collections'],
)
def create_collection(body: CollectionBody):
def add_collection(body: CollectionBody):
""" POST /collections endpoint"""
check_collection_name_absent(body.name)

return collections.create_collection(
return create_collection(
name=body.name,
cmetadata=body.cmetadata,
)
Expand All @@ -61,15 +67,15 @@ def create_collection(body: CollectionBody):
dependencies=get_dependencies(),
tags=['Collections'],
)
def update_collection(uuid: str, body: CollectionBody):
def change_collection(uuid: str, body: CollectionBody):
""" PATCH /collections endpoint"""
check_collection_uuid(uuid)
current = collections.single_collection(uuid)
current = single_collection(uuid)
if current.name != body.name:
# if name is changed check that it's not in use
check_collection_name_absent(body.name)

collections.update_collection(uuid=uuid, name=body.name, cmetadata=body.cmetadata)
update_collection(uuid=uuid, name=body.name, cmetadata=body.cmetadata)


@router.delete(
Expand All @@ -78,7 +84,7 @@ def update_collection(uuid: str, body: CollectionBody):
dependencies=get_dependencies(json_content_type=False),
tags=['Collections'],
)
def delete_collection(uuid: str):
def remove_collection(uuid: str):
""" DELETE /collections endpoint"""
check_collection_uuid(uuid)
collections.delete_collection(uuid)
delete_collection(uuid)
5 changes: 3 additions & 2 deletions brevia/routers/index_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
save_upload_file_tmp,
check_collection_uuid,
)
from brevia import index, collections, load_file
from brevia.collections_tools import single_collection
from brevia import index, load_file

router = APIRouter()

Expand Down Expand Up @@ -49,7 +50,7 @@ def index_document(item: IndexBody):

def load_collection(collection_id: str) -> CollectionStore:
""" Load collection by ID and throw 404 if not found"""
collection = collections.single_collection(collection_id)
collection = single_collection(collection_id)
if collection is None:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand Down
Loading