diff --git a/brevia/collections.py b/brevia/collections.py index 4bd7448..d3cb06d 100644 --- a/brevia/collections.py +++ b/brevia/collections.py @@ -1,97 +1,29 @@ -"""Collections handling functions""" -from langchain_community.vectorstores.pgembedding import CollectionStore -from sqlalchemy.orm import Session -from brevia import connection -from brevia.utilities.uuid import is_valid_uuid - - -def collections_info(collection: str | None = None): - """ Retrieve collections data """ - filter_collection = CollectionStore.name == collection - if collection is None: - filter_collection = CollectionStore.name is not None - - with Session(connection.db_connection()) as session: - query = ( - session.query( - CollectionStore.uuid, - CollectionStore.name, - CollectionStore.cmetadata, - ) - .filter(filter_collection) - .limit(100) - ) - - result = [u._asdict() for u in query.all()] - session.close() - - return result - - -def collection_name_exists(name: str) -> bool: - """ Check if a collection name exists already """ - store = single_collection_by_name(name) - - return store is not None - - -def collection_exists(uuid: str) -> bool: - """ Check if a collection uuid exists """ - store = single_collection(uuid) - - return store is not None - - -def single_collection(uuid: str) -> (CollectionStore | None): - """ Get single collection by UUID """ - if not is_valid_uuid(uuid): - return None - with Session(connection.db_connection()) as session: - return session.get(CollectionStore, uuid) - - -def single_collection_by_name(name: str) -> (CollectionStore | None): - """ Get single collection by name""" - with Session(connection.db_connection()) as session: - return CollectionStore.get_by_name(session=session, name=name) - - -def create_collection( - name: str, - cmetadata: dict, -) -> CollectionStore: - """ Create single collection """ - with Session(connection.db_connection()) as session: - collection_store = CollectionStore( - name=name, - cmetadata=cmetadata, - ) - session.expire_on_commit = False - session.add(collection_store) - session.commit() - - return collection_store - - -def update_collection( - uuid: str, - name: str, - cmetadata: dict, -): - """ Update single collection """ - with Session(connection.db_connection()) as session: - collection = session.get(CollectionStore, uuid) - collection.name = name - collection.cmetadata = cmetadata - session.add(collection) - session.commit() - - -def delete_collection( - uuid: str, -): - """ Delete single collection """ - with Session(connection.db_connection()) as session: - collection = session.get(CollectionStore, uuid) - session.delete(collection) - session.commit() +"""Collections handling functions (DEPRECATED)""" +import warnings +from brevia.collections_tools import ( + collections_info, + collection_name_exists, + collection_exists, + create_collection, + update_collection, + delete_collection, + single_collection, + single_collection_by_name, +) + +warnings.warn( + "'collections' module is deprecated; use 'collections_tools' instead.", + DeprecationWarning, + stacklevel=2 +) + +__all__ = [ + 'collections_info', + 'collection_name_exists', + 'collection_exists', + 'create_collection', + 'update_collection', + 'delete_collection', + 'single_collection', + 'single_collection_by_name', +] diff --git a/brevia/collections_tools.py b/brevia/collections_tools.py new file mode 100644 index 0000000..4bd7448 --- /dev/null +++ b/brevia/collections_tools.py @@ -0,0 +1,97 @@ +"""Collections handling functions""" +from langchain_community.vectorstores.pgembedding import CollectionStore +from sqlalchemy.orm import Session +from brevia import connection +from brevia.utilities.uuid import is_valid_uuid + + +def collections_info(collection: str | None = None): + """ Retrieve collections data """ + filter_collection = CollectionStore.name == collection + if collection is None: + filter_collection = CollectionStore.name is not None + + with Session(connection.db_connection()) as session: + query = ( + session.query( + CollectionStore.uuid, + CollectionStore.name, + CollectionStore.cmetadata, + ) + .filter(filter_collection) + .limit(100) + ) + + result = [u._asdict() for u in query.all()] + session.close() + + return result + + +def collection_name_exists(name: str) -> bool: + """ Check if a collection name exists already """ + store = single_collection_by_name(name) + + return store is not None + + +def collection_exists(uuid: str) -> bool: + """ Check if a collection uuid exists """ + store = single_collection(uuid) + + return store is not None + + +def single_collection(uuid: str) -> (CollectionStore | None): + """ Get single collection by UUID """ + if not is_valid_uuid(uuid): + return None + with Session(connection.db_connection()) as session: + return session.get(CollectionStore, uuid) + + +def single_collection_by_name(name: str) -> (CollectionStore | None): + """ Get single collection by name""" + with Session(connection.db_connection()) as session: + return CollectionStore.get_by_name(session=session, name=name) + + +def create_collection( + name: str, + cmetadata: dict, +) -> CollectionStore: + """ Create single collection """ + with Session(connection.db_connection()) as session: + collection_store = CollectionStore( + name=name, + cmetadata=cmetadata, + ) + session.expire_on_commit = False + session.add(collection_store) + session.commit() + + return collection_store + + +def update_collection( + uuid: str, + name: str, + cmetadata: dict, +): + """ Update single collection """ + with Session(connection.db_connection()) as session: + collection = session.get(CollectionStore, uuid) + collection.name = name + collection.cmetadata = cmetadata + session.add(collection) + session.commit() + + +def delete_collection( + uuid: str, +): + """ Delete single collection """ + with Session(connection.db_connection()) as session: + collection = session.get(CollectionStore, uuid) + session.delete(collection) + session.commit() diff --git a/brevia/dependencies.py b/brevia/dependencies.py index 216bf93..dd081b0 100644 --- a/brevia/dependencies.py +++ b/brevia/dependencies.py @@ -6,7 +6,12 @@ from langchain_community.vectorstores.pgembedding import CollectionStore from fastapi import HTTPException, status, Header, Depends, UploadFile from fastapi.security import OAuth2PasswordBearer -from brevia import collections, tokens +from brevia.tokens import verify_token +from brevia.collections_tools import ( + collection_exists, + single_collection_by_name, + collection_name_exists, +) from brevia.settings import get_settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') # use token authentication @@ -39,7 +44,7 @@ def application_json(content_type: str = Header(...)): def token_auth(token: str = Depends(oauth2_scheme)): """Check authorization header bearer token""" try: - tokens.verify_token(token) + verify_token(token) except JWTError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -49,7 +54,7 @@ def token_auth(token: str = Depends(oauth2_scheme)): def check_collection_name(name: str) -> CollectionStore: """Raise a 404 response if a collection name does not exist""" - collection = collections.single_collection_by_name(name) + collection = single_collection_by_name(name) if not collection: raise HTTPException( status.HTTP_404_NOT_FOUND, @@ -61,7 +66,7 @@ def check_collection_name(name: str) -> CollectionStore: def check_collection_uuid(uuid: str): """Raise a 404 response if a collection uuid does not exist""" - if not collections.collection_exists(uuid=uuid): + if not collection_exists(uuid=uuid): raise HTTPException( status.HTTP_404_NOT_FOUND, f"Collection id '{uuid}' was not found", @@ -70,7 +75,7 @@ def check_collection_uuid(uuid: str): def check_collection_name_absent(name: str): """Raise a 409 conflict if a collection name already exists""" - if collections.collection_name_exists(name=name): + if collection_name_exists(name=name): raise HTTPException( status.HTTP_409_CONFLICT, f"Collection '{name}' exists", diff --git a/brevia/index.py b/brevia/index.py index 2e406ab..4490020 100644 --- a/brevia/index.py +++ b/brevia/index.py @@ -12,7 +12,7 @@ from requests import HTTPError from sqlalchemy.orm import Session from brevia import connection, load_file -from brevia.collections import single_collection_by_name +from brevia.collections_tools import single_collection_by_name from brevia.models import load_embeddings from brevia.settings import get_settings from brevia.utilities.json_api import query_data_pagination diff --git a/brevia/query.py b/brevia/query.py index b841cb4..e15bb07 100644 --- a/brevia/query.py +++ b/brevia/query.py @@ -22,7 +22,7 @@ from langchain_core.prompts.loading import load_prompt_from_config from pydantic import BaseModel from brevia.connection import connection_string -from brevia.collections import single_collection_by_name +from brevia.collections_tools import single_collection_by_name from brevia.callback import AsyncLoggingCallbackHandler from brevia.models import load_chatmodel, load_embeddings from brevia.settings import get_settings diff --git a/brevia/routers/collections_router.py b/brevia/routers/collections_router.py index bad4cb3..3c62c72 100644 --- a/brevia/routers/collections_router.py +++ b/brevia/routers/collections_router.py @@ -6,7 +6,13 @@ check_collection_name_absent, check_collection_uuid ) -from brevia import collections +from brevia.collections_tools import ( + collections_info, + create_collection, + update_collection, + delete_collection, + single_collection, +) router = APIRouter() @@ -18,7 +24,7 @@ ) async def collections_index(name: str | None = None): """ GET /collections endpoint, information on available collections """ - return collections.collections_info(collection=name) + return collections_info(collection=name) @router.get( @@ -30,7 +36,7 @@ async def read_collection(uuid: str): """ GET /collections/{uuid} endpoint""" check_collection_uuid(uuid) - return collections.single_collection(uuid) + return single_collection(uuid) class CollectionBody(BaseModel): @@ -45,11 +51,11 @@ class CollectionBody(BaseModel): dependencies=get_dependencies(), tags=['Collections'], ) -def create_collection(body: CollectionBody): +def add_collection(body: CollectionBody): """ POST /collections endpoint""" check_collection_name_absent(body.name) - return collections.create_collection( + return create_collection( name=body.name, cmetadata=body.cmetadata, ) @@ -61,15 +67,15 @@ def create_collection(body: CollectionBody): dependencies=get_dependencies(), tags=['Collections'], ) -def update_collection(uuid: str, body: CollectionBody): +def change_collection(uuid: str, body: CollectionBody): """ PATCH /collections endpoint""" check_collection_uuid(uuid) - current = collections.single_collection(uuid) + current = single_collection(uuid) if current.name != body.name: # if name is changed check that it's not in use check_collection_name_absent(body.name) - collections.update_collection(uuid=uuid, name=body.name, cmetadata=body.cmetadata) + update_collection(uuid=uuid, name=body.name, cmetadata=body.cmetadata) @router.delete( @@ -78,7 +84,7 @@ def update_collection(uuid: str, body: CollectionBody): dependencies=get_dependencies(json_content_type=False), tags=['Collections'], ) -def delete_collection(uuid: str): +def remove_collection(uuid: str): """ DELETE /collections endpoint""" check_collection_uuid(uuid) - collections.delete_collection(uuid) + delete_collection(uuid) diff --git a/brevia/routers/index_router.py b/brevia/routers/index_router.py index 0bbc93f..569c502 100644 --- a/brevia/routers/index_router.py +++ b/brevia/routers/index_router.py @@ -13,7 +13,8 @@ save_upload_file_tmp, check_collection_uuid, ) -from brevia import index, collections, load_file +from brevia.collections_tools import single_collection +from brevia import index, load_file router = APIRouter() @@ -49,7 +50,7 @@ def index_document(item: IndexBody): def load_collection(collection_id: str) -> CollectionStore: """ Load collection by ID and throw 404 if not found""" - collection = collections.single_collection(collection_id) + collection = single_collection(collection_id) if collection is None: raise HTTPException( status.HTTP_404_NOT_FOUND, diff --git a/brevia/utilities/collections_io.py b/brevia/utilities/collections_io.py index 3c2aa27..b4cbd33 100644 --- a/brevia/utilities/collections_io.py +++ b/brevia/utilities/collections_io.py @@ -1,6 +1,7 @@ """Utility functions to import/export collections using CSV postgres files.""" from os import path -from brevia import connection, collections +from brevia.collections_tools import collection_name_exists +from brevia import connection def export_collection_data( @@ -8,7 +9,7 @@ def export_collection_data( collection: str, ): """Export collection data using `psql`""" - if not collections.collection_name_exists(collection): + if not collection_name_exists(collection): raise ValueError(f"Collection '{collection}' was not found") csv_file_collection = f"{folder_path}/{collection}-collection.csv" @@ -41,7 +42,7 @@ def import_collection_data( collection: str, ): """Import collection data using `psql`""" - if collections.collection_name_exists(collection): + if collection_name_exists(collection): raise ValueError(f"Collection '{collection}' already exists, exiting") csv_file_collection = f"{folder_path}/{collection}-collection.csv" diff --git a/tests/routers/test_chat_history_router.py b/tests/routers/test_chat_history_router.py index 6988ffc..fdc4823 100644 --- a/tests/routers/test_chat_history_router.py +++ b/tests/routers/test_chat_history_router.py @@ -5,7 +5,7 @@ from fastapi import FastAPI from sqlalchemy.orm import Session from brevia.routers import chat_history_router -from brevia.collections import create_collection +from brevia.collections_tools import create_collection from brevia.connection import db_connection from brevia.chat_history import add_history, ChatHistoryStore diff --git a/tests/routers/test_collections_router.py b/tests/routers/test_collections_router.py index b01f190..e0149ce 100644 --- a/tests/routers/test_collections_router.py +++ b/tests/routers/test_collections_router.py @@ -2,7 +2,7 @@ from fastapi.testclient import TestClient from fastapi import FastAPI from brevia.routers import collections_router -from brevia.collections import create_collection +from brevia.collections_tools import create_collection app = FastAPI() app.include_router(collections_router.router) diff --git a/tests/routers/test_index_router.py b/tests/routers/test_index_router.py index 4159a00..0b002cb 100644 --- a/tests/routers/test_index_router.py +++ b/tests/routers/test_index_router.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from langchain.docstore.document import Document from brevia.routers import index_router -from brevia.collections import create_collection +from brevia.collections_tools import create_collection from brevia.index import add_document, read_document from unittest.mock import patch diff --git a/tests/routers/test_qa_router.py b/tests/routers/test_qa_router.py index 73de27e..21b680b 100644 --- a/tests/routers/test_qa_router.py +++ b/tests/routers/test_qa_router.py @@ -6,7 +6,7 @@ from brevia.routers.qa_router import ( router, ChatBody, chat_language, retrieve_chat_history, extract_content_score ) -from brevia.collections import create_collection +from brevia.collections_tools import create_collection from brevia.index import add_document from brevia.settings import get_settings diff --git a/tests/test_callback.py b/tests/test_callback.py index 95807e5..a58e983 100644 --- a/tests/test_callback.py +++ b/tests/test_callback.py @@ -3,7 +3,7 @@ from json import loads from langchain.docstore.document import Document from brevia.callback import ConversationCallbackHandler, TokensCallbackHandler -from brevia.collections import create_collection +from brevia.collections_tools import create_collection def test_chain_result(): diff --git a/tests/test_chat_history.py b/tests/test_chat_history.py index 9398bb9..fb056c8 100644 --- a/tests/test_chat_history.py +++ b/tests/test_chat_history.py @@ -9,7 +9,7 @@ get_history, ChatHistoryFilter, ) -from brevia.collections import create_collection +from brevia.collections_tools import create_collection def test_history(): diff --git a/tests/test_collections.py b/tests/test_collections.py index 3ca2dc7..5a68b60 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -1,5 +1,4 @@ -"""collections module tests""" -import uuid +"""collections module tests (deprecated)""" from brevia.collections import ( collections_info, collection_name_exists, @@ -7,49 +6,25 @@ create_collection, update_collection, delete_collection, - single_collection + single_collection, + single_collection_by_name, ) -def test_create_collection(): - """ Test create_collection function """ +def test_collection_module(): + """ Test deprecated `collections` module function""" + result = collections_info() + assert result == [] result = create_collection('test_collection', {'key': 'value'}) assert result.uuid is not None assert result.name == 'test_collection' assert result.cmetadata == {'key': 'value'} - - -def test_collections_info(): - """Test collections_info()""" - result = collections_info() - assert result == [] - - -def test_collection_name_exists(): - """Test collection_name_exists function""" - create_collection('test_collection2', {}) - assert collection_name_exists('test_collection2') is True - assert collection_name_exists('nonexistent_collection') is False - - -def test_collection_exists(): - """Test collection_exists function""" - collection = create_collection('new_collection', {}) - assert collection_exists(collection.uuid) is True - assert collection_exists(uuid.uuid4()) is False - - -def test_update_collection(): - """Test update_collection function""" - collection = create_collection('new_collection', {}) - update_collection(collection.uuid, 'updated_collection', {'new_key': 'new_value'}) - updated = single_collection(collection.uuid) + update_collection(result.uuid, 'updated_collection', {'new_key': 'new_value'}) + updated = single_collection(result.uuid) assert updated.name == 'updated_collection' assert updated.cmetadata == {'new_key': 'new_value'} - - -def test_delete_collection(): - """Test delete_collection function""" - collection = create_collection('new_collection', {}) - delete_collection(collection.uuid) + coll = single_collection_by_name('updated_collection') + assert coll.uuid == result.uuid + assert collection_exists(result.uuid) + delete_collection(result.uuid) assert collection_name_exists('new_collection') is False diff --git a/tests/test_collections_tools.py b/tests/test_collections_tools.py new file mode 100644 index 0000000..4a0d76d --- /dev/null +++ b/tests/test_collections_tools.py @@ -0,0 +1,55 @@ +"""collections_tools module tests""" +import uuid +from brevia.collections_tools import ( + collections_info, + collection_name_exists, + collection_exists, + create_collection, + update_collection, + delete_collection, + single_collection +) + + +def test_create_collection(): + """ Test create_collection function """ + result = create_collection('test_collection', {'key': 'value'}) + assert result.uuid is not None + assert result.name == 'test_collection' + assert result.cmetadata == {'key': 'value'} + + +def test_collections_info(): + """Test collections_info()""" + result = collections_info() + assert result == [] + + +def test_collection_name_exists(): + """Test collection_name_exists function""" + create_collection('test_collection2', {}) + assert collection_name_exists('test_collection2') is True + assert collection_name_exists('nonexistent_collection') is False + + +def test_collection_exists(): + """Test collection_exists function""" + collection = create_collection('new_collection', {}) + assert collection_exists(collection.uuid) is True + assert collection_exists(uuid.uuid4()) is False + + +def test_update_collection(): + """Test update_collection function""" + collection = create_collection('new_collection', {}) + update_collection(collection.uuid, 'updated_collection', {'new_key': 'new_value'}) + updated = single_collection(collection.uuid) + assert updated.name == 'updated_collection' + assert updated.cmetadata == {'new_key': 'new_value'} + + +def test_delete_collection(): + """Test delete_collection function""" + collection = create_collection('new_collection', {}) + delete_collection(collection.uuid) + assert collection_name_exists('new_collection') is False diff --git a/tests/test_commands.py b/tests/test_commands.py index b8781a0..e9dd433 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -16,7 +16,7 @@ create_openapi, update_collection_links, ) -from brevia.collections import create_collection, collection_name_exists +from brevia.collections_tools import create_collection, collection_name_exists from brevia.settings import get_settings from brevia.index import add_document diff --git a/tests/test_index.py b/tests/test_index.py index eb2f919..d44d18b 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -12,7 +12,7 @@ add_document, document_has_changed, select_load_link_options, documents_metadata, create_splitter, ) -from brevia.collections import create_collection +from brevia.collections_tools import create_collection from brevia.settings import get_settings diff --git a/tests/test_query.py b/tests/test_query.py index 4b57fd6..0e1ad83 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -13,7 +13,7 @@ ChatParams, SearchQuery, ) -from brevia.collections import create_collection +from brevia.collections_tools import create_collection from brevia.index import add_document from brevia.settings import get_settings diff --git a/tests/utilities/test_collections_io.py b/tests/utilities/test_collections_io.py index 7d54c6d..969810f 100644 --- a/tests/utilities/test_collections_io.py +++ b/tests/utilities/test_collections_io.py @@ -6,7 +6,7 @@ export_collection_data, import_collection_data, ) -from brevia.collections import create_collection +from brevia.collections_tools import create_collection def test_export_collection_data(): diff --git a/tests/utilities/test_files_import.py b/tests/utilities/test_files_import.py index 65b6ac5..64a7eb4 100644 --- a/tests/utilities/test_files_import.py +++ b/tests/utilities/test_files_import.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest from brevia.utilities.files_import import index_file_folder -from brevia.collections import create_collection +from brevia.collections_tools import create_collection def test_index_file_folder():