diff --git a/.env.example b/.env.example
index 4ae028019242..4961c508d711 100644
--- a/.env.example
+++ b/.env.example
@@ -54,6 +54,8 @@ EXTERNAL_SUPABASE_URL=http://localhost:54321
SUPABASE_SERVICE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImV4cCI6MTk4MzgxMjk5Nn0.EGIM96RAZx35lJzdJsyH-qQwv8Hdp7fsn3W0YpN81IU
PG_DATABASE_URL=postgresql://postgres:postgres@host.docker.internal:54322/postgres
PG_DATABASE_ASYNC_URL=postgresql+asyncpg://postgres:postgres@host.docker.internal:54322/postgres
+SQLALCHEMY_POOL_SIZE=10
+SQLALCHEMY_MAX_POOL_OVERFLOW=0
JWT_SECRET_KEY=super-secret-jwt-token-with-at-least-32-characters-long
AUTHENTICATE=true
TELEMETRY_ENABLED=true
diff --git a/.gitignore b/.gitignore
index 89fbda293a27..6e6de69cc764 100644
--- a/.gitignore
+++ b/.gitignore
@@ -103,3 +103,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json
.tox
Pipfile
*.pkl
+backend/benchmarks/data.json
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 86370d352832..9cd662a50680 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -1,8 +1,9 @@
{
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit",
+ "source.organizeImports.ruff": "explicit",
"source.fixAll": "explicit",
- "source.unusedImports": "explicit",
+ "source.unusedImports": "explicit"
},
"editor.formatOnSave": true,
"editor.formatOnSaveMode": "file",
@@ -24,9 +25,7 @@
"source.fixAll": "explicit"
}
},
- "python.analysis.extraPaths": [
- "./backend"
- ],
+ "python.analysis.extraPaths": ["./backend"],
"python.defaultInterpreterPath": "python3",
"python.testing.pytestArgs": [
"-v",
@@ -43,6 +42,5 @@
"reportMissingImports": "error",
"reportUnusedImport": "warning",
"reportGeneralTypeIssues": "warning"
- },
- "makefile.configureOnOpen": false
+ }
}
diff --git a/backend/api/quivr_api/__init__.py b/backend/api/quivr_api/__init__.py
index f25c4b4b9308..92c5ed5104e5 100644
--- a/backend/api/quivr_api/__init__.py
+++ b/backend/api/quivr_api/__init__.py
@@ -1,15 +1,18 @@
from quivr_api.modules.brain.entity.brain_entity import Brain
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from .modules.chat.entity.chat import Chat, ChatHistory
-from .modules.sync.entity.sync_models import NotionSyncFile
+from .modules.sync.entity.sync_models import NotionSyncFile, Sync
from .modules.user.entity.user_identity import User
__all__ = [
"Chat",
"ChatHistory",
+ "BrainUserDB",
"User",
"NotionSyncFile",
"KnowledgeDB",
"Brain",
+ "Sync",
]
diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py
index 5f4050a99781..745677e9415e 100644
--- a/backend/api/quivr_api/models/settings.py
+++ b/backend/api/quivr_api/models/settings.py
@@ -103,6 +103,8 @@ def set_once_user_properties(self, user_id: UUID, event_name, properties: dict):
class BrainSettings(BaseSettings):
model_config = SettingsConfigDict(validate_default=False)
+ pg_database_url: str
+ pg_database_async_url: str
openai_api_key: str = ""
azure_openai_embeddings_url: str = ""
supabase_url: str = ""
@@ -112,9 +114,10 @@ class BrainSettings(BaseSettings):
ollama_api_base_url: str | None = None
langfuse_public_key: str | None = None
langfuse_secret_key: str | None = None
- pg_database_url: str
- pg_database_async_url: str
+ sqlalchemy_pool_size: int = 5
+ sqlalchemy_max_pool_overflow: int = 5
embedding_dim: int = 1536
+ max_file_size: int = int(5e7)
class ResendSettings(BaseSettings):
diff --git a/backend/api/quivr_api/models/sqlalchemy_repository.py b/backend/api/quivr_api/models/sqlalchemy_repository.py
deleted file mode 100644
index 7b295187973a..000000000000
--- a/backend/api/quivr_api/models/sqlalchemy_repository.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from datetime import datetime
-from uuid import uuid4
-
-from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import relationship
-
-Base = declarative_base()
-
-
-class User(Base):
- __tablename__ = "users"
-
- user_id = Column(String, primary_key=True)
- email = Column(String)
- date = Column(DateTime)
- daily_requests_count = Column(Integer)
-
-
-class Brain(Base):
- __tablename__ = "brains"
-
- brain_id = Column(Integer, primary_key=True)
- name = Column(String)
- users = relationship("BrainUser", back_populates="brain")
- vectors = relationship("BrainVector", back_populates="brain")
-
-
-class BrainUser(Base):
- __tablename__ = "brains_users"
-
- id = Column(Integer, primary_key=True)
- user_id = Column(Integer, ForeignKey("users.user_id"))
- brain_id = Column(Integer, ForeignKey("brains.brain_id"))
- rights = Column(String)
-
- user = relationship("User")
- brain = relationship("Brain", back_populates="users")
-
-
-class BrainVector(Base):
- __tablename__ = "brains_vectors"
-
- vector_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
- brain_id = Column(Integer, ForeignKey("brains.brain_id"))
- file_sha1 = Column(String)
-
- brain = relationship("Brain", back_populates="vectors")
-
-
-class BrainSubscriptionInvitation(Base):
- __tablename__ = "brain_subscription_invitations"
-
- id = Column(Integer, primary_key=True) # Assuming an integer primary key named 'id'
- brain_id = Column(String, ForeignKey("brains.brain_id"))
- email = Column(String, ForeignKey("users.email"))
- rights = Column(String)
-
- brain = relationship("Brain")
- user = relationship("User", foreign_keys=[email])
-
-
-class ApiKey(Base):
- __tablename__ = "api_keys"
-
- key_id = Column(String, primary_key=True, default=lambda: str(uuid4()))
- user_id = Column(Integer, ForeignKey("users.user_id"))
- api_key = Column(String, unique=True)
- creation_time = Column(DateTime, default=datetime.utcnow)
- is_active = Column(Boolean, default=True)
- deleted_time = Column(DateTime, nullable=True)
-
- user = relationship("User")
diff --git a/backend/api/quivr_api/modules/brain/entity/brain_entity.py b/backend/api/quivr_api/modules/brain/entity/brain_entity.py
index 708b8d48220c..28bc55d3759f 100644
--- a/backend/api/quivr_api/modules/brain/entity/brain_entity.py
+++ b/backend/api/quivr_api/modules/brain/entity/brain_entity.py
@@ -10,6 +10,7 @@
from sqlmodel import TIMESTAMP, Column, Field, Relationship, SQLModel, text
from sqlmodel import UUID as PGUUID
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
from quivr_api.modules.brain.entity.integration_brain import (
IntegrationDescriptionEntity,
IntegrationEntity,
@@ -69,7 +70,12 @@ class Brain(AsyncAttrs, SQLModel, table=True):
knowledges: List[KnowledgeDB] = Relationship(
back_populates="brains", link_model=KnowledgeBrain
)
-
+ users: List["User"] = Relationship( # type: ignore # noqa: F821
+ back_populates="brains",
+ link_model=BrainUserDB,
+ )
+ snippet_color: str | None = Field(default="#d0c6f2")
+ snippet_emoji: str | None = Field(default="🧠")
# TODO : add
# "meaning" "public"."vector",
# "tags" "public"."tags"[]
diff --git a/backend/api/quivr_api/modules/brain/entity/brain_user.py b/backend/api/quivr_api/modules/brain/entity/brain_user.py
new file mode 100644
index 000000000000..24b1b029c307
--- /dev/null
+++ b/backend/api/quivr_api/modules/brain/entity/brain_user.py
@@ -0,0 +1,14 @@
+from uuid import UUID
+
+from sqlmodel import Field, SQLModel
+
+
+class BrainUserDB(SQLModel, table=True):
+ __tablename__ = "brains_users" # type: ignore
+
+ brain_id: UUID = Field(
+ nullable=False, foreign_key="brains.brain_id", primary_key=True
+ )
+ user_id: UUID = Field(nullable=False, foreign_key="users.id", primary_key=True)
+ default_brain: bool
+ rights: str
diff --git a/backend/api/quivr_api/modules/brain/repository/brains_users.py b/backend/api/quivr_api/modules/brain/repository/brains_users.py
index 9176eeb35ce6..cdbc69f903b8 100644
--- a/backend/api/quivr_api/modules/brain/repository/brains_users.py
+++ b/backend/api/quivr_api/modules/brain/repository/brains_users.py
@@ -2,7 +2,7 @@
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import (
- BrainUser,
+ BrainUserDB,
MinimalUserBrainEntity,
)
from quivr_api.modules.brain.repository.interfaces.brains_users_interface import (
@@ -161,7 +161,7 @@ def get_user_default_brain_id(self, user_id: UUID) -> UUID | None:
return None
return UUID(response[0].get("brain_id"))
- def get_brain_users(self, brain_id: UUID) -> list[BrainUser]:
+ def get_brain_users(self, brain_id: UUID) -> list[BrainUserDB]:
response = (
self.db.table("brains_users")
.select("id:brain_id, *")
@@ -169,7 +169,7 @@ def get_brain_users(self, brain_id: UUID) -> list[BrainUser]:
.execute()
)
- return [BrainUser(**item) for item in response.data]
+ return [BrainUserDB(**item) for item in response.data]
def delete_brain_subscribers(self, brain_id: UUID):
results = (
diff --git a/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py b/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py
index dabe8ef924b7..d87365239a41 100644
--- a/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py
+++ b/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py
@@ -3,7 +3,7 @@
from uuid import UUID
from quivr_api.modules.brain.entity.brain_entity import (
- BrainUser,
+ BrainUserDB,
MinimalUserBrainEntity,
)
@@ -56,7 +56,7 @@ def get_user_default_brain_id(self, user_id: UUID) -> UUID | None:
pass
@abstractmethod
- def get_brain_users(self, brain_id: UUID) -> List[BrainUser]:
+ def get_brain_users(self, brain_id: UUID) -> List[BrainUserDB]:
"""
Get all users for a brain
"""
@@ -88,7 +88,7 @@ def update_brain_user_default_status(
@abstractmethod
def update_brain_user_rights(
self, brain_id: UUID, user_id: UUID, rights: str
- ) -> BrainUser:
+ ) -> BrainUserDB:
"""
Update the rights for a user in a brain
"""
diff --git a/backend/api/quivr_api/modules/brain/service/brain_service.py b/backend/api/quivr_api/modules/brain/service/brain_service.py
index 7b9da881c7ff..0a701326fbc7 100644
--- a/backend/api/quivr_api/modules/brain/service/brain_service.py
+++ b/backend/api/quivr_api/modules/brain/service/brain_service.py
@@ -1,4 +1,4 @@
-from typing import Dict, Optional, Tuple
+from typing import Optional, Tuple
from uuid import UUID
from fastapi import HTTPException
@@ -54,7 +54,7 @@ def find_brain_from_question(
chat_id: UUID,
history,
vector_store: CustomSupabaseVectorStore,
- ) -> Tuple[Optional[BrainEntity], Dict[str, str]]:
+ ) -> Tuple[Optional[BrainEntity], dict[str, str]]:
"""Find the brain to use for a question.
Args:
diff --git a/backend/api/quivr_api/modules/brain/service/brain_user_service.py b/backend/api/quivr_api/modules/brain/service/brain_user_service.py
index 031cfb8a3351..61699fcededf 100644
--- a/backend/api/quivr_api/modules/brain/service/brain_user_service.py
+++ b/backend/api/quivr_api/modules/brain/service/brain_user_service.py
@@ -6,7 +6,7 @@
from quivr_api.logger import get_logger
from quivr_api.modules.brain.entity.brain_entity import (
BrainEntity,
- BrainUser,
+ BrainUserDB,
MinimalUserBrainEntity,
RoleEnum,
)
@@ -74,7 +74,7 @@ def get_user_brains(self, user_id: UUID) -> list[MinimalUserBrainEntity]:
return results # type: ignore
- def get_brain_users(self, brain_id: UUID) -> List[BrainUser]:
+ def get_brain_users(self, brain_id: UUID) -> List[BrainUserDB]:
return self.brain_user_repository.get_brain_users(brain_id)
def update_brain_user_rights(
diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py
index e74bdee45d1c..75cf3eb1df1e 100644
--- a/backend/api/quivr_api/modules/dependencies.py
+++ b/backend/api/quivr_api/modules/dependencies.py
@@ -5,12 +5,7 @@
from fastapi import Depends
from langchain.embeddings.base import Embeddings
from langchain_community.embeddings.ollama import OllamaEmbeddings
-
-# from langchain_community.vectorstores.supabase import SupabaseVectorStore
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
-
-# from quivr_api.modules.vector.service.vector_service import VectorService
-# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore
from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import Session, text
@@ -61,16 +56,20 @@ def get_repository_cls(cls) -> Type[R]:
future=True,
# NOTE: pessimistic bound on
pool_pre_ping=True,
- pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
+ pool_size=1, # NOTE: no bouncer for now, if 6 process workers => 6
+ max_overflow=0,
pool_recycle=1800,
)
async_engine = create_async_engine(
settings.pg_database_async_url,
- connect_args={"server_settings": {"application_name": "quivr-api-async"}},
+ connect_args={
+ "server_settings": {"application_name": "quivr-api-async"},
+ },
echo=True if os.getenv("ORM_DEBUG") else False,
future=True,
pool_pre_ping=True,
- pool_size=5, # NOTE: no bouncer for now, if 6 process workers => 6
+ pool_size=settings.sqlalchemy_pool_size, # NOTE: no bouncer for now, if 6 process workers => 6
+ max_overflow=settings.sqlalchemy_max_pool_overflow,
pool_recycle=1800,
isolation_level="AUTOCOMMIT",
)
diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py
index 68d01afb0c5a..5fc5f62842cb 100644
--- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py
+++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py
@@ -1,19 +1,35 @@
+import asyncio
from http import HTTPStatus
-from typing import Annotated, List, Optional
+from typing import List, Optional
from uuid import UUID
-from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status
+from fastapi import (
+ APIRouter,
+ Depends,
+ File,
+ HTTPException,
+ Query,
+ Response,
+ UploadFile,
+ status,
+)
+from quivr_core.models import KnowledgeStatus
+from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.brain.entity.brain_entity import RoleEnum
+from quivr_api.models.settings import settings
from quivr_api.modules.brain.service.brain_authorization_service import (
- has_brain_authorization,
validate_brain_authorization,
)
from quivr_api.modules.dependencies import get_service
-from quivr_api.modules.knowledge.dto.inputs import AddKnowledge
-from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate
+from quivr_api.modules.knowledge.dto.inputs import (
+ AddKnowledge,
+ KnowledgeUpdate,
+ LinkKnowledgeBrain,
+ UnlinkKnowledgeBrain,
+)
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
KnowledgeDeleteError,
KnowledgeForbiddenAccess,
@@ -21,30 +37,34 @@
UploadError,
)
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
+from quivr_api.modules.notification.service.notification_service import (
+ NotificationService,
+)
+from quivr_api.modules.sync.service.sync_service import SyncsService
from quivr_api.modules.upload.service.generate_file_signed_url import (
generate_file_signed_url,
)
from quivr_api.modules.user.entity.user_identity import UserIdentity
-knowledge_router = APIRouter()
logger = get_logger(__name__)
+knowledge_router = APIRouter()
-get_km_service = get_service(KnowledgeService)
-KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)]
+notification_service = NotificationService()
+get_knowledge_service = get_service(KnowledgeService)
+get_sync_service = get_service(SyncsService)
@knowledge_router.get(
"/knowledge", dependencies=[Depends(AuthBearer())], tags=["Knowledge"]
)
async def list_knowledge_in_brain_endpoint(
- knowledge_service: KnowledgeServiceDep,
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
brain_id: UUID = Query(..., description="The ID of the brain"),
current_user: UserIdentity = Depends(get_current_user),
):
"""
Retrieve and list all the knowledge in a brain.
"""
-
validate_brain_authorization(brain_id=brain_id, user_id=current_user.id)
knowledges = await knowledge_service.get_all_knowledge_in_brain(brain_id)
@@ -52,33 +72,6 @@ async def list_knowledge_in_brain_endpoint(
return {"knowledges": knowledges}
-@knowledge_router.delete(
- "/knowledge/{knowledge_id}",
- dependencies=[
- Depends(AuthBearer()),
- Depends(has_brain_authorization(RoleEnum.Owner)),
- ],
- tags=["Knowledge"],
-)
-async def delete_knowledge_brain(
- knowledge_id: UUID,
- knowledge_service: KnowledgeServiceDep,
- current_user: UserIdentity = Depends(get_current_user),
- brain_id: UUID = Query(..., description="The ID of the brain"),
-):
- """
- Delete a specific knowledge from a brain.
- """
-
- knowledge = await knowledge_service.get_knowledge(knowledge_id)
- file_name = knowledge.file_name if knowledge.file_name else knowledge.url
- await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id)
-
- return {
- "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}."
- }
-
-
@knowledge_router.get(
"/knowledge/{knowledge_id}/signed_download_url",
dependencies=[Depends(AuthBearer())],
@@ -86,7 +79,7 @@ async def delete_knowledge_brain(
)
async def generate_signed_url_endpoint(
knowledge_id: UUID,
- knowledge_service: KnowledgeServiceDep,
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
"""
@@ -120,15 +113,20 @@ async def generate_signed_url_endpoint(
@knowledge_router.post(
"/knowledge/",
tags=["Knowledge"],
- response_model=Knowledge,
+ response_model=KnowledgeDTO,
)
async def create_knowledge(
knowledge_data: str = File(...),
file: Optional[UploadFile] = None,
- knowledge_service: KnowledgeService = Depends(get_km_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
knowledge = AddKnowledge.model_validate_json(knowledge_data)
+ if file and file.size and file.size > settings.max_file_size:
+ raise HTTPException(
+ status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
+ detail="Uploaded file is too large",
+ )
if not knowledge.file_name and not knowledge.url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -159,19 +157,20 @@ async def create_knowledge(
@knowledge_router.get(
- "/knowledge/children",
- response_model=List[Knowledge] | None,
+ "/knowledge/files",
+ response_model=List[KnowledgeDTO] | None,
tags=["Knowledge"],
)
async def list_knowledge(
parent_id: UUID | None = None,
- knowledge_service: KnowledgeService = Depends(get_km_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
# TODO: Returns one level of children
children = await knowledge_service.list_knowledge(parent_id, current_user.id)
- return [await c.to_dto(get_children=False) for c in children]
+ children_dto = [await c.to_dto(get_children=False) for c in children]
+ return sort_knowledge_dtos(children_dto)
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}"
@@ -186,12 +185,12 @@ async def list_knowledge(
@knowledge_router.get(
"/knowledge/{knowledge_id}",
- response_model=Knowledge,
+ response_model=KnowledgeDTO,
tags=["Knowledge"],
)
async def get_knowledge(
knowledge_id: UUID,
- knowledge_service: KnowledgeService = Depends(get_km_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
@@ -213,13 +212,13 @@ async def get_knowledge(
@knowledge_router.patch(
"/knowledge/{knowledge_id}",
status_code=status.HTTP_202_ACCEPTED,
- response_model=Knowledge,
+ response_model=KnowledgeDTO,
tags=["Knowledge"],
)
async def update_knowledge(
knowledge_id: UUID,
payload: KnowledgeUpdate,
- knowledge_service: KnowledgeService = Depends(get_km_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
@@ -230,7 +229,7 @@ async def update_knowledge(
detail="You do not have permission to access this knowledge.",
)
km = await knowledge_service.update_knowledge(km, payload)
- return km
+ return await km.to_dto()
except KnowledgeNotFoundException as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}"
@@ -246,12 +245,11 @@ async def update_knowledge(
)
async def delete_knowledge(
knowledge_id: UUID,
- knowledge_service: KnowledgeService = Depends(get_km_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
current_user: UserIdentity = Depends(get_current_user),
):
try:
km = await knowledge_service.get_knowledge(knowledge_id)
-
if km.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@@ -265,3 +263,96 @@ async def delete_knowledge(
)
except KnowledgeDeleteError:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+
+@knowledge_router.post(
+ "/knowledge/link_to_brains/",
+ status_code=status.HTTP_201_CREATED,
+ response_model=List[KnowledgeDTO],
+ tags=["Knowledge"],
+)
+async def link_knowledge_to_brain(
+ link_request: LinkKnowledgeBrain,
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
+ current_user: UserIdentity = Depends(get_current_user),
+):
+ brains_ids, knowledge_dto = (
+ link_request.brain_ids,
+ link_request.knowledge,
+ )
+ if len(brains_ids) == 0:
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
+
+ if knowledge_dto.id is None:
+ if knowledge_dto.sync_file_id is None:
+ raise HTTPException(
+ status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unknown knowledge entity"
+ )
+ # Create a knowledge from this sync
+ knowledge = await knowledge_service.create_knowledge(
+ user_id=current_user.id,
+ knowledge_to_add=AddKnowledge(**knowledge_dto.model_dump()),
+ upload_file=None,
+ )
+ # TODO (@AmineDiro): Check if tree is necessary or updating this knowledge suffice
+ linked_kms = await knowledge_service.link_knowledge_tree_brains(
+ knowledge, brains_ids=brains_ids, user_id=current_user.id
+ )
+
+ else:
+ linked_kms = await knowledge_service.link_knowledge_tree_brains(
+ knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id
+ )
+
+ for knowledge in [
+ k
+ for k in linked_kms
+ if await k.awaitable_attrs.status
+ not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING]
+ ]:
+ celery.send_task(
+ "process_file_task",
+ kwargs={
+ "knowledge_id": await knowledge.awaitable_attrs.id,
+ },
+ )
+ knowledge = await knowledge_service.update_knowledge(
+ knowledge=knowledge,
+ payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING),
+ )
+
+ linked_kms = await asyncio.gather(*[k.to_dto() for k in linked_kms])
+ return sort_knowledge_dtos(linked_kms)
+
+
+@knowledge_router.delete(
+ "/knowledge/unlink_from_brains/",
+ response_model=List[KnowledgeDTO] | None,
+ tags=["Knowledge"],
+)
+async def unlink_knowledge_from_brain(
+ unlink_request: UnlinkKnowledgeBrain,
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
+ current_user: UserIdentity = Depends(get_current_user),
+):
+ brains_ids, knowledge_id = unlink_request.brain_ids, unlink_request.knowledge_id
+
+ if len(brains_ids) == 0:
+ raise HTTPException(
+ status_code=status.HTTP_204_NO_CONTENT,
+ )
+
+ km = await knowledge_service.get_knowledge(knowledge_id)
+ if km.user_id != current_user.id:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="You do not have permission to remove this knowledge.",
+ )
+
+ unlinked_kms = await knowledge_service.unlink_knowledge_tree_brains(
+ knowledge=knowledge_id, brains_ids=brains_ids, user_id=current_user.id
+ )
+
+ if unlinked_kms:
+ unlinked_knowledges = await asyncio.gather(*[k.to_dto() for k in unlinked_kms])
+ return sort_knowledge_dtos(unlinked_knowledges)
diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py
index 85a2438e9205..9bf6da2e6a44 100644
--- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py
+++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py
@@ -1,9 +1,12 @@
-from typing import Dict, Optional
+from datetime import datetime
+from typing import Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
from quivr_core.models import KnowledgeStatus
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+
class CreateKnowledgeProperties(BaseModel):
brain_id: UUID
@@ -23,9 +26,33 @@ class CreateKnowledgeProperties(BaseModel):
class AddKnowledge(BaseModel):
file_name: Optional[str] = None
url: Optional[str] = None
- extension: str = ".txt"
+ is_folder: bool = False
source: str = "local"
source_link: Optional[str] = None
+ parent_id: UUID | None = None
+ sync_id: int | None = None
+ sync_file_id: str | None = None
metadata: Optional[Dict[str, str]] = None
- is_folder: bool = False
+
+
+class KnowledgeUpdate(BaseModel):
+ file_name: Optional[str] = None
+ status: Optional[KnowledgeStatus] = None
+ url: Optional[str] = None
+ file_sha1: Optional[str] = None
+ extension: Optional[str] = None
parent_id: Optional[UUID] = None
+ source: Optional[str] = None
+ source_link: Optional[str] = None
+ metadata: Optional[Dict[str, str]] = None
+ last_synced_at: Optional[datetime] = None
+
+
+class LinkKnowledgeBrain(BaseModel):
+ knowledge: KnowledgeDTO
+ brain_ids: List[UUID]
+
+
+class UnlinkKnowledgeBrain(BaseModel):
+ knowledge_id: UUID
+ brain_ids: List[UUID]
diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py
index 20218dfce3e6..7de9587689f9 100644
--- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py
+++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py
@@ -1,9 +1,43 @@
+from datetime import datetime
+from typing import Any, Dict, List, Optional, Self
from uuid import UUID
from pydantic import BaseModel
+from quivr_core.models import KnowledgeStatus
class DeleteKnowledgeResponse(BaseModel):
file_name: str | None = None
status: str = "DELETED"
knowledge_id: UUID
+
+
+class KnowledgeDTO(BaseModel):
+ id: Optional[UUID]
+ file_size: int = 0
+ status: Optional[KnowledgeStatus]
+ file_name: Optional[str] = None
+ url: Optional[str] = None
+ extension: str = ".txt"
+ is_folder: bool = False
+ updated_at: datetime
+ created_at: datetime
+ source: Optional[str] = None
+ source_link: Optional[str] = None
+ file_sha1: Optional[str] = None
+ metadata: Optional[Dict[str, str]] = None
+ user_id: UUID
+ # TODO: brain dto here not the brain nor the model_dump
+ brains: List[Dict[str, Any]]
+ parent: Optional[Self]
+ children: List[Self]
+ sync_id: int | None
+ sync_file_id: str | None
+ last_synced_at: datetime | None = None
+
+
+def sort_knowledge_dtos(dtos: List[KnowledgeDTO]) -> List[KnowledgeDTO]:
+ return sorted(
+ dtos,
+ key=lambda dto: (not dto.is_folder, dto.file_name is None, dto.file_name or ""),
+ )
diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py
index e08f3c0abcdb..39f63ac627f2 100644
--- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py
+++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py
@@ -1,56 +1,29 @@
+import asyncio
from datetime import datetime
from enum import Enum
-from typing import Any, Dict, List, Optional
+from typing import Dict, List, Optional
from uuid import UUID
-from pydantic import BaseModel
from quivr_core.models import KnowledgeStatus
from sqlalchemy import JSON, TIMESTAMP, Column, text
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel import UUID as PGUUID
from sqlmodel import Field, Relationship, SQLModel
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
+from quivr_api.modules.sync.entity.sync_models import Sync
class KnowledgeSource(str, Enum):
LOCAL = "local"
WEB = "web"
- GDRIVE = "google drive"
+ NOTETAKER = "notetaker"
+ GOOGLE = "google"
+ AZURE = "azure"
DROPBOX = "dropbox"
- SHAREPOINT = "sharepoint"
-
-
-class Knowledge(BaseModel):
- id: UUID
- file_size: int = 0
- status: KnowledgeStatus
- file_name: Optional[str] = None
- url: Optional[str] = None
- extension: str = ".txt"
- is_folder: bool = False
- updated_at: datetime
- created_at: datetime
- source: Optional[str] = None
- source_link: Optional[str] = None
- file_sha1: Optional[str] = None
- metadata: Optional[Dict[str, str]] = None
- user_id: Optional[UUID] = None
- brains: List[Dict[str, Any]]
- parent: Optional["Knowledge"]
- children: Optional[list["Knowledge"]]
-
-
-class KnowledgeUpdate(BaseModel):
- file_name: Optional[str] = None
- status: Optional[KnowledgeStatus] = None
- url: Optional[str] = None
- file_sha1: Optional[str] = None
- extension: Optional[str] = None
- parent_id: Optional[UUID] = None
- source: Optional[str] = None
- source_link: Optional[str] = None
- metadata: Optional[Dict[str, str]] = None
+ NOTION = "notion"
+ GITHUB = "github"
class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
@@ -64,7 +37,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
primary_key=True,
),
)
- file_name: str = Field(default="", max_length=255)
+ file_name: Optional[str] = Field(default=None, max_length=255)
url: Optional[str] = Field(default=None, max_length=2048)
extension: str = Field(default=".txt", max_length=100)
status: str = Field(max_length=50)
@@ -77,18 +50,25 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
created_at: datetime | None = Field(
default=None,
sa_column=Column(
- TIMESTAMP(timezone=False),
+ TIMESTAMP(timezone=True),
server_default=text("CURRENT_TIMESTAMP"),
),
)
updated_at: datetime | None = Field(
default=None,
sa_column=Column(
- TIMESTAMP(timezone=False),
+ TIMESTAMP(timezone=True),
server_default=text("CURRENT_TIMESTAMP"),
onupdate=datetime.utcnow,
),
)
+
+ last_synced_at: datetime | None = Field(
+ default=None,
+ sa_column=Column(
+ TIMESTAMP(timezone=True),
+ ),
+ )
metadata_: Optional[Dict[str, str]] = Field(
default=None, sa_column=Column("metadata", JSON)
)
@@ -97,7 +77,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
brains: List["Brain"] = Relationship( # type: ignore # noqa: F821
back_populates="knowledges",
link_model=KnowledgeBrain,
- sa_relationship_kwargs={"lazy": "select"},
+ sa_relationship_kwargs={"lazy": "joined"},
)
parent_id: UUID | None = Field(
@@ -113,25 +93,37 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True):
"cascade": "all, delete-orphan",
},
)
+ sync_id: int | None = Field(
+ default=None, foreign_key="syncs.id", ondelete="CASCADE"
+ )
+ sync: Sync | None = Relationship(
+ back_populates="knowledges", sa_relationship_kwargs={"lazy": "joined"}
+ )
+ sync_file_id: str | None = Field(default=None)
# TODO: nested folder search
async def to_dto(
self, get_children: bool = True, get_parent: bool = True
- ) -> Knowledge:
+ ) -> KnowledgeDTO:
assert (
- self.updated_at
+ await self.awaitable_attrs.updated_at
), "knowledge should be inserted before transforming to dto"
assert (
- self.created_at
+ await self.awaitable_attrs.created_at
), "knowledge should be inserted before transforming to dto"
brains = await self.awaitable_attrs.brains
+ brains = sorted(brains, key=lambda b: (b is None, b.name))
children: list[KnowledgeDB] = (
await self.awaitable_attrs.children if get_children else []
)
+ children_dto = await asyncio.gather(
+ *[c.to_dto(get_children=False) for c in children]
+ )
+ children_dto = sort_knowledge_dtos(children_dto)
parent = await self.awaitable_attrs.parent if get_parent else None
- parent = await parent.to_dto(get_children=False) if parent else None
+ parent_dto = await parent.to_dto(get_children=False) if parent else None
- return Knowledge(
+ return KnowledgeDTO(
id=self.id, # type: ignore
file_name=self.file_name,
url=self.url,
@@ -142,11 +134,14 @@ async def to_dto(
is_folder=self.is_folder,
file_size=self.file_size or 0,
file_sha1=self.file_sha1,
- updated_at=self.updated_at,
- created_at=self.created_at,
+ updated_at=await self.awaitable_attrs.updated_at,
+ created_at=await self.awaitable_attrs.created_at,
metadata=self.metadata_, # type: ignore
brains=[b.model_dump() for b in brains],
- parent=parent,
- children=[await c.to_dto(get_children=False) for c in children],
+ parent=parent_dto,
+ children=children_dto,
user_id=self.user_id,
+ sync_id=self.sync_id,
+ sync_file_id=self.sync_file_id,
+ last_synced_at=self.last_synced_at,
)
diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py
index 0f9b8e8ae771..017f6fb98386 100644
--- a/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py
+++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py
@@ -1,7 +1,6 @@
from datetime import datetime
from uuid import UUID
-from sqlalchemy import TIMESTAMP, Column, text
from sqlmodel import TIMESTAMP, Column, Field, SQLModel, text
from sqlmodel import UUID as PGUUID
diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py
index 82da7c8e8c9b..48b07fd8d6f7 100644
--- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py
+++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py
@@ -1,26 +1,33 @@
-from typing import Any, Sequence
+from datetime import datetime, timezone
+from typing import Any, List, Sequence
from uuid import UUID
from fastapi import HTTPException
from quivr_core.models import KnowledgeStatus
from sqlalchemy.exc import IntegrityError, NoResultFound
from sqlalchemy.orm import joinedload
-from sqlmodel import select, text
+from sqlalchemy.sql.functions import random
+from sqlmodel import and_, col, select, text
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.logger import get_logger
-from quivr_api.modules.brain.entity.brain_entity import Brain
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainUserDB, RoleEnum
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
-from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
+from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate
+from quivr_api.modules.knowledge.dto.outputs import (
+ DeleteKnowledgeResponse,
+ KnowledgeDTO,
+)
from quivr_api.modules.knowledge.entity.knowledge import (
- Knowledge,
KnowledgeDB,
- KnowledgeUpdate,
+ KnowledgeSource,
)
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
+ KnowledgeCreationError,
KnowledgeNotFoundException,
KnowledgeUpdateError,
)
+from quivr_api.modules.sync.entity.sync_models import SyncType
logger = get_logger(__name__)
@@ -31,11 +38,15 @@ def __init__(self, session: AsyncSession):
supabase_client = get_supabase_client()
self.db = supabase_client
- async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB:
+ async def create_knowledge(
+ self, knowledge: KnowledgeDB, autocommit: bool
+ ) -> KnowledgeDB:
try:
self.session.add(knowledge)
- await self.session.commit()
- await self.session.refresh(knowledge)
+ if autocommit:
+ await self.session.commit()
+ await self.session.refresh(knowledge)
+ await self.session.flush()
except IntegrityError:
await self.session.rollback()
raise
@@ -47,7 +58,8 @@ async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB:
async def update_knowledge(
self,
knowledge: KnowledgeDB,
- payload: Knowledge | KnowledgeUpdate | dict[str, Any],
+ payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any],
+ autocommit: bool,
) -> KnowledgeDB:
try:
logger.debug(f"updating {knowledge.id} with payload {payload}")
@@ -57,16 +69,106 @@ async def update_knowledge(
update_data = payload.model_dump(exclude_unset=True)
for field in update_data:
setattr(knowledge, field, update_data[field])
-
+ knowledge.updated_at = datetime.now(timezone.utc)
self.session.add(knowledge)
- await self.session.commit()
- await self.session.refresh(knowledge)
+ if autocommit:
+ await self.session.commit()
+ await self.session.refresh(knowledge)
+ else:
+ await self.session.flush()
return knowledge
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Error updating knowledge {e}")
raise KnowledgeUpdateError
+ async def unlink_knowledge_tree_brains(
+ self, knowledge: KnowledgeDB, brains_ids: List[UUID], user_id: UUID
+ ) -> list[KnowledgeDB] | None:
+ assert knowledge.id, "can't link knowledge not in db"
+ try:
+ # TODO: Move check somewhere else
+ stmt = (
+ select(Brain)
+ .join(BrainUserDB, col(Brain.brain_id) == col(BrainUserDB.brain_id))
+ .where(
+ and_(
+ col(Brain.brain_id).in_(brains_ids),
+ BrainUserDB.user_id == user_id,
+ BrainUserDB.rights == RoleEnum.Owner,
+ )
+ )
+ )
+ unlink_brains = list((await self.session.exec(stmt)).unique().all())
+ unlink_brain_ids = {b.brain_id for b in unlink_brains}
+
+ if len(unlink_brains) == 0:
+ logger.info(
+ f"No brains for user_id={user_id}, brains_list={brains_ids}"
+ )
+ return
+ children = await self.get_knowledge_tree(knowledge.id)
+ all_kms = [knowledge, *children]
+ for k in all_kms:
+ k.brains = list(
+ filter(lambda b: b.brain_id not in unlink_brain_ids, k.brains)
+ )
+ [self.session.add(k) for k in all_kms]
+ await self.session.commit()
+ [await self.session.refresh(k) for k in all_kms]
+ return all_kms
+ except IntegrityError:
+ await self.session.rollback()
+ raise
+ except Exception:
+ await self.session.rollback()
+ raise
+
+ async def link_knowledge_tree_brains(
+ self, knowledge: KnowledgeDB, brains_ids: List[UUID], user_id: UUID
+ ) -> list[KnowledgeDB]:
+ assert knowledge.id, "can't link knowledge not in db"
+ # TODO(@aminediro @StanGirard): verification should be done elsewhere
+ # should rewrite BrainService and Brain Authorization to be as middleware
+ try:
+ stmt = (
+ select(Brain)
+ .join(BrainUserDB, col(Brain.brain_id) == col(BrainUserDB.brain_id))
+ .where(
+ and_(
+ col(Brain.brain_id).in_(brains_ids),
+ BrainUserDB.user_id == user_id,
+ BrainUserDB.rights == RoleEnum.Owner,
+ )
+ )
+ )
+ brains = list((await self.session.exec(stmt)).unique().all())
+ if len(brains) == 0:
+ logger.error(
+ f"No brains for user_id={user_id}, brains_list={brains_ids}"
+ )
+ raise KnowledgeCreationError("can't associate knowledge to brains")
+ children = await self.get_knowledge_tree(knowledge.id)
+ all_kms = [knowledge, *children]
+ for k in all_kms:
+ km_brains = {km_brain.brain_id for km_brain in k.brains}
+ for b in filter(
+ lambda b: b.brain_id not in km_brains,
+ brains,
+ ):
+ k.brains.append(b)
+ for k in all_kms:
+ await self.session.merge(k)
+ await self.session.commit()
+ [await self.session.refresh(k) for k in all_kms]
+ return all_kms
+ except IntegrityError:
+ await self.session.rollback()
+ raise
+ except Exception:
+ await self.session.rollback()
+ raise
+
async def insert_knowledge_brain(
self, knowledge: KnowledgeDB, brain_id: UUID
) -> KnowledgeDB:
@@ -114,10 +216,14 @@ async def remove_knowledge_from_brain(
await self.session.refresh(knowledge)
return knowledge
- async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
+ async def remove_knowledge(
+ self, knowledge: KnowledgeDB, autocommit: bool
+ ) -> DeleteKnowledgeResponse:
assert knowledge.id
await self.session.delete(knowledge)
- await self.session.commit()
+ if autocommit:
+ await self.session.commit()
+
return DeleteKnowledgeResponse(
status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name
)
@@ -153,6 +259,16 @@ async def get_knowledge_by_sync_id(self, sync_id: int) -> KnowledgeDB:
return knowledge
+ async def get_all_knowledge_sync_user(
+ self, sync_id: int, user_id: UUID | None = None
+ ) -> List[KnowledgeDB]:
+ query = select(KnowledgeDB).where(KnowledgeDB.sync_id == sync_id)
+ if user_id:
+ query = query.where(KnowledgeDB.user_id == user_id)
+
+ result = await self.session.exec(query)
+ return list(result.unique().all())
+
async def get_knowledge_by_file_name_brain_id(
self, file_name: str, brain_id: UUID
) -> KnowledgeDB:
@@ -179,51 +295,42 @@ async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB:
return knowledge
- async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]:
- query = text("""
- WITH RECURSIVE knowledge_tree AS (
- SELECT *
- FROM knowledge
- WHERE parent_id = :parent_id
- UNION ALL
- SELECT k.*
- FROM knowledge k
- JOIN knowledge_tree kt ON k.parent_id = kt.id
- )
- SELECT * FROM knowledge_tree
- """)
-
- result = await self.session.execute(query, params={"parent_id": parent_id})
- rows = result.fetchall()
- knowledge_list = []
- for row in rows:
- knowledge = KnowledgeDB(
- id=row.id,
- parent_id=row.parent_id,
- file_name=row.file_name,
- url=row.url,
- extension=row.extension,
- status=row.status,
- source=row.source,
- source_link=row.source_link,
- file_size=row.file_size,
- file_sha1=row.file_sha1,
- created_at=row.created_at,
- updated_at=row.updated_at,
- metadata_=row.metadata,
- is_folder=row.is_folder,
- user_id=row.user_id,
+ async def get_knowledge_tree(self, parent_id: UUID) -> list[KnowledgeDB]:
+ from sqlalchemy.orm import aliased
+
+ Knowledge = aliased(KnowledgeDB)
+ KnowledgeRecursive = aliased(KnowledgeDB)
+
+ recursive_cte = (
+ select(KnowledgeRecursive)
+ .where(KnowledgeRecursive.parent_id == parent_id)
+ .cte(name="knowledge_tree", recursive=True)
+ )
+
+ recursive_cte = recursive_cte.union_all(
+ select(Knowledge).join(
+ recursive_cte, col(Knowledge.parent_id) == col(recursive_cte.c.id)
)
- knowledge_list.append(knowledge)
+ )
+ # TODO(@AmineDiro): Optimize get_knowledge_tree
+ query = (
+ select(KnowledgeDB)
+ .join(recursive_cte, col(KnowledgeDB.id) == recursive_cte.c.id)
+ .options(joinedload(KnowledgeDB.brains))
+ )
+
+ result = await self.session.exec(query)
+ knowledge_list = result.unique().all()
- return knowledge_list
+ return list(knowledge_list)
async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]:
query = (
select(KnowledgeDB)
.where(KnowledgeDB.parent_id.is_(None)) # type: ignore
.where(KnowledgeDB.user_id == user_id)
- .options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore
+ .where(KnowledgeDB.source == KnowledgeSource.LOCAL)
+ .options(joinedload(KnowledgeDB.children)) # type: ignore
)
result = await self.session.exec(query)
kms = result.unique().all()
@@ -246,14 +353,13 @@ async def get_knowledge_by_id(
return knowledge
async def get_brain_by_id(
- self, brain_id: UUID, get_knowledge: bool = False
+ self, brain_id: UUID, get_knowledge: bool = True
) -> Brain:
query = select(Brain).where(Brain.brain_id == brain_id)
if get_knowledge:
query = query.options(
joinedload(Brain.knowledges).joinedload(KnowledgeDB.brains)
)
-
result = await self.session.exec(query)
brain = result.first()
if not brain:
@@ -346,3 +452,28 @@ async def get_all_knowledge(self) -> Sequence[KnowledgeDB]:
query = select(KnowledgeDB)
result = await self.session.exec(query)
return result.all()
+
+ async def get_outdated_syncs(
+ self,
+ limit_time: datetime,
+ batch_size: int,
+ km_sync_type: SyncType,
+ ) -> List[KnowledgeDB]:
+ is_folder_check = km_sync_type == SyncType.FOLDER
+ query = (
+ select(KnowledgeDB)
+ .where(
+ KnowledgeDB.is_folder == is_folder_check,
+ col(KnowledgeDB.sync_id).isnot(None),
+ col(KnowledgeDB.sync_file_id).isnot(None),
+ col(KnowledgeDB.last_synced_at) < limit_time,
+ col(KnowledgeDB.brains).any(),
+ )
+ # Oldest first
+ .order_by(col(KnowledgeDB.last_synced_at).asc(), random())
+ .limit(batch_size)
+ )
+
+ # Execute the query (assuming you have a session)
+ result = await self.session.exec(query)
+ return list(result.unique().all())
diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage.py b/backend/api/quivr_api/modules/knowledge/repository/storage.py
index ad35659dbbd0..9214a8200aad 100644
--- a/backend/api/quivr_api/modules/knowledge/repository/storage.py
+++ b/backend/api/quivr_api/modules/knowledge/repository/storage.py
@@ -5,18 +5,30 @@
from quivr_api.modules.dependencies import get_supabase_async_client
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
+from supabase.client import AsyncClient
logger = get_logger(__name__)
class SupabaseS3Storage(StorageInterface):
- def __init__(self):
- self.client = None
+ def __init__(self, client: AsyncClient | None = None):
+ self.client = client
async def _set_client(self):
if self.client is None:
self.client = await get_supabase_async_client()
+ async def download_file(
+ self,
+ knowledge: KnowledgeDB,
+ bucket_name: str = "quivr",
+ ) -> bytes:
+ await self._set_client()
+ assert self.client
+ path = self.get_storage_path(knowledge)
+ file_data = await self.client.storage.from_(bucket_name).download(path)
+ return file_data
+
def get_storage_path(
self,
knowledge: KnowledgeDB,
diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py
index bd5a3debc03a..3a3e8cb8cc06 100644
--- a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py
+++ b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py
@@ -12,6 +12,10 @@ def get_storage_path(
) -> str:
pass
+ @abstractmethod
+ async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes:
+ pass
+
@abstractmethod
async def upload_file_storage(
self,
diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py
index cb36c4ef87b9..71a96124fa2f 100644
--- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py
+++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py
@@ -1,5 +1,6 @@
import asyncio
import io
+from datetime import datetime
from typing import Any, List
from uuid import UUID
@@ -7,18 +8,22 @@
from quivr_core.models import KnowledgeStatus
from sqlalchemy.exc import NoResultFound
+from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
+from quivr_api.modules.brain.entity.brain_entity import Brain
from quivr_api.modules.dependencies import BaseService
from quivr_api.modules.knowledge.dto.inputs import (
AddKnowledge,
CreateKnowledgeProperties,
+ KnowledgeUpdate,
+)
+from quivr_api.modules.knowledge.dto.outputs import (
+ DeleteKnowledgeResponse,
+ KnowledgeDTO,
)
-from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse
from quivr_api.modules.knowledge.entity.knowledge import (
- Knowledge,
KnowledgeDB,
KnowledgeSource,
- KnowledgeUpdate,
)
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
@@ -28,12 +33,9 @@
KnowledgeForbiddenAccess,
UploadError,
)
-from quivr_api.modules.sync.entity.sync_models import (
- DBSyncFile,
- DownloadedSyncFile,
- SyncFile,
-)
+from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncType
from quivr_api.modules.upload.service.upload_file import check_file_exists
+from quivr_api.utils.knowledge_utils import parse_file_extension
logger = get_logger(__name__)
@@ -44,17 +46,67 @@ class KnowledgeService(BaseService[KnowledgeRepository]):
def __init__(
self,
repository: KnowledgeRepository,
- storage: StorageInterface = SupabaseS3Storage(),
+ storage: StorageInterface = SupabaseS3Storage(client=None),
):
self.repository = repository
self.storage = storage
- async def get_knowledge_sync(self, sync_id: int) -> Knowledge:
- km_db = await self.repository.get_knowledge_by_sync_id(sync_id)
- assert km_db.id, "Knowledge ID not generated"
- km = await km_db.to_dto()
+ async def get_knowledge_sync(self, sync_id: int) -> KnowledgeDTO:
+ km = await self.repository.get_knowledge_by_sync_id(sync_id)
+ assert km.id, "Knowledge ID not generated"
+ km = await km.to_dto()
return km
+ async def create_or_link_sync_knowledge(
+ self,
+ syncfile_id_to_knowledge: dict[str, KnowledgeDB],
+ parent_knowledge: KnowledgeDB,
+ sync_file: SyncFile,
+ autocommit: bool = True,
+ ):
+ existing_km = syncfile_id_to_knowledge.get(sync_file.id)
+ if existing_km is not None:
+ # NOTE: function called in worker processor
+ # The parent_knowledge was just added to db (we are processing it)
+ # This implies that we could have sync children files and folders that were processed before
+ # If SyncKnowledge already exists
+ # IF STATUS == PROCESSED:
+ # => It's already processed in some other brain !
+ # => Link it to the parent and update its brains to the correct ones
+ # ELSE Reprocess the file
+ km_brains = {km_brain.brain_id for km_brain in existing_km.brains}
+ for brain in filter(
+ lambda b: b.brain_id not in km_brains,
+ parent_knowledge.brains,
+ ):
+ await self.repository.update_knowledge(
+ existing_km,
+ KnowledgeUpdate(parent_id=parent_knowledge.id),
+ autocommit=autocommit,
+ )
+ await self.repository.link_to_brain(
+ existing_km, brain_id=brain.brain_id
+ )
+ return existing_km
+ else:
+ # create sync file knowledge
+ # automagically gets the brains associated with the parent
+ file_knowledge = await self.create_knowledge(
+ user_id=parent_knowledge.user_id,
+ knowledge_to_add=AddKnowledge(
+ file_name=sync_file.name,
+ is_folder=sync_file.is_folder,
+ source=parent_knowledge.source, # same as parent
+ source_link=sync_file.web_view_link,
+ parent_id=parent_knowledge.id,
+ sync_id=parent_knowledge.sync_id,
+ sync_file_id=sync_file.id,
+ ),
+ status=KnowledgeStatus.PROCESSING,
+ upload_file=None,
+ )
+ return file_knowledge
+
# TODO: this is temporary fix for getting knowledge path.
# KM storage path should be unrelated to brain
async def get_knowledge_storage_path(
@@ -73,6 +125,14 @@ async def get_knowledge_storage_path(
except NoResultFound:
raise FileNotFoundError(f"No knowledge for file_name: {file_name}")
+ async def map_syncs_knowledge_user(
+ self, sync_id: int, user_id: UUID
+ ) -> dict[str, KnowledgeDB]:
+ list_kms = await self.repository.get_all_knowledge_sync_user(
+ sync_id=sync_id, user_id=user_id
+ )
+ return {k.sync_file_id: k for k in list_kms if k.sync_file_id}
+
async def list_knowledge(
self, knowledge_id: UUID | None, user_id: UUID | None = None
) -> list[KnowledgeDB]:
@@ -93,74 +153,64 @@ async def get_knowledge(
async def update_knowledge(
self,
- knowledge: KnowledgeDB,
- payload: Knowledge | KnowledgeUpdate | dict[str, Any],
+ knowledge: KnowledgeDB | UUID,
+ payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any],
+ autocommit: bool = True,
):
- return await self.repository.update_knowledge(knowledge, payload)
-
- # TODO: Remove all of this
- # TODO (@aminediro): Replace with ON CONFLICT smarter query...
- # there is a chance of race condition but for now we let it crash in worker
- # the tasks will be dealt with on retry
- async def update_sha1_conflict(
- self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str
- ) -> bool:
- assert knowledge.id
- knowledge.file_sha1 = file_sha1
-
- try:
- existing_knowledge = await self.repository.get_knowledge_by_sha1(
- knowledge.file_sha1
- )
- logger.debug("The content of the knowledge already exists in the brain. ")
- # Get existing knowledge sha1 and brains
- if (
- existing_knowledge.status == KnowledgeStatus.UPLOADED
- or existing_knowledge.status == KnowledgeStatus.PROCESSING
- ):
- existing_brains = await existing_knowledge.awaitable_attrs.brains
- if brain_id in [b.brain_id for b in existing_brains]:
- logger.debug("Added file to brain that already has the knowledge")
- raise FileExistsError(
- f"Existing file in brain {brain_id} with name {existing_knowledge.file_name}"
- )
- else:
- await self.repository.link_to_brain(existing_knowledge, brain_id)
- await self.remove_knowledge_brain(brain_id, knowledge.id)
- return False
- else:
- logger.debug(f"Removing previous errored file {existing_knowledge.id}")
- assert existing_knowledge.id
- await self.remove_knowledge_brain(brain_id, existing_knowledge.id)
- await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
- return True
- except NoResultFound:
- logger.debug(
- f"First knowledge with sha1. Updating file_sha1 of {knowledge.id}"
- )
- await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1)
- return True
+ if isinstance(knowledge, UUID):
+ knowledge = await self.repository.get_knowledge_by_id(knowledge)
+ return await self.repository.update_knowledge(knowledge, payload, autocommit)
async def create_knowledge(
self,
user_id: UUID,
knowledge_to_add: AddKnowledge,
upload_file: UploadFile | None = None,
+ status: KnowledgeStatus = KnowledgeStatus.RESERVED,
+ link_brains: list[Brain] = [],
+ autocommit: bool = True,
+ process_async: bool = True,
) -> KnowledgeDB:
+ brains = []
+ if knowledge_to_add.parent_id:
+ parent_knowledge = await self.get_knowledge(knowledge_to_add.parent_id)
+ brains = await parent_knowledge.awaitable_attrs.brains
+ if len(link_brains) > 0:
+ brains.extend(
+ [
+ b
+ for b in link_brains
+ if b.brain_id not in {b.brain_id for b in brains}
+ ]
+ )
+ # TODO: slugify url names here !!
+ extension = (
+ parse_file_extension(knowledge_to_add.file_name)
+ if knowledge_to_add.file_name
+ else ""
+ )
+
knowledgedb = KnowledgeDB(
user_id=user_id,
file_name=knowledge_to_add.file_name,
is_folder=knowledge_to_add.is_folder,
url=knowledge_to_add.url,
- extension=knowledge_to_add.extension,
+ extension=extension,
source=knowledge_to_add.source,
source_link=knowledge_to_add.source_link,
file_size=upload_file.size if upload_file else 0,
metadata_=knowledge_to_add.metadata, # type: ignore
- status=KnowledgeStatus.RESERVED,
+ status=status,
parent_id=knowledge_to_add.parent_id,
+ sync_id=knowledge_to_add.sync_id,
+ sync_file_id=knowledge_to_add.sync_file_id,
+ brains=brains,
+ )
+
+ knowledge_db = await self.repository.create_knowledge(
+ knowledgedb, autocommit=autocommit
)
- knowledge_db = await self.repository.create_knowledge(knowledgedb)
+
try:
if knowledgedb.source == KnowledgeSource.LOCAL and upload_file:
# NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO..
@@ -169,23 +219,48 @@ async def create_knowledge(
knowledgedb, buff_reader
)
knowledgedb.source_link = storage_path
- knowledge_db = await self.repository.update_knowledge(
- knowledge_db,
- KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore
- )
- return knowledge_db
+ knowledge_db = await self.repository.update_knowledge(
+ knowledge_db,
+ KnowledgeUpdate(status=KnowledgeStatus.UPLOADED),
+ autocommit=autocommit,
+ )
+ if knowledge_db.brains and len(knowledge_db.brains) > 0 and process_async:
+ # Schedule this new knowledge to be processed
+ knowledge_db = await self.repository.update_knowledge(
+ knowledge_db,
+ KnowledgeUpdate(status=KnowledgeStatus.PROCESSING),
+ autocommit=autocommit,
+ )
+ celery.send_task(
+ "process_file_task",
+ kwargs={
+ "knowledge_id": knowledge_db.id,
+ },
+ )
+ return knowledge_db
+ else:
+ knowledge_db = await self.repository.update_knowledge(
+ knowledge_db,
+ KnowledgeUpdate(status=KnowledgeStatus.UPLOADED),
+ autocommit=autocommit,
+ )
+ return knowledge_db
+
except Exception as e:
logger.exception(
f"Error uploading knowledge {knowledgedb.id} to storage : {e}"
)
- await self.repository.remove_knowledge(knowledge=knowledge_db)
+ await self.repository.remove_knowledge(
+ knowledge=knowledge_db, autocommit=autocommit
+ )
raise UploadError()
async def insert_knowledge_brain(
self,
user_id: UUID,
knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name
- ) -> Knowledge:
+ ) -> KnowledgeDTO:
+ # TODO: check input
knowledge = KnowledgeDB(
file_name=knowledge_to_add.file_name,
url=knowledge_to_add.url,
@@ -207,7 +282,7 @@ async def insert_knowledge_brain(
inserted_knowledge = await knowledge_db.to_dto()
return inserted_knowledge
- async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[Knowledge]:
+ async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[KnowledgeDTO]:
brain = await self.repository.get_brain_by_id(brain_id, get_knowledge=True)
all_knowledges: List[KnowledgeDB] = await brain.awaitable_attrs.knowledges
knowledges = [
@@ -240,23 +315,33 @@ async def update_status_knowledge(
async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str):
return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1)
- async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse:
+ async def remove_knowledge(
+ self, knowledge: KnowledgeDB, autocommit: bool = True
+ ) -> DeleteKnowledgeResponse:
assert knowledge.id
try:
# TODO:
# - Notion folders are special, they are themselves files and should be removed from storage
- children = await self.repository.get_all_children(knowledge.id)
- km_paths = [
- self.storage.get_storage_path(k) for k in children if not k.is_folder
- ]
- if not knowledge.is_folder:
- km_paths.append(self.storage.get_storage_path(knowledge))
-
+ km_paths = []
+ if knowledge.source == KnowledgeSource.LOCAL:
+ if knowledge.is_folder:
+ children = await self.repository.get_knowledge_tree(knowledge.id)
+ km_paths.extend(
+ [
+ self.storage.get_storage_path(k)
+ for k in children
+ if not k.is_folder
+ ]
+ )
+ if not knowledge.is_folder:
+ km_paths.append(self.storage.get_storage_path(knowledge))
# recursively deletes files
- deleted_km = await self.repository.remove_knowledge(knowledge)
+ deleted_km = await self.repository.remove_knowledge(
+ knowledge, autocommit=autocommit
+ )
+ # TODO: remove storage asynchronously in background task or in some task
await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths])
-
return deleted_km
except Exception as e:
logger.error(f"Error while remove knowledge : {e}")
@@ -295,45 +380,30 @@ async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None:
f"All knowledge in brain {brain_id} removed successfully from table"
)
- # TODO: REDO THIS MESS !!!!
- # REMOVE ALL SYNC TABLES and start from scratch
- async def update_or_create_knowledge_sync(
- self,
- brain_id: UUID,
- user_id: UUID,
- file: SyncFile,
- new_sync_file: DBSyncFile | None,
- prev_sync_file: DBSyncFile | None,
- downloaded_file: DownloadedSyncFile,
- source: str,
- source_link: str,
- ) -> Knowledge:
- sync_id = None
- # TODO: THIS IS A HACK!! Remove all of this
- if prev_sync_file:
- prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id)
- if len(prev_knowledge.brains) > 1:
- await self.repository.remove_knowledge_from_brain(
- prev_knowledge.id, brain_id
- )
- else:
- await self.repository.remove_knowledge_by_id(prev_knowledge.id)
- sync_id = prev_sync_file.id
-
- sync_id = new_sync_file.id if new_sync_file else sync_id
- knowledge_to_add = CreateKnowledgeProperties(
- brain_id=brain_id,
- file_name=file.name,
- extension=downloaded_file.extension,
- source=source,
- status=KnowledgeStatus.PROCESSING,
- source_link=source_link,
- file_size=file.size if file.size else 0,
- # FIXME (@aminediro): This is a temporary fix, redo in KMS
- file_sha1=None,
- metadata={"sync_file_id": str(sync_id)},
+ async def link_knowledge_tree_brains(
+ self, knowledge: KnowledgeDB | UUID, brains_ids: List[UUID], user_id: UUID
+ ) -> List[KnowledgeDB]:
+ if isinstance(knowledge, UUID):
+ knowledge = await self.repository.get_knowledge_by_id(knowledge)
+ return await self.repository.link_knowledge_tree_brains(
+ knowledge, brains_ids=brains_ids, user_id=user_id
+ )
+
+ async def unlink_knowledge_tree_brains(
+ self, knowledge: KnowledgeDB | UUID, brains_ids: List[UUID], user_id: UUID
+ ) -> List[KnowledgeDB] | None:
+ if isinstance(knowledge, UUID):
+ knowledge = await self.repository.get_knowledge_by_id(knowledge)
+ return await self.repository.unlink_knowledge_tree_brains(
+ knowledge, brains_ids=brains_ids, user_id=user_id
)
- added_knowledge = await self.insert_knowledge_brain(
- knowledge_to_add=knowledge_to_add, user_id=user_id
+
+ async def get_outdated_syncs(
+ self,
+ limit_time: datetime,
+ km_sync_type: SyncType,
+ batch_size: int = 1,
+ ) -> List[KnowledgeDB]:
+ return await self.repository.get_outdated_syncs(
+ limit_time=limit_time, batch_size=batch_size, km_sync_type=km_sync_type
)
- return added_knowledge
diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py
index 2074110f6b5c..c87d20af5968 100644
--- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py
+++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py
@@ -1,7 +1,17 @@
from io import BufferedReader, FileIO
+from uuid import uuid4
-from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB
+import pytest_asyncio
+from sqlmodel import select, text
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync
+from quivr_api.modules.user.entity.user_identity import User
class ErrorStorage(StorageInterface):
@@ -15,7 +25,7 @@ async def upload_file_storage(
def get_storage_path(
self,
- knowledge: KnowledgeDB | Knowledge,
+ knowledge: KnowledgeDB | KnowledgeDTO,
) -> str:
if knowledge.id is None:
raise ValueError("knowledge should have a valid id")
@@ -24,6 +34,9 @@ def get_storage_path(
async def remove_file(self, storage_path: str):
raise SystemError
+ async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes:
+ raise NotImplementedError
+
class FakeStorage(StorageInterface):
def __init__(self):
@@ -31,7 +44,7 @@ def __init__(self):
def get_storage_path(
self,
- knowledge: KnowledgeDB | Knowledge,
+ knowledge: KnowledgeDB | KnowledgeDTO,
) -> str:
if knowledge.id is None:
raise ValueError("knowledge should have a valid id")
@@ -46,7 +59,13 @@ async def upload_file_storage(
storage_path = f"{knowledge.id}"
if not upsert and storage_path in self.storage:
raise ValueError(f"File already exists at {storage_path}")
- self.storage[storage_path] = knowledge_data
+ if isinstance(knowledge_data, FileIO) or isinstance(
+ knowledge_data, BufferedReader
+ ):
+ self.storage[storage_path] = knowledge_data.read()
+ else:
+ self.storage[storage_path] = knowledge_data
+
return storage_path
async def remove_file(self, storage_path: str):
@@ -60,8 +79,77 @@ def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes:
raise FileNotFoundError(f"File not found at {storage_path}")
return self.storage[storage_path]
- def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool:
+ def knowledge_exists(self, knowledge: KnowledgeDB | KnowledgeDTO) -> bool:
return self.get_storage_path(knowledge) in self.storage
def clear_storage(self):
self.storage.clear()
+
+ async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes:
+ storage_path = self.get_storage_path(knowledge)
+ return self.storage[storage_path]
+
+
+@pytest_asyncio.fixture(scope="function")
+async def other_user(session: AsyncSession):
+ sql = text(
+ """
+ INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
+ ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
+ """
+ )
+ await session.execute(sql, params={"id": uuid4()})
+
+ other_user = (
+ await session.exec(select(User).where(User.email == "other@quivr.app"))
+ ).one()
+ return other_user
+
+
+@pytest_asyncio.fixture(scope="function")
+async def user(session):
+ user_1 = (
+ await session.exec(select(User).where(User.email == "admin@quivr.app"))
+ ).one()
+ return user_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync(session: AsyncSession, user: User) -> Sync:
+ assert user.id
+ sync = Sync(
+ name="test_sync",
+ email="test@test.com",
+ user_id=user.id,
+ credentials={"test": "test"},
+ provider=SyncProvider.GOOGLE,
+ )
+
+ session.add(sync)
+ await session.commit()
+ await session.refresh(sync)
+ return sync
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain(session):
+ brain_1 = Brain(
+ name="brain_1",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ return brain_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain2(session):
+ brain_1 = Brain(
+ name="brain_2",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ return brain_1
diff --git a/backend/api/quivr_api/modules/knowledge/tests/integration_test.py b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py
new file mode 100644
index 000000000000..33b552f4039d
--- /dev/null
+++ b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py
@@ -0,0 +1,49 @@
+import asyncio
+import json
+from uuid import UUID
+
+from httpx import AsyncClient
+
+from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+
+
+async def main():
+ url = "http://localhost:5050"
+ km_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": False,
+ "parent_id": None,
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(km_data), "application/json"),
+ "file": ("test_file.txt", b"Test file content", "application/octet-stream"),
+ }
+
+ async with AsyncClient(
+ base_url=url, headers={"Authorization": "Bearer 123"}
+ ) as test_client:
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+ response.raise_for_status()
+ km = KnowledgeDTO.model_validate(response.json())
+
+ json_data = LinkKnowledgeBrain(
+ brain_ids=[UUID("40ba47d7-51b2-4b2a-9247-89e29619efb0")],
+ knowledge=km,
+ ).model_dump_json()
+ response = await test_client.post(
+ "/knowledge/link_to_brains/",
+ content=json_data,
+ headers={"Content-Type": "application/json"},
+ )
+ response.raise_for_status()
+ print(response.json())
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py
index cf6313e97a19..21f59168f9f0 100644
--- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py
+++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py
@@ -1,17 +1,28 @@
import json
+from datetime import datetime
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
+from quivr_core.models import KnowledgeStatus
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.main import app
from quivr_api.middlewares.auth.auth_bearer import get_current_user
-from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service
+from quivr_api.models.settings import BrainSettings
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
+from quivr_api.modules.knowledge.controller.knowledge_routes import (
+ get_knowledge_service,
+)
+from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate, LinkKnowledgeBrain
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.knowledge.tests.conftest import FakeStorage
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync
from quivr_api.modules.user.entity.user_identity import User, UserIdentity
@@ -24,6 +35,43 @@ async def user(session: AsyncSession) -> User:
return user_1
+@pytest_asyncio.fixture(scope="function")
+async def brain(session, user):
+ assert user.id
+ brain_1 = Brain(
+ name="test_brain",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ await session.refresh(brain_1)
+ assert brain_1.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync(session: AsyncSession, user: User) -> Sync:
+ assert user.id
+ sync = Sync(
+ name="test_sync",
+ email="test@test.com",
+ user_id=user.id,
+ credentials={"test": "test"},
+ provider=SyncProvider.GOOGLE,
+ )
+
+ session.add(sync)
+ await session.commit()
+ await session.refresh(sync)
+ return sync
+
+
@pytest_asyncio.fixture(scope="function")
async def test_client(session: AsyncSession, user: User):
def default_current_user() -> UserIdentity:
@@ -36,11 +84,12 @@ async def test_service():
return KnowledgeService(repository, storage)
app.dependency_overrides[get_current_user] = default_current_user
- app.dependency_overrides[get_km_service] = test_service
+ app.dependency_overrides[get_knowledge_service] = test_service
# app.dependency_overrides[get_async_session] = lambda: session
async with AsyncClient(
- transport=ASGITransport(app=app), base_url="http://test"
+ transport=ASGITransport(app=app), # type: ignore
+ base_url="http://test",
) as ac:
yield ac
app.dependency_overrides = {}
@@ -68,7 +117,327 @@ async def test_post_knowledge(test_client: AsyncClient):
assert response.status_code == 200
+@pytest.mark.asyncio(loop_scope="session")
+async def test_post_knowledge_folder(test_client: AsyncClient):
+ km_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": None,
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(km_data), "application/json"),
+ }
+
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+
+ assert response.status_code == 200
+ km = KnowledgeDTO.model_validate(response.json())
+
+ assert km.id
+ assert km.is_folder
+ assert km.parent is None
+ assert km.children == []
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_add_knowledge_large_file(monkeypatch, test_client):
+ _settings = BrainSettings()
+ _settings.max_file_size = 2
+ monkeypatch.setattr(
+ "quivr_api.modules.knowledge.controller.knowledge_routes.settings", _settings
+ )
+ km_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": False,
+ "parent_id": None,
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(km_data), "application/json"),
+ "file": ("test_file.txt", b"Test file content", "application/octet-stream"),
+ }
+
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+
+ assert response.status_code == 413
+
+
@pytest.mark.asyncio(loop_scope="session")
async def test_add_knowledge_invalid_input(test_client):
response = await test_client.post("/knowledge/", files={})
assert response.status_code == 422
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_link_knowledge_sync_file(
+ monkeypatch,
+ session: AsyncSession,
+ test_client: AsyncClient,
+ brain: Brain,
+ user: User,
+ sync: Sync,
+):
+ tasks = {}
+
+ def _send_task(*args, **kwargs):
+ tasks["args"] = args
+ tasks["kwargs"] = {**kwargs["kwargs"]}
+
+ monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+
+ assert user.id
+ assert brain.brain_id
+ km = KnowledgeDTO(
+ id=None,
+ file_name="test.txt",
+ extension=".txt",
+ status=None,
+ user_id=user.id,
+ created_at=datetime.now(),
+ updated_at=datetime.now(),
+ brains=[],
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test.txt",
+ sync_id=sync.id,
+ sync_file_id="sync_file_id_1",
+ parent=None,
+ children=[],
+ )
+ json_data = LinkKnowledgeBrain(
+ brain_ids=[brain.brain_id], knowledge=km
+ ).model_dump_json()
+ response = await test_client.post(
+ "/knowledge/link_to_brains/",
+ content=json_data,
+ headers={"Content-Type": "application/json"},
+ )
+
+ assert response.status_code == 201
+ km = KnowledgeDTO.model_validate(response.json()[0])
+ assert km.id
+ assert km.status == KnowledgeStatus.PROCESSING
+ assert len(km.brains) == 1
+
+ # Assert task added to celery
+ assert len(tasks) > 0
+ assert tasks["args"] == ("process_file_task",)
+
+ minimal_task_kwargs = {
+ "knowledge_id": km.id,
+ }
+ all(
+ minimal_task_kwargs[key] == tasks["kwargs"][key] # type: ignore
+ for key in minimal_task_kwargs
+ )
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_link_knowledge_folder(
+ monkeypatch,
+ session: AsyncSession,
+ test_client: AsyncClient,
+ brain: Brain,
+ user: User,
+ sync: Sync,
+):
+ assert brain.brain_id
+ tasks = {}
+
+ def _send_task(*args, **kwargs):
+ tasks["args"] = args
+ tasks["kwargs"] = {**kwargs["kwargs"]}
+
+ monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+
+ folder_data = {
+ "file_name": "folder",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": None,
+ }
+ response = await test_client.post(
+ "/knowledge/",
+ files={
+ "knowledge_data": (None, json.dumps(folder_data), "application/json"),
+ },
+ )
+ # 1. Insert folder
+ folder_km = KnowledgeDTO.model_validate(response.json())
+ file_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": str(folder_km.id),
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(file_data), "application/json"),
+ }
+ # 2. Insert file in folder
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+ file_km = KnowledgeDTO.model_validate(response.json())
+
+ json_data = LinkKnowledgeBrain(
+ brain_ids=[brain.brain_id], knowledge=folder_km
+ ).model_dump_json()
+
+ response = await test_client.post(
+ "/knowledge/link_to_brains/",
+ content=json_data,
+ headers={"Content-Type": "application/json"},
+ )
+ assert response.status_code == 201
+ updated_kms = [KnowledgeDTO.model_validate(d) for d in response.json()]
+
+ # 3. Validate that created knowledges are correct
+ assert len(updated_kms) == 2
+ assert next(
+ filter(lambda k: k.id == folder_km.id, updated_kms)
+ ), "file not in updated list"
+ assert next(
+ filter(lambda k: k.id == file_km.id, updated_kms)
+ ), "file not in updated list"
+ for km in updated_kms:
+ assert len(km.brains) == 1
+
+ # 4. Assert both files are being scheduled for processing
+ assert len(tasks) == 2
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_move_knowledge_to_folder(
+ monkeypatch,
+ session: AsyncSession,
+ test_client: AsyncClient,
+ brain: Brain,
+ user: User,
+ sync: Sync,
+):
+ assert brain.brain_id
+ tasks = {}
+
+ def _send_task(*args, **kwargs):
+ tasks["args"] = args
+ tasks["kwargs"] = {**kwargs["kwargs"]}
+
+ monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+
+ folder_data = {
+ "file_name": "folder",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": None,
+ }
+ response = await test_client.post(
+ "/knowledge/",
+ files={
+ "knowledge_data": (None, json.dumps(folder_data), "application/json"),
+ },
+ )
+ # 1. Insert folder
+ folder_km = KnowledgeDTO.model_validate(response.json())
+ file_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": None,
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(file_data), "application/json"),
+ }
+ # 2. Insert file in Root
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+ file_km = KnowledgeDTO.model_validate(response.json())
+
+ # Move file to folder
+ update = KnowledgeUpdate(parent_id=folder_km.id)
+ response = await test_client.patch(
+ f"/knowledge/{file_km.id}",
+ content=update.model_dump_json(exclude_unset=True),
+ headers={"Content-Type": "application/json"},
+ )
+ assert response.status_code == 202
+ updated_km = KnowledgeDTO.model_validate(response.json())
+
+ # 3. Validate that created knowledges are correct
+ assert updated_km.parent and updated_km.parent.id
+ assert updated_km.parent.id == folder_km.id
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_move_knowledge_to_root(
+ monkeypatch,
+ session: AsyncSession,
+ test_client: AsyncClient,
+ brain: Brain,
+ user: User,
+ sync: Sync,
+):
+ assert brain.brain_id
+ tasks = {}
+
+ def _send_task(*args, **kwargs):
+ tasks["args"] = args
+ tasks["kwargs"] = {**kwargs["kwargs"]}
+
+ monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+
+ folder_data = {
+ "file_name": "folder",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": None,
+ }
+ response = await test_client.post(
+ "/knowledge/",
+ files={
+ "knowledge_data": (None, json.dumps(folder_data), "application/json"),
+ },
+ )
+ # 1. Insert folder
+ folder_km = KnowledgeDTO.model_validate(response.json())
+ file_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": True,
+ "parent_id": str(folder_km.id),
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(file_data), "application/json"),
+ }
+ # 2. Insert file in Root
+ response = await test_client.post(
+ "/knowledge/",
+ files=multipart_data,
+ )
+ file_km = KnowledgeDTO.model_validate(response.json())
+
+ # Move file to Root
+ update = KnowledgeUpdate(parent_id=None)
+ response = await test_client.patch(
+ f"/knowledge/{file_km.id}",
+ content=update.model_dump_json(exclude_unset=True),
+ headers={"Content-Type": "application/json"},
+ )
+ assert response.status_code == 202
+ updated_km = KnowledgeDTO.model_validate(response.json())
+
+ # 3. Validate that updated
+ assert updated_km.parent is None
diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py
index 92563fe78b30..8964e2964384 100644
--- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py
+++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py
@@ -1,55 +1,21 @@
+from datetime import datetime
from typing import List, Tuple
from uuid import uuid4
import pytest
import pytest_asyncio
from quivr_core.models import KnowledgeStatus
-from sqlmodel import select, text
+from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
-from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
-from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_api.modules.brain.entity.brain_entity import Brain
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource
from quivr_api.modules.user.entity.user_identity import User
TestData = Tuple[Brain, List[KnowledgeDB]]
-@pytest_asyncio.fixture(scope="function")
-async def other_user(session: AsyncSession):
- sql = text(
- """
- INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES
- ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL);
- """
- )
- await session.execute(sql, params={"id": uuid4()})
-
- other_user = (
- await session.exec(select(User).where(User.email == "other@quivr.app"))
- ).one()
- return other_user
-
-
-@pytest_asyncio.fixture(scope="function")
-async def user(session):
- user_1 = (
- await session.exec(select(User).where(User.email == "admin@quivr.app"))
- ).one()
- return user_1
-
-
-@pytest_asyncio.fixture(scope="function")
-async def brain(session):
- brain_1 = Brain(
- name="test_brain",
- description="this is a test brain",
- brain_type=BrainType.integration,
- )
- session.add(brain_1)
- await session.commit()
- return brain_1
-
-
@pytest_asyncio.fixture(scope="function")
async def folder(session, user):
folder = KnowledgeDB(
@@ -175,7 +141,7 @@ async def test_knowledge_remove_folder_cascade(
@pytest.mark.asyncio(loop_scope="session")
-async def test_knowledge_dto(session, user, brain):
+async def test_knowledge_dto(session, user, brain, brain2, sync):
# add folder in brain
folder = KnowledgeDB(
file_name="folder_1",
@@ -199,8 +165,10 @@ async def test_knowledge_dto(session, user, brain):
file_size=100,
file_sha1="test_sha1",
user_id=user.id,
- brains=[brain],
+ brains=[brain2, brain],
parent=folder,
+ sync_file_id="file1",
+ sync=sync,
)
session.add(km)
session.add(km)
@@ -220,10 +188,57 @@ async def test_knowledge_dto(session, user, brain):
assert km_dto.file_sha1 == km.file_sha1
assert km_dto.updated_at == km.updated_at
assert km_dto.created_at == km.created_at
- assert km_dto.metadata == km.metadata_ # type: ignor
+ assert km_dto.metadata == km.metadata_ # type: ignore
assert km_dto.parent
assert km_dto.parent.id == folder.id
-
+ # Syncs fields
+ assert km_dto.sync_id == km.sync_id
+ assert km_dto.sync_file_id == km.sync_file_id
+ # Check brain_name order
+ assert len(km_dto.brains) == 2
+ assert km_dto.brains[1]["name"] > km_dto.brains[0]["name"]
+
+ # Check folder to dto
folder_dto = await folder.to_dto()
assert folder_dto.brains[0] == brain.model_dump()
assert folder_dto.children == [await km.to_dto()]
+
+
+def test_sort_knowledge_dtos():
+ user_id = uuid4()
+
+ data_dict = {
+ "extension": ".txt",
+ "status": None,
+ "user_id": user_id,
+ "created_at": datetime.now(),
+ "updated_at": datetime.now(),
+ "brains": [],
+ "source": KnowledgeSource.LOCAL,
+ "source_link": "://test.txt",
+ "sync_id": None,
+ "sync_file_id": None,
+ "parent": None,
+ "children": [],
+ }
+ dtos = [
+ KnowledgeDTO(id=uuid4(), is_folder=False, file_name=None, **data_dict),
+ KnowledgeDTO(id=uuid4(), is_folder=False, file_name="B", **data_dict),
+ KnowledgeDTO(id=uuid4(), is_folder=True, file_name="A", **data_dict),
+ KnowledgeDTO(id=uuid4(), is_folder=True, file_name=None, **data_dict),
+ ]
+
+ sorted_dtos = sort_knowledge_dtos(dtos)
+
+ # First element should be a folder with file_name="A"
+ assert sorted_dtos[0].is_folder is True
+ assert sorted_dtos[0].file_name == "A"
+ # Second element should be a folder with file_name=None
+ assert sorted_dtos[1].is_folder is True
+ assert sorted_dtos[1].file_name is None
+ # Third element should be a file with file_name="B"
+ assert sorted_dtos[2].is_folder is False
+ assert sorted_dtos[2].file_name == "B"
+ # Fourth element should be a file with file_name=None
+ assert sorted_dtos[3].is_folder is False
+ assert sorted_dtos[3].file_name is None
diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py
index 7381b6e917dd..c2cd9450bcbd 100644
--- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py
+++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py
@@ -1,4 +1,5 @@
import os
+from datetime import datetime, timedelta, timezone
from io import BytesIO
from typing import List, Tuple
from uuid import uuid4
@@ -11,8 +12,13 @@
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
-from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeStatus
-from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeUpdate
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
+from quivr_api.modules.knowledge.dto.inputs import (
+ AddKnowledge,
+ KnowledgeStatus,
+ KnowledgeUpdate,
+)
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource
from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
from quivr_api.modules.knowledge.service.knowledge_exceptions import (
@@ -22,6 +28,8 @@
)
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.knowledge.tests.conftest import ErrorStorage, FakeStorage
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync, SyncType
from quivr_api.modules.upload.service.upload_file import upload_file_storage
from quivr_api.modules.user.entity.user_identity import User
from quivr_api.modules.vector.entity.vector import Vector
@@ -54,6 +62,66 @@ async def user(session: AsyncSession) -> User:
return user_1
+@pytest_asyncio.fixture(scope="function")
+async def brain_user(session, user: User) -> Brain:
+ assert user.id
+ brain_1 = Brain(
+ name="test_brain",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ await session.refresh(brain_1)
+ assert brain_1.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain_user2(session, user: User) -> Brain:
+ assert user.id
+ brain = Brain(
+ name="test_brain2",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain)
+ await session.commit()
+ await session.refresh(brain)
+ assert brain.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain_user3(session, user: User) -> Brain:
+ assert user.id
+ brain = Brain(
+ name="test_brain2",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain)
+ await session.commit()
+ await session.refresh(brain)
+ assert brain.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain
+
+
@pytest_asyncio.fixture(scope="function")
async def test_data(session: AsyncSession) -> TestData:
user_1 = (
@@ -188,6 +256,31 @@ async def folder_km(session: AsyncSession, user: User):
return folder
+@pytest_asyncio.fixture(scope="function")
+async def folder_km_brain(session: AsyncSession, brain_user: Brain):
+ "local folder linked to a brain"
+ user: User = (await brain_user.awaitable_attrs.users)[0]
+ assert user.id
+ folder = KnowledgeDB(
+ file_name="folder_1",
+ extension="",
+ status="UPLOADED",
+ source="local",
+ source_link="local",
+ file_size=0,
+ file_sha1=None,
+ brains=[brain_user],
+ children=[],
+ user_id=user.id,
+ is_folder=True,
+ parent_id=None,
+ )
+ session.add(folder)
+ await session.commit()
+ await session.refresh(folder)
+ return folder
+
+
@pytest.mark.asyncio(loop_scope="session")
async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData):
brain, knowledges = test_data
@@ -369,167 +462,6 @@ async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData
assert brain.brain_id in brains_of_knowledge
-@pytest.mark.asyncio(loop_scope="session")
-async def test_should_process_knowledge_exists(
- session: AsyncSession, test_data: TestData
-):
- brain, [existing_knowledge, _] = test_data
- assert brain.brain_id
- new = KnowledgeDB(
- file_name="new",
- extension="txt",
- status="PROCESSING",
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1=None,
- brains=[brain],
- user_id=existing_knowledge.user_id,
- )
- session.add(new)
- await session.commit()
- await session.refresh(new)
- repo = KnowledgeRepository(session)
- service = KnowledgeService(repo)
- assert existing_knowledge.file_sha1
- with pytest.raises(FileExistsError):
- await service.update_sha1_conflict(
- new, brain.brain_id, file_sha1=existing_knowledge.file_sha1
- )
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_should_process_knowledge_link_brain(
- session: AsyncSession, test_data: TestData
-):
- repo = KnowledgeRepository(session)
- service = KnowledgeService(repo)
- brain, [existing_knowledge, _] = test_data
- user_id = existing_knowledge.user_id
- assert brain.brain_id
- prev = KnowledgeDB(
- file_name="prev",
- extension=".txt",
- status=KnowledgeStatus.UPLOADED,
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1="test1",
- brains=[brain],
- user_id=user_id,
- )
- brain_2 = Brain(
- name="test_brain",
- description="this is a test brain",
- brain_type=BrainType.integration,
- )
- session.add(brain_2)
- session.add(prev)
- await session.commit()
- await session.refresh(prev)
- await session.refresh(brain_2)
-
- assert prev.id
- assert brain_2.brain_id
-
- new = KnowledgeDB(
- file_name="new",
- extension="txt",
- status="PROCESSING",
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1=None,
- brains=[brain_2],
- user_id=user_id,
- )
- session.add(new)
- await session.commit()
- await session.refresh(new)
-
- incoming_knowledge = await new.to_dto()
- assert prev.file_sha1
-
- should_process = await service.update_sha1_conflict(
- incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1
- )
- assert not should_process
-
- # Check prev knowledge was linked
- assert incoming_knowledge.file_sha1
- prev_knowledge = await service.repository.get_knowledge_by_id(prev.id)
- prev_brains = await prev_knowledge.awaitable_attrs.brains
- assert {b.brain_id for b in prev_brains} == {
- brain.brain_id,
- brain_2.brain_id,
- }
- # Check new knowledge was removed
- assert new.id
- with pytest.raises(KnowledgeNotFoundException):
- await service.repository.get_knowledge_by_id(new.id)
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_should_process_knowledge_prev_error(
- session: AsyncSession, test_data: TestData
-):
- repo = KnowledgeRepository(session)
- service = KnowledgeService(repo)
- brain, [existing_knowledge, _] = test_data
- user_id = existing_knowledge.user_id
- assert brain.brain_id
- prev = KnowledgeDB(
- file_name="prev",
- extension="txt",
- status=KnowledgeStatus.ERROR,
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1="test1",
- brains=[brain],
- user_id=user_id,
- )
- session.add(prev)
- await session.commit()
- await session.refresh(prev)
-
- assert prev.id
-
- new = KnowledgeDB(
- file_name="new",
- extension="txt",
- status="PROCESSING",
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1=None,
- brains=[brain],
- user_id=user_id,
- )
- session.add(new)
- await session.commit()
- await session.refresh(new)
-
- incoming_knowledge = await new.to_dto()
- assert prev.file_sha1
- should_process = await service.update_sha1_conflict(
- incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1
- )
-
- # Checks we should process this file
- assert should_process
- # Previous errored file is cleaned up
- with pytest.raises(KnowledgeNotFoundException):
- await service.repository.get_knowledge_by_id(prev.id)
-
- assert new.id
- new = await service.repository.get_knowledge_by_id(new.id)
- assert new.file_sha1
-
-
-@pytest.mark.skip(
- reason="Bug: UnboundLocalError: cannot access local variable 'response'"
-)
@pytest.mark.asyncio(loop_scope="session")
async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData):
_, [knowledge, _] = test_data
@@ -585,6 +517,30 @@ async def test_create_knowledge_file(session: AsyncSession, user: User):
storage.knowledge_exists(km)
+@pytest.mark.asyncio(loop_scope="session")
+async def test_create_knowledge_web(session: AsyncSession, user: User):
+ assert user.id
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ km_to_add = AddKnowledge(
+ url="http://quivr.app",
+ source=KnowledgeSource.WEB,
+ is_folder=False,
+ parent_id=None,
+ )
+
+ km = await service.create_knowledge(
+ user_id=user.id, knowledge_to_add=km_to_add, upload_file=None
+ )
+
+ assert km.id
+ assert km.url == km_to_add.url
+ assert km.status == KnowledgeStatus.UPLOADED
+ assert not km.is_folder
+
+
@pytest.mark.asyncio(loop_scope="session")
async def test_create_knowledge_folder(session: AsyncSession, user: User):
assert user.id
@@ -593,7 +549,7 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User):
service = KnowledgeService(repository, storage)
km_to_add = AddKnowledge(
- file_name="test",
+ file_name="test.txt",
source="local",
is_folder=True,
parent_id=None,
@@ -610,9 +566,9 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User):
assert km.id
# Knowledge properties
assert km.file_name == km_to_add.file_name
+ assert km.extension == ".txt"
assert km.is_folder == km_to_add.is_folder
assert km.url == km_to_add.url
- assert km.extension == km_to_add.extension
assert km.source == km_to_add.source
assert km.file_size == 128
assert km.metadata_ == km_to_add.metadata
@@ -622,6 +578,56 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User):
assert storage.knowledge_exists(km)
+@pytest.mark.asyncio(loop_scope="session")
+async def test_create_knowledge_file_in_folder_in_brain(
+ monkeypatch, session: AsyncSession, user: User, folder_km_brain: KnowledgeDB
+):
+ tasks = {}
+
+ def _send_task(*args, **kwargs):
+ tasks["args"] = args
+ tasks["kwargs"] = {**kwargs["kwargs"]}
+
+ monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+ assert user.id
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ km_to_add = AddKnowledge(
+ file_name="test",
+ source="local",
+ is_folder=True,
+ parent_id=folder_km_brain.id,
+ )
+ km_data = BytesIO(os.urandom(128))
+ km = await service.create_knowledge(
+ user_id=user.id,
+ knowledge_to_add=km_to_add,
+ upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name),
+ )
+
+ assert km.file_name == km_to_add.file_name
+ assert km.id
+ # Knowledge properties
+ assert km.file_name == km_to_add.file_name
+ assert km.is_folder == km_to_add.is_folder
+ assert km.url == km_to_add.url
+ assert km.source == km_to_add.source
+ assert km.file_size == 128
+ assert km.metadata_ == km_to_add.metadata
+ assert km.is_folder == km_to_add.is_folder
+ assert km.status == KnowledgeStatus.PROCESSING
+ # Knowledge was saved
+ assert storage.knowledge_exists(km)
+ assert km.brains
+ assert len(km.brains) > 0
+ assert km.brains[0].brain_id == folder_km_brain.brains[0].brain_id
+
+ # Scheduled
+ assert len(tasks) > 0
+
+
@pytest.mark.asyncio(loop_scope="session")
async def test_create_knowledge_upload_error(session: AsyncSession, user: User):
assert user.id
@@ -1020,3 +1026,319 @@ async def test_list_knowledge(session: AsyncSession, user: User):
assert len(kms) == 1
assert kms[0].id == nested_file.id
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_link_knowledge_brain(
+ session: AsyncSession, user: User, brain_user: Brain
+):
+ assert user.id
+ assert brain_user.brain_id
+
+ root_folder = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status="UPLOADED",
+ source="local",
+ source_link="local",
+ file_size=4,
+ file_sha1=None,
+ brains=[],
+ children=[],
+ user_id=user.id,
+ is_folder=True,
+ )
+ nested_file = KnowledgeDB(
+ file_name="file_2",
+ extension="",
+ status="UPLOADED",
+ source="local",
+ source_link="local",
+ file_size=10,
+ file_sha1=None,
+ user_id=user.id,
+ parent=root_folder,
+ )
+ session.add(nested_file)
+ session.add(root_folder)
+ await session.commit()
+ await session.refresh(root_folder)
+ await session.refresh(nested_file)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ await service.link_knowledge_tree_brains(
+ root_folder, brains_ids=[brain_user.brain_id], user_id=user.id
+ )
+ kms = await service.get_all_knowledge_in_brain(brain_id=brain_user.brain_id)
+ assert len(kms) == 2
+ assert {k.id for k in kms} == {root_folder.id, nested_file.id}
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_link_knowledge_brain_existing_brains(
+ session: AsyncSession, user: User, brain_user: Brain
+):
+ """test knowledge already in brain and we add it to the same brain because we added his parent"""
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_unlink_knowledge_brain(
+ session: AsyncSession,
+ user: User,
+ brain_user: Brain,
+ brain_user2: Brain,
+ brain_user3: Brain,
+):
+ assert user.id
+ assert brain_user.brain_id
+ assert brain_user2.brain_id
+ assert brain_user3.brain_id
+
+ root_folder = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status="UPLOADED",
+ source="local",
+ source_link="local",
+ file_size=4,
+ file_sha1=None,
+ brains=[brain_user, brain_user2],
+ children=[],
+ user_id=user.id,
+ is_folder=True,
+ )
+ file = KnowledgeDB(
+ file_name="file_2",
+ extension="",
+ status="UPLOADED",
+ source="local",
+ source_link="local",
+ file_size=10,
+ file_sha1=None,
+ user_id=user.id,
+ parent=root_folder,
+ # 1 additional brain
+ brains=[brain_user, brain_user2, brain_user3],
+ )
+ session.add(file)
+ session.add(root_folder)
+ await session.commit()
+ await session.refresh(root_folder)
+ await session.refresh(file)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ await service.unlink_knowledge_tree_brains(
+ root_folder,
+ brains_ids=[brain_user.brain_id, brain_user2.brain_id],
+ user_id=user.id,
+ )
+ kms = await service.get_all_knowledge_in_brain(brain_id=brain_user.brain_id)
+ assert len(kms) == 0
+
+ kms = await service.get_all_knowledge_in_brain(brain_id=brain_user2.brain_id)
+ assert len(kms) == 0
+
+ kms = await service.get_all_knowledge_in_brain(brain_id=brain_user3.brain_id)
+ assert len(kms) == 1
+ assert kms[0].id == file.id
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_get_outdated_sync_update_date(
+ session: AsyncSession, user: User, brain_user: Brain, sync: Sync
+):
+ assert user.id
+ assert brain_user.brain_id
+
+ file1 = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=4,
+ file_sha1="",
+ brains=[brain_user],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="file2",
+ last_synced_at=datetime.now() - timedelta(days=2),
+ )
+ file2 = KnowledgeDB(
+ file_name="file_2",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=10,
+ file_sha1=None,
+ brains=[brain_user],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="file2",
+ last_synced_at=datetime.now(),
+ )
+ session.add(file2)
+ session.add(file1)
+ await session.commit()
+ await session.refresh(file1)
+ await session.refresh(file2)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ last_time = datetime.now(timezone.utc) - timedelta(hours=4)
+ kms = await service.get_outdated_syncs(
+ limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE
+ )
+ assert len(kms) == 1
+ assert kms[0].id == file1.id
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_get_outdated_sync_file_only_brains(
+ session: AsyncSession, user: User, brain_user: Brain, sync: Sync
+):
+ assert user.id
+ assert brain_user.brain_id
+
+ file2 = KnowledgeDB(
+ file_name="file",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=KnowledgeSource.LOCAL,
+ source_link="path",
+ file_size=4,
+ file_sha1="",
+ brains=[],
+ user_id=user.id,
+ )
+
+ file1 = KnowledgeDB(
+ file_name="file",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=4,
+ file_sha1="",
+ brains=[],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="file2",
+ last_synced_at=datetime.now() - timedelta(days=2),
+ )
+ session.add(file1)
+ session.add(file2)
+ await session.commit()
+ await session.refresh(file1)
+ await session.refresh(file2)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ last_time = datetime.now(timezone.utc) - timedelta(hours=4)
+ kms = await service.get_outdated_syncs(
+ limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE
+ )
+ assert len(kms) == 0
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_get_outdated_sync_file_only(
+ session: AsyncSession, user: User, brain_user: Brain, sync: Sync
+):
+ assert user.id
+ assert brain_user.brain_id
+
+ file1 = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=4,
+ file_sha1="",
+ brains=[brain_user],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="file2",
+ is_folder=True,
+ last_synced_at=datetime.now() - timedelta(days=2),
+ )
+ session.add(file1)
+ await session.commit()
+ await session.refresh(file1)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ last_time = datetime.now(timezone.utc) - timedelta(hours=4)
+ kms = await service.get_outdated_syncs(
+ limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE
+ )
+ assert len(kms) == 0
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_get_outdated_folders_sync(
+ session: AsyncSession, user: User, brain_user: Brain, sync: Sync
+):
+ assert user.id
+ assert brain_user.brain_id
+
+ folder = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=0,
+ file_sha1="",
+ brains=[brain_user],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="folder1",
+ is_folder=True,
+ last_synced_at=datetime.now() - timedelta(days=2),
+ )
+ file = KnowledgeDB(
+ file_name="file",
+ extension="",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://file2",
+ file_size=4,
+ file_sha1="",
+ brains=[brain_user],
+ user_id=user.id,
+ sync_id=sync.id,
+ sync_file_id="file",
+ is_folder=False,
+ last_synced_at=datetime.now() - timedelta(days=2),
+ parent=folder,
+ )
+ session.add(folder)
+ session.add(file)
+ await session.commit()
+ await session.refresh(folder)
+
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ service = KnowledgeService(repository, storage)
+
+ last_time = datetime.now(timezone.utc) - timedelta(hours=4)
+ kms = await service.get_outdated_syncs(
+ limit_time=last_time, batch_size=10, km_sync_type=SyncType.FOLDER
+ )
+ assert len(kms) == 1
+ assert kms[0].id == folder.id
diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
index d2949f9fc8c9..dcfdc13074c0 100644
--- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py
@@ -1,18 +1,17 @@
import os
import requests
-from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import HTMLResponse
from msal import ConfidentialClientApplication
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
@@ -20,9 +19,8 @@
# Initialize logger
logger = get_logger(__name__)
-# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
+
+syncs_service_dep = get_service(SyncsService)
# Initialize API router
azure_sync_router = APIRouter()
@@ -45,8 +43,11 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-def authorize_azure(
- request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
+async def authorize_azure(
+ request: Request,
+ name: str,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+ current_user: UserIdentity = Depends(get_current_user),
):
"""
Authorize Azure sync for the current user.
@@ -62,26 +63,30 @@ def authorize_azure(
CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY
)
logger.debug(f"Authorizing Azure sync for user: {current_user.id}")
- state = f"user_id={current_user.id}, name={name}"
+ state = await syncs_service.create_oauth2_state(
+ provider=SyncProvider.AZURE, name=name, user_id=current_user.id
+ )
flow = client.initiate_auth_code_flow(
- scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state, prompt="select_account"
+ scopes=SCOPE,
+ redirect_uri=REDIRECT_URI,
+ state=state.model_dump_json(),
+ prompt="select_account",
)
-
- sync_user_input = SyncsUserInput(
- user_id=str(current_user.id),
- name=name,
- provider="Azure",
- credentials={},
- state={"state": state},
- additional_data={"flow": flow},
- status=str(SyncsUserStatus.SYNCING),
+ # Azure needs additional data
+ await syncs_service.update_sync(
+ sync_id=state.sync_id,
+ sync_user_input=SyncUpdateInput(
+ additional_data={"flow": flow}, status=SyncStatus.SYNCED
+ ),
)
- sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": flow["auth_uri"]}
@azure_sync_router.get("/sync/azure/oauth2callback", tags=["Sync"])
-def oauth2callback_azure(request: Request):
+async def oauth2callback_azure(
+ request: Request,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+):
"""
Handle OAuth2 callback from Azure.
@@ -94,41 +99,28 @@ def oauth2callback_azure(request: Request):
client = ConfidentialClientApplication(
CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY
)
- state = request.query_params.get("state")
- state_split = state.split(",")
- current_user = state_split[0].split("=")[1] # Extract user_id from state
- name = state_split[1].split("=")[1] if state else None
- state_dict = {"state": state}
+ state_str = request.query_params.get("state")
+ state = parse_oauth2_state(state_str)
logger.debug(
- f"Handling OAuth2 callback for user: {current_user} with state: {state}"
+ f"Handling OAuth2 callback for user: {state.user_id} with state: {state}"
)
- sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
- logger.info(f"Retrieved sync user state: {sync_user_state}")
-
- if not sync_user_state or state_dict != sync_user_state.state:
- logger.error("Invalid state parameter")
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- if str(sync_user_state.user_id) != current_user:
- logger.info(f"Sync user state: {sync_user_state}")
- logger.info(f"Current user: {current_user}")
- logger.info(f"Sync user state user_id: {sync_user_state.user_id}")
- logger.error("Invalid user")
- raise HTTPException(status_code=400, detail="Invalid user")
-
- result = client.acquire_token_by_auth_code_flow(
- sync_user_state.additional_data["flow"], dict(request.query_params)
+ sync = await syncs_service.get_from_oauth2_state(state)
+ if sync.additional_data is None:
+ raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
+
+ flow_data = client.acquire_token_by_auth_code_flow(
+ sync.additional_data["flow"], dict(request.query_params)
)
- if "access_token" not in result:
- logger.error(f"Failed to acquire token: {result}")
+
+ if "access_token" not in flow_data:
+ logger.error(f"Failed to acquire token: {flow_data}")
raise HTTPException(
status_code=400,
- detail=f"Failed to acquire token: {result}",
+ detail=f"Failed to acquire token: {flow_data}",
)
- access_token = result["access_token"]
-
- creds = result
- logger.info(f"Fetched OAuth2 token for user: {current_user}")
+ access_token = flow_data["access_token"]
+ logger.info(f"Fetched OAuth2 token for user: {state.user_id}")
# Fetch user email from Microsoft Graph API
graph_url = "https://graph.microsoft.com/v1.0/me"
@@ -140,14 +132,11 @@ def oauth2callback_azure(request: Request):
user_info = response.json()
user_email = user_info.get("mail") or user_info.get("userPrincipalName")
- logger.info(f"Retrieved email for user: {current_user} - {user_email}")
+ logger.info(f"Retrieved email for user: {state.user_id} - {user_email}")
- sync_user_input = SyncUserUpdateInput(
- credentials=result,
- email=user_email,
- status=str(SyncsUserStatus.SYNCED),
+ sync_user_input = SyncUpdateInput(
+ credentials=flow_data, state={}, email=user_email, status=SyncStatus.SYNCED
)
-
- sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
- logger.info(f"Azure sync created successfully for user: {current_user}")
+ await syncs_service.update_sync(state.sync_id, sync_user_input)
+ logger.info(f"Azure sync created successfully for user: {state.user_id}")
return HTMLResponse(successfullConnectionPage)
diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py
index df3c955a9356..b715e899ec7b 100644
--- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py
@@ -1,5 +1,5 @@
import os
-from uuid import UUID
+from typing import Tuple
from dropbox import Dropbox, DropboxOAuth2Flow
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -7,12 +7,11 @@
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
@@ -24,8 +23,7 @@
SCOPE = ["files.metadata.read", "account_info.read", "files.content.read"]
# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
+syncs_service_dep = get_service(SyncsService)
logger = get_logger(__name__)
@@ -38,8 +36,11 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-def authorize_dropbox(
- request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
+async def authorize_dropbox(
+ request: Request,
+ name: str,
+ current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(syncs_service_dep),
):
"""
Authorize DropBox sync for the current user.
@@ -63,27 +64,31 @@ def authorize_dropbox(
token_access_type="offline",
scope=SCOPE,
)
- state: str = f"user_id={current_user.id}, name={name}"
- authorize_url = auth_flow.start(state)
-
+ state = await syncs_service.create_oauth2_state(
+ provider=SyncProvider.DROPBOX, name=name, user_id=current_user.id
+ )
+ authorize_url = auth_flow.start(state.model_dump_json())
logger.info(
f"Generated authorization URL: {authorize_url} for user: {current_user.id}"
)
- sync_user_input = SyncsUserInput(
- name=name,
- user_id=str(current_user.id),
- provider="DropBox",
- credentials={},
- state={"state": state},
- additional_data={},
- status=str(SyncsUserStatus.SYNCING),
- )
- sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorize_url}
+def parse_dropbox_oauth2_session(state_str: str | None) -> Tuple[dict[str, str], str]:
+ if state_str is None:
+ raise ValueError
+ session = {}
+ session["csrf-token"] = state_str.split("|")[0] if "|" in state_str else ""
+ logger.debug("Keys in session : %s", session.keys())
+ logger.debug("Value in session : %s", session.values())
+ return session, state_str.split("|")[1]
+
+
@dropbox_sync_router.get("/sync/dropbox/oauth2callback", tags=["Sync"])
-def oauth2callback_dropbox(request: Request):
+async def oauth2callback_dropbox(
+ request: Request,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+):
"""
Handle OAuth2 callback from DropBox.
@@ -93,34 +98,10 @@ def oauth2callback_dropbox(request: Request):
Returns:
dict: A dictionary containing a success message.
"""
- state = request.query_params.get("state")
- if not state:
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- session = {}
- session["csrf-token"] = state.split("|")[0] if "|" in state else ""
-
- logger.debug("Keys in session : %s", session.keys())
- logger.debug("Value in session : %s", session.values())
-
- state = state.split("|")[1] if "|" in state else state # type: ignore
- state_dict = {"state": state}
- state_split = state.split(",") # type: ignore
- current_user = UUID(state_split[0].split("=")[1]) if state else None
- logger.debug(
- f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}"
- )
- sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
-
- if not sync_user_state or state_dict != sync_user_state.state:
- logger.error("Invalid state parameter")
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- else:
- logger.info(
- f"CURRENT USER: {current_user}, SYNC USER STATE USER: {sync_user_state.user_id}"
- )
-
- if sync_user_state.user_id != current_user:
- raise HTTPException(status_code=400, detail="Invalid user")
+ state_str = request.query_params.get("state")
+ session, state_str = parse_dropbox_oauth2_session(state_str)
+ state = parse_oauth2_state(state_str)
+ sync = await syncs_service.get_from_oauth2_state(state)
auth_flow = DropboxOAuth2Flow(
DROPBOX_APP_KEY,
@@ -143,22 +124,21 @@ def oauth2callback_dropbox(request: Request):
user_email = account_info.email # type: ignore
account_id = account_info.account_id # type: ignore
- result: dict[str, str] = {
+ credentials: dict[str, str] = {
"access_token": oauth_result.access_token,
"refresh_token": oauth_result.refresh_token,
"account_id": account_id,
"expires_in": str(oauth_result.expires_at),
}
- sync_user_input = SyncUserUpdateInput(
- credentials=result,
- # state={},
+ sync_user_input = SyncUpdateInput(
+ credentials=credentials,
+ state={},
email=user_email,
- status=str(SyncsUserStatus.SYNCED),
+ status=SyncStatus.SYNCED,
)
- assert current_user
- sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
- logger.info(f"DropBox sync created successfully for user: {current_user}")
+ await syncs_service.update_sync(sync.id, sync_user_input)
+ logger.info(f"DropBox sync created successfully for user: {state.user_id}")
return HTMLResponse(successfullConnectionPage)
except Exception as e:
logger.error(f"Error: {e}")
diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
index fc4cd4a91a49..e42f6d4d7e22 100644
--- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py
@@ -6,12 +6,11 @@
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
@@ -20,8 +19,7 @@
logger = get_logger(__name__)
# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
+syncs_service_dep = get_service(SyncsService)
# Initialize API router
github_sync_router = APIRouter()
@@ -39,8 +37,11 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-def authorize_github(
- request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
+async def authorize_github(
+ request: Request,
+ name: str,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+ current_user: UserIdentity = Depends(get_current_user),
):
"""
Authorize GitHub sync for the current user.
@@ -53,26 +54,20 @@ def authorize_github(
dict: A dictionary containing the authorization URL.
"""
logger.debug(f"Authorizing GitHub sync for user: {current_user.id}")
- state = f"user_id={current_user.id},name={name}"
+ state = await syncs_service.create_oauth2_state(
+ provider=SyncProvider.GITHUB, name=name, user_id=current_user.id
+ )
authorization_url = (
f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}"
- f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}"
- )
-
- sync_user_input = SyncsUserInput(
- user_id=str(current_user.id),
- name=name,
- provider="GitHub",
- credentials={},
- state={"state": state},
- status=str(SyncsUserStatus.SYNCING),
+ f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state.model_dump_json()}"
)
- sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorization_url}
@github_sync_router.get("/sync/github/oauth2callback", tags=["Sync"])
-def oauth2callback_github(request: Request):
+async def oauth2callback_github(
+ request: Request, syncs_service: SyncsService = Depends(syncs_service_dep)
+):
"""
Handle OAuth2 callback from GitHub.
@@ -82,24 +77,12 @@ def oauth2callback_github(request: Request):
Returns:
dict: A dictionary containing a success message.
"""
- state = request.query_params.get("state")
- state_split = state.split(",")
- current_user = state_split[0].split("=")[1] # Extract user_id from state
- name = state_split[1].split("=")[1] if state else None
- state_dict = {"state": state}
+ state_str = request.query_params.get("state")
+ state = parse_oauth2_state(state_str)
logger.debug(
- f"Handling OAuth2 callback for user: {current_user} with state: {state}"
+ f"Handling OAuth2 callback for user: {state.user_id} with state: {state}"
)
- sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
- logger.info(f"Retrieved sync user state: {sync_user_state}")
-
- if state_dict != sync_user_state["state"]:
- logger.error("Invalid state parameter")
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- if sync_user_state.get("user_id") != current_user:
- logger.error("Invalid user")
- raise HTTPException(status_code=400, detail="Invalid user")
-
+ sync = await syncs_service.get_from_oauth2_state(state)
token_url = "https://github.com/login/oauth/access_token"
data = {
"client_id": CLIENT_ID,
@@ -126,8 +109,7 @@ def oauth2callback_github(request: Request):
detail=f"Failed to acquire token: {result}",
)
- creds = result
- logger.info(f"Fetched OAuth2 token for user: {current_user}")
+ logger.info(f"Fetched OAuth2 token for user: {state.user_id}")
# Fetch user email from GitHub API
github_api_url = "https://api.github.com/user"
@@ -150,15 +132,16 @@ def oauth2callback_github(request: Request):
logger.error("Failed to fetch user email from GitHub API")
raise HTTPException(status_code=400, detail="Failed to fetch user email")
- logger.info(f"Retrieved email for user: {current_user} - {user_email}")
+ logger.info(f"Retrieved email for user: {state.user_id} - {user_email}")
- sync_user_input = SyncUserUpdateInput(
+ sync_user_input = SyncUpdateInput(
credentials=result,
- # state={},
+ state={},
email=user_email,
- status=str(SyncsUserStatus.SYNCED),
+ status=SyncStatus.SYNCED,
)
- sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
- logger.info(f"GitHub sync created successfully for user: {current_user}")
+ # TODO: This an additional select query :(
+ await syncs_service.update_sync(sync.id, sync_user_input)
+ logger.info(f"GitHub sync created successfully for user: {state.user_id}")
return HTMLResponse(successfullConnectionPage)
diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py
index 3e4b5c9a5e87..e9dd95a52d2e 100644
--- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py
@@ -1,20 +1,18 @@
import json
import os
-from uuid import UUID
-from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from google_auth_oauthlib.flow import Flow
from googleapiclient.discovery import build
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
@@ -26,8 +24,7 @@
logger = get_logger(__name__)
# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
+syncs_service_dep = get_service(SyncsService)
# Initialize API router
google_sync_router = APIRouter()
@@ -66,8 +63,11 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-def authorize_google(
- request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
+async def authorize_google(
+ request: Request,
+ name: str,
+ current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(syncs_service_dep),
):
"""
Authorize Google Drive sync for the current user.
@@ -88,31 +88,27 @@ def authorize_google(
scopes=SCOPES,
redirect_uri=redirect_uri,
)
- state = f"user_id={current_user.id}, name={name}"
+
+ state = await syncs_service.create_oauth2_state(
+ provider=SyncProvider.GOOGLE, name=name, user_id=current_user.id
+ )
authorization_url, state = flow.authorization_url(
access_type="offline",
include_granted_scopes="true",
- state=state,
+ state=state.model_dump_json(),
prompt="consent",
)
logger.info(
f"Generated authorization URL: {authorization_url} for user: {current_user.id}"
)
- sync_user_input = SyncsUserInput(
- name=name,
- user_id=str(current_user.id),
- provider="Google",
- credentials={},
- state={"state": state},
- additional_data={},
- status=str(SyncsUserStatus.SYNCED),
- )
- sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorization_url}
@google_sync_router.get("/sync/google/oauth2callback", tags=["Sync"])
-def oauth2callback_google(request: Request):
+async def oauth2callback_google(
+ request: Request,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+):
"""
Handle OAuth2 callback from Google.
@@ -122,49 +118,35 @@ def oauth2callback_google(request: Request):
Returns:
dict: A dictionary containing a success message.
"""
- state = request.query_params.get("state")
- state_dict = {"state": state}
- logger.info(f"State: {state}")
- state_split = state.split(",")
- current_user = UUID(state_split[0].split("=")[1]) if state else None
- assert current_user, f"oauth2callback_googl empty current_user in {request}"
+ state_str = request.query_params.get("state")
+ state = parse_oauth2_state(state_str)
logger.debug(
- f"Handling OAuth2 callback for user: {current_user} with state: {state}"
+ f"Handling OAuth2 callback for user: {state.user_id} with state: {state}"
)
- sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
- logger.info(f"Retrieved sync user state: {sync_user_state}")
-
- if not sync_user_state or state_dict != sync_user_state.state:
- logger.error("Invalid state parameter")
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- if sync_user_state.user_id != current_user:
- logger.error("Invalid user")
- logger.info(f"Invalid user: {current_user}")
- raise HTTPException(status_code=400, detail="Invalid user")
-
+ sync = await syncs_service.get_from_oauth2_state(state)
redirect_uri = f"{BASE_REDIRECT_URI}"
flow = Flow.from_client_config(
CLIENT_SECRETS_FILE_CONTENT,
scopes=SCOPES,
- state=state,
+ state=state_str,
redirect_uri=redirect_uri,
)
flow.fetch_token(authorization_response=str(request.url))
creds = flow.credentials
- logger.info(f"Fetched OAuth2 token for user: {current_user}")
+ logger.info(f"Fetched OAuth2 token for user: {state.user_id}")
# Use the credentials to get the user's email
service = build("oauth2", "v2", credentials=creds)
user_info = service.userinfo().get().execute()
user_email = user_info.get("email")
- logger.info(f"Retrieved email for user: {current_user} - {user_email}")
+ logger.info(f"Retrieved email for user: {state.user_id} - {user_email}")
- sync_user_input = SyncUserUpdateInput(
+ sync_user_input = SyncUpdateInput(
credentials=json.loads(creds.to_json()),
- # state={},
+ state={},
email=user_email,
- status=str(SyncsUserStatus.SYNCED),
+ status=SyncStatus.SYNCED,
)
- sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
- logger.info(f"Google Drive sync created successfully for user: {current_user}")
+ sync = await syncs_service.update_sync(sync.id, sync_user_input)
+ logger.info(f"Google Drive sync created successfully for user: {state.user_id}")
return HTMLResponse(successfullConnectionPage)
diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py
index 4c450ecb1ce8..b151a043cb2d 100644
--- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py
@@ -1,21 +1,19 @@
import base64
import os
-from uuid import UUID
import requests
-from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
+from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from notion_client import Client
from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.dependencies import get_service
+from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
@@ -29,8 +27,7 @@
# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
+syncs_service_dep = get_service(SyncsService)
logger = get_logger(__name__)
@@ -43,8 +40,11 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-def authorize_notion(
- request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
+async def authorize_notion(
+ request: Request,
+ name: str,
+ current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(syncs_service_dep),
):
"""
Authorize Notion sync for the current user.
@@ -57,26 +57,22 @@ def authorize_notion(
dict: A dictionary containing the authorization URL.
"""
logger.debug(f"Authorizing Notion sync for user: {current_user.id}, name : {name}")
- state: str = f"user_id={current_user.id}, name={name}"
- authorize_url = str(NOTION_AUTH_URL) + f"&state={state}"
-
- logger.info(
- f"Generated authorization URL: {authorize_url} for user: {current_user.id}"
+ state = await syncs_service.create_oauth2_state(
+ provider=SyncProvider.NOTION, name=name, user_id=current_user.id
)
- sync_user_input = SyncsUserInput(
- name=name,
- user_id=str(current_user.id),
- provider="Notion",
- credentials={},
- state={"state": state},
- status=str(SyncsUserStatus.SYNCING),
+ # Finalize the state
+ authorize_url = str(NOTION_AUTH_URL) + f"&state={state.model_dump_json()}"
+ logger.debug(
+ f"Generated authorization URL: {authorize_url} for user: {current_user.id}"
)
- sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorize_url}
@notion_sync_router.get("/sync/notion/oauth2callback", tags=["Sync"])
-def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks):
+async def oauth2callback_notion(
+ request: Request,
+ syncs_service: SyncsService = Depends(syncs_service_dep),
+):
"""
Handle OAuth2 callback from Notion.
@@ -87,29 +83,9 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks):
dict: A dictionary containing a success message.
"""
code = request.query_params.get("code")
- state = request.query_params.get("state")
- if not state:
- raise HTTPException(status_code=400, detail="Invalid state parameter")
-
- state_dict = {"state": state}
- state_split = state.split(",") # type: ignore
- current_user = UUID(state_split[0].split("=")[1]) if state else None
- assert current_user, "Oauth callback user is None"
- logger.debug(
- f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}"
- )
- sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
-
- if not sync_user_state or state_dict != sync_user_state.state:
- logger.error(f"Invalid state parameter for {sync_user_state}")
- raise HTTPException(status_code=400, detail="Invalid state parameter")
- else:
- logger.info(
- f"Current user: {current_user}, sync user state: {sync_user_state.state}"
- )
- if sync_user_state.user_id != current_user:
- raise HTTPException(status_code=400, detail="Invalid user")
+ state_str = request.query_params.get("state")
+ state = parse_oauth2_state(state_str)
try:
token_url = "https://api.notion.com/v1/oauth/token"
@@ -148,22 +124,19 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks):
"expires_in": oauth_result.get("expires_in", ""),
}
- sync_user_input = SyncUserUpdateInput(
+ sync_user_input = SyncUpdateInput(
credentials=result,
- # state={},
+ state={},
email=user_email,
- status=str(SyncsUserStatus.SYNCING),
+ status=SyncStatus.SYNCED,
)
- sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
- logger.info(f"Notion sync created successfully for user: {current_user}")
+ await syncs_service.update_sync(state.sync_id, sync_user_input)
+
+ logger.info(f"Notion sync created successfully for user: {state.user_id}")
# launch celery task to sync notion data
celery.send_task(
"fetch_and_store_notion_files_task",
- kwargs={
- "access_token": access_token,
- "user_id": current_user,
- "sync_user_id": sync_user_state.id,
- },
+ kwargs={"access_token": access_token, "user_id": state.user_id},
)
return HTMLResponse(successfullConnectionPage)
diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py
index 3adcbe41b4bb..e01d9df8d752 100644
--- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py
+++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py
@@ -1,15 +1,16 @@
+import asyncio
import os
-import uuid
-from typing import Annotated, List
+from datetime import datetime
+from typing import List, Tuple
+from uuid import UUID
-from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi import APIRouter, Depends, status
-from quivr_api.celery_config import celery
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.dependencies import get_service
-from quivr_api.modules.notification.dto.inputs import CreateNotification
-from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeDTO
+from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.notification.service.notification_service import (
NotificationService,
)
@@ -19,11 +20,9 @@
from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router
from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router
from quivr_api.modules.sync.dto import SyncsDescription
-from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
-from quivr_api.modules.sync.dto.outputs import AuthMethodEnum
-from quivr_api.modules.sync.entity.sync_models import SyncsActive
-from quivr_api.modules.sync.service.sync_notion import SyncNotionService
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.sync.dto.outputs import AuthMethodEnum, SyncProvider
+from quivr_api.modules.sync.entity.sync_models import SyncFile
+from quivr_api.modules.sync.service.sync_service import SyncsService
from quivr_api.modules.user.entity.user_identity import UserIdentity
notification_service = NotificationService()
@@ -35,9 +34,8 @@
logger = get_logger(__name__)
# Initialize sync service
-sync_service = SyncService()
-sync_user_service = SyncUserService()
-NotionServiceDep = Annotated[SyncNotionService, Depends(get_service(SyncNotionService))]
+get_sync_service = get_service(SyncsService)
+get_knowledge_service = get_service(KnowledgeService)
# Initialize API router
@@ -53,31 +51,31 @@
# Google sync description
google_sync = SyncsDescription(
- name="Google",
+ name=SyncProvider.GOOGLE,
description="Sync your Google Drive with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
azure_sync = SyncsDescription(
- name="Azure",
+ name=SyncProvider.AZURE,
description="Sync your Azure Drive with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
dropbox_sync = SyncsDescription(
- name="DropBox",
+ name=SyncProvider.DROPBOX,
description="Sync your DropBox Drive with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
notion_sync = SyncsDescription(
- name="Notion",
+ name=SyncProvider.NOTION,
description="Sync your Notion with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
github_sync = SyncsDescription(
- name="GitHub",
+ name=SyncProvider.GITHUB,
description="Sync your GitHub Drive with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
@@ -89,7 +87,7 @@
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-async def get_syncs(current_user: UserIdentity = Depends(get_current_user)):
+async def get_all_sync_typs(current_user: UserIdentity = Depends(get_current_user)):
"""
Get all available sync descriptions.
@@ -108,7 +106,10 @@ async def get_syncs(current_user: UserIdentity = Depends(get_current_user)):
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
-async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user)):
+async def get_user_syncs(
+ current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(get_sync_service),
+):
"""
Get syncs for the current user.
@@ -119,7 +120,7 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user))
List: A list of syncs for the user.
"""
logger.debug(f"Fetching user syncs for user: {current_user.id}")
- return sync_user_service.get_syncs_user(current_user.id)
+ return await syncs_service.get_user_syncs(current_user.id)
@sync_router.delete(
@@ -129,7 +130,9 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user))
tags=["Sync"],
)
async def delete_user_sync(
- sync_id: int, current_user: UserIdentity = Depends(get_current_user)
+ sync_id: int,
+ current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(get_sync_service),
):
"""
Delete a sync for the current user.
@@ -144,237 +147,21 @@ async def delete_user_sync(
logger.debug(
f"Deleting user sync for user: {current_user.id} with sync ID: {sync_id}"
)
- sync_user_service.delete_sync_user(sync_id, str(current_user.id)) # type: ignore
+ await syncs_service.delete_sync(sync_id, current_user.id)
return None
-@sync_router.post(
- "/sync/active",
- response_model=SyncsActive,
- dependencies=[Depends(AuthBearer())],
- tags=["Sync"],
-)
-async def create_sync_active(
- sync_active_input: SyncsActiveInput,
- current_user: UserIdentity = Depends(get_current_user),
-):
- """
- Create a new active sync for the current user.
-
- Args:
- sync_active_input (SyncsActiveInput): The sync active input data.
- current_user (UserIdentity): The current authenticated user.
-
- Returns:
- SyncsActive: The created sync active data.
- """
- logger.debug(
- f"Creating active sync for user: {current_user.id} with data: {sync_active_input}"
- )
- bulk_id = uuid.uuid4()
- notification = notification_service.add_notification(
- CreateNotification(
- user_id=current_user.id,
- status=NotificationsStatusEnum.INFO,
- title="Synchronization created! ",
- description="Your brain is preparing to sync files. This may take a few minutes before proceeding.",
- category="generic",
- bulk_id=bulk_id,
- brain_id=sync_active_input.brain_id,
- )
- )
- sync_active_input.notification_id = str(notification.id)
- sync_active = sync_service.create_sync_active(
- sync_active_input, str(current_user.id)
- )
- if not sync_active:
- raise HTTPException(
- status_code=500, detail=f"Error creating sync active for {current_user}"
- )
-
- celery.send_task(
- "process_sync_task",
- kwargs={
- "sync_id": sync_active.id,
- "user_id": sync_active.user_id,
- "files_ids": sync_active_input.settings.files,
- "folder_ids": sync_active_input.settings.folders,
- },
- )
-
- return sync_active
-
-
-@sync_router.put(
- "/sync/active/{sync_id}",
- response_model=SyncsActive | None,
- dependencies=[Depends(AuthBearer())],
- tags=["Sync"],
-)
-async def update_sync_active(
- sync_id: int,
- sync_active_input: SyncsActiveUpdateInput,
- current_user: UserIdentity = Depends(get_current_user),
-):
- """
- Update an existing active sync for the current user.
-
- Args:
- sync_id (str): The ID of the active sync to update.
- sync_active_input (SyncsActiveUpdateInput): The updated sync active input data.
- current_user (UserIdentity): The current authenticated user.
-
- Returns:
- SyncsActive: The updated sync active data.
- """
- logger.info(
- f"Updating active sync for user: {current_user.id} with data: {sync_active_input}"
- )
-
- details_sync_active = sync_service.get_details_sync_active(sync_id)
-
- if details_sync_active is None:
- raise HTTPException(
- status_code=500,
- detail="Error updating sync",
- )
-
- if sync_active_input.settings is None:
- return {"message": "No modification to sync active"}
-
- input_file_ids = (
- sync_active_input.settings.files if sync_active_input.settings.files else []
- )
- input_folder_ids = (
- sync_active_input.settings.folders if sync_active_input.settings.folders else []
- )
-
- if (input_file_ids == details_sync_active["settings"]["files"]) and (
- input_folder_ids == details_sync_active["settings"]["folders"]
- ):
- logger.info({"message": "No modification to sync active"})
- return None
-
- logger.debug(
- f"Updating sync_id {details_sync_active['id']}. Sync prev_settings={details_sync_active['settings'] }, Sync active input={sync_active_input.settings}"
- )
-
- bulk_id = uuid.uuid4()
- sync_active_input.force_sync = True
- notification = notification_service.add_notification(
- CreateNotification(
- user_id=current_user.id,
- status=NotificationsStatusEnum.INFO,
- title="Sync updated! Synchronization takes a few minutes to complete",
- description="Your brain is syncing files. This may take a few minutes before proceeding.",
- category="generic",
- bulk_id=bulk_id,
- brain_id=details_sync_active["brain_id"], # type: ignore
- )
- )
- sync_active_input.notification_id = str(notification.id)
- sync_active = sync_service.update_sync_active(sync_id, sync_active_input)
- if not sync_active:
- raise HTTPException(
- status_code=500,
- detail=f"Error updating sync active for {current_user.id}",
- )
- logger.debug(
- f"Sending task process_sync_task for sync_id={sync_id}, user_id={current_user.id}"
- )
-
- added_files_ids = set(input_file_ids).difference(
- set(details_sync_active["settings"]["files"])
- )
- added_folder_ids = set(input_folder_ids).difference(
- set(details_sync_active["settings"]["folders"])
- )
- if len(added_files_ids) + len(added_folder_ids) > 0:
- celery.send_task(
- "process_sync_task",
- kwargs={
- "sync_id": sync_active.id,
- "user_id": sync_active.user_id,
- "files_ids": list(added_files_ids),
- "folder_ids": list(added_folder_ids),
- },
- )
-
- else:
- return None
-
-
-@sync_router.delete(
- "/sync/active/{sync_id}",
- status_code=status.HTTP_204_NO_CONTENT,
- dependencies=[Depends(AuthBearer())],
- tags=["Sync"],
-)
-async def delete_sync_active(
- sync_id: int, current_user: UserIdentity = Depends(get_current_user)
-):
- """
- Delete an existing active sync for the current user.
-
- Args:
- sync_id (str): The ID of the active sync to delete.
- current_user (UserIdentity): The current authenticated user.
-
- Returns:
- None
- """
- logger.debug(
- f"Deleting active sync for user: {current_user.id} with sync ID: {sync_id}"
- )
-
- details_sync_active = sync_service.get_details_sync_active(sync_id)
- notification_service.add_notification(
- CreateNotification(
- user_id=current_user.id,
- status=NotificationsStatusEnum.SUCCESS,
- title="Sync deleted!",
- description="Sync deleted!",
- category="generic",
- bulk_id=uuid.uuid4(),
- brain_id=details_sync_active["brain_id"], # type: ignore
- )
- )
- sync_service.delete_sync_active(sync_id, str(current_user.id)) # type: ignore
- return None
-
-
-@sync_router.get(
- "/sync/active",
- response_model=List[SyncsActive],
- dependencies=[Depends(AuthBearer())],
- tags=["Sync"],
-)
-async def get_active_syncs_for_user(
- current_user: UserIdentity = Depends(get_current_user),
-):
- """
- Get all active syncs for the current user.
-
- Args:
- current_user (UserIdentity): The current authenticated user.
-
- Returns:
- List[SyncsActive]: A list of active syncs for the current user.
- """
- logger.debug(f"Fetching active syncs for user: {current_user.id}")
- return sync_service.get_syncs_active(str(current_user.id))
-
-
@sync_router.get(
"/sync/{sync_id}/files",
- dependencies=[Depends(AuthBearer())],
+ response_model=List[KnowledgeDTO] | None,
tags=["Sync"],
)
-async def get_files_folder_user_sync(
- user_sync_id: int,
- notion_service: NotionServiceDep,
+async def list_sync_files(
+ sync_id: int,
folder_id: str | None = None,
current_user: UserIdentity = Depends(get_current_user),
+ syncs_service: SyncsService = Depends(get_sync_service),
+ knowledge_service: KnowledgeService = Depends(get_knowledge_service),
):
"""
Get files for an active sync.
@@ -387,25 +174,65 @@ async def get_files_folder_user_sync(
Returns:
SyncsActive: The active sync data.
"""
- logger.debug(
- f"Fetching files for user sync: {user_sync_id} for user: {current_user.id}"
- )
- return await sync_user_service.get_files_folder_user_sync(
- user_sync_id, current_user.id, folder_id, notion_service=notion_service
- )
+ logger.debug(f"Fetching files for user sync: {sync_id} for user: {current_user.id}")
+ # TODO: check to see if this is inefficient
+ # Gets knowledge for each call to list the files,
+ # The logic is that getting from DB will be faster than provider repsonse ?
+ # NOTE: asyncio.gather didn't correcly typecheck
-@sync_router.get(
- "/sync/active/interval",
- dependencies=[Depends(AuthBearer())],
- tags=["Sync"],
-)
-async def get_syncs_active_in_interval() -> List[SyncsActive]:
- """
- Get all active syncs that need to be synced.
+ async def fetch_sync_knowledge(
+ sync_id: int,
+ user_id: UUID,
+ folder_id: str | None,
+ ) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]:
+ map_knowledges_task = knowledge_service.map_syncs_knowledge_user(
+ sync_id=sync_id, user_id=user_id
+ )
+ sync_files_task = syncs_service.get_files_folder_user_sync(
+ sync_id,
+ user_id,
+ folder_id,
+ )
+ return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821
- Returns:
- List: A list of active syncs that need to be synced.
- """
- logger.debug("Fetching active syncs in interval")
- return await sync_service.get_syncs_active_in_interval()
+ sync = await syncs_service.get_sync_by_id(sync_id=sync_id)
+ syncfile_to_knowledge, sync_files = await fetch_sync_knowledge(
+ sync_id=sync_id,
+ user_id=current_user.id,
+ folder_id=folder_id,
+ )
+ if not sync_files:
+ return None
+
+ kms = []
+ for file in sync_files:
+ existing_km = syncfile_to_knowledge.get(file.id)
+ if existing_km:
+ kms.append(await existing_km.to_dto(get_children=False, get_parent=False))
+ else:
+ last_modified_at = (
+ file.last_modified_at if file.last_modified_at else datetime.now()
+ )
+ kms.append(
+ KnowledgeDTO(
+ id=None,
+ file_name=file.name,
+ is_folder=file.is_folder,
+ extension=file.extension,
+ source=sync.provider,
+ source_link=file.web_view_link,
+ user_id=current_user.id,
+ brains=[],
+ parent=None,
+ children=[],
+ # TODO: Handle a sync not added status
+ status=None,
+ # TODO: retrieve created at from sync provider
+ created_at=last_modified_at,
+ updated_at=last_modified_at,
+ sync_id=sync_id,
+ sync_file_id=file.id,
+ )
+ )
+ return kms
diff --git a/backend/api/quivr_api/modules/sync/dto/__init__.py b/backend/api/quivr_api/modules/sync/dto/__init__.py
index 986765df6302..b40f07618217 100644
--- a/backend/api/quivr_api/modules/sync/dto/__init__.py
+++ b/backend/api/quivr_api/modules/sync/dto/__init__.py
@@ -1 +1 @@
-from .outputs import SyncsDescription, SyncsUserOutput
+from .outputs import SyncsDescription, SyncsOutput
diff --git a/backend/api/quivr_api/modules/sync/dto/inputs.py b/backend/api/quivr_api/modules/sync/dto/inputs.py
index b192216acf71..4a3824f7c67a 100644
--- a/backend/api/quivr_api/modules/sync/dto/inputs.py
+++ b/backend/api/quivr_api/modules/sync/dto/inputs.py
@@ -1,24 +1,17 @@
import enum
-from typing import List, Optional
+from uuid import UUID
from pydantic import BaseModel
-class SyncsUserStatus(enum.Enum):
- """
- Enum for the status of a sync user.
- """
-
+class SyncStatus(str, enum.Enum):
SYNCED = "SYNCED"
SYNCING = "SYNCING"
ERROR = "ERROR"
REMOVED = "REMOVED"
- def __str__(self):
- return self.value
-
-class SyncsUserInput(BaseModel):
+class SyncCreateInput(BaseModel):
"""
Input model for creating a new sync user.
@@ -30,7 +23,7 @@ class SyncsUserInput(BaseModel):
state (dict): The state information for the sync user.
"""
- user_id: str
+ user_id: UUID
name: str
email: str | None = None
provider: str
@@ -40,7 +33,7 @@ class SyncsUserInput(BaseModel):
status: str
-class SyncUserUpdateInput(BaseModel):
+class SyncUpdateInput(BaseModel):
"""
Input model for updating an existing sync user.
@@ -49,82 +42,8 @@ class SyncUserUpdateInput(BaseModel):
state (dict): The updated state information for the sync user.
"""
- credentials: dict
+ additional_data: dict | None = None
+ credentials: dict | None = None
state: dict | None = None
- email: str
- status: str
-
-
-class SyncActiveSettings(BaseModel):
- """
- Sync active settings.
-
- Attributes:
- folders (List[str] | None): A list of folder paths to be synced, or None if not applicable.
- files (List[str] | None): A list of file paths to be synced, or None if not applicable.
- """
-
- folders: Optional[List[str]] = None
- files: Optional[List[str]] = None
-
-
-class SyncsActiveInput(BaseModel):
- """
- Input model for creating a new active sync.
-
- Attributes:
- name (str): The name of the sync.
- syncs_user_id (int): The ID of the sync user associated with this sync.
- settings (SyncActiveSettings): The settings for the active sync.
- """
-
- name: str
- syncs_user_id: int
- settings: SyncActiveSettings
- brain_id: str
- notification_id: Optional[str] = None
-
-
-class SyncsActiveUpdateInput(BaseModel):
- """
- Input model for updating an existing active sync.
-
- Attributes:
- name (str): The updated name of the sync.
- sync_interval_minutes (int): The updated sync interval in minutes.
- settings (dict): The updated settings for the active sync.
- """
-
- name: Optional[str] = None
- settings: Optional[SyncActiveSettings] = None
- last_synced: Optional[str] = None
- force_sync: Optional[bool] = False
- notification_id: Optional[str] = None
-
-
-class SyncFileInput(BaseModel):
- """
- Input model for creating a new sync file.
-
- Attributes:
- path (str): The path of the file.
- syncs_active_id (int): The ID of the active sync associated with this file.
- """
-
- path: str
- syncs_active_id: int
- last_modified: str
- brain_id: str
- supported: Optional[bool] = True
-
-
-class SyncFileUpdateInput(BaseModel):
- """
- Input model for updating an existing sync file.
-
- Attributes:
- last_modified (datetime.datetime): The updated last modified date and time.
- """
-
- last_modified: Optional[str] = None
- supported: Optional[bool] = None
+ email: str | None = None
+ status: SyncStatus
diff --git a/backend/api/quivr_api/modules/sync/dto/outputs.py b/backend/api/quivr_api/modules/sync/dto/outputs.py
index 498f66177b01..3bdf004c2e36 100644
--- a/backend/api/quivr_api/modules/sync/dto/outputs.py
+++ b/backend/api/quivr_api/modules/sync/dto/outputs.py
@@ -1,4 +1,5 @@
from enum import Enum
+from uuid import UUID
from pydantic import BaseModel
@@ -7,14 +8,24 @@ class AuthMethodEnum(str, Enum):
URI_WITH_CALLBACK = "uri_with_callback"
+class SyncProvider(str, Enum):
+ GOOGLE = "google"
+ AZURE = "azure"
+ DROPBOX = "dropbox"
+ NOTION = "notion"
+ GITHUB = "github"
+
+
class SyncsDescription(BaseModel):
name: str
description: str
auth_method: AuthMethodEnum
-class SyncsUserOutput(BaseModel):
- user_id: str
- provider: str
- state: dict
- credentials: dict
+class SyncsOutput(BaseModel):
+ id: int
+ user_id: UUID
+ provider: SyncProvider
+ state: dict | None
+ credentials: dict | None
+ additional_data: dict | None
diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py
index 0b22cba38aea..036b84669ffb 100644
--- a/backend/api/quivr_api/modules/sync/entity/sync_models.py
+++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py
@@ -2,13 +2,24 @@
import io
from dataclasses import dataclass
from datetime import datetime
-from typing import Optional
+from enum import Enum, auto
+from typing import Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
-from sqlmodel import TIMESTAMP, Column, Field, Relationship, SQLModel, text
+from sqlmodel import ( # noqa: F811
+ JSON,
+ TIMESTAMP,
+ Column,
+ Field,
+ Relationship,
+ SQLModel,
+ text,
+)
from sqlmodel import UUID as PGUUID
+from quivr_api.modules.sync.dto.inputs import SyncStatus
+from quivr_api.modules.sync.dto.outputs import SyncProvider, SyncsOutput
from quivr_api.modules.user.entity.user_identity import User
@@ -27,57 +38,68 @@ def file_sha1(self) -> str:
return m.hexdigest()
-class DBSyncFile(BaseModel):
- id: int
- path: str
- syncs_active_id: int
- last_modified: str
- brain_id: str
- supported: bool
-
-
class SyncFile(BaseModel):
id: str
name: str
is_folder: bool
- last_modified: str
- mime_type: str
+ last_modified_at: Optional[datetime]
+ extension: str
web_view_link: str
size: Optional[int] = None
- notification_id: UUID | None = None
icon: Optional[str] = None
parent_id: Optional[str] = None
type: Optional[str] = None
-class SyncsUser(BaseModel):
- id: int
- user_id: UUID
- name: str
- email: str | None = None
- provider: str
- credentials: dict
- state: dict
- additional_data: dict
- status: Optional[str] = None
+class SyncType(Enum):
+ FOLDER = auto()
+ FILE = auto()
+
+class Sync(SQLModel, table=True):
+ __tablename__ = "syncs" # type: ignore
-class SyncsActive(BaseModel):
- id: int
+ id: int | None = Field(default=None, primary_key=True)
name: str
- syncs_user_id: int
- user_id: UUID
- settings: dict
- last_synced: str
- sync_interval_minutes: int
- brain_id: UUID
- syncs_user: Optional[SyncsUser] = None
- notification_id: Optional[str] = None
-
-
-# TODO: all of this should be rewritten
-class SyncsActiveDetails(BaseModel):
- pass
+ email: str | None = None
+ provider: str
+ email: str | None = Field(default=None)
+ user_id: UUID = Field(foreign_key="users.id", nullable=False)
+ credentials: Dict[str, str] | None = Field(
+ default=None, sa_column=Column("credentials", JSON)
+ )
+ state: Dict[str, str] | None = Field(default=None, sa_column=Column("state", JSON))
+ status: str = Field(default=SyncStatus.SYNCING)
+ created_at: datetime | None = Field(
+ default=None,
+ sa_column=Column(
+ TIMESTAMP(timezone=False),
+ server_default=text("CURRENT_TIMESTAMP"),
+ ),
+ )
+ updated_at: datetime | None = Field(
+ default=None,
+ sa_column=Column(
+ TIMESTAMP(timezone=False),
+ server_default=text("CURRENT_TIMESTAMP"),
+ onupdate=datetime.utcnow,
+ ),
+ )
+ additional_data: dict | None = Field(
+ default=None, sa_column=Column("additional_data", JSON)
+ )
+ knowledges: List["KnowledgeDB"] | None = Relationship(back_populates="sync")
+
+ def to_dto(self) -> SyncsOutput:
+ assert self.id, "can't create create output if sync isn't inserted"
+ return SyncsOutput(
+ id=self.id,
+ user_id=self.user_id,
+ provider=SyncProvider(self.provider.lower()),
+ credentials=self.credentials,
+ state=self.state,
+ additional_data=self.additional_data,
+ )
class NotionSyncFile(SQLModel, table=True):
diff --git a/backend/api/quivr_api/modules/sync/repository/notion_repository.py b/backend/api/quivr_api/modules/sync/repository/notion_repository.py
new file mode 100644
index 000000000000..f87be94c029e
--- /dev/null
+++ b/backend/api/quivr_api/modules/sync/repository/notion_repository.py
@@ -0,0 +1,130 @@
+from typing import List, Sequence
+from uuid import UUID
+
+from sqlalchemy import or_
+from sqlalchemy.exc import IntegrityError
+from sqlmodel import col, select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from quivr_api.logger import get_logger
+from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
+from quivr_api.modules.notification.service.notification_service import (
+ NotificationService,
+)
+from quivr_api.modules.sync.entity.sync_models import NotionSyncFile
+
+notification_service = NotificationService()
+
+logger = get_logger(__name__)
+
+
+class NotionRepository(BaseRepository):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+ self.session = session
+ self.db = get_supabase_client()
+
+ async def get_user_notion_files(self, user_id: UUID) -> Sequence[NotionSyncFile]:
+ query = select(NotionSyncFile).where(NotionSyncFile.user_id == user_id)
+ response = await self.session.exec(query)
+ return response.all()
+
+ async def create_notion_files(
+ self, notion_files: List[NotionSyncFile]
+ ) -> List[NotionSyncFile]:
+ try:
+ self.session.add_all(notion_files)
+ await self.session.commit()
+ except IntegrityError:
+ await self.session.rollback()
+ raise Exception("Integrity error while creating notion files.")
+ except Exception as e:
+ await self.session.rollback()
+ raise e
+
+ return notion_files
+
+ async def update_notion_file(self, updated_notion_file: NotionSyncFile) -> bool:
+ try:
+ is_update = False
+ query = select(NotionSyncFile).where(
+ NotionSyncFile.notion_id == updated_notion_file.notion_id
+ )
+ result = await self.session.exec(query)
+ existing_page = result.one_or_none()
+
+ if existing_page:
+ # Update existing page
+ existing_page.name = updated_notion_file.name
+ existing_page.last_modified = updated_notion_file.last_modified
+ self.session.add(existing_page)
+ is_update = True
+ else:
+ # Add new page
+ self.session.add(updated_notion_file)
+
+ await self.session.commit()
+
+ # Refresh the object that's actually in the session
+ refreshed_file = existing_page if existing_page else updated_notion_file
+ await self.session.refresh(refreshed_file)
+
+ logger.info(f"Updated notion file in notion repo: {refreshed_file}")
+ return is_update
+
+ except IntegrityError as ie:
+ logger.error(f"IntegrityError occurred: {ie}")
+ await self.session.rollback()
+ raise Exception("Integrity error while updating notion file.")
+ except Exception as e:
+ logger.error(f"Exception occurred: {e}")
+ await self.session.rollback()
+ raise
+
+ async def get_notion_files_by_ids(self, ids: List[str]) -> Sequence[NotionSyncFile]:
+ query = select(NotionSyncFile).where(NotionSyncFile.notion_id.in_(ids)) # type: ignore
+ response = await self.session.exec(query)
+ return response.all()
+
+ async def get_notion_files_by_parent_id(
+ self, parent_id: str | None
+ ) -> Sequence[NotionSyncFile]:
+ query = select(NotionSyncFile).where(NotionSyncFile.parent_id == parent_id)
+ response = await self.session.exec(query)
+ return response.all()
+
+ async def get_all_notion_files(self) -> Sequence[NotionSyncFile]:
+ query = select(NotionSyncFile)
+ response = await self.session.exec(query)
+ return response.all()
+
+ async def is_folder_page(self, page_id: str) -> bool:
+ query = select(NotionSyncFile).where(NotionSyncFile.parent_id == page_id)
+ response = await self.session.exec(query)
+ return response.first() is not None
+
+ async def delete_notion_page(self, notion_id: UUID):
+ query = select(NotionSyncFile).where(NotionSyncFile.notion_id == notion_id)
+ response = await self.session.exec(query)
+ notion_file = response.first()
+ if notion_file:
+ await self.session.delete(notion_file)
+ await self.session.commit()
+ return notion_file
+ return None
+
+ async def delete_notion_pages(self, notion_ids: List[UUID]):
+ query = select(NotionSyncFile).where(
+ or_(
+ col(NotionSyncFile.notion_id).in_(notion_ids),
+ col(NotionSyncFile.parent_id).in_(notion_ids),
+ )
+ )
+ response = await self.session.exec(query)
+ notion_files = response.all()
+ if notion_files:
+ for notion_file in notion_files:
+ await self.session.delete(notion_file)
+ await self.session.commit()
+ return notion_files
+ return None
diff --git a/backend/api/quivr_api/modules/sync/repository/sync_files.py b/backend/api/quivr_api/modules/sync/repository/sync_files.py
deleted file mode 100644
index e814192a74b8..000000000000
--- a/backend/api/quivr_api/modules/sync/repository/sync_files.py
+++ /dev/null
@@ -1,129 +0,0 @@
-from quivr_api.logger import get_logger
-from quivr_api.modules.dependencies import get_supabase_client
-from quivr_api.modules.sync.dto.inputs import (
- SyncFileInput,
- SyncFileUpdateInput,
-)
-from quivr_api.modules.sync.entity.sync_models import DBSyncFile, SyncFile, SyncsActive
-from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface
-
-logger = get_logger(__name__)
-
-
-class SyncFilesRepository(SyncFileInterface):
- def __init__(self):
- """
- Initialize the SyncFiles class with a Supabase client.
- """
- supabase_client = get_supabase_client()
- self.db = supabase_client # type: ignore
- logger.debug("Supabase client initialized")
-
- def create_sync_file(self, sync_file_input: SyncFileInput) -> DBSyncFile | None:
- """
- Create a new sync file in the database.
-
- Args:
- sync_file_input (SyncFileInput): The input data for creating a sync file.
-
- Returns:
- SyncsFiles: The created sync file data.
- """
- logger.info("Creating sync file with input: %s", sync_file_input)
- response = (
- self.db.from_("syncs_files")
- .insert(
- {
- "path": sync_file_input.path,
- "syncs_active_id": sync_file_input.syncs_active_id,
- "last_modified": sync_file_input.last_modified,
- "brain_id": sync_file_input.brain_id,
- }
- )
- .execute()
- )
- if response.data:
- logger.info("Sync file created successfully: %s", response.data[0])
- return DBSyncFile(**response.data[0])
- logger.warning("Failed to create sync file")
-
- def get_sync_files(self, sync_active_id: int) -> list[DBSyncFile]:
- """
- Retrieve sync files from the database.
-
- Args:
- sync_active_id (int): The ID of the active sync.
-
- Returns:
- list[SyncsFiles]: A list of sync files matching the criteria.
- """
- logger.info("Retrieving sync files for sync_active_id: %s", sync_active_id)
- response = (
- self.db.from_("syncs_files")
- .select("*")
- .eq("syncs_active_id", sync_active_id)
- .execute()
- )
- if response.data:
- # logger.info("Sync files retrieved successfully: %s", response.data)
- return [DBSyncFile(**file) for file in response.data]
- logger.warning("No sync files found for sync_active_id: %s", sync_active_id)
- return []
-
- def update_sync_file(self, sync_file_id: int, sync_file_input: SyncFileUpdateInput):
- """
- Update a sync file in the database.
-
- Args:
- sync_file_id (int): The ID of the sync file.
- sync_file_input (SyncFileUpdateInput): The input data for updating the sync file.
- """
- logger.info(
- "Updating sync file with sync_file_id: %s, input: %s",
- sync_file_id,
- sync_file_input,
- )
- self.db.from_("syncs_files").update(
- sync_file_input.model_dump(exclude_unset=True)
- ).eq("id", sync_file_id).execute()
- logger.info("Sync file updated successfully")
-
- def update_or_create_sync_file(
- self,
- file: SyncFile,
- sync_active: SyncsActive,
- previous_file: DBSyncFile | None,
- supported: bool,
- ) -> DBSyncFile | None:
- if previous_file:
- logger.debug(f"Upserting file {previous_file} in database.")
- sync_file = self.update_sync_file(
- previous_file.id,
- SyncFileUpdateInput(
- last_modified=file.last_modified,
- supported=previous_file.supported or supported,
- ),
- )
- else:
- logger.debug("Creating new file in database.")
- sync_file = self.create_sync_file(
- SyncFileInput(
- path=file.name,
- syncs_active_id=sync_active.id,
- last_modified=file.last_modified,
- brain_id=str(sync_active.brain_id),
- supported=supported,
- )
- )
- return sync_file
-
- def delete_sync_file(self, sync_file_id: int):
- """
- Delete a sync file from the database.
-
- Args:
- sync_file_id (int): The ID of the sync file.
- """
- logger.info("Deleting sync file with sync_file_id: %s", sync_file_id)
- self.db.from_("syncs_files").delete().eq("id", sync_file_id).execute()
- logger.info("Sync file deleted successfully")
diff --git a/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py b/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py
deleted file mode 100644
index dc6a9add6eb9..000000000000
--- a/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py
+++ /dev/null
@@ -1,123 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, List, Literal
-from uuid import UUID
-
-from quivr_api.modules.sync.dto.inputs import (
- SyncFileInput,
- SyncFileUpdateInput,
- SyncsActiveInput,
- SyncsActiveUpdateInput,
- SyncsUserInput,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.entity.sync_models import (
- DBSyncFile,
- SyncFile,
- SyncsActive,
-)
-
-
-class SyncUserInterface(ABC):
- @abstractmethod
- def create_sync_user(
- self,
- sync_user_input: SyncsUserInput,
- ):
- pass
-
- @abstractmethod
- def get_syncs_user(self, user_id: str, sync_user_id: int | None = None):
- pass
-
- @abstractmethod
- def get_sync_user_by_id(self, sync_id: int):
- pass
-
- @abstractmethod
- def delete_sync_user(self, sync_user_id: int, user_id: UUID | str):
- pass
-
- @abstractmethod
- def get_sync_user_by_state(self, state: dict):
- pass
-
- @abstractmethod
- def update_sync_user(
- self, sync_user_id: int, state: dict, sync_user_input: SyncUserUpdateInput
- ):
- pass
-
- @abstractmethod
- async def get_files_folder_user_sync(
- self,
- sync_active_id: int,
- user_id: str,
- notion_service: Any = None,
- folder_id: int | str | None = None,
- recursive: bool = False,
- ) -> None | dict[str, List[SyncFile]] | Literal["No sync found"]:
- pass
-
- @abstractmethod
- def get_all_notion_user_syncs(self):
- pass
-
-
-class SyncInterface(ABC):
- @abstractmethod
- def create_sync_active(
- self,
- sync_active_input: SyncsActiveInput,
- user_id: str,
- ) -> SyncsActive | None:
- pass
-
- @abstractmethod
- def get_syncs_active(self, user_id: UUID | str) -> List[SyncsActive]:
- pass
-
- @abstractmethod
- def update_sync_active(
- self, sync_id: UUID | int, sync_active_input: SyncsActiveUpdateInput
- ):
- pass
-
- @abstractmethod
- def delete_sync_active(self, sync_active_id: int, user_id: str):
- pass
-
- @abstractmethod
- def get_details_sync_active(self, sync_active_id: int):
- pass
-
- @abstractmethod
- async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
- pass
-
-
-class SyncFileInterface(ABC):
- @abstractmethod
- def create_sync_file(self, sync_file_input: SyncFileInput) -> DBSyncFile:
- pass
-
- @abstractmethod
- def get_sync_files(self, sync_active_id: int) -> list[DBSyncFile]:
- pass
-
- @abstractmethod
- def update_sync_file(self, sync_file_id: int, sync_file_input: SyncFileUpdateInput):
- pass
-
- @abstractmethod
- def delete_sync_file(self, sync_file_id: int):
- pass
-
- @abstractmethod
- def update_or_create_sync_file(
- self,
- file: SyncFile,
- sync_active: SyncsActive,
- previous_file: DBSyncFile | None,
- supported: bool,
- ) -> DBSyncFile | None:
- pass
diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py
index e669d582e6e1..584074b1fe29 100644
--- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py
+++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py
@@ -1,327 +1,213 @@
-from datetime import datetime, timedelta
-from typing import List, Sequence
+from sqlite3 import IntegrityError
+from typing import Any, List
from uuid import UUID
-from sqlalchemy import or_
-from sqlalchemy.exc import IntegrityError
-from sqlmodel import col, select
+from sqlmodel import delete, select
from sqlmodel.ext.asyncio.session import AsyncSession
from quivr_api.logger import get_logger
from quivr_api.modules.dependencies import BaseRepository, get_supabase_client
-from quivr_api.modules.notification.service.notification_service import (
- NotificationService,
+from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile
+from quivr_api.modules.sync.repository.notion_repository import NotionRepository
+from quivr_api.modules.sync.service.sync_notion import SyncNotionService
+from quivr_api.modules.sync.utils.sync import (
+ AzureDriveSync,
+ BaseSync,
+ DropboxSync,
+ GitHubSync,
+ GoogleDriveSync,
+ NotionSync,
+)
+from quivr_api.modules.sync.utils.sync_exceptions import (
+ SyncEmptyCredentials,
+ SyncNotFoundException,
+ SyncProviderError,
+ SyncUpdateError,
)
-from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
-from quivr_api.modules.sync.entity.sync_models import NotionSyncFile, SyncsActive
-from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface
-
-notification_service = NotificationService()
logger = get_logger(__name__)
-class Sync(SyncInterface):
- def __init__(self):
- """
- Initialize the Sync class with a Supabase client.
- """
- supabase_client = get_supabase_client()
- self.db = supabase_client # type: ignore
- logger.debug("Supabase client initialized")
-
- def create_sync_active(
- self, sync_active_input: SyncsActiveInput, user_id: str
- ) -> SyncsActive | None:
- """
- Create a new active sync in the database.
-
- Args:
- sync_active_input (SyncsActiveInput): The input data for creating an active sync.
- user_id (str): The user ID associated with the active sync.
-
- Returns:
- SyncsActive or None: The created active sync data or None if creation failed.
- """
- logger.info(
- "Creating active sync for user_id: %s with input: %s",
- user_id,
- sync_active_input,
- )
- sync_active_input_dict = sync_active_input.model_dump()
- sync_active_input_dict["user_id"] = user_id
- response = (
- self.db.from_("syncs_active").insert(sync_active_input_dict).execute()
- )
- if response.data:
- logger.info("Active sync created successfully: %s", response.data[0])
- return SyncsActive(**response.data[0])
-
- logger.error("Failed to create active sync for user_id: %s", user_id)
+class SyncsRepository(BaseRepository):
+ def __init__(
+ self,
+ session: AsyncSession,
+ sync_provider_mapping: dict[SyncProvider, BaseSync] | None = None,
+ ):
+ self.session = session
+ self.db = get_supabase_client()
- def get_syncs_active(self, user_id: UUID | str) -> List[SyncsActive]:
+ if sync_provider_mapping is None:
+ self.sync_provider_mapping: dict[SyncProvider, BaseSync] = {
+ SyncProvider.GOOGLE: GoogleDriveSync(),
+ SyncProvider.DROPBOX: DropboxSync(),
+ SyncProvider.AZURE: AzureDriveSync(),
+ SyncProvider.NOTION: NotionSync(
+ notion_service=SyncNotionService(NotionRepository(self.session))
+ ),
+ SyncProvider.GITHUB: GitHubSync(),
+ }
+ else:
+ self.sync_provider_mapping = sync_provider_mapping
+
+ async def create_sync(
+ self,
+ sync_user_input: SyncCreateInput,
+ ) -> Sync:
"""
- Retrieve active syncs from the database.
+ Create a new sync user in the database.
Args:
- user_id (str): The user ID to filter active syncs.
+ sync_user_input (SyncsUserInput): The input data for creating a sync user.
Returns:
- List[SyncsActive]: A list of active syncs matching the criteria.
"""
- logger.info("Retrieving active syncs for user_id: %s", user_id)
- response = (
- self.db.from_("syncs_active")
- .select("*, syncs_user(*)")
- .eq("user_id", user_id)
- .execute()
- )
- if response.data:
- logger.info("Active syncs retrieved successfully: %s", response.data)
- return [SyncsActive(**sync) for sync in response.data]
- logger.warning("No active syncs found for user_id: %s", user_id)
- return []
+ logger.info(f"Creating sync user with input: {sync_user_input}")
+ try:
+ sync = Sync.model_validate(sync_user_input.model_dump())
+ self.session.add(sync)
+ await self.session.commit()
+ await self.session.refresh(sync)
+ return sync
+ except IntegrityError:
+ await self.session.rollback()
+ raise
+ except Exception:
+ await self.session.rollback()
+ raise
- def update_sync_active(
- self, sync_id: int | str, sync_active_input: SyncsActiveUpdateInput
- ) -> SyncsActive | None:
+ async def get_sync_id(self, sync_id: int, user_id: UUID | None = None) -> Sync:
"""
- Update an active sync in the database.
-
- Args:
- sync_id (int): The ID of the active sync.
- sync_active_input (SyncsActiveUpdateInput): The input data for updating the active sync.
-
- Returns:
- dict or None: The updated active sync data or None if update failed.
+ Retrieve sync users from the database.
"""
- logger.info(
- "Updating active sync with sync_id: %s, input: %s",
- sync_id,
- sync_active_input,
- )
-
- response = (
- self.db.from_("syncs_active")
- .update(sync_active_input.model_dump(exclude_unset=True))
- .eq("id", sync_id)
- .execute()
- )
+ query = select(Sync).where(Sync.id == sync_id)
- if response.data:
- logger.info("Active sync updated successfully: %s", response.data[0])
- return SyncsActive.model_validate(response.data[0])
+ if user_id:
+ query = query.where(Sync.user_id == user_id)
+ result = await self.session.exec(query)
+ sync = result.first()
- logger.error("Failed to update active sync with sync_id: %s", sync_id)
+ if not sync:
+ logger.error(
+ f"No sync user found for sync_id: {sync_id}",
+ )
+ raise SyncNotFoundException()
+ return sync
- def delete_sync_active(self, sync_active_id: int, user_id: UUID):
+ async def get_syncs(self, user_id: UUID, sync_id: int | None = None):
"""
- Delete an active sync from the database.
+ Retrieve sync users from the database.
Args:
- sync_active_id (int): The ID of the active sync.
- user_id (str): The user ID associated with the active sync.
+ user_id (str): The user ID to filter sync users.
+ sync_user_id (int, optional): The sync user ID to filter sync users. Defaults to None.
Returns:
- dict or None: The deleted active sync data or None if deletion failed.
+ list: A list of sync users matching the criteria.
"""
logger.info(
- "Deleting active sync with sync_active_id: %s, user_id: %s",
- sync_active_id,
- user_id,
- )
- response = (
- self.db.from_("syncs_active")
- .delete()
- .eq("id", sync_active_id)
- .eq("user_id", str(user_id))
- .execute()
+ f"Retrieving sync users for user_id: {user_id}, sync_user_id: {sync_id}",
)
- if response.data:
- logger.info("Active sync deleted successfully: %s", response.data[0])
- return response.data[0]
- logger.warning(
- "Failed to delete active sync with sync_active_id: %s, user_id: %s",
- sync_active_id,
- user_id,
- )
- return None
+ query = select(Sync).where(Sync.user_id == user_id)
+ if sync_id is not None:
+ query = query.where(Sync.id == sync_id)
+ result = await self.session.exec(query)
+ return list(result.all())
- def get_details_sync_active(self, sync_active_id: int):
+ async def get_sync_user_by_state(self, state: dict) -> Sync:
"""
- Retrieve details of an active sync, including associated sync user data.
+ Retrieve a sync user by their state.
Args:
- sync_active_id (int): The ID of the active sync.
-
- Returns:
- dict or None: The detailed active sync data or None if not found.
- """
- logger.info(
- "Retrieving details for active sync with sync_active_id: %s", sync_active_id
- )
- response = (
- self.db.table("syncs_active")
- .select("*, syncs_user(provider, credentials)")
- .eq("id", sync_active_id)
- .execute()
- )
- if response.data:
- logger.info(
- "Details for active sync retrieved successfully: %s", response.data[0]
- )
- return response.data[0]
- logger.warning(
- "No details found for active sync with sync_active_id: %s", sync_active_id
- )
- return None
-
- async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
- """
- Retrieve active syncs that are due for synchronization based on their interval.
+ state (dict): The state to filter sync users.
Returns:
- list: A list of active syncs that are due for synchronization.
+ dict or None: The sync user data matching the state or None if not found.
"""
- logger.info("Retrieving active syncs due for synchronization")
+ logger.info(f"Getting sync user by state: {state}")
- current_time = datetime.now()
-
- # The Query filters the active syncs based on the sync_interval_minutes field and last_synced timestamp
- response = (
- self.db.table("syncs_active")
- .select("*")
- .lt("last_synced", (current_time - timedelta(minutes=360)).isoformat())
- .execute()
- )
+ query = select(Sync).where(Sync.state == state)
+ result = await self.session.exec(query)
+ sync = result.first()
+ if not sync:
+ raise SyncNotFoundException()
+ return sync
- force_sync = (
- self.db.table("syncs_active").select("*").eq("force_sync", True).execute()
- )
- merge_data = response.data + force_sync.data
- if merge_data:
- logger.info("Active syncs retrieved successfully: %s", merge_data)
- return [SyncsActive(**sync) for sync in merge_data]
- logger.info("No active syncs found due for synchronization")
- return []
-
-
-class NotionRepository(BaseRepository):
- def __init__(self, session: AsyncSession):
- super().__init__(session)
- self.session = session
- self.db = get_supabase_client()
+ return None
- async def get_user_notion_files(
- self, user_id: UUID, sync_user_id: int
- ) -> Sequence[NotionSyncFile]:
- query = select(NotionSyncFile).where(
- NotionSyncFile.user_id == user_id
- and NotionSyncFile.sync_user_id == sync_user_id
+ async def delete_sync(self, sync_id: int, user_id: UUID):
+ logger.info(f"Deleting sync user with sync_id: {sync_id}, user_id: {user_id}")
+ await self.session.execute(
+ delete(Sync).where(Sync.id == sync_id).where(Sync.user_id == user_id) # type: ignore
)
- response = await self.session.exec(query)
- return response.all()
+ logger.info("Sync user deleted successfully")
- async def create_notion_files(
- self, notion_files: List[NotionSyncFile]
- ) -> List[NotionSyncFile]:
+ async def update_sync(
+ self, sync: Sync, sync_input: SyncUpdateInput | dict[str, Any]
+ ):
+ logger.debug(f"Updating sync user with user_id: {sync.id}")
try:
- self.session.add_all(notion_files)
- await self.session.commit()
- except IntegrityError:
- await self.session.rollback()
- raise Exception("Integrity error while creating notion files.")
- except Exception as e:
- await self.session.rollback()
- raise e
-
- return notion_files
-
- async def update_notion_file(self, updated_notion_file: NotionSyncFile) -> bool:
- try:
- is_update = False
- query = select(NotionSyncFile).where(
- NotionSyncFile.notion_id == updated_notion_file.notion_id
- )
- result = await self.session.exec(query)
- existing_page = result.one_or_none()
-
- if existing_page:
- # Update existing page
- existing_page.name = updated_notion_file.name
- existing_page.last_modified = updated_notion_file.last_modified
- self.session.add(existing_page)
- is_update = True
+ if isinstance(sync_input, dict):
+ update_data = sync_input
else:
- # Add new page
- self.session.add(updated_notion_file)
+ update_data = sync_input.model_dump(exclude_unset=True)
+ for field in update_data:
+ setattr(sync, field, update_data[field])
+ self.session.add(sync)
await self.session.commit()
-
- # Refresh the object that's actually in the session
- refreshed_file = existing_page if existing_page else updated_notion_file
- await self.session.refresh(refreshed_file)
-
- logger.info(f"Updated notion file in notion repo: {refreshed_file}")
- return is_update
-
- except IntegrityError as ie:
- logger.error(f"IntegrityError occurred: {ie}")
- await self.session.rollback()
- raise Exception("Integrity error while updating notion file.")
- except Exception as e:
- logger.error(f"Exception occurred: {e}")
+ await self.session.refresh(sync)
+ return sync
+ except IntegrityError as e:
await self.session.rollback()
- raise
+ logger.error(f"Error updating knowledge {e}")
+ raise SyncUpdateError
- async def get_notion_files_by_ids(self, ids: List[str]) -> Sequence[NotionSyncFile]:
- query = select(NotionSyncFile).where(NotionSyncFile.notion_id.in_(ids)) # type: ignore
- response = await self.session.exec(query)
- return response.all()
+ def get_all_notion_user_syncs(self):
+ """
+ Retrieve all Notion sync users from the database.
- async def get_notion_files_by_parent_id(
- self, parent_id: str | None, sync_user_id: int
- ) -> Sequence[NotionSyncFile]:
- query = (
- select(NotionSyncFile)
- .where(NotionSyncFile.parent_id == parent_id)
- .where(NotionSyncFile.sync_user_id == sync_user_id)
+ Returns:
+ list: A list of Notion sync users.
+ """
+ logger.info("Retrieving all Notion sync users")
+ response = (
+ self.db.from_("syncs_user").select("*").eq("provider", "Notion").execute()
)
- response = await self.session.exec(query)
- return response.all()
+ if response.data:
+ logger.info("Notion sync users retrieved successfully")
+ return response.data
+ return []
- async def get_all_notion_files(self) -> Sequence[NotionSyncFile]:
- query = select(NotionSyncFile)
- response = await self.session.exec(query)
- return response.all()
+ async def get_files_folder_user_sync(
+ self,
+ sync_id: int,
+ user_id: UUID,
+ folder_id: str | None = None,
+ recursive: bool = False,
+ ) -> List[SyncFile] | None:
+ logger.info(
+ f"Retrieving files for user sync with sync_active_id: {sync_id}, user_id: {user_id}, folder_id: {folder_id}",
+ )
+ sync = await self.get_sync_id(sync_id=sync_id, user_id=user_id)
+ if not sync:
+ logger.error(
+ f"No sync user found for sync_active_id: {sync_id}, user_id: {user_id}",
+ )
+ return None
- async def is_folder_page(self, page_id: str) -> bool:
- query = select(NotionSyncFile).where(NotionSyncFile.parent_id == page_id)
- response = await self.session.exec(query)
- return response.first() is not None
+ try:
+ sync_provider = self.sync_provider_mapping[
+ SyncProvider(sync.provider.lower())
+ ]
+ except KeyError:
+ raise SyncProviderError
- async def delete_notion_page(self, notion_id: UUID):
- query = select(NotionSyncFile).where(NotionSyncFile.notion_id == notion_id)
- response = await self.session.exec(query)
- notion_file = response.first()
- if notion_file:
- await self.session.delete(notion_file)
- await self.session.commit()
- return notion_file
- return None
+ if sync.credentials is None:
+ raise SyncEmptyCredentials
- async def delete_notion_pages(self, notion_ids: List[UUID]):
- query = select(NotionSyncFile).where(
- or_(
- col(NotionSyncFile.notion_id).in_(notion_ids),
- col(NotionSyncFile.parent_id).in_(notion_ids),
- )
+ return await sync_provider.aget_files(
+ sync.credentials, folder_id if folder_id else "", recursive
)
- response = await self.session.exec(query)
- notion_files = response.all()
- if notion_files:
- for notion_file in notion_files:
- await self.session.delete(notion_file)
- await self.session.commit()
- return notion_files
- return None
diff --git a/backend/api/quivr_api/modules/sync/repository/sync_user.py b/backend/api/quivr_api/modules/sync/repository/sync_user.py
deleted file mode 100644
index 09ff5007d7b4..000000000000
--- a/backend/api/quivr_api/modules/sync/repository/sync_user.py
+++ /dev/null
@@ -1,319 +0,0 @@
-import json
-from typing import List, Literal
-from uuid import UUID
-
-from quivr_api.logger import get_logger
-from quivr_api.modules.dependencies import get_supabase_client
-from quivr_api.modules.sync.dto.inputs import (
- SyncsUserInput,
- SyncUserUpdateInput,
-)
-from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncsUser
-from quivr_api.modules.sync.service.sync_notion import SyncNotionService
-from quivr_api.modules.sync.utils.sync import (
- AzureDriveSync,
- BaseSync,
- DropboxSync,
- GitHubSync,
- GoogleDriveSync,
- NotionSync,
-)
-
-logger = get_logger(__name__)
-
-
-class SyncUserRepository:
- def __init__(self):
- """
- Initialize the Sync class with a Supabase client.
- """
- supabase_client = get_supabase_client()
- self.db = supabase_client # type: ignore
- logger.debug("Supabase client initialized")
-
- def create_sync_user(
- self,
- sync_user_input: SyncsUserInput,
- ):
- """
- Create a new sync user in the database.
-
- Args:
- sync_user_input (SyncsUserInput): The input data for creating a sync user.
-
- Returns:
- dict or None: The created sync user data or None if creation failed.
- """
- logger.info("Creating sync user with input: %s", sync_user_input)
- response = (
- self.db.from_("syncs_user")
- .insert(sync_user_input.model_dump(exclude_none=True, exclude_unset=True))
- .execute()
- )
-
- if response.data:
- logger.info("Sync user created successfully: %s", response.data[0])
- return response.data[0]
- logger.warning("Failed to create sync user")
-
- def get_sync_user_by_id(self, sync_id: int) -> SyncsUser | None:
- """
- Retrieve sync users from the database.
- """
- response = self.db.from_("syncs_user").select("*").eq("id", sync_id).execute()
- if response.data:
- logger.info("Sync user found: %s", response.data[0])
- return SyncsUser.model_validate(response.data[0])
- logger.error("No sync user found for sync_id: %s", sync_id)
-
- def clean_notion_user_syncs(self):
- """
- Clean all Removed Notion sync users from the database.
- """
- logger.info("Cleaning all Removed Notion sync users")
- self.db.from_("syncs_user").delete().eq("provider", "Notion").eq(
- "status", "REMOVED"
- ).execute()
- logger.info("Removed Notion sync users cleaned successfully")
-
- def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
- """
- Retrieve sync users from the database.
-
- Args:
- user_id (str): The user ID to filter sync users.
- sync_user_id (int, optional): The sync user ID to filter sync users. Defaults to None.
-
- Returns:
- list: A list of sync users matching the criteria.
- """
- logger.info(
- "Retrieving sync users for user_id: %s, sync_user_id: %s",
- user_id,
- sync_user_id,
- )
- query = (
- self.db.from_("syncs_user").select("*").eq("user_id", user_id)
- # .neq("status", "REMOVED")
- )
- if sync_user_id:
- query = query.eq("id", str(sync_user_id))
- response = query.execute()
- if response.data:
- # logger.info("Sync users retrieved successfully: %s", response.data)
- return response.data
- logger.warning(
- "No sync users found for user_id: %s, sync_user_id: %s",
- user_id,
- sync_user_id,
- )
- return []
-
- def get_sync_user_by_state(self, state: dict) -> SyncsUser | None:
- """
- Retrieve a sync user by their state.
-
- Args:
- state (dict): The state to filter sync users.
-
- Returns:
- dict or None: The sync user data matching the state or None if not found.
- """
- logger.info("Getting sync user by state: %s", state)
-
- state_str = json.dumps(state)
- response = (
- self.db.from_("syncs_user").select("*").eq("state", state_str).execute()
- )
- if response.data and len(response.data) > 0:
- logger.info("Sync user found by state: %s", response.data[0])
- sync_user = SyncsUser.model_validate(response.data[0])
- return sync_user
- logger.error("No sync user found for state: %s", state)
- return None
-
- def delete_sync_user(self, sync_id: int, user_id: UUID | str):
- """
- Delete a sync user from the database.
-
- Args:
- provider (str): The provider of the sync user.
- user_id (str): The user ID of the sync user.
- """
- logger.info(
- "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id
- )
- self.db.from_("syncs_user").delete().eq("id", sync_id).eq(
- "user_id", user_id
- ).execute()
-
- logger.info("Sync user deleted successfully")
-
- def update_sync_user(
- self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput
- ):
- """
- Update a sync user in the database.
-
- Args:
- sync_user_id (str): The user ID of the sync user.
- state (dict): The state to filter sync users.
- sync_user_input (SyncUserUpdateInput): The input data for updating the sync user.
- """
- logger.info(
- "Updating sync user with user_id: %s, state: %s, input: %s",
- sync_user_id,
- state,
- sync_user_input,
- )
-
- state_str = json.dumps(state)
- self.db.from_("syncs_user").update(
- sync_user_input.model_dump(exclude_unset=True)
- ).eq("user_id", str(sync_user_id)).eq("state", state_str).execute()
- logger.info("Sync user updated successfully")
-
- def update_sync_user_status(self, sync_user_id: int, status: str):
- """
- Update the status of a sync user in the database.
-
- Args:
- sync_user_id (str): The user ID of the sync user.
- status (str): The new status of the sync user.
- """
- logger.info(
- "Updating sync user status with user_id: %s, status: %s",
- sync_user_id,
- status,
- )
-
- self.db.from_("syncs_user").update({"status": status}).eq(
- "id", str(sync_user_id)
- ).execute()
- logger.info("Sync user status updated successfully")
-
- def get_all_notion_user_syncs(self):
- """
- Retrieve all Notion sync users from the database.
-
- Returns:
- list: A list of Notion sync users.
- """
- logger.info("Retrieving all Notion sync users")
- response = (
- self.db.from_("syncs_user").select("*").eq("provider", "Notion").execute()
- )
- if response.data:
- logger.info("Notion sync users retrieved successfully")
- return response.data
- logger.warning("No Notion sync users found")
- return []
-
- async def get_files_folder_user_sync(
- self,
- sync_active_id: int,
- user_id: UUID,
- notion_service: SyncNotionService | None,
- folder_id: str | None = None,
- recursive: bool = False,
- ) -> None | dict[str, List[SyncFile]] | Literal["No sync found"]:
- """
- Retrieve files from a user's sync folder, either from Google Drive or Azure.
-
- Args:
- sync_active_id (int): The ID of the active sync.
- user_id (str): The user ID associated with the active sync.
- folder_id (str, optional): The folder ID to filter files. Defaults to None.
-
- Returns:
- dict or str: A dictionary containing the list of files or a string indicating the sync provider.
- """
- logger.info(
- "Retrieving files for user sync with sync_active_id: %s, user_id: %s, folder_id: %s",
- sync_active_id,
- user_id,
- folder_id,
- )
- # Check whether the sync is Google or Azure
- sync_user = self.get_syncs_user(user_id=user_id, sync_user_id=sync_active_id)
- if not sync_user:
- logger.warning(
- "No sync user found for sync_active_id: %s, user_id: %s",
- sync_active_id,
- user_id,
- )
- return None
-
- sync_user = sync_user[0]
- sync: BaseSync
-
- provider = sync_user["provider"].lower()
- if provider == "google":
- logger.info("Getting files for Google sync")
- sync = GoogleDriveSync()
- return {"files": sync.get_files(sync_user["credentials"], folder_id)}
- elif provider == "azure":
- logger.info("Getting files for Azure sync")
- sync = AzureDriveSync()
- return {
- "files": sync.get_files(sync_user["credentials"], folder_id, recursive)
- }
- elif provider == "dropbox":
- logger.info("Getting files for Drop Box sync")
- sync = DropboxSync()
- return {
- "files": sync.get_files(
- sync_user["credentials"], folder_id if folder_id else "", recursive
- )
- }
- elif provider == "notion":
- if notion_service is None:
- raise ValueError("provider notion but notion_service is None")
- logger.info("Getting files for Notion sync")
- sync = NotionSync(notion_service=notion_service)
- return {
- "files": await sync.aget_files(
- sync_user["credentials"],
- folder_id if folder_id else "",
- recursive,
- sync_user["id"],
- )
- }
- elif provider == "github":
- logger.info("Getting files for GitHub sync")
- sync = GitHubSync()
- return {
- "files": sync.get_files(
- sync_user["credentials"], folder_id if folder_id else "", recursive
- )
- }
-
- else:
- logger.warning(
- "No sync found for provider: %s", sync_user["provider"], recursive
- )
- return "No sync found"
-
- def get_corresponding_deleted_sync(self, user_id: str) -> SyncsUser | None:
- """
- Retrieve the deleted sync user from the database.
- """
- logger.info(
- "Retrieving notion deleted sync user for user_id: %s",
- user_id,
- )
- response = (
- self.db.from_("syncs_user")
- .select("*")
- .eq("user_id", user_id)
- .eq("provider", "Notion")
- .eq("status", "REMOVED")
- .execute()
- )
- if response.data:
- logger.info(
- "Deleted sync user retrieved successfully: %s", response.data[0]
- )
- return SyncsUser.model_validate(response.data[0])
- logger.error("No deleted notion sync user found for user_id: %s", user_id)
- return None
diff --git a/backend/api/quivr_api/modules/sync/service/sync_notion.py b/backend/api/quivr_api/modules/sync/service/sync_notion.py
index 5eca27d57fda..a96ccf514c5d 100644
--- a/backend/api/quivr_api/modules/sync/service/sync_notion.py
+++ b/backend/api/quivr_api/modules/sync/service/sync_notion.py
@@ -9,7 +9,7 @@
from quivr_api.modules.dependencies import BaseService
from quivr_api.modules.sync.entity.notion_page import NotionPage, NotionSearchResult
from quivr_api.modules.sync.entity.sync_models import NotionSyncFile
-from quivr_api.modules.sync.repository.sync_repository import NotionRepository
+from quivr_api.modules.sync.repository.notion_repository import NotionRepository
logger = get_logger(__name__)
diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py
index 498ba3bc94fe..ea440ac9884e 100644
--- a/backend/api/quivr_api/modules/sync/service/sync_service.py
+++ b/backend/api/quivr_api/modules/sync/service/sync_service.py
@@ -1,119 +1,86 @@
-from abc import ABC, abstractmethod
-from typing import Dict, List, Union
+from typing import Any
from uuid import UUID
+from fastapi import HTTPException
+
from quivr_api.logger import get_logger
+from quivr_api.modules.dependencies import BaseService
from quivr_api.modules.sync.dto.inputs import (
- SyncsActiveInput,
- SyncsActiveUpdateInput,
- SyncsUserInput,
- SyncsUserStatus,
- SyncUserUpdateInput,
+ SyncCreateInput,
+ SyncStatus,
+ SyncUpdateInput,
)
-from quivr_api.modules.sync.entity.sync_models import SyncsActive, SyncsUser
-from quivr_api.modules.sync.repository.sync_repository import Sync
-from quivr_api.modules.sync.repository.sync_user import SyncUserRepository
-from quivr_api.modules.sync.service.sync_notion import SyncNotionService
+from quivr_api.modules.sync.dto.outputs import SyncProvider, SyncsOutput
+from quivr_api.modules.sync.repository.sync_repository import SyncsRepository
+from quivr_api.modules.sync.utils.oauth2 import Oauth2BaseState, Oauth2State
logger = get_logger(__name__)
-class ISyncUserService(ABC):
- @abstractmethod
- def get_syncs_user(self, user_id: UUID, sync_user_id: Union[int, None] = None):
- pass
+class SyncsService(BaseService[SyncsRepository]):
+ repository_cls = SyncsRepository
- @abstractmethod
- def create_sync_user(self, sync_user_input: SyncsUserInput):
- pass
+ def __init__(self, repository: SyncsRepository):
+ self.repository = repository
- @abstractmethod
- def delete_sync_user(self, sync_id: int, user_id: str):
- pass
+ async def create_sync_user(self, sync_user_input: SyncCreateInput) -> SyncsOutput:
+ sync = await self.repository.create_sync(sync_user_input)
+ return sync.to_dto()
- @abstractmethod
- def get_sync_user_by_state(self, state: Dict) -> Union["SyncsUser", None]:
- pass
+ async def get_user_syncs(self, user_id: UUID, sync_id: int | None = None):
+ return await self.repository.get_syncs(user_id=user_id, sync_id=sync_id)
- @abstractmethod
- def get_sync_user_by_id(self, sync_id: int):
- pass
+ async def delete_sync(self, sync_id: int, user_id: UUID):
+ await self.repository.delete_sync(sync_id, user_id)
- @abstractmethod
- def update_sync_user(
- self, sync_user_id: UUID, state: Dict, sync_user_input: SyncUserUpdateInput
- ):
- pass
+ async def get_sync_by_id(self, sync_id: int):
+ return await self.repository.get_sync_id(sync_id)
- @abstractmethod
- def get_all_notion_user_syncs(self):
- pass
+ async def get_from_oauth2_state(self, state: Oauth2State) -> SyncsOutput:
+ assert state.sync_id, "state should have associated sync_id"
+ sync = await self.get_sync_by_id(state.sync_id)
- @abstractmethod
- async def get_files_folder_user_sync(
+ # TODO: redo these exceptions
+ if (
+ not sync
+ or not sync.state
+ or state.model_dump_json(exclude={"sync_id"}) != sync.state["state"]
+ ):
+ logger.error("Invalid state parameter")
+ raise HTTPException(status_code=400, detail="Invalid state parameter")
+ if sync.user_id != state.user_id:
+ raise HTTPException(status_code=400, detail="Invalid user")
+ return sync.to_dto()
+
+ async def create_oauth2_state(
self,
- sync_active_id: int,
+ provider: SyncProvider,
+ name: str,
user_id: UUID,
- folder_id: Union[str, None] = None,
- recursive: bool = False,
- notion_service: Union["SyncNotionService", None] = None,
- ):
- pass
-
-
-class SyncUserService(ISyncUserService):
- def __init__(self):
- self.repository = SyncUserRepository()
-
- def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
- return self.repository.get_syncs_user(user_id, sync_user_id)
-
- def create_sync_user(self, sync_user_input: SyncsUserInput):
- if sync_user_input.provider == "Notion":
- response = self.repository.get_corresponding_deleted_sync(
- user_id=sync_user_input.user_id
- )
- if response:
- raise ValueError("User removed this connection less than 24 hours ago")
-
- return self.repository.create_sync_user(sync_user_input)
-
- def delete_sync_user(self, sync_id: int, user_id: str):
- sync_user = self.repository.get_sync_user_by_id(sync_id)
- if sync_user and sync_user.provider == "Notion":
- sync_user_input = SyncUserUpdateInput(
- email=str(sync_user.email),
- credentials=sync_user.credentials,
- state=sync_user.state,
- status=str(SyncsUserStatus.REMOVED),
- )
- self.repository.update_sync_user(
- sync_user_id=sync_user.user_id,
- state=sync_user.state,
- sync_user_input=sync_user_input,
- )
- return None
- else:
- return self.repository.delete_sync_user(sync_id, user_id)
-
- def clean_notion_user_syncs(self):
- return self.repository.clean_notion_user_syncs()
-
- def get_sync_user_by_state(self, state: dict) -> SyncsUser | None:
- return self.repository.get_sync_user_by_state(state)
-
- def get_sync_user_by_id(self, sync_id: int):
- return self.repository.get_sync_user_by_id(sync_id)
-
- def update_sync_user(
- self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput
- ):
- return self.repository.update_sync_user(sync_user_id, state, sync_user_input)
+ additional_data: dict[str, Any] = {},
+ ) -> Oauth2State:
+ state_struct = Oauth2BaseState(name=name, user_id=user_id)
+ state = state_struct.model_dump_json()
+ sync_user_input = SyncCreateInput(
+ name=name,
+ user_id=user_id,
+ provider=provider,
+ credentials={},
+ state={"state": state},
+ additional_data=additional_data,
+ status=SyncStatus.SYNCING,
+ )
+ sync = await self.create_sync_user(sync_user_input)
+ return Oauth2State(sync_id=sync.id, **state_struct.model_dump())
- def update_sync_user_status(self, sync_user_id: int, status: str):
- return self.repository.update_sync_user_status(sync_user_id, status)
+ async def update_sync(
+ self, sync_id: int, sync_user_input: SyncUpdateInput
+ ) -> SyncsOutput:
+ sync = await self.repository.get_sync_id(sync_id)
+ sync = await self.repository.update_sync(sync, sync_user_input)
+ return sync.to_dto()
- def get_all_notion_user_syncs(self):
+ async def get_all_notion_user_syncs(self):
return self.repository.get_all_notion_user_syncs()
async def get_files_folder_user_sync(
@@ -122,69 +89,10 @@ async def get_files_folder_user_sync(
user_id: UUID,
folder_id: str | None = None,
recursive: bool = False,
- notion_service: SyncNotionService | None = None,
):
return await self.repository.get_files_folder_user_sync(
- sync_active_id=sync_active_id,
+ sync_id=sync_active_id,
user_id=user_id,
folder_id=folder_id,
recursive=recursive,
- notion_service=notion_service,
)
-
-
-class ISyncService(ABC):
- @abstractmethod
- def create_sync_active(
- self, sync_active_input: SyncsActiveInput, user_id: str
- ) -> Union["SyncsActive", None]:
- pass
-
- @abstractmethod
- def get_syncs_active(self, user_id: str) -> List[SyncsActive]:
- pass
-
- @abstractmethod
- def update_sync_active(
- self, sync_id: int, sync_active_input: SyncsActiveUpdateInput
- ):
- pass
-
- @abstractmethod
- def delete_sync_active(self, sync_active_id: int, user_id: UUID):
- pass
-
- @abstractmethod
- async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
- pass
-
- @abstractmethod
- def get_details_sync_active(self, sync_active_id: int):
- pass
-
-
-class SyncService(ISyncService):
- def __init__(self):
- self.repository = Sync()
-
- def create_sync_active(
- self, sync_active_input: SyncsActiveInput, user_id: str
- ) -> SyncsActive | None:
- return self.repository.create_sync_active(sync_active_input, user_id)
-
- def get_syncs_active(self, user_id: str) -> List[SyncsActive]:
- return self.repository.get_syncs_active(user_id)
-
- def update_sync_active(
- self, sync_id: int, sync_active_input: SyncsActiveUpdateInput
- ):
- return self.repository.update_sync_active(sync_id, sync_active_input)
-
- def delete_sync_active(self, sync_active_id: int, user_id: UUID):
- return self.repository.delete_sync_active(sync_active_id, user_id)
-
- async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
- return await self.repository.get_syncs_active_in_interval()
-
- def get_details_sync_active(self, sync_active_id: int):
- return self.repository.get_details_sync_active(sync_active_id)
diff --git a/backend/api/quivr_api/modules/sync/tests/conftest.py b/backend/api/quivr_api/modules/sync/tests/conftest.py
index 32392aeeb875..7c71f124828d 100644
--- a/backend/api/quivr_api/modules/sync/tests/conftest.py
+++ b/backend/api/quivr_api/modules/sync/tests/conftest.py
@@ -1,22 +1,14 @@
import json
import os
-import uuid
-from collections import defaultdict
-from datetime import datetime, timedelta
+from datetime import datetime
from io import BytesIO
from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Union
from uuid import UUID, uuid4
import pytest
-import pytest_asyncio
-from dotenv import load_dotenv
-from sqlmodel import select, text
-
-from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
-from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
-from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
-from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
+from sqlmodel import select
+
from quivr_api.modules.notification.dto.inputs import (
CreateNotification,
NotificationUpdatableProperties,
@@ -28,17 +20,6 @@
from quivr_api.modules.notification.repository.notifications_interface import (
NotificationInterface,
)
-from quivr_api.modules.notification.service.notification_service import (
- NotificationService,
-)
-from quivr_api.modules.sync.dto.inputs import (
- SyncFileInput,
- SyncFileUpdateInput,
- SyncsActiveInput,
- SyncsActiveUpdateInput,
- SyncsUserInput,
- SyncUserUpdateInput,
-)
from quivr_api.modules.sync.entity.notion_page import (
BlockParent,
DatabaseParent,
@@ -52,29 +33,14 @@
WorkspaceParent,
)
from quivr_api.modules.sync.entity.sync_models import (
- DBSyncFile,
- NotionSyncFile,
SyncFile,
- SyncsActive,
- SyncsUser,
-)
-from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface
-from quivr_api.modules.sync.service.sync_notion import SyncNotionService
-from quivr_api.modules.sync.service.sync_service import (
- ISyncService,
- ISyncUserService,
- SyncUserService,
)
from quivr_api.modules.sync.utils.sync import (
BaseSync,
)
-from quivr_api.modules.sync.utils.syncutils import (
- SyncUtils,
-)
from quivr_api.modules.user.entity.user_identity import User
pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
-load_dotenv()
@pytest.fixture(scope="function")
@@ -357,8 +323,8 @@ async def aget_files_by_id(
id=fid,
name=f"file_{fid}",
is_folder=False,
- last_modified=datetime.now().strftime(self.datetime_format),
- mime_type="txt",
+ last_modified_at=datetime.now(),
+ extension="txt",
web_view_link=f"{self.name}/{fid}",
)
for fid in file_ids
@@ -377,8 +343,8 @@ async def aget_files(
id=str(uuid4()),
name=f"file_in_{folder_id}",
is_folder=False,
- last_modified=datetime.now().strftime(self.datetime_format),
- mime_type="txt",
+ last_modified_at=datetime.now(),
+ extension="txt",
web_view_link=f"{self.name}/{fid}",
)
for fid in range(n_files)
@@ -441,377 +407,320 @@ def remove_notification_by_id(self, notification_id: UUID):
del self.received[notification_id]
-class MockSyncService(ISyncService):
- def __init__(self, sync_active: SyncsActive):
- self.syncs_active_user = {}
- self.syncs_active_id = {}
- self.syncs_active_user[sync_active.user_id] = sync_active
- self.syncs_active_id[sync_active.id] = sync_active
-
- def create_sync_active(
- self,
- sync_active_input: SyncsActiveInput,
- user_id: str,
- ) -> SyncsActive | None:
- sactive = SyncsActive(
- id=len(self.syncs_active_user) + 1,
- user_id=UUID(user_id),
- **sync_active_input.model_dump(),
- )
- self.syncs_active_user[user_id] = sactive
- return sactive
-
- def get_syncs_active(self, user_id: str) -> List[SyncsActive]:
- return self.syncs_active_user[user_id]
-
- def update_sync_active(
- self, sync_id: int, sync_active_input: SyncsActiveUpdateInput
- ):
- sync = self.syncs_active_id[sync_id]
- sync = SyncsActive(**sync.model_dump(), **sync_active_input.model_dump())
- self.syncs_active_id[sync_id] = sync
- return sync
-
- def delete_sync_active(self, sync_active_id: int, user_id: UUID):
- del self.syncs_active_id[sync_active_id]
- del self.syncs_active_user[user_id]
-
- async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
- return list(self.syncs_active_id.values())
-
- def get_details_sync_active(self, sync_active_id: int):
- return
-
-
-class MockSyncUserService(ISyncUserService):
- def __init__(self, sync_user: SyncsUser):
- self.map_id = {}
- self.map_userid = {}
- self.map_id[sync_user.id] = sync_user
- self.map_userid[sync_user.id] = sync_user
-
- def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
- return self.map_userid[user_id]
-
- def get_sync_user_by_id(self, sync_id: int):
- return self.map_id[sync_id]
-
- def create_sync_user(self, sync_user_input: SyncsUserInput):
- id = len(self.map_userid) + 1
- self.map_userid[sync_user_input.user_id] = SyncsUser(
- id=id, **sync_user_input.model_dump()
- )
- self.map_id[id] = self.map_userid[sync_user_input.user_id]
- return self.map_id[id]
-
- def delete_sync_user(self, sync_id: int, user_id: str):
- del self.map_userid[user_id]
- del self.map_userid[sync_id]
-
- def get_sync_user_by_state(self, state: dict) -> SyncsUser | None:
- return list(self.map_userid.values())[-1]
-
- def update_sync_user(
- self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput
- ):
- return
-
- def get_all_notion_user_syncs(self):
- return
-
- async def get_files_folder_user_sync(
- self,
- sync_active_id: int,
- user_id: UUID,
- folder_id: str | None = None,
- recursive: bool = False,
- notion_service: SyncNotionService | None = None,
- ):
- return
-
-
-class MockSyncFilesRepository(SyncFileInterface):
- def __init__(self):
- self.files_store = defaultdict(list)
- self.next_id = 1
-
- def create_sync_file(self, sync_file_input: SyncFileInput) -> Optional[DBSyncFile]:
- supported = sync_file_input.supported if sync_file_input.supported else True
- new_file = DBSyncFile(
- id=self.next_id,
- path=sync_file_input.path,
- syncs_active_id=sync_file_input.syncs_active_id,
- last_modified=sync_file_input.last_modified,
- brain_id=sync_file_input.brain_id,
- supported=supported,
- )
- self.files_store[sync_file_input.syncs_active_id].append(new_file)
- self.next_id += 1
- return new_file
-
- def get_sync_files(self, sync_active_id: int) -> List[DBSyncFile]:
- """
- Retrieve sync files from the mock database.
-
- Args:
- sync_active_id (int): The ID of the active sync.
-
- Returns:
- List[DBSyncFile]: A list of sync files matching the criteria.
- """
- return self.files_store[sync_active_id]
-
- def update_sync_file(
- self, sync_file_id: int, sync_file_input: SyncFileUpdateInput
- ) -> None:
- for sync_files in self.files_store.values():
- for file in sync_files:
- if file.id == sync_file_id:
- update_data = sync_file_input.model_dump(exclude_unset=True)
- if "last_modified" in update_data:
- file.last_modified = update_data["last_modified"]
- if "supported" in update_data:
- file.supported = update_data["supported"]
- return
-
- def update_or_create_sync_file(
- self,
- file: SyncFile,
- sync_active: SyncsActive,
- previous_file: Optional[DBSyncFile],
- supported: bool,
- ) -> Optional[DBSyncFile]:
- if previous_file:
- self.update_sync_file(
- previous_file.id,
- SyncFileUpdateInput(
- last_modified=file.last_modified,
- supported=previous_file.supported or supported,
- ),
- )
- return previous_file
- else:
- return self.create_sync_file(
- SyncFileInput(
- path=file.name,
- syncs_active_id=sync_active.id,
- last_modified=file.last_modified,
- brain_id=str(sync_active.brain_id),
- supported=supported,
- )
- )
-
- def delete_sync_file(self, sync_file_id: int) -> None:
- for sync_active_id, sync_files in self.files_store.items():
- self.files_store[sync_active_id] = [
- file for file in sync_files if file.id != sync_file_id
- ]
-
-
-@pytest.fixture
-def sync_file():
- file = SyncFile(
- id=str(uuid4()),
- name="test_file.txt",
- is_folder=False,
- last_modified=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- mime_type=".txt",
- web_view_link="",
- notification_id=uuid4(), #
- )
- return file
-
-
-@pytest.fixture
-def prev_file():
- file = SyncFile(
- id=str(uuid4()),
- name="test_file.txt",
- is_folder=False,
- last_modified=(datetime.now() - timedelta(hours=1)).strftime(
- "%Y-%m-%d %H:%M:%S"
- ),
- mime_type="txt",
- web_view_link="",
- notification_id=uuid4(), #
- )
- return file
-
-
-@pytest_asyncio.fixture(scope="function")
-async def brain_user_setup(
- session,
-) -> Tuple[Brain, User]:
- user_1 = (
- await session.exec(select(User).where(User.email == "admin@quivr.app"))
- ).one()
- # Brain data
- brain_1 = Brain(
- name="test_brain",
- description="this is a test brain",
- brain_type=BrainType.integration,
- )
-
- session.add(brain_1)
- await session.refresh(user_1)
- await session.commit()
- assert user_1
- assert brain_1.brain_id
- return brain_1, user_1
-
-
-@pytest_asyncio.fixture(scope="function")
-async def sync_user_notion_setup(
- session,
-):
- sync_user_service = SyncUserService()
- user_1 = (
- await session.exec(select(User).where(User.email == "admin@quivr.app"))
- ).one()
-
- # Sync User
- sync_user_input = SyncsUserInput(
- user_id=str(user_1.id),
- name="sync_user_1",
- provider="notion",
- credentials={},
- state={},
- additional_data={},
- status="",
- )
- sync_user = SyncsUser.model_validate(
- sync_user_service.create_sync_user(sync_user_input)
- )
- assert sync_user.id
-
- # Notion pages
- notion_page_1 = NotionSyncFile(
- notion_id=uuid.uuid4(),
- sync_user_id=sync_user.id,
- user_id=sync_user.user_id,
- name="test",
- last_modified=datetime.now() - timedelta(hours=5),
- mime_type="txt",
- web_view_link="",
- icon="",
- is_folder=False,
- )
-
- notion_page_2 = NotionSyncFile(
- notion_id=uuid.uuid4(),
- sync_user_id=sync_user.id,
- user_id=sync_user.user_id,
- name="test_2",
- last_modified=datetime.now() - timedelta(hours=5),
- mime_type="txt",
- web_view_link="",
- icon="",
- is_folder=False,
- )
- session.add(notion_page_1)
- session.add(notion_page_2)
- yield sync_user
- await session.execute(
- text("DELETE FROM syncs_user WHERE id = :sync_id"), {"sync_id": sync_user.id}
- )
-
-
-@pytest_asyncio.fixture(scope="function")
-async def setup_syncs_data(
- brain_user_setup,
-) -> Tuple[SyncsUser, SyncsActive]:
- brain_1, user_1 = brain_user_setup
-
- sync_user = SyncsUser(
- id=0,
- user_id=user_1.id,
- name="c8xfz3g566b8xa1ajiesdh",
- provider="mock",
- credentials={},
- state={},
- additional_data={},
- status="",
- )
- sync_active = SyncsActive(
- id=0,
- name="test",
- syncs_user_id=sync_user.id,
- user_id=sync_user.user_id,
- settings={},
- last_synced=str(datetime.now() - timedelta(hours=5)),
- sync_interval_minutes=1,
- brain_id=brain_1.brain_id,
- )
-
- return (sync_user, sync_active)
-
-
-@pytest.fixture
-def syncutils(
- sync_file: SyncFile,
- prev_file: SyncFile,
- setup_syncs_data: Tuple[SyncsUser, SyncsActive],
- session,
-) -> SyncUtils:
- (sync_user, sync_active) = setup_syncs_data
- assert sync_file.notification_id
- sync_active_service = MockSyncService(sync_active)
- sync_user_service = MockSyncUserService(sync_user)
- sync_files_repo_service = MockSyncFilesRepository()
- knowledge_service = KnowledgeService(KnowledgeRepository(session))
- notification_service = NotificationService(
- repository=MockNotification(
- [sync_file.notification_id, prev_file.notification_id], # type: ignore
- sync_user.user_id,
- sync_active.brain_id,
- )
- )
- brain_vectors = BrainsVectors()
- sync_cloud = MockSyncCloud()
-
- sync_util = SyncUtils(
- sync_user_service=sync_user_service,
- sync_active_service=sync_active_service,
- sync_files_repo=sync_files_repo_service,
- sync_cloud=sync_cloud,
- notification_service=notification_service,
- brain_vectors=brain_vectors,
- knowledge_service=knowledge_service,
- )
-
- return sync_util
-
-
-@pytest.fixture
-def syncutils_notion(
- sync_file: SyncFile,
- prev_file: SyncFile,
- setup_syncs_data: Tuple[SyncsUser, SyncsActive],
- session,
-) -> SyncUtils:
- (sync_user, sync_active) = setup_syncs_data
- assert sync_file.notification_id
- sync_active_service = MockSyncService(sync_active)
- sync_user_service = MockSyncUserService(sync_user)
- sync_files_repo_service = MockSyncFilesRepository()
- knowledge_service = KnowledgeService(KnowledgeRepository(session))
- notification_service = NotificationService(
- repository=MockNotification(
- [sync_file.notification_id, prev_file.notification_id], # type: ignore
- sync_user.user_id,
- sync_active.brain_id,
- )
- )
- brain_vectors = BrainsVectors()
- sync_cloud = MockSyncCloudNotion()
- sync_util = SyncUtils(
- sync_user_service=sync_user_service,
- sync_active_service=sync_active_service,
- sync_files_repo=sync_files_repo_service,
- sync_cloud=sync_cloud,
- notification_service=notification_service,
- brain_vectors=brain_vectors,
- knowledge_service=knowledge_service,
- )
-
- return sync_util
+# class MockSyncService(ISyncService):
+# def __init__(self, sync_active: SyncsActive):
+# self.syncs_active_user = {}
+# self.syncs_active_id = {}
+# self.syncs_active_user[sync_active.user_id] = sync_active
+# self.syncs_active_id[sync_active.id] = sync_active
+
+# def create_sync_active(
+# self,
+# sync_active_input: SyncsActiveInput,
+# user_id: str,
+# ) -> SyncsActive | None:
+# sactive = SyncsActive(
+# id=len(self.syncs_active_user) + 1,
+# user_id=UUID(user_id),
+# **sync_active_input.model_dump(),
+# )
+# self.syncs_active_user[user_id] = sactive
+# return sactive
+
+# def get_syncs_active(self, user_id: str) -> List[SyncsActive]:
+# return self.syncs_active_user[user_id]
+
+# def update_sync_active(
+# self, sync_id: int, sync_active_input: SyncsActiveUpdateInput
+# ):
+# sync = self.syncs_active_id[sync_id]
+# sync = SyncsActive(**sync.model_dump(), **sync_active_input.model_dump())
+# self.syncs_active_id[sync_id] = sync
+# return sync
+
+# def delete_sync_active(self, sync_active_id: int, user_id: UUID):
+# del self.syncs_active_id[sync_active_id]
+# del self.syncs_active_user[user_id]
+
+# async def get_syncs_active_in_interval(self) -> List[SyncsActive]:
+# return list(self.syncs_active_id.values())
+
+# def get_details_sync_active(self, sync_active_id: int):
+# return
+
+
+# class MockSyncUserService(ISyncUserService):
+# def __init__(self, sync_user: SyncsUser):
+# self.map_id = {}
+# self.map_userid = {}
+# self.map_id[sync_user.id] = sync_user
+# self.map_userid[sync_user.id] = sync_user
+
+# def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None):
+# return self.map_userid[user_id]
+
+# def get_sync_user_by_id(self, sync_id: int):
+# return self.map_id[sync_id]
+
+# def create_sync_user(self, sync_user_input: SyncCreateInput):
+# id = len(self.map_userid) + 1
+# self.map_userid[sync_user_input.user_id] = SyncsUser(
+# id=id, **sync_user_input.model_dump()
+# )
+# self.map_id[id] = self.map_userid[sync_user_input.user_id]
+# return self.map_id[id]
+
+# def delete_sync_user(self, sync_id: int, user_id: str):
+# del self.map_userid[user_id]
+# del self.map_userid[sync_id]
+
+# def get_sync_user_by_state(self, state: dict) -> SyncsUser | None:
+# return list(self.map_userid.values())[-1]
+
+# def update_sync_user(
+# self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput
+# ):
+# return
+
+# def get_all_notion_user_syncs(self):
+# return
+
+# async def get_files_folder_user_sync(
+# self,
+# sync_active_id: int,
+# user_id: UUID,
+# folder_id: str | None = None,
+# recursive: bool = False,
+# notion_service: SyncNotionService | None = None,
+# ):
+# return
+
+
+# class MockSyncFilesRepository(SyncFileInterface):
+# def __init__(self):
+# self.files_store = defaultdict(list)
+# self.next_id = 1
+
+# def create_sync_file(self, sync_file_input: SyncFileInput) -> Optional[DBSyncFile]:
+# supported = sync_file_input.supported if sync_file_input.supported else True
+# new_file = DBSyncFile(
+# id=self.next_id,
+# path=sync_file_input.path,
+# syncs_active_id=sync_file_input.syncs_active_id,
+# last_modified=sync_file_input.last_modified,
+# brain_id=sync_file_input.brain_id,
+# supported=supported,
+# )
+# self.files_store[sync_file_input.syncs_active_id].append(new_file)
+# self.next_id += 1
+# return new_file
+
+# def get_sync_files(self, sync_active_id: int) -> List[DBSyncFile]:
+# """
+# Retrieve sync files from the mock database.
+
+# Args:
+# sync_active_id (int): The ID of the active sync.
+
+# Returns:
+# List[DBSyncFile]: A list of sync files matching the criteria.
+# """
+# return self.files_store[sync_active_id]
+
+# def update_sync_file(
+# self, sync_file_id: int, sync_file_input: SyncFileUpdateInput
+# ) -> None:
+# for sync_files in self.files_store.values():
+# for file in sync_files:
+# if file.id == sync_file_id:
+# update_data = sync_file_input.model_dump(exclude_unset=True)
+# if "last_modified" in update_data:
+# file.last_modified = update_data["last_modified"]
+# if "supported" in update_data:
+# file.supported = update_data["supported"]
+# return
+
+# def update_or_create_sync_file(
+# self,
+# file: SyncFile,
+# sync_active: SyncsActive,
+# previous_file: Optional[DBSyncFile],
+# supported: bool,
+# ) -> Optional[DBSyncFile]:
+# if previous_file:
+# self.update_sync_file(
+# previous_file.id,
+# SyncFileUpdateInput(
+# last_modified=file.last_modified,
+# supported=previous_file.supported or supported,
+# ),
+# )
+# return previous_file
+# else:
+# return self.create_sync_file(
+# SyncFileInput(
+# path=file.name,
+# syncs_active_id=sync_active.id,
+# last_modified=file.last_modified,
+# brain_id=str(sync_active.brain_id),
+# supported=supported,
+# )
+# )
+
+# def delete_sync_file(self, sync_file_id: int) -> None:
+# for sync_active_id, sync_files in self.files_store.items():
+# self.files_store[sync_active_id] = [
+# file for file in sync_files if file.id != sync_file_id
+# ]
+
+
+# @pytest.fixture
+# def sync_file():
+# file = SyncFile(
+# id=str(uuid4()),
+# name="test_file.txt",
+# is_folder=False,
+# last_modified=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
+# mime_type=".txt",
+# web_view_link="",
+# notification_id=uuid4(), #
+# )
+# return file
+
+
+# @pytest.fixture
+# def prev_file():
+# file = SyncFile(
+# id=str(uuid4()),
+# name="test_file.txt",
+# is_folder=False,
+# last_modified=(datetime.now() - timedelta(hours=1)).strftime(
+# "%Y-%m-%d %H:%M:%S"
+# ),
+# mime_type="txt",
+# web_view_link="",
+# notification_id=uuid4(), #
+# )
+# return file
+
+
+# @pytest_asyncio.fixture(scope="function")
+# async def brain_user_setup(
+# session,
+# ) -> Tuple[Brain, User]:
+# user_1 = (
+# await session.exec(select(User).where(User.email == "admin@quivr.app"))
+# ).one()
+# # Brain data
+# brain_1 = Brain(
+# name="test_brain",
+# description="this is a test brain",
+# brain_type=BrainType.integration,
+# )
+
+# session.add(brain_1)
+# await session.refresh(user_1)
+# await session.commit()
+# assert user_1
+# assert brain_1.brain_id
+# return brain_1, user_1
+
+
+# @pytest_asyncio.fixture(scope="function")
+# async def setup_syncs_data(
+# brain_user_setup,
+# ) -> Tuple[SyncsUser, SyncsActive]:
+# brain_1, user_1 = brain_user_setup
+
+# sync_user = SyncsUser(
+# id=0,
+# user_id=user_1.id,
+# name="c8xfz3g566b8xa1ajiesdh",
+# provider="mock",
+# credentials={},
+# state={},
+# additional_data={},
+# )
+# sync_active = SyncsActive(
+# id=0,
+# name="test",
+# syncs_user_id=sync_user.id,
+# user_id=sync_user.user_id,
+# settings={},
+# last_synced=str(datetime.now() - timedelta(hours=5)),
+# sync_interval_minutes=1,
+# brain_id=brain_1.brain_id,
+# )
+
+# return (sync_user, sync_active)
+
+
+# @pytest.fixture
+# def syncutils(
+# sync_file: SyncFile,
+# prev_file: SyncFile,
+# setup_syncs_data: Tuple[SyncsUser, SyncsActive],
+# session,
+# ) -> SyncUtils:
+# (sync_user, sync_active) = setup_syncs_data
+# assert sync_file.notification_id
+# sync_active_service = MockSyncService(sync_active)
+# sync_user_service = MockSyncUserService(sync_user)
+# sync_files_repo_service = MockSyncFilesRepository()
+# knowledge_service = KnowledgeService(KnowledgeRepository(session))
+# notification_service = NotificationService(
+# repository=MockNotification(
+# [sync_file.notification_id, prev_file.notification_id], # type: ignore
+# sync_user.user_id,
+# sync_active.brain_id,
+# )
+# )
+# brain_vectors = BrainsVectors()
+# sync_cloud = MockSyncCloud()
+
+# sync_util = SyncUtils(
+# sync_user_service=sync_user_service,
+# sync_active_service=sync_active_service,
+# sync_files_repo=sync_files_repo_service,
+# sync_cloud=sync_cloud,
+# notification_service=notification_service,
+# brain_vectors=brain_vectors,
+# knowledge_service=knowledge_service,
+# )
+
+# return sync_util
+
+
+# @pytest.fixture
+# def syncutils_notion(
+# sync_file: SyncFile,
+# prev_file: SyncFile,
+# setup_syncs_data: Tuple[SyncsUser, SyncsActive],
+# session,
+# ) -> SyncUtils:
+# (sync_user, sync_active) = setup_syncs_data
+# assert sync_file.notification_id
+# sync_active_service = MockSyncService(sync_active)
+# sync_user_service = MockSyncUserService(sync_user)
+# sync_files_repo_service = MockSyncFilesRepository()
+# knowledge_service = KnowledgeService(KnowledgeRepository(session))
+# notification_service = NotificationService(
+# repository=MockNotification(
+# [sync_file.notification_id, prev_file.notification_id], # type: ignore
+# sync_user.user_id,
+# sync_active.brain_id,
+# )
+# )
+# brain_vectors = BrainsVectors()
+# sync_cloud = MockSyncCloudNotion()
+# sync_util = SyncUtils(
+# sync_user_service=sync_user_service,
+# sync_active_service=sync_active_service,
+# sync_files_repo=sync_files_repo_service,
+# sync_cloud=sync_cloud,
+# notification_service=notification_service,
+# brain_vectors=brain_vectors,
+# knowledge_service=knowledge_service,
+# )
+
+# return sync_util
diff --git a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py
index 526114c5ef8b..23a921c38526 100644
--- a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py
+++ b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py
@@ -1,5 +1,4 @@
from datetime import datetime
-from typing import Tuple
import httpx
import pytest
@@ -8,8 +7,7 @@
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionPage
from quivr_api.modules.sync.entity.notion_page import NotionSearchResult
-from quivr_api.modules.sync.entity.sync_models import SyncsActive, SyncsUser
-from quivr_api.modules.sync.repository.sync_repository import NotionRepository
+from quivr_api.modules.sync.repository.notion_repository import NotionRepository
from quivr_api.modules.sync.service.sync_notion import (
SyncNotionService,
fetch_limit_notion_pages,
@@ -74,31 +72,31 @@ def handler(request):
assert len(result) == 0
-@pytest.mark.skip(
- reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'"
-)
-@pytest.mark.asyncio(loop_scope="session")
-async def test_store_notion_pages_success(
- session: AsyncSession,
- notion_search_result: NotionSearchResult,
- setup_syncs_data: Tuple[SyncsUser, SyncsActive],
- sync_user_notion_setup: SyncsUser,
- user_1: User,
-):
- assert user_1.id
+# @pytest.mark.skip(
+# reason="Bug: httpx.ConnectError: [Errno -2] Name or service not known'"
+# )
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_store_notion_pages_success(
+# session: AsyncSession,
+# notion_search_result: NotionSearchResult,
+# setup_syncs_data: Tuple[SyncsUser, SyncsActive],
+# sync_user_notion_setup: SyncsUser,
+# user_1: User,
+# ):
+# assert user_1.id
- notion_repository = NotionRepository(session)
- notion_service = SyncNotionService(notion_repository)
- sync_files = await store_notion_pages(
- notion_search_result.results,
- notion_service,
- user_1.id,
- sync_user_id=sync_user_notion_setup.id,
- )
- assert sync_files
- assert len(sync_files) == 1
- assert sync_files[0].notion_id == notion_search_result.results[0].id
- assert sync_files[0].mime_type == "md"
+# notion_repository = NotionRepository(session)
+# notion_service = SyncNotionService(notion_repository)
+# sync_files = await store_notion_pages(
+# notion_search_result.results,
+# notion_service,
+# user_1.id,
+# sync_user_id=sync_user_notion_setup.id,
+# )
+# assert sync_files
+# assert len(sync_files) == 1
+# assert sync_files[0].notion_id == notion_search_result.results[0].id
+# assert sync_files[0].mime_type == "md"
@pytest.mark.asyncio(loop_scope="session")
diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py
new file mode 100644
index 000000000000..df4780ba0af0
--- /dev/null
+++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py
@@ -0,0 +1,224 @@
+import os
+from datetime import datetime
+from io import BytesIO
+from typing import Dict, List, Union
+
+import pytest
+import pytest_asyncio
+from httpx import ASGITransport, AsyncClient
+from quivr_core.models import KnowledgeStatus
+from sqlmodel import select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+from quivr_api.main import app
+from quivr_api.middlewares.auth.auth_bearer import get_current_user
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
+from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
+from quivr_api.modules.knowledge.tests.conftest import FakeStorage
+from quivr_api.modules.sync.controller.sync_routes import (
+ get_knowledge_service,
+ get_sync_service,
+)
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile
+from quivr_api.modules.sync.repository.sync_repository import SyncsRepository
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.sync import BaseSync
+from quivr_api.modules.user.entity.user_identity import User, UserIdentity
+
+# TODO: move to top layer
+MAX_SYNC_FILES = 1000
+N_GET_FILES = 2
+FOLDER_SYNC_FILE_IDS = [f"file-{str(idx)}" for idx in range(MAX_SYNC_FILES)]
+
+
+class FakeSync(BaseSync):
+ name = "FakeProvider"
+ lower_name = "google"
+ datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+
+ def __init__(self, provider_name: str | None = None, n_get_files: int = 2):
+ super().__init__()
+ self.n_get_files = n_get_files
+ if n_get_files > MAX_SYNC_FILES:
+ raise ValueError("can't create fake sync")
+ self.folder_sync_file_ids = FOLDER_SYNC_FILE_IDS[:n_get_files]
+ if provider_name:
+ self.lower_name = provider_name
+
+ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]:
+ return [
+ SyncFile(
+ id=str(fid),
+ name=f"file_{fid}",
+ extension=".txt",
+ web_view_link=f"test.com/{fid}",
+ is_folder=False,
+ last_modified_at=datetime.now(),
+ )
+ for fid in file_ids
+ ]
+
+ async def aget_files_by_id(
+ self, credentials: Dict, file_ids: List[str]
+ ) -> List[SyncFile]:
+ return self.get_files_by_id(
+ credentials=credentials,
+ file_ids=file_ids,
+ )
+
+ def get_files(
+ self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
+ ) -> List[SyncFile]:
+ return [
+ SyncFile(
+ id=fid,
+ name=f"file_{fid}",
+ extension=".txt",
+ web_view_link=f"test.com/{fid}",
+ parent_id=folder_id,
+ is_folder=idx % 2 == 1,
+ last_modified_at=datetime.now(),
+ )
+ for idx, fid in enumerate(self.folder_sync_file_ids)
+ ]
+
+ async def aget_files(
+ self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
+ ) -> List[SyncFile]:
+ return self.get_files(
+ credentials=credentials, folder_id=folder_id, recursive=recursive
+ )
+
+ def check_and_refresh_access_token(self, credentials: dict) -> Dict:
+ raise NotImplementedError
+
+ def download_file(
+ self, credentials: Dict, file: SyncFile
+ ) -> Dict[str, Union[str, BytesIO]]:
+ raise NotImplementedError
+
+ async def adownload_file(
+ self, credentials: Dict, file: SyncFile
+ ) -> Dict[str, Union[str, BytesIO]]:
+ return {"content": str(os.urandom(24))}
+
+
+@pytest_asyncio.fixture(scope="function")
+async def user(session: AsyncSession) -> User:
+ user_1 = (
+ await session.exec(select(User).where(User.email == "admin@quivr.app"))
+ ).one()
+ assert user_1.id
+ return user_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync(session: AsyncSession, user: User) -> Sync:
+ assert user.id
+ sync = Sync(
+ name="test_sync",
+ email="test@test.com",
+ user_id=user.id,
+ credentials={"test": "test"},
+ provider=SyncProvider.GOOGLE,
+ )
+
+ session.add(sync)
+ await session.commit()
+ await session.refresh(sync)
+ return sync
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain(session):
+ brain_1 = Brain(
+ name="test_brain",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ return brain_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def knowledge_sync(session, user: User, sync: Sync, brain: Brain):
+ assert user.id
+ km = KnowledgeDB(
+ file_name="sync_file_1.txt",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source="test_source",
+ source_link="test_source_link",
+ file_size=100,
+ file_sha1="test_sha1",
+ brains=[brain],
+ user_id=user.id,
+ sync=sync,
+ sync_file_id=FOLDER_SYNC_FILE_IDS[0],
+ )
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def test_client(session: AsyncSession, user: User):
+ def default_current_user() -> UserIdentity:
+ assert user.id
+ return UserIdentity(email=user.email, id=user.id)
+
+ async def _sync_service():
+ fake_provider: dict[SyncProvider, BaseSync] = {
+ provider: FakeSync(n_get_files=N_GET_FILES)
+ for provider in list(SyncProvider)
+ }
+ repository = SyncsRepository(session)
+ repository.sync_provider_mapping = fake_provider
+ return SyncsService(repository)
+
+ async def _km_service():
+ storage = FakeStorage()
+ repository = KnowledgeRepository(session)
+ return KnowledgeService(repository, storage)
+
+ app.dependency_overrides[get_current_user] = default_current_user
+ app.dependency_overrides[get_knowledge_service] = _km_service
+ app.dependency_overrides[get_sync_service] = _sync_service
+
+ async with AsyncClient(
+ transport=ASGITransport(app=app), # type: ignore
+ base_url="http://test",
+ ) as ac:
+ yield ac
+ app.dependency_overrides = {}
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_list_sync_no_knowledge(test_client: AsyncClient, sync: Sync):
+ params = {"folder_id": 12}
+ response = await test_client.get(f"/sync/{sync.id}/files", params=params)
+ assert response.status_code == 200
+ kms = response.json()
+ assert len(kms) == N_GET_FILES
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_list_sync_with_knowledge(
+ test_client: AsyncClient, sync: Sync, knowledge_sync
+):
+ params = {"folder_id": 12}
+ response = await test_client.get(f"/sync/{sync.id}/files", params=params)
+ assert response.status_code == 200
+ kms = response.json()
+
+ assert len(kms) == N_GET_FILES
+ km = next(
+ filter(lambda x: x["id"] == str(knowledge_sync.id), kms),
+ )
+ assert km, "at least one knowledge should "
+ assert len(km["brains"]) == 1
diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py
new file mode 100644
index 000000000000..19fc5f041f7d
--- /dev/null
+++ b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py
@@ -0,0 +1,40 @@
+import pytest
+import pytest_asyncio
+from sqlmodel import select
+
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync
+from quivr_api.modules.sync.repository.sync_repository import SyncsRepository
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.user.entity.user_identity import User
+
+
+@pytest_asyncio.fixture(scope="function")
+async def user(session):
+ user_1 = (
+ await session.exec(select(User).where(User.email == "admin@quivr.app"))
+ ).one()
+ return user_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def test_sync(session, user):
+ assert user.id
+
+ sync = Sync(
+ user_id=user.id,
+ name="test_sync",
+ provider=SyncProvider.GOOGLE,
+ )
+
+ session.add(sync)
+ await session.commit()
+ await session.refresh(sync)
+ return sync
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_sync_delete_sync(session, test_sync):
+ assert test_sync.id
+ service = SyncsService(SyncsRepository(session))
+ await service.delete_sync(test_sync.id, test_sync.user_id)
diff --git a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py
index 3c20f70d9679..5685965fddf5 100644
--- a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py
+++ b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py
@@ -1,452 +1,428 @@
-from datetime import datetime, timedelta, timezone
-from typing import Tuple
-from uuid import uuid4
-
-import pytest
-
-from quivr_api.modules.brain.entity.brain_entity import Brain
-from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum
-from quivr_api.modules.sync.entity.sync_models import (
- DBSyncFile,
- SyncFile,
- SyncsActive,
- SyncsUser,
-)
-from quivr_api.modules.sync.utils.syncutils import (
- SyncUtils,
- filter_on_supported_files,
- should_download_file,
-)
-from quivr_api.modules.upload.service.upload_file import check_file_exists
-from quivr_api.modules.user.entity.user_identity import User
-
-
-def test_filter_on_supported_files_empty_existing():
- files = [
- SyncFile(
- id="1",
- name="file_name",
- is_folder=True,
- last_modified=str(datetime.now()),
- mime_type="txt",
- web_view_link="link",
- )
- ]
- existing_file = {}
-
- assert [(files[0], None)] == filter_on_supported_files(files, existing_file)
-
-
-def test_filter_on_supported_files_prev_not_supported():
- files = [
- SyncFile(
- id=f"{idx}",
- name=f"file_name_{idx}",
- is_folder=False,
- last_modified=str(datetime.now()),
- mime_type="txt",
- web_view_link="link",
- )
- for idx in range(3)
- ]
- existing_files = {
- file.name: DBSyncFile(
- id=idx,
- path=file.name,
- syncs_active_id=1,
- last_modified=str(datetime.now()),
- brain_id=str(uuid4()),
- supported=idx % 2 == 0,
- )
- for idx, file in enumerate(files)
- }
-
- assert [
- (files[idx], existing_files[f"file_name_{idx}"])
- for idx in range(3)
- if idx % 2 == 0
- ] == filter_on_supported_files(files, existing_files)
-
-
-def test_should_download_file_no_sync_time_not_folder():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=False,
- last_modified=datetime.now().strftime(datetime_format),
- mime_type="txt",
- web_view_link="link",
- )
- assert should_download_file(
- file=file_not_folder,
- last_updated_sync_active=None,
- provider_name="google",
- datetime_format=datetime_format,
- )
-
-
-def test_should_download_file_no_sync_time_folder():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=True,
- last_modified=datetime.now().strftime(datetime_format),
- mime_type="txt",
- web_view_link="link",
- )
- assert not should_download_file(
- file=file_not_folder,
- last_updated_sync_active=None,
- provider_name="google",
- datetime_format=datetime_format,
- )
-
-
-def test_should_download_file_notiondb():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=False,
- last_modified=datetime.now().strftime(datetime_format),
- mime_type="db",
- web_view_link="link",
- )
-
- assert not should_download_file(
- file=file_not_folder,
- last_updated_sync_active=(datetime.now() - timedelta(hours=5)).astimezone(
- timezone.utc
- ),
- provider_name="notion",
- datetime_format=datetime_format,
- )
-
-
-def test_should_download_file_not_notiondb():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=False,
- last_modified=datetime.now().strftime(datetime_format),
- mime_type="md",
- web_view_link="link",
- )
-
- assert should_download_file(
- file=file_not_folder,
- last_updated_sync_active=None,
- provider_name="notion",
- datetime_format=datetime_format,
- )
-
-
-def test_should_download_file_lastsynctime_before():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=False,
- last_modified=datetime.now().strftime(datetime_format),
- mime_type="txt",
- web_view_link="link",
- )
- last_sync_time = (datetime.now() - timedelta(hours=5)).astimezone(timezone.utc)
-
- assert should_download_file(
- file=file_not_folder,
- last_updated_sync_active=last_sync_time,
- provider_name="google",
- datetime_format=datetime_format,
- )
-
-
-def test_should_download_file_lastsynctime_after():
- datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
- file_not_folder = SyncFile(
- id="1",
- name="file_name",
- is_folder=False,
- last_modified=(datetime.now() - timedelta(hours=5)).strftime(datetime_format),
- mime_type="txt",
- web_view_link="link",
- )
- last_sync_time = datetime.now().astimezone(timezone.utc)
-
- assert not should_download_file(
- file=file_not_folder,
- last_updated_sync_active=last_sync_time,
- provider_name="google",
- datetime_format=datetime_format,
- )
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils):
- files = await syncutils.get_syncfiles_from_ids(
- credentials={}, files_ids=[str(uuid4())], folder_ids=[], sync_user_id=1
- )
- assert len(files) == 1
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils):
- files = await syncutils.get_syncfiles_from_ids(
- credentials={},
- files_ids=[str(uuid4())],
- folder_ids=[str(uuid4())],
- sync_user_id=0,
- )
- assert len(files) == 2
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_get_syncfiles_from_ids_notion(syncutils_notion: SyncUtils):
- files = await syncutils_notion.get_syncfiles_from_ids(
- credentials={},
- files_ids=[str(uuid4())],
- folder_ids=[str(uuid4())],
- sync_user_id=0,
- )
- assert len(files) == 3
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_download_file(syncutils: SyncUtils):
- file = SyncFile(
- id=str(uuid4()),
- name="test_file.txt",
- is_folder=False,
- last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format),
- mime_type="txt",
- web_view_link="",
- )
- dfile = await syncutils.download_file(file, {})
- assert dfile.extension == ".txt"
- assert dfile.file_name == file.name
- assert len(dfile.file_data.read()) > 0
-
-
-@pytest.mark.asyncio(loop_scope="session")
-async def test_process_sync_file_not_supported(syncutils: SyncUtils):
- file = SyncFile(
- id=str(uuid4()),
- name="test_file.asldkjfalsdkjf",
- is_folder=False,
- last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format),
- mime_type="txt",
- web_view_link="",
- notification_id=uuid4(), #
- )
- brain_id = uuid4()
- sync_user = SyncsUser(
- id=1,
- user_id=uuid4(),
- name="c8xfz3g566b8xa1ajiesdh",
- provider="mock",
- credentials={},
- state={},
- additional_data={},
- status="",
- )
- sync_active = SyncsActive(
- id=1,
- name="test",
- syncs_user_id=1,
- user_id=sync_user.user_id,
- settings={},
- last_synced=str(datetime.now() - timedelta(hours=5)),
- sync_interval_minutes=1,
- brain_id=brain_id,
- )
-
- with pytest.raises(ValueError):
- await syncutils.process_sync_file(
- file=file,
- previous_file=None,
- current_user=sync_user,
- sync_active=sync_active,
- )
-
-
-@pytest.mark.skip(
- reason="Bug: UnboundLocalError: cannot access local variable 'response'"
-)
-@pytest.mark.asyncio(loop_scope="session")
-async def test_process_sync_file_noprev(
- monkeypatch,
- brain_user_setup: Tuple[Brain, User],
- setup_syncs_data: Tuple[SyncsUser, SyncsActive],
- syncutils: SyncUtils,
- sync_file: SyncFile,
-):
- task = {}
-
- def _send_task(*args, **kwargs):
- task["args"] = args
- task["kwargs"] = {**kwargs["kwargs"]}
-
- monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
-
- brain_1, _ = brain_user_setup
- assert brain_1.brain_id
- (sync_user, sync_active) = setup_syncs_data
- await syncutils.process_sync_file(
- file=sync_file,
- previous_file=None,
- current_user=sync_user,
- sync_active=sync_active,
- )
-
- # Check notification inserted
- assert (
- sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore
- )
- assert (
- syncutils.notification_service.repository.received[ # type: ignore
- sync_file.notification_id # type: ignore
- ].status
- == NotificationsStatusEnum.SUCCESS
- )
-
- # Check Syncfile created
- dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
- assert len(dbfiles) == 1
- assert dbfiles[0].brain_id == str(brain_1.brain_id)
- assert dbfiles[0].syncs_active_id == sync_active.id
- assert dbfiles[0].supported
-
- # Check knowledge created
- all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain(
- brain_1.brain_id
- )
- assert len(all_km) == 1
- created_km = all_km[0]
- assert created_km.file_name == sync_file.name
- assert created_km.extension == ".txt"
- assert created_km.file_sha1 is None
- assert created_km.created_at is not None
- assert created_km.metadata == {"sync_file_id": "1"}
- assert len(created_km.brains) > 0
- assert created_km.brains[0]["brain_id"] == brain_1.brain_id
-
- # Assert celery task in correct
- assert task["args"] == ("process_file_task",)
- minimal_task_kwargs = {
- "brain_id": brain_1.brain_id,
- "knowledge_id": created_km.id,
- "file_original_name": sync_file.name,
- "source": syncutils.sync_cloud.name,
- "notification_id": sync_file.notification_id,
- }
- all(
- minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore
- for key in minimal_task_kwargs
- )
-
-
-@pytest.mark.skip(
- reason="Bug: UnboundLocalError: cannot access local variable 'response'"
-)
-@pytest.mark.asyncio(loop_scope="session")
-async def test_process_sync_file_with_prev(
- monkeypatch,
- supabase_client,
- brain_user_setup: Tuple[Brain, User],
- setup_syncs_data: Tuple[SyncsUser, SyncsActive],
- syncutils: SyncUtils,
- sync_file: SyncFile,
- prev_file: SyncFile,
-):
- task = {}
-
- def _send_task(*args, **kwargs):
- task["args"] = args
- task["kwargs"] = {**kwargs["kwargs"]}
-
- monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
- brain_1, _ = brain_user_setup
- assert brain_1.brain_id
- (sync_user, sync_active) = setup_syncs_data
-
- # Run process_file on prev_file first
- await syncutils.process_sync_file(
- file=prev_file,
- previous_file=None,
- current_user=sync_user,
- sync_active=sync_active,
- )
- dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
- assert len(dbfiles) == 1
- prev_dbfile = dbfiles[0]
-
- assert check_file_exists(str(brain_1.brain_id), prev_file.name)
- prev_file_data = supabase_client.storage.from_("quivr").download(
- f"{brain_1.brain_id}/{prev_file.name}"
- )
-
- #####
- # Run process_file on newer file
- await syncutils.process_sync_file(
- file=sync_file,
- previous_file=prev_dbfile,
- current_user=sync_user,
- sync_active=sync_active,
- )
-
- # Check notification inserted
- assert (
- sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore
- )
- assert (
- syncutils.notification_service.repository.received[ # type: ignore
- sync_file.notification_id # type: ignore
- ].status
- == NotificationsStatusEnum.SUCCESS
- )
-
- # Check Syncfile created
- dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
- assert len(dbfiles) == 1
- assert dbfiles[0].brain_id == str(brain_1.brain_id)
- assert dbfiles[0].syncs_active_id == sync_active.id
- assert dbfiles[0].supported
-
- # Check prev file was deleted and replaced with the new
- all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain(
- brain_1.brain_id
- )
- assert len(all_km) == 1
- created_km = all_km[0]
- assert created_km.file_name == sync_file.name
- assert created_km.extension == ".txt"
- assert created_km.file_sha1 is None
- assert created_km.updated_at
- assert created_km.created_at
- assert created_km.updated_at == created_km.created_at # new line
- assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
- assert created_km.brains[0]["brain_id"] == brain_1.brain_id
-
- # Check file content changed
- assert check_file_exists(str(brain_1.brain_id), sync_file.name)
- new_file_data = supabase_client.storage.from_("quivr").download(
- f"{brain_1.brain_id}/{sync_file.name}"
- )
- assert new_file_data != prev_file_data, "Same file in prev_file and new file"
-
- # Assert celery task in correct
- assert task["args"] == ("process_file_task",)
- minimal_task_kwargs = {
- "brain_id": brain_1.brain_id,
- "knowledge_id": created_km.id,
- "file_original_name": sync_file.name,
- "source": syncutils.sync_cloud.name,
- "notification_id": sync_file.notification_id,
- }
- all(
- minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore
- for key in minimal_task_kwargs
- )
+# from datetime import datetime, timedelta, timezone
+
+
+# from quivr_api.modules.sync.entity.sync_models import (
+# SyncFile,
+# )
+# from quivr_api.modules.sync.utils.syncutils import (
+# filter_on_supported_files,
+# should_download_file,
+# )
+
+
+# def test_filter_on_supported_files_empty_existing():
+# files = [
+# SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=True,
+# last_modified=str(datetime.now()),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# ]
+# existing_file = {}
+
+# assert [(files[0], None)] == filter_on_supported_files(files, existing_file)
+
+
+# def test_filter_on_supported_files_prev_not_supported():
+# files = [
+# SyncFile(
+# id=f"{idx}",
+# name=f"file_name_{idx}",
+# is_folder=False,
+# last_modified=str(datetime.now()),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# for idx in range(3)
+# ]
+# existing_files = {
+# file.name: DBSyncFile(
+# id=idx,
+# path=file.name,
+# syncs_active_id=1,
+# last_modified=str(datetime.now()),
+# brain_id=str(uuid4()),
+# supported=idx % 2 == 0,
+# )
+# for idx, file in enumerate(files)
+# }
+
+# assert [
+# (files[idx], existing_files[f"file_name_{idx}"])
+# for idx in range(3)
+# if idx % 2 == 0
+# ] == filter_on_supported_files(files, existing_files)
+
+
+# def test_should_download_file_no_sync_time_not_folder():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=False,
+# last_modified=datetime.now().strftime(datetime_format),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# assert should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=None,
+# provider_name="google",
+# datetime_format=datetime_format,
+# )
+
+
+# def test_should_download_file_no_sync_time_folder():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=True,
+# last_modified=datetime.now().strftime(datetime_format),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# assert not should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=None,
+# provider_name="google",
+# datetime_format=datetime_format,
+# )
+
+
+# def test_should_download_file_notiondb():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=False,
+# last_modified=datetime.now().strftime(datetime_format),
+# mime_type="db",
+# web_view_link="link",
+# )
+
+# assert not should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=(datetime.now() - timedelta(hours=5)).astimezone(
+# timezone.utc
+# ),
+# provider_name="notion",
+# datetime_format=datetime_format,
+# )
+
+
+# def test_should_download_file_not_notiondb():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=False,
+# last_modified=datetime.now().strftime(datetime_format),
+# mime_type="md",
+# web_view_link="link",
+# )
+
+# assert should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=None,
+# provider_name="notion",
+# datetime_format=datetime_format,
+# )
+
+
+# def test_should_download_file_lastsynctime_before():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=False,
+# last_modified=datetime.now().strftime(datetime_format),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# last_sync_time = (datetime.now() - timedelta(hours=5)).astimezone(timezone.utc)
+
+# assert should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=last_sync_time,
+# provider_name="google",
+# datetime_format=datetime_format,
+# )
+
+
+# def test_should_download_file_lastsynctime_after():
+# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+# file_not_folder = SyncFile(
+# id="1",
+# name="file_name",
+# is_folder=False,
+# last_modified=(datetime.now() - timedelta(hours=5)).strftime(datetime_format),
+# mime_type="txt",
+# web_view_link="link",
+# )
+# last_sync_time = datetime.now().astimezone(timezone.utc)
+
+# assert not should_download_file(
+# file=file_not_folder,
+# last_updated_sync_active=last_sync_time,
+# provider_name="google",
+# datetime_format=datetime_format,
+# )
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils):
+# files = await syncutils.get_syncfiles_from_ids(
+# credentials={}, files_ids=[str(uuid4())], folder_ids=[]
+# )
+# assert len(files) == 1
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils):
+# files = await syncutils.get_syncfiles_from_ids(
+# credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())]
+# )
+# assert len(files) == 2
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_get_syncfiles_from_ids_notion(syncutils_notion: SyncUtils):
+# files = await syncutils_notion.get_syncfiles_from_ids(
+# credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())]
+# )
+# assert len(files) == 3
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_download_file(syncutils: SyncUtils):
+# file = SyncFile(
+# id=str(uuid4()),
+# name="test_file.txt",
+# is_folder=False,
+# last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format),
+# mime_type="txt",
+# web_view_link="",
+# )
+# dfile = await syncutils.download_file(file, {})
+# assert dfile.extension == ".txt"
+# assert dfile.file_name == file.name
+# assert len(dfile.file_data.read()) > 0
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_process_sync_file_not_supported(syncutils: SyncUtils):
+# file = SyncFile(
+# id=str(uuid4()),
+# name="test_file.asldkjfalsdkjf",
+# is_folder=False,
+# last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format),
+# mime_type="txt",
+# web_view_link="",
+# notification_id=uuid4(), #
+# )
+# brain_id = uuid4()
+# sync_user = SyncsUser(
+# id=1,
+# user_id=uuid4(),
+# name="c8xfz3g566b8xa1ajiesdh",
+# provider="mock",
+# credentials={},
+# state={},
+# additional_data={},
+# )
+# sync_active = SyncsActive(
+# id=1,
+# name="test",
+# syncs_user_id=1,
+# user_id=sync_user.user_id,
+# settings={},
+# last_synced=str(datetime.now() - timedelta(hours=5)),
+# sync_interval_minutes=1,
+# brain_id=brain_id,
+# )
+
+# with pytest.raises(ValueError):
+# await syncutils.process_sync_file(
+# file=file,
+# previous_file=None,
+# current_user=sync_user,
+# sync_active=sync_active,
+# )
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_process_sync_file_noprev(
+# monkeypatch,
+# brain_user_setup: Tuple[Brain, User],
+# setup_syncs_data: Tuple[SyncsUser, SyncsActive],
+# syncutils: SyncUtils,
+# sync_file: SyncFile,
+# ):
+# task = {}
+
+# def _send_task(*args, **kwargs):
+# task["args"] = args
+# task["kwargs"] = {**kwargs["kwargs"]}
+
+# monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+
+# brain_1, _ = brain_user_setup
+# assert brain_1.brain_id
+# (sync_user, sync_active) = setup_syncs_data
+# await syncutils.process_sync_file(
+# file=sync_file,
+# previous_file=None,
+# current_user=sync_user,
+# sync_active=sync_active,
+# )
+
+# # Check notification inserted
+# assert (
+# sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore
+# )
+# assert (
+# syncutils.notification_service.repository.received[ # type: ignore
+# sync_file.notification_id # type: ignore
+# ].status
+# == NotificationsStatusEnum.SUCCESS
+# )
+
+# # Check Syncfile created
+# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
+# assert len(dbfiles) == 1
+# assert dbfiles[0].brain_id == str(brain_1.brain_id)
+# assert dbfiles[0].syncs_active_id == sync_active.id
+# assert dbfiles[0].supported
+
+# # Check knowledge created
+# all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain(
+# brain_1.brain_id
+# )
+# assert len(all_km) == 1
+# created_km = all_km[0]
+# assert created_km.file_name == sync_file.name
+# assert created_km.extension == ".txt"
+# assert created_km.file_sha1 is None
+# assert created_km.created_at is not None
+# assert created_km.metadata == {"sync_file_id": "1"}
+# assert len(created_km.brains) > 0
+# assert created_km.brains[0]["brain_id"] == brain_1.brain_id
+
+# # Assert celery task in correct
+# assert task["args"] == ("process_file_task",)
+# minimal_task_kwargs = {
+# "brain_id": brain_1.brain_id,
+# "knowledge_id": created_km.id,
+# "file_original_name": sync_file.name,
+# "source": syncutils.sync_cloud.name,
+# "notification_id": sync_file.notification_id,
+# }
+# all(
+# minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore
+# for key in minimal_task_kwargs
+# )
+
+
+# @pytest.mark.asyncio(loop_scope="session")
+# async def test_process_sync_file_with_prev(
+# monkeypatch,
+# supabase_client,
+# brain_user_setup: Tuple[Brain, User],
+# setup_syncs_data: Tuple[SyncsUser, SyncsActive],
+# syncutils: SyncUtils,
+# sync_file: SyncFile,
+# prev_file: SyncFile,
+# ):
+# task = {}
+
+# def _send_task(*args, **kwargs):
+# task["args"] = args
+# task["kwargs"] = {**kwargs["kwargs"]}
+
+# monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task)
+# brain_1, _ = brain_user_setup
+# assert brain_1.brain_id
+# (sync_user, sync_active) = setup_syncs_data
+
+# # Run process_file on prev_file first
+# await syncutils.process_sync_file(
+# file=prev_file,
+# previous_file=None,
+# current_user=sync_user,
+# sync_active=sync_active,
+# )
+# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
+# assert len(dbfiles) == 1
+# prev_dbfile = dbfiles[0]
+
+# assert check_file_exists(str(brain_1.brain_id), prev_file.name)
+# prev_file_data = supabase_client.storage.from_("quivr").download(
+# f"{brain_1.brain_id}/{prev_file.name}"
+# )
+
+# #####
+# # Run process_file on newer file
+# await syncutils.process_sync_file(
+# file=sync_file,
+# previous_file=prev_dbfile,
+# current_user=sync_user,
+# sync_active=sync_active,
+# )
+
+# # Check notification inserted
+# assert (
+# sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore
+# )
+# assert (
+# syncutils.notification_service.repository.received[ # type: ignore
+# sync_file.notification_id # type: ignore
+# ].status
+# == NotificationsStatusEnum.SUCCESS
+# )
+
+# # Check Syncfile created
+# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id)
+# assert len(dbfiles) == 1
+# assert dbfiles[0].brain_id == str(brain_1.brain_id)
+# assert dbfiles[0].syncs_active_id == sync_active.id
+# assert dbfiles[0].supported
+
+# # Check prev file was deleted and replaced with the new
+# all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain(
+# brain_1.brain_id
+# )
+# assert len(all_km) == 1
+# created_km = all_km[0]
+# assert created_km.file_name == sync_file.name
+# assert created_km.extension == ".txt"
+# assert created_km.file_sha1 is None
+# assert created_km.updated_at
+# assert created_km.created_at
+# assert created_km.updated_at == created_km.created_at # new line
+# assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)}
+# assert created_km.brains[0]["brain_id"] == brain_1.brain_id
+
+# # Check file content changed
+# assert check_file_exists(str(brain_1.brain_id), sync_file.name)
+# new_file_data = supabase_client.storage.from_("quivr").download(
+# f"{brain_1.brain_id}/{sync_file.name}"
+# )
+# assert new_file_data != prev_file_data, "Same file in prev_file and new file"
+
+# # Assert celery task in correct
+# assert task["args"] == ("process_file_task",)
+# minimal_task_kwargs = {
+# "brain_id": brain_1.brain_id,
+# "knowledge_id": created_km.id,
+# "file_original_name": sync_file.name,
+# "source": syncutils.sync_cloud.name,
+# "notification_id": sync_file.notification_id,
+# }
+# all(
+# minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore
+# for key in minimal_task_kwargs
+# )
diff --git a/backend/api/quivr_api/modules/sync/utils/oauth2.py b/backend/api/quivr_api/modules/sync/utils/oauth2.py
new file mode 100644
index 000000000000..e2344caf4225
--- /dev/null
+++ b/backend/api/quivr_api/modules/sync/utils/oauth2.py
@@ -0,0 +1,31 @@
+from uuid import UUID
+
+from fastapi import HTTPException, status
+from pydantic import BaseModel
+
+from quivr_api.logger import get_logger
+
+logger = get_logger(__name__)
+
+
+class Oauth2BaseState(BaseModel):
+ name: str
+ user_id: UUID
+
+
+class Oauth2State(Oauth2BaseState):
+ sync_id: int
+
+
+def parse_oauth2_state(state_str: str | None) -> Oauth2State:
+ if not state_str:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state parameter"
+ )
+
+ state = Oauth2State.model_validate_json(state_str)
+ if state.sync_id is None:
+ raise HTTPException(
+ status_code=400, detail="Invalid state parameter. Unknown sync"
+ )
+ return state
diff --git a/backend/api/quivr_api/modules/sync/utils/sync.py b/backend/api/quivr_api/modules/sync/utils/sync.py
index a60ccb0aa75f..363c5f512ee9 100644
--- a/backend/api/quivr_api/modules/sync/utils/sync.py
+++ b/backend/api/quivr_api/modules/sync/utils/sync.py
@@ -3,7 +3,7 @@
import os
import time
from abc import ABC, abstractmethod
-from datetime import datetime
+from datetime import datetime, timezone
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
@@ -93,7 +93,7 @@ def download_file(
) -> Dict[str, Union[str, BytesIO]]:
file_id = file.id
file_name = file.name
- mime_type = file.mime_type
+ mime_type = file.extension
if not self.creds:
self.check_and_refresh_access_token(credentials)
if not self.service:
@@ -193,8 +193,10 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi
is_folder=(
result["mimeType"] == "application/vnd.google-apps.folder"
),
- last_modified=result["modifiedTime"],
- mime_type=result["mimeType"],
+ last_modified_at=datetime.strptime(
+ result["modifiedTime"], self.datetime_format
+ ).replace(tzinfo=timezone.utc),
+ extension=result["mimeType"],
web_view_link=result["webViewLink"],
size=result.get("size", None),
)
@@ -206,6 +208,8 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi
return files
except HTTPError as error:
+ if error.response.status_code == 404:
+ raise FileNotFoundError
logger.error(
"An error occurred while retrieving Google Drive files: %s", error
)
@@ -269,8 +273,10 @@ def get_files(
is_folder=(
item["mimeType"] == "application/vnd.google-apps.folder"
),
- last_modified=item["modifiedTime"],
- mime_type=item["mimeType"],
+ last_modified_at=datetime.strptime(
+ item["modifiedTime"], self.datetime_format
+ ).replace(tzinfo=timezone.utc),
+ extension=item["mimeType"],
web_view_link=item["webViewLink"],
size=item.get("size", None),
)
@@ -447,8 +453,10 @@ def fetch_files(endpoint, headers, max_retries=1):
else f'{site_id}:{item.get("id")}'
),
is_folder="folder" in item or not site_folder_id,
- last_modified=item.get("lastModifiedDateTime"),
- mime_type=item.get("file", {}).get("mimeType", "folder"),
+ last_modified_at=datetime.strptime(
+ item["lastModiedDateTime"], self.datetime_format
+ ).replace(tzinfo=timezone.utc),
+ extension=item.get("file", {}).get("mimeType", "folder"),
web_view_link=item.get("webUrl"),
size=item.get("size", None),
)
@@ -467,8 +475,8 @@ def fetch_files(endpoint, headers, max_retries=1):
name="My Drive",
id="root:",
is_folder=True,
- last_modified="",
- mime_type="folder",
+ last_modified_at=None,
+ extension="folder",
web_view_link="https://onedrive.live.com",
)
)
@@ -530,8 +538,10 @@ def get_files_by_id(self, credentials: dict, file_ids: List[str]) -> List[SyncFi
name=result.get("name"),
id=f'{site_id}:{result.get("id")}',
is_folder="folder" in result,
- last_modified=result.get("lastModifiedDateTime"),
- mime_type=result.get("file", {}).get("mimeType", "folder"),
+ last_modified_at=datetime.strptime(
+ result.get("lastModifiedDateTime"), self.datetime_format
+ ).replace(tzinfo=timezone.utc),
+ extension=result.get("file", {}).get("mimeType", "folder"),
web_view_link=result.get("webUrl"),
size=result.get("size", None),
)
@@ -641,10 +651,14 @@ def fetch_files(metadata):
name=file.name,
id=file.id,
is_folder=is_folder,
- last_modified=(
- str(file.client_modified) if not is_folder else ""
+ last_modified_at=(
+ datetime.strptime(
+ file.client_modified, self.datetime_format
+ )
+ if not is_folder
+ else None
),
- mime_type=(
+ extension=(
file.path_lower.split(".")[-1] if not is_folder else ""
),
web_view_link=shared_link,
@@ -720,10 +734,14 @@ def get_files_by_id(
name=metadata.name,
id=metadata.id,
is_folder=is_folder,
- last_modified=(
- str(metadata.client_modified) if not is_folder else ""
+ last_modified_at=(
+ datetime.strptime(
+ metadata.client_modified, self.datetime_format
+ )
+ if not is_folder
+ else None
),
- mime_type=(
+ extension=(
metadata.path_lower.split(".")[-1] if not is_folder else ""
),
web_view_link=shared_link,
@@ -819,8 +837,8 @@ async def aget_files(
name=page.name,
id=str(page.notion_id),
is_folder=await self.notion_service.is_folder_page(page.notion_id),
- last_modified=str(page.last_modified),
- mime_type=page.mime_type,
+ last_modified_at=page.last_modified,
+ extension=page.mime_type,
web_view_link=page.web_view_link,
icon=page.icon,
)
@@ -862,8 +880,8 @@ async def aget_files_by_id(
name=page.name,
id=str(page.notion_id),
is_folder=await self.notion_service.is_folder_page(page.notion_id),
- last_modified=str(page.last_modified),
- mime_type=page.mime_type,
+ last_modified_at=page.last_modified,
+ extension=page.mime_type,
web_view_link=page.web_view_link,
icon=page.icon,
)
@@ -1082,8 +1100,8 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi
name=remove_special_characters(result.get("name")),
id=f"{repo_name}:{result.get('path')}",
is_folder=False,
- last_modified=datetime.now().strftime(self.datetime_format),
- mime_type=result.get("type"),
+ last_modified_at=datetime.now(),
+ extension=result.get("type"),
web_view_link=result.get("html_url"),
size=result.get("size", None),
)
@@ -1156,8 +1174,8 @@ def fetch_repos(endpoint, headers):
name=remove_special_characters(item.get("name")),
id=f"{item.get('full_name')}:",
is_folder=True,
- last_modified=str(item.get("updated_at")),
- mime_type="repository",
+ last_modified_at=item.get("updated_at"),
+ extension="repository",
web_view_link=item.get("html_url"),
size=item.get("size", None),
)
@@ -1203,8 +1221,8 @@ def fetch_files(endpoint, headers):
name=remove_special_characters(item.get("name")),
id=f"{repo_name}:{item.get('path')}",
is_folder=item.get("type") == "dir",
- last_modified=str(item.get("updated_at")),
- mime_type=item.get("type"),
+ last_modified_at=str(item.get("updated_at")),
+ extension=item.get("type"),
web_view_link=item.get("html_url"),
size=item.get("size", None),
)
diff --git a/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py
new file mode 100644
index 000000000000..5d2ad17fcf31
--- /dev/null
+++ b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py
@@ -0,0 +1,36 @@
+class SyncException(Exception):
+ def __init__(self, message="A sync-related error occurred"):
+ self.message = message
+ super().__init__(self.message)
+
+
+class SyncCreationError(SyncException):
+ def __init__(self, message="An error occurred while creating"):
+ super().__init__(message)
+
+
+class SyncUpdateError(SyncException):
+ def __init__(self, message="An error occurred while updating"):
+ super().__init__(message)
+
+
+class SyncDeleteError(SyncException):
+ def __init__(self, message="An error occurred while deleting"):
+ super().__init__(message)
+
+
+class SyncEmptyCredentials(SyncException):
+ def __init__(
+ self, message="You do not have credentials to access files from this sync."
+ ):
+ super().__init__(message)
+
+
+class SyncNotFoundException(SyncException):
+ def __init__(self, message="The requested sync was not found"):
+ super().__init__(message)
+
+
+class SyncProviderError(SyncException):
+ def __init__(self, message="Unknown provider"):
+ super().__init__(message)
diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py
deleted file mode 100644
index 5fe9f53105b0..000000000000
--- a/backend/api/quivr_api/modules/sync/utils/syncutils.py
+++ /dev/null
@@ -1,414 +0,0 @@
-import io
-import os
-from datetime import datetime, timezone
-from typing import Any, List, Tuple
-from uuid import UUID, uuid4
-
-from quivr_api.celery_config import celery
-from quivr_api.logger import get_logger
-from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
-from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
-from quivr_api.modules.notification.dto.inputs import (
- CreateNotification,
- NotificationUpdatableProperties,
-)
-from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum
-from quivr_api.modules.notification.service.notification_service import (
- NotificationService,
-)
-from quivr_api.modules.sync.dto.inputs import SyncsActiveUpdateInput
-from quivr_api.modules.sync.entity.sync_models import (
- DBSyncFile,
- DownloadedSyncFile,
- SyncFile,
- SyncsActive,
- SyncsUser,
-)
-from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface
-from quivr_api.modules.sync.service.sync_service import (
- ISyncService,
- ISyncUserService,
-)
-from quivr_api.modules.sync.utils.sync import BaseSync
-from quivr_api.modules.upload.service.upload_file import (
- check_file_exists,
- upload_file_storage,
-)
-
-logger = get_logger(__name__)
-
-celery_inspector = celery.control.inspect()
-
-
-# NOTE: we are filtering based on file path names in sync !
-def filter_on_supported_files(
- files: list[SyncFile], existing_files: dict[str, DBSyncFile]
-) -> list[Tuple[SyncFile, DBSyncFile | None]]:
- res = []
- for new_file in files:
- prev_file = existing_files.get(new_file.name, None)
- if (prev_file and prev_file.supported) or prev_file is None:
- res.append((new_file, prev_file))
-
- return res
-
-
-def should_download_file(
- file: SyncFile,
- last_updated_sync_active: datetime | None,
- provider_name: str,
- datetime_format: str,
-) -> bool:
- file_last_modified_utc = datetime.strptime(
- file.last_modified, datetime_format
- ).replace(tzinfo=timezone.utc)
-
- should_download = (
- last_updated_sync_active is None
- or file_last_modified_utc > last_updated_sync_active
- )
-
- # TODO: Handle notion database
- if provider_name == "notion":
- should_download &= file.mime_type != "db"
- else:
- should_download &= not file.is_folder
-
- return should_download
-
-
-class SyncUtils:
- def __init__(
- self,
- sync_user_service: ISyncUserService,
- sync_active_service: ISyncService,
- knowledge_service: KnowledgeService,
- sync_files_repo: SyncFileInterface,
- sync_cloud: BaseSync,
- notification_service: NotificationService,
- brain_vectors: BrainsVectors,
- ) -> None:
- self.sync_user_service = sync_user_service
- self.sync_active_service = sync_active_service
- self.knowledge_service = knowledge_service
- self.sync_files_repo = sync_files_repo
- self.sync_cloud = sync_cloud
- self.notification_service = notification_service
- self.brain_vectors = brain_vectors
-
- # TODO: This modifies the file, we should treat it as such
- def create_sync_bulk_notification(
- self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID
- ) -> list[SyncFile]:
- res = []
- # TODO: bulk insert in batch
- for file in files:
- upload_notification = self.notification_service.add_notification(
- CreateNotification(
- user_id=current_user,
- bulk_id=bulk_id,
- status=NotificationsStatusEnum.INFO,
- title=file.name,
- category="sync",
- brain_id=str(brain_id),
- )
- )
- file.notification_id = upload_notification.id
- res.append(file)
- return res
-
- async def download_file(
- self, file: SyncFile, credentials: dict[str, Any]
- ) -> DownloadedSyncFile:
- logger.info(f"Downloading {file} using {self.sync_cloud}")
- file_response = await self.sync_cloud.adownload_file(credentials, file)
- logger.debug(f"Fetch sync file response: {file_response}")
- file_name = str(file_response["file_name"])
- raw_data = file_response["content"]
- file_data = (
- io.BufferedReader(raw_data) # type: ignore
- if isinstance(raw_data, io.BytesIO)
- else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore
- )
- extension = os.path.splitext(file_name)[-1].lower()
- dfile = DownloadedSyncFile(
- file_name=file_name,
- file_data=file_data,
- extension=extension,
- )
- logger.debug(f"Successfully downloaded sync file : {dfile}")
- return dfile
-
- # TODO: REDO THIS MESS !!!!
- # REMOVE ALL SYNC TABLES and start from scratch
-
- async def process_sync_file(
- self,
- file: SyncFile,
- previous_file: DBSyncFile | None,
- current_user: SyncsUser,
- sync_active: SyncsActive,
- ):
- logger.info("Processing file: %s", file.name)
- brain_id = sync_active.brain_id
- source, source_link = self.sync_cloud.name, file.web_view_link
- downloaded_file = await self.download_file(file, current_user.credentials)
- storage_path = f"{brain_id}/{downloaded_file.file_name}"
- exists_in_storage = check_file_exists(str(brain_id), file.name)
-
- if downloaded_file.extension not in [
- ".pdf",
- ".txt",
- ".md",
- ".csv",
- ".docx",
- ".xlsx",
- ".pptx",
- ".doc",
- ]:
- raise ValueError(f"Incompatible file extension for {downloaded_file}")
-
- response = await upload_file_storage(
- downloaded_file.file_data,
- storage_path,
- upsert=exists_in_storage,
- )
- assert response, f"Error uploading {downloaded_file} to {storage_path}"
- self.notification_service.update_notification_by_id(
- file.notification_id,
- NotificationUpdatableProperties(
- status=NotificationsStatusEnum.SUCCESS,
- description="File downloaded successfully",
- ),
- )
- # TODO : why knowledge + syncfile, drop syncfile ...
- # FIXME : Simplify this logic in KMS plzzz
- sync_file_db = self.sync_files_repo.update_or_create_sync_file(
- file=file,
- previous_file=previous_file,
- sync_active=sync_active,
- supported=True,
- )
- knowledge = await self.knowledge_service.update_or_create_knowledge_sync(
- brain_id=brain_id,
- file=file,
- new_sync_file=sync_file_db,
- prev_sync_file=previous_file,
- downloaded_file=downloaded_file,
- source=source,
- source_link=source_link,
- user_id=current_user.user_id,
- )
-
- # Send file for processing
- celery.send_task(
- "process_file_task",
- kwargs={
- "brain_id": brain_id,
- "knowledge_id": knowledge.id,
- "file_name": storage_path,
- "file_original_name": file.name,
- "source": source,
- "source_link": source_link,
- "notification_id": file.notification_id,
- },
- )
- return file
-
- async def process_sync_files(
- self,
- files: List[SyncFile],
- current_user: SyncsUser,
- sync_active: SyncsActive,
- ):
- logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}")
- current_user.credentials = self.sync_cloud.check_and_refresh_access_token(
- current_user.credentials
- )
-
- bulk_id = uuid4()
- downloaded_files = []
- list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id)
- existing_files = {f.path: f for f in list_existing_files}
-
- supported_files = filter_on_supported_files(files, existing_files)
-
- files = self.create_sync_bulk_notification(
- files, current_user.user_id, sync_active.brain_id, bulk_id
- )
-
- for file, prev_file in supported_files:
- try:
- result = await self.process_sync_file(
- file=file,
- previous_file=prev_file,
- current_user=current_user,
- sync_active=sync_active,
- )
- if result is not None:
- downloaded_files.append(result)
-
- self.notification_service.update_notification_by_id(
- file.notification_id,
- NotificationUpdatableProperties(
- status=NotificationsStatusEnum.SUCCESS,
- description="File downloaded successfully",
- ),
- )
-
- except Exception as e:
- logger.error(
- "An error occurred while syncing %s files: %s",
- self.sync_cloud.name,
- e,
- )
- # TODO: this process_sync_file could fail for a LOT of reason redo this logic
- # File isn't supported so we set it as so ?
- self.sync_files_repo.update_or_create_sync_file(
- file=file,
- sync_active=sync_active,
- previous_file=prev_file,
- supported=False,
- )
- self.notification_service.update_notification_by_id(
- file.notification_id,
- NotificationUpdatableProperties(
- status=NotificationsStatusEnum.ERROR,
- description="Error downloading file",
- ),
- )
-
- return {"downloaded_files": downloaded_files}
-
- async def get_files_to_download(
- self, sync_active: SyncsActive, user_sync: SyncsUser
- ) -> list[SyncFile]:
- # Get the folder id from the settings from sync_active
- folders = sync_active.settings.get("folders", [])
- files_ids = sync_active.settings.get("files", [])
-
- files = await self.get_syncfiles_from_ids(
- user_sync.credentials,
- files_ids=files_ids,
- folder_ids=folders,
- sync_user_id=user_sync.id,
- )
-
- logger.debug(f"original files to download for {sync_active.id} : {files}")
-
- last_synced_time = (
- datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc)
- if sync_active.last_synced
- else None
- )
-
- files_ids = [
- file
- for file in files
- if should_download_file(
- file=file,
- last_updated_sync_active=last_synced_time,
- provider_name=self.sync_cloud.lower_name,
- datetime_format=self.sync_cloud.datetime_format,
- )
- ]
-
- logger.debug(f"filter files to download for {sync_active} : {files_ids}")
- return files_ids
-
- async def get_syncfiles_from_ids(
- self,
- credentials: dict[str, Any],
- files_ids: list[str],
- folder_ids: list[str],
- sync_user_id: int,
- ) -> list[SyncFile]:
- files = []
- if self.sync_cloud.lower_name == "notion":
- files_ids += folder_ids
-
- for folder_id in folder_ids:
- logger.debug(
- f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}"
- )
- files.extend(
- await self.sync_cloud.aget_files(
- credentials=credentials,
- sync_user_id=sync_user_id,
- folder_id=folder_id,
- recursive=True,
- )
- )
- if len(files_ids) > 0:
- files.extend(
- await self.sync_cloud.aget_files_by_id(
- credentials=credentials,
- file_ids=files_ids,
- )
- )
- return files
-
- async def direct_sync(
- self,
- sync_active: SyncsActive,
- user_sync: SyncsUser,
- files_ids: list[str],
- folder_ids: list[str],
- ):
- files = await self.get_syncfiles_from_ids(
- user_sync.credentials, files_ids, folder_ids, user_sync.id
- )
- processed_files = await self.process_sync_files(
- files=files,
- current_user=user_sync,
- sync_active=sync_active,
- )
-
- # Update the last_synced timestamp
- self.sync_active_service.update_sync_active(
- sync_active.id,
- SyncsActiveUpdateInput(
- last_synced=datetime.now().astimezone().isoformat(), force_sync=False
- ),
- )
- logger.info(
- f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.",
- )
- return processed_files
-
- async def sync(
- self,
- sync_active: SyncsActive,
- user_sync: SyncsUser,
- ):
- """
- Check if the Specific sync has not been synced and download the folders and files based on the settings.
-
- Args:
- sync_active_id (int): The ID of the active sync.
- user_id (str): The user ID associated with the active sync.
- """
- logger.info(
- "Starting %s sync for sync_active: %s",
- self.sync_cloud.lower_name,
- sync_active,
- )
-
- files_to_download = await self.get_files_to_download(sync_active, user_sync)
- processed_files = await self.process_sync_files(
- files=files_to_download,
- current_user=user_sync,
- sync_active=sync_active,
- )
-
- # Update the last_synced timestamp
- self.sync_active_service.update_sync_active(
- sync_active.id,
- SyncsActiveUpdateInput(
- last_synced=datetime.now().astimezone().isoformat(), force_sync=False
- ),
- )
- logger.info(
- f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.",
- )
- return processed_files
diff --git a/backend/api/quivr_api/modules/user/entity/user_identity.py b/backend/api/quivr_api/modules/user/entity/user_identity.py
index 3f734f1a66cc..22e4940af882 100644
--- a/backend/api/quivr_api/modules/user/entity/user_identity.py
+++ b/backend/api/quivr_api/modules/user/entity/user_identity.py
@@ -2,6 +2,7 @@
from uuid import UUID, uuid4
from pydantic import BaseModel
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
from sqlmodel import Field, Relationship, SQLModel
@@ -17,6 +18,10 @@ class User(SQLModel, table=True):
onboarded: bool | None = None
chats: List["Chat"] | None = Relationship(back_populates="user") # type: ignore
notion_syncs: List["NotionSyncFile"] | None = Relationship(back_populates="user") # type: ignore
+ brains: List["Brain"] = Relationship(
+ back_populates="users",
+ link_model=BrainUserDB,
+ )
class UserIdentity(BaseModel):
diff --git a/backend/api/quivr_api/modules/vector/entity/vector.py b/backend/api/quivr_api/modules/vector/entity/vector.py
index b0583f64034a..c8c2bec775ff 100644
--- a/backend/api/quivr_api/modules/vector/entity/vector.py
+++ b/backend/api/quivr_api/modules/vector/entity/vector.py
@@ -4,7 +4,6 @@
from pgvector.sqlalchemy import Vector as PGVector
from pydantic import BaseModel
from quivr_api.models.settings import settings
-from sqlalchemy import Column
from sqlmodel import JSON, Column, Field, SQLModel, text
from sqlmodel import UUID as PGUUID
@@ -20,10 +19,10 @@ class Vector(SQLModel, table=True):
),
)
content: str = Field(default=None)
- metadata_: dict = Field(default={}, sa_column=Column("metadata", JSON, default={}))
embedding: Optional[PGVector] = Field(
sa_column=Column(PGVector(settings.embedding_dim)),
- ) # Verify with text_ada -> put it in Env variabme
+ )
+ metadata_: dict = Field(default={}, sa_column=Column("metadata", JSON, default={}))
knowledge_id: UUID = Field(default=None, foreign_key="knowledge.id")
class Config:
diff --git a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py
index a7a5a4b63b85..73cf281f11eb 100644
--- a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py
+++ b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py
@@ -5,42 +5,44 @@
from quivr_api.modules.dependencies import BaseRepository
from quivr_api.modules.vector.entity.vector import SimilaritySearchOutput, Vector
from sqlalchemy import exc, text
-from sqlmodel import Session, select
+from sqlmodel import select
+from sqlmodel.ext.asyncio.session import AsyncSession
logger = get_logger(__name__)
class VectorRepository(BaseRepository):
- def __init__(self, session: Session):
+ def __init__(self, session: AsyncSession):
super().__init__(session)
self.session = session
- def create_vectors(self, new_vectors: List[Vector]) -> List[Vector]:
+ async def create_vectors(
+ self, new_vectors: List[Vector], autocommit: bool
+ ) -> List[Vector]:
try:
- # Use SQLAlchemy session to add and commit the new vector
self.session.add_all(new_vectors)
- self.session.commit()
+ # FIXME: @AmineDiro : check if this is possible with nested transactions
+ if autocommit:
+ await self.session.commit()
+ for vector in new_vectors:
+ await self.session.refresh(vector)
+ await self.session.flush()
+ return new_vectors
except exc.IntegrityError:
# Rollback the session if there’s an IntegrityError
- self.session.rollback()
+ await self.session.rollback()
raise Exception("Integrity error occurred while creating vector.")
except Exception as e:
- self.session.rollback()
+ await self.session.rollback()
print(f"Error: {e}")
raise Exception(f"An error occurred while creating vector: {e}")
- # Refresh the session to get any updated fields (like auto-generated IDs)
- for vector in new_vectors:
- self.session.refresh(vector)
-
- return new_vectors
-
- def get_vectors_by_knowledge_id(self, knowledge_id: UUID) -> Sequence[Vector]:
+ async def get_vectors_by_knowledge_id(self, knowledge_id: UUID) -> Sequence[Vector]:
query = select(Vector).where(Vector.knowledge_id == knowledge_id)
- results = self.session.execute(query)
+ results = await self.session.execute(query)
return results.scalars().all()
- def similarity_search(
+ async def similarity_search(
self,
query_embedding: List[float],
brain_id: UUID,
@@ -94,13 +96,13 @@ def similarity_search(
""")
params = {
- "query_embedding": query_embedding,
+ "query_embedding": str(query_embedding),
"p_brain_id": brain_id,
"k": k,
"max_chunk_sum": max_chunk_sum,
}
- result = self.session.execute(sql_query, params=params)
+ result = await self.session.execute(sql_query, params=params)
full_results = result.all()
formated_result = [
SimilaritySearchOutput(
diff --git a/backend/api/quivr_api/modules/vector/service/vector_service.py b/backend/api/quivr_api/modules/vector/service/vector_service.py
index 6a775dd6e1fb..2a0d577c44f1 100644
--- a/backend/api/quivr_api/modules/vector/service/vector_service.py
+++ b/backend/api/quivr_api/modules/vector/service/vector_service.py
@@ -13,18 +13,26 @@
class VectorService(BaseService[VectorRepository]):
repository_cls = VectorRepository
- _embedding: Embeddings = get_embedding_client()
- def __init__(self, repository: VectorRepository):
+ def __init__(
+ self, repository: VectorRepository, embedder: Embeddings | None = None
+ ):
+ if embedder is None:
+ self.embedder = get_embedding_client()
+ else:
+ self.embedder = embedder
+
self.repository = repository
- def create_vectors(self, chunks: List[Document], knowledge_id: UUID) -> List[UUID]:
+ async def create_vectors(
+ self, chunks: List[Document], knowledge_id: UUID, autocommit: bool = True
+ ) -> List[UUID]:
# Vector is created upon the user's first question asked
logger.info(
f"New vector entry in vectors table for knowledge_id {knowledge_id}"
)
# FIXME ADD a check in case of failure
- embeddings = self._embedding.embed_documents(
+ embeddings = self.embedder.embed_documents(
[chunk.page_content for chunk in chunks]
)
new_vectors = [
@@ -36,14 +44,14 @@ def create_vectors(self, chunks: List[Document], knowledge_id: UUID) -> List[UUI
)
for i, chunk in enumerate(chunks)
]
- created_vector = self.repository.create_vectors(new_vectors)
+ created_vector = await self.repository.create_vectors(new_vectors, autocommit)
return [vector.id for vector in created_vector if vector.id]
- def similarity_search(self, query: str, brain_id: UUID, k: int = 40):
- vectors = self._embedding.embed_documents([query])
+ async def similarity_search(self, query: str, brain_id: UUID, k: int = 40):
+ vectors = self.embedder.embed_documents([query])
query_embedding = vectors[0]
- vectors = self.repository.similarity_search(
+ vectors = await self.repository.similarity_search(
query_embedding=query_embedding, brain_id=brain_id, k=k
)
diff --git a/backend/api/quivr_api/modules/vector/tests/test_vectors.py b/backend/api/quivr_api/modules/vector/tests/test_vectors.py
index ce4b6f04bf2a..d64d772b04bb 100644
--- a/backend/api/quivr_api/modules/vector/tests/test_vectors.py
+++ b/backend/api/quivr_api/modules/vector/tests/test_vectors.py
@@ -1,10 +1,13 @@
from typing import List, Tuple
import pytest
+import pytest_asyncio
from langchain.docstore.document import Document
from langchain_core.embeddings import DeterministicFakeEmbedding
-from sqlmodel import Session, select
+from sqlmodel import select
+from sqlmodel.ext.asyncio.session import AsyncSession
+from quivr_api.models.settings import settings
from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.user.entity.user_identity import User
@@ -19,13 +22,13 @@
@pytest.fixture(scope="module")
def embedder():
- return DeterministicFakeEmbedding(size=1536)
+ return DeterministicFakeEmbedding(size=settings.embedding_dim)
-@pytest.fixture(scope="function")
-def test_data(sync_session: Session, embedder) -> TestData:
+@pytest_asyncio.fixture(scope="function")
+async def test_data(session: AsyncSession, embedder) -> TestData:
user_1 = (
- sync_session.exec(select(User).where(User.email == "admin@quivr.app"))
+ await session.exec(select(User).where(User.email == "admin@quivr.app"))
).one()
assert user_1.id
vectors = embedder.embed_documents(
@@ -51,9 +54,9 @@ def test_data(sync_session: Session, embedder) -> TestData:
brains=[brain_1],
user_id=user_1.id,
)
- sync_session.add(knowledge_1)
- sync_session.commit()
- sync_session.refresh(knowledge_1)
+ session.add(knowledge_1)
+ await session.commit()
+ await session.refresh(knowledge_1)
assert knowledge_1.id, "Knowledge ID not generated"
@@ -71,39 +74,44 @@ def test_data(sync_session: Session, embedder) -> TestData:
knowledge_id=knowledge_1.id,
)
- sync_session.add(vector_1)
- sync_session.add(vector_2)
+ session.add(vector_1)
+ session.add(vector_2)
- sync_session.commit()
+ await session.commit()
return ([vector_1, vector_2], knowledge_1, brain_1)
-def test_create_vectors_service(sync_session: Session, test_data: TestData, embedder):
+@pytest.mark.asyncio(loop_scope="session")
+async def test_create_vectors_service(
+ session: AsyncSession, test_data: TestData, embedder
+):
_, knowledge, _ = test_data
assert knowledge.id
- repo = VectorRepository(sync_session)
+ repo = VectorRepository(session)
service = VectorService(repo)
- service._embedding = embedder
+ service.embedder = embedder
chunk_1 = Document(page_content="I love eating pasta with tomato sauce")
chunk_2 = Document(page_content="I love eating pizza with extra cheese")
# Create vectors from documents
- new_vectors_id: List[int] = service.create_vectors([chunk_1, chunk_2], knowledge.id) # type: ignore
+ new_vectors_id: List[int] = await service.create_vectors(
+ [chunk_1, chunk_2], knowledge.id
+ ) # type: ignore
# Verify the correct number of vectors were created
assert len(new_vectors_id) == 2, f"Expected 2 vectors, got {len(new_vectors_id)}"
# Verify the content of the first vector matches the corresponding document
vector_1_content = (
- sync_session.execute(select(Vector).where(Vector.id == new_vectors_id[0]))
+ (await session.execute(select(Vector).where(Vector.id == new_vectors_id[0])))
.scalars()
.first()
.content
)
vector_2_content = (
- sync_session.execute(select(Vector).where(Vector.id == new_vectors_id[1]))
+ (await session.execute(select(Vector).where(Vector.id == new_vectors_id[1])))
.scalars()
.first()
.content
@@ -117,12 +125,13 @@ def test_create_vectors_service(sync_session: Session, test_data: TestData, embe
), "The content of the second vector does not match"
-def test_get_vectors_by_knowledge_id(sync_session: Session, test_data: TestData):
+@pytest.mark.asyncio(loop_scope="session")
+async def test_get_vectors_by_knowledge_id(session: AsyncSession, test_data: TestData):
vectors, knowledge, _ = test_data
assert knowledge.id
- repo = VectorRepository(sync_session)
- results = repo.get_vectors_by_knowledge_id(knowledge.id) # type: ignore
+ repo = VectorRepository(session)
+ results = await repo.get_vectors_by_knowledge_id(knowledge.id) # type: ignore
assert len(results) == 2, f"Expected 2 vectors, got {len(results)}"
assert (
@@ -133,71 +142,75 @@ def test_get_vectors_by_knowledge_id(sync_session: Session, test_data: TestData)
), f"Expected {vectors[1].content}, got {results[1].content}"
-def test_service_similarity_search(
- sync_session: Session, test_data: TestData, embedder
+@pytest.mark.asyncio(loop_scope="session")
+async def test_service_similarity_search(
+ session: AsyncSession, test_data: TestData, embedder
):
vectors, knowledge, brain = test_data
assert knowledge.id
assert brain.brain_id
- repo = VectorRepository(sync_session)
- service = VectorService(repo)
- service._embedding = embedder
+ repo = VectorRepository(session)
+ service = VectorService(repo, embedder=embedder)
k = 2
- results = service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore
+ results = await service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore
assert len(results) == k
assert results[0].page_content == vectors[0].content
- results = service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore
+ results = await service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore
assert results[0].page_content == vectors[1].content
k = 1
- results = service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore
+ results = await service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore
assert len(results) == k
assert results[0].page_content == vectors[0].content
- results = service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore
+ results = await service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore
assert results[0].page_content == vectors[1].content
-def test_similarity_search(sync_session: Session, test_data: TestData):
+@pytest.mark.asyncio(loop_scope="session")
+async def test_similarity_search(session: AsyncSession, test_data: TestData):
vectors, knowledge, brain = test_data
assert knowledge.id
assert brain.brain_id
- repo = VectorRepository(sync_session)
+ repo = VectorRepository(session)
k = 2
- results = repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore
+ results = await repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore
assert len(results) == k
assert results[0].content == vectors[0].content
- results = repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore
+ results = await repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore
assert results[0].content == vectors[1].content
k = 1
- results = repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore
+ results = await repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore
assert len(results) == k
assert results[0].content == vectors[0].content
- results = repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore
+ results = await repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore
assert results[0].content == vectors[1].content
-def test_similarity_with_oversized_chunk(sync_session: Session, test_data: TestData):
+@pytest.mark.asyncio(loop_scope="session")
+async def test_similarity_with_oversized_chunk(
+ session: AsyncSession, test_data: TestData
+):
vectors, knowledge, brain = test_data
assert knowledge.id
assert brain.brain_id
- repo = VectorRepository(sync_session)
+ repo = VectorRepository(session)
k = 2
- results = repo.similarity_search(
+ results = await repo.similarity_search(
vectors[0].embedding, # type: ignore
brain.brain_id,
k=k,
diff --git a/backend/api/quivr_api/utils/knowledge_utils.py b/backend/api/quivr_api/utils/knowledge_utils.py
new file mode 100644
index 000000000000..5a5a24dc7fb6
--- /dev/null
+++ b/backend/api/quivr_api/utils/knowledge_utils.py
@@ -0,0 +1,10 @@
+from quivr_core.files.file import FileExtension
+
+
+def parse_file_extension(file_name: str) -> FileExtension | str:
+ if file_name.startswith(".") and file_name.count(".") == 1:
+ return ""
+ if "." not in file_name or file_name.endswith("."):
+ return ""
+
+ return FileExtension(f".{file_name.split('.')[-1]}")
diff --git a/backend/api/quivr_api/vectorstore/supabase.py b/backend/api/quivr_api/vectorstore/supabase.py
index dfca682f0099..a4dcd8f79183 100644
--- a/backend/api/quivr_api/vectorstore/supabase.py
+++ b/backend/api/quivr_api/vectorstore/supabase.py
@@ -69,18 +69,17 @@ def find_brain_closest_query(
]
return brain_details
- def similarity_search(
+ async def asimilarity_search(
self,
query: str,
k: int = 40,
- table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any,
) -> List[Document]:
logger.debug(f"Similarity search for query: {query}")
assert self.brain_id, "Brain ID is required for similarity search"
- match_result = self.vector_service.similarity_search(
+ match_result = await self.vector_service.similarity_search(
query, brain_id=self.brain_id, k=k
)
diff --git a/backend/benchmarks/benchmark_kms.sh b/backend/benchmarks/benchmark_kms.sh
new file mode 100755
index 000000000000..21976ec1ec78
--- /dev/null
+++ b/backend/benchmarks/benchmark_kms.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+# Function to handle cleanup on exit
+cleanup() {
+ echo "Cleaning up..."
+ # Stop Uvicorn server if running
+ if [[ ! -z "$UVICORN_PID" ]]; then
+ kill "$UVICORN_PID"
+ wait "$UVICORN_PID"
+ fi
+ exit 0
+}
+
+# Trap signals (like Ctrl+C, SIGTERM) to run the cleanup function
+#trap cleanup SIGINT SIGTERM
+
+# Reset and start Supabase
+supabase db reset && supabase stop && supabase start
+
+# Remove old benchmark data
+rm -f benchmarks/data.json
+
+# Load new data
+rye run python benchmarks/load_data.py
+
+# Start Uvicorn server in the background
+LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop &
+UVICORN_PID=$!
+
+# Wait a bit to ensure the server is running
+sleep 1
+
+# Run Locust for benchmarking
+rye run locust -f benchmarks/locustfile_kms.py -H http://localhost:5050
+
+# Wait for all background processes (including Uvicorn) to finish
+wait "$UVICORN_PID"
diff --git a/backend/benchmarks/load_data.py b/backend/benchmarks/load_data.py
new file mode 100644
index 000000000000..4e47b2331754
--- /dev/null
+++ b/backend/benchmarks/load_data.py
@@ -0,0 +1,173 @@
+import os
+from typing import List
+from uuid import UUID
+
+import numpy as np
+from pydantic import BaseModel
+from quivr_api.logger import get_logger
+from quivr_api.models.settings import settings
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_api.modules.user.entity.user_identity import User
+from quivr_api.modules.vector.entity.vector import Vector
+from sqlmodel import Session, create_engine, select
+
+N_BRAINS = 100
+N_USERS = 1
+KNOWLEDGE_PER_BRAIN_MAX = 50
+KNOWLEDGE_PER_BRAIN_MIN = 20
+MEAN_VECTORS_PER_KNOWLEDGE = 500
+STD_VECTORS_PER_KNOWLEDGE = 200
+SAVE_PATH = "benchmarks/data.json"
+
+
+logger = get_logger("load_testing")
+pg_database_base_url = "postgresql://postgres:postgres@localhost:54322/postgres"
+
+
+class Data(BaseModel):
+ brains_ids: List[UUID]
+ knowledges_ids: List[UUID]
+ vectors_ids: List[UUID]
+
+
+def setup_brains(session: Session, user_id: UUID):
+ brains = []
+ brains_users = []
+
+ for idx in range(N_BRAINS):
+ brain = Brain(
+ name=f"brain_{idx}",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ status="private",
+ )
+ brains.append(brain)
+
+ session.add_all(brains)
+ session.commit()
+ [session.refresh(b) for b in brains]
+
+ for brain in brains:
+ brain_user = BrainUserDB(
+ brain_id=brain.brain_id,
+ user_id=user_id,
+ default_brain=True,
+ rights="Owner",
+ )
+ brains_users.append(brain_user)
+ session.add_all(brains_users)
+ session.commit()
+
+ return brains
+
+
+def setup_knowledge_brain(session: Session, brain: Brain, n_km: int, user_id: UUID):
+ kms = []
+
+ for idx in range(n_km):
+ knowledge = KnowledgeDB(
+ file_name=f"test_file_{idx}_brain_{idx}",
+ extension="txt",
+ status="UPLOADED",
+ source="test_source",
+ source_link="test_source_link",
+ file_size=100,
+ file_sha1=f"{os.urandom(128)}",
+ brains=[brain],
+ user_id=user_id,
+ )
+ kms.append(knowledge)
+
+ return kms
+
+
+def setup_vectors_knowledge(session: Session, knowledge: KnowledgeDB, n_vecs: int):
+ vecs = []
+ assert knowledge.id
+ for idx in range(n_vecs):
+ vector = Vector(
+ content=f"vector_{idx}",
+ metadata_={"file_name": f"{knowledge.file_name}", "chunk_size": 96},
+ embedding=np.random.randn(settings.embedding_dim), # type: ignore
+ knowledge_id=knowledge.id,
+ )
+
+ vecs.append(vector)
+
+ return vecs
+
+
+def setup_all(session: Session):
+ user = (session.exec(select(User).where(User.email == "admin@quivr.app"))).one()
+ assert user.id
+ brains = setup_brains(session, user.id)
+ logger.info(f"Inserted all {len(brains)} brains")
+ # all_km = []
+ # all_vecs = []
+ # for brain in brains:
+ # assert brain
+ # n_knowledges = random.randint(KNOWLEDGE_PER_BRAIN_MIN, KNOWLEDGE_PER_BRAIN_MAX)
+ # knowledges = setup_knowledge_brain(
+ # session, brain=brain, n_km=n_knowledges, user_id=user.id
+ # )
+ # logger.info(f"Inserted all {len(knowledges)} kms for {brain.name}")
+ # all_km.extend(knowledges)
+
+ # session.add_all(all_km)
+ # session.commit()
+ # [session.refresh(b) for b in all_km]
+
+ # n_vecs = np.random.normal(
+ # MEAN_VECTORS_PER_KNOWLEDGE, STD_VECTORS_PER_KNOWLEDGE, len(all_km)
+ # ).tolist()
+ # for n_vecs_km, knowledge in zip(n_vecs, all_km, strict=False):
+ # vecs = setup_vectors_knowledge(session, knowledge, int(n_vecs_km))
+ # all_vecs.extend(vecs)
+
+ # logger.info(f"Inserting all {len(all_vecs)} vecs for knowledge {knowledge.id}")
+ # session.add_all(all_vecs)
+ # session.commit()
+ # [session.refresh(b) for b in all_km]
+ # [session.refresh(b) for b in all_vecs]
+
+ return Data(
+ brains_ids=[b.brain_id for b in brains],
+ knowledges_ids=[], # [k.id for k in all_km],
+ vectors_ids=[], # [v.id for v in all_vecs],
+ )
+
+
+def setup_data():
+ logger.info(f"""Starting load data script
+ N_BRAINS = {N_BRAINS},
+ N_USERS = {N_USERS},
+ KNOWLEDGE_PER_BRAIN_MIN = {KNOWLEDGE_PER_BRAIN_MIN},
+ KNOWLEDGE_PER_BRAIN_MAX = {KNOWLEDGE_PER_BRAIN_MAX },
+ MEAN_VECTORS_PER_KNOWLEDGE = {MEAN_VECTORS_PER_KNOWLEDGE}
+ STD_VECTORS_PER_KNOWLEDGE ={STD_VECTORS_PER_KNOWLEDGE}
+ """)
+ sync_engine = create_engine(
+ pg_database_base_url,
+ echo=True if os.getenv("ORM_DEBUG") else False,
+ future=True,
+ # NOTE: pessimistic bound on
+ pool_pre_ping=True,
+ pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
+ pool_recycle=1800,
+ )
+
+ with Session(sync_engine, expire_on_commit=False, autoflush=False) as session:
+ data = setup_all(session)
+
+ logger.info(
+ f"Insert {len(data.brains_ids)} brains, {len(data.knowledges_ids)} knowledges, {len(data.vectors_ids)} vectors"
+ )
+
+ with open(SAVE_PATH, "w") as f:
+ f.write(data.model_dump_json())
+
+
+if __name__ == "__main__":
+ setup_data()
diff --git a/backend/benchmarks/locustfile.py b/backend/benchmarks/locustfile.py
new file mode 100644
index 000000000000..065d100501f5
--- /dev/null
+++ b/backend/benchmarks/locustfile.py
@@ -0,0 +1,101 @@
+import os
+import random
+import tempfile
+from typing import List
+from uuid import UUID
+
+from locust import between, task
+from locust.contrib.fasthttp import FastHttpUser
+from pydantic import BaseModel
+
+FILE_SIZE = 1024 * 1024
+
+
+class Data(BaseModel):
+ brains_ids: List[UUID]
+ knowledges_ids: List[UUID]
+ vectors_ids: List[UUID]
+
+
+with open("data.json", "r") as f:
+ data = Data.model_validate_json(f.read())
+
+
+class QuivrUser(FastHttpUser):
+ wait_time = between(0.2, 1) # Wait 1-5 seconds between tasks
+ host = "http://localhost:5050"
+ auth_headers = {
+ "Authorization": "Bearer 123",
+ }
+ query_params = "?brain_id=40ba47d7-51b2-4b2a-9247-89e29619efb0"
+
+ def on_start(self):
+ # Prepare the file to be uploaded
+ self.file_path = "test_file.txt"
+ with open(self.file_path, "wb") as f:
+ f.write(os.urandom(1024)) # 1 KB
+
+ @task(10)
+ def upload_file(self):
+ with tempfile.NamedTemporaryFile(suffix="_file.txt") as fp:
+ fp.write(os.urandom(1024)) # 1 KB
+ fp.flush()
+ files = {
+ "uploadFile": fp,
+ }
+ response = self.client.post(
+ f"/upload{self.query_params}",
+ files=files,
+ headers={"Content-Type": "multipart/form-data", **self.auth_headers},
+ )
+
+ # Check if the upload was successful
+ if response.status_code == 200:
+ print(f"File uploaded successfully. Response: {response.text}")
+ else:
+ print(f"File upload failed. Status code: {response.status_code}")
+
+ upload_file.__name__ = f"{upload_file.__name__}_1MB"
+
+ @task(10)
+ def get_brains(self):
+ self.client.get("/brains", headers=self.auth_headers)
+
+ @task(10)
+ def get_brain_by_id(self):
+ random_brain_id = random.choice(data.brains_ids)
+ self.client.get(f"/brains/{random_brain_id}", headers=self.auth_headers)
+
+ @task(10)
+ def get_knowledge_by_id(self):
+ random_brain_id = random.choice(data.brains_ids)
+ self.client.get(
+ f"/knowledge?brain_id={random_brain_id}", headers=self.auth_headers
+ )
+
+ @task(2)
+ def get_knowledge_signed_url(self):
+ random_knowledge = random.choice(data.knowledges_ids)
+ self.client.get(
+ f"/knowledge/{random_knowledge}/signed_download_url",
+ headers=self.auth_headers,
+ )
+
+ @task(1)
+ def delete_knowledge(self):
+ random_knowledge = random.choice(data.knowledges_ids)
+ data.knowledges_ids.remove(random_knowledge)
+ self.client.delete(
+ f"/knowledge/{random_knowledge}",
+ headers=self.auth_headers,
+ )
+
+ def on_stop(self):
+ # Clean up the test file
+ if os.path.exists(self.file_path):
+ os.remove(self.file_path)
+
+
+# GET Knowledge brain
+# DELETE knowledge cascades on vectors
+# GET /knowledge/{knowledge_id}/signed_download_url
diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py
new file mode 100644
index 000000000000..8e31def3d55a
--- /dev/null
+++ b/backend/benchmarks/locustfile_kms.py
@@ -0,0 +1,261 @@
+import json
+import os
+import random
+from typing import List
+from uuid import UUID
+
+from locust import between, task
+from locust.contrib.fasthttp import FastHttpUser
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
+from quivr_api.modules.dependencies import get_supabase_client
+from quivr_api.modules.knowledge.dto.inputs import (
+ KnowledgeUpdate,
+ LinkKnowledgeBrain,
+ UnlinkKnowledgeBrain,
+)
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+from quivr_api.modules.user.entity.user_identity import User
+from sqlmodel import Session, create_engine, select, text
+
+pg_database_base_url = "postgresql://postgres:postgres@localhost:54322/postgres"
+
+
+load_params = {
+ "n_brains": 100,
+ "file_size": 1024 * 1024, # 1MB
+ "parent_prob": 0.3,
+ "folder_prob": 0.2,
+ "km_root_prob": 0.2,
+ # Task rates
+ "create_km_rate": 10,
+ "list_km_rate": 10,
+ "link_brain_rate": 5,
+ "max_link_brains": 3,
+ "delete_km_rate": 2,
+ "unlink_brain_rate": 5,
+ "update_km_rate": 2,
+}
+
+
+all_kms: List[KnowledgeDTO] = []
+brains_ids: List[UUID] = []
+
+
+def is_folder() -> bool:
+ return random.random() < load_params["folder_prob"]
+
+
+def get_parent_id() -> str | None:
+ if random.random() < load_params["parent_prob"] and len(all_kms) > 0:
+ folders = list(filter(lambda k: k.is_folder, all_kms))
+ if len(folders) == 0:
+ return None
+ folder = random.choice(folders)
+ return str(folder.id)
+ return None
+
+
+def setup_brains(session: Session, user_id: UUID) -> List[Brain]:
+ brains = []
+ brains_users = []
+
+ for idx in range(load_params["n_brains"]):
+ brain = Brain(
+ name=f"brain_{idx}",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ status="private",
+ )
+ brains.append(brain)
+
+ session.add_all(brains)
+ session.commit()
+ [session.refresh(b) for b in brains]
+
+ for brain in brains:
+ brain_user = BrainUserDB(
+ brain_id=brain.brain_id,
+ user_id=user_id,
+ default_brain=True,
+ rights="Owner",
+ )
+ brains_users.append(brain_user)
+ session.add_all(brains_users)
+ session.commit()
+
+ return brains
+
+
+class QuivrUser(FastHttpUser):
+ # Wait 1-5 seconds between tasks
+ wait_time = between(1, 5)
+ host = "http://localhost:5050"
+ auth_headers = {
+ "Authorization": "Bearer 123",
+ }
+
+ data = os.urandom(load_params["file_size"])
+ sync_engine = create_engine(
+ pg_database_base_url,
+ echo=False,
+ )
+
+ def on_start(self) -> None:
+ global brains_ids
+
+ with Session(self.sync_engine) as session:
+ user = (
+ session.exec(select(User).where(User.email == "admin@quivr.app"))
+ ).one()
+ assert user.id
+ brains = setup_brains(session, user.id)
+ brains_ids = [b.brain_id for b in brains] # type: ignore
+
+ @task(load_params["create_km_rate"])
+ def create_knowledge(self):
+ km_data = {
+ "file_name": "test_file.txt",
+ "source": "local",
+ "is_folder": is_folder(),
+ "parent_id": get_parent_id(),
+ }
+
+ multipart_data = {
+ "knowledge_data": (None, json.dumps(km_data), "application/json"),
+ "file": ("test_file.txt", self.data, "application/octet-stream"),
+ }
+ response = self.client.post(
+ "/knowledge/",
+ headers=self.auth_headers,
+ files=multipart_data,
+ )
+ returned_km = KnowledgeDTO.model_validate_json(response.text)
+ all_kms.append(returned_km)
+
+ @task(load_params["link_brain_rate"])
+ def link_to_brains(self):
+ global all_kms
+ if len(all_kms) == 0:
+ return
+ nb_brains = random.randint(1, load_params["max_link_brains"])
+ random_brains = [random.choice(brains_ids) for _ in range(nb_brains)]
+ random_idx = random.choice(range(len(all_kms)))
+ random_km = all_kms.pop(random_idx)
+ json_data = LinkKnowledgeBrain(
+ brain_ids=random_brains, knowledge=random_km
+ ).model_dump_json()
+ response = self.client.post(
+ "/knowledge/link_to_brains/",
+ data=json_data,
+ headers={
+ "Content-Type": "application/json",
+ **self.auth_headers,
+ },
+ )
+ response.raise_for_status()
+ kms = [KnowledgeDTO.model_validate(r) for r in response.json()]
+ all_kms.extend(kms)
+
+ @task(load_params["list_km_rate"])
+ def list_knowledge_files(self):
+ if random.random() < load_params["km_root_prob"] or len(all_kms) == 0:
+ self.client.get(
+ "/knowledge/files",
+ headers=self.auth_headers,
+ name="/knowledge/files",
+ )
+ else:
+ random_km = random.choice(all_kms)
+ self.client.get(
+ f"/knowledge/files?parent_id={str(random_km.id)}",
+ headers=self.auth_headers,
+ name="/knowledge/files",
+ )
+
+ @task(load_params["unlink_brain_rate"])
+ def unlink_knowledge_brain(self):
+ global all_kms
+ if len(all_kms) == 0:
+ return
+ random_idx = random.choice(range(len(all_kms)))
+ random_km = all_kms.pop(random_idx)
+ if len(random_km.brains) == 0:
+ return
+ random_brain = random.choice(random_km.brains)
+ assert random_km.id
+ json_data = UnlinkKnowledgeBrain(
+ brain_ids=[UUID(random_brain["brain_id"])],
+ knowledge_id=random_km.id,
+ ).model_dump_json()
+ self.client.delete(
+ "/knowledge/unlink_from_brains/",
+ data=json_data,
+ headers={
+ "Content-Type": "application/json",
+ **self.auth_headers,
+ },
+ )
+
+ @task(load_params["delete_km_rate"])
+ def delete_knowledge_files(self):
+ global all_kms
+ only_files = [idx for idx, km in enumerate(all_kms) if not km.is_folder]
+ if len(only_files) == 0:
+ return
+ random_index = random.choice(only_files)
+ random_km = all_kms.pop(random_index)
+ children_ids = [c.id for c in random_km.children]
+ all_kms[:] = [k for k in all_kms if k.id not in children_ids]
+ self.client.delete(
+ f"/knowledge/{str(random_km.id)}",
+ headers=self.auth_headers,
+ name="/knowledge/delete",
+ )
+
+ @task(load_params["update_km_rate"])
+ def update_knowledge(self):
+ global all_kms
+ if len(all_kms) == 0:
+ return
+ random_idx = random.choice(range(len(all_kms)))
+ random_km = all_kms.pop(random_idx)
+ assert random_km.id
+ json_data = KnowledgeUpdate(file_name=f"file-{uuid4()}").model_dump_json(
+ exclude_unset=True
+ )
+ response = self.client.patch(
+ f"/knowledge/{random_km.id}/",
+ data=json_data,
+ name="/knowledge/update",
+ headers={
+ "Content-Type": "application/json",
+ **self.auth_headers,
+ },
+ )
+ assert response and response.text
+ km = KnowledgeDTO.model_validate_json(response.text)
+ all_kms.append(km)
+
+ # CRUD operations
+ create_knowledge.__name__ = "create_knowledge_1MB"
+ update_knowledge.__name__ = "update_knowledge"
+ delete_knowledge_files.__name__ = "delete_knowledge_file"
+ list_knowledge_files.__name__ = "list_knowledge_files"
+ # Special linking/unlinking brains
+ link_to_brains.__name__ = "link_to_brain"
+ unlink_knowledge_brain.__name__ = "unlink_knowledge_brain"
+
+ def on_stop(self):
+ global brains_ids
+ global all_kms
+ all_kms = []
+ brains_ids = []
+ # Cleanup db
+ with Session(self.sync_engine) as session:
+ session.execute(text("DELETE FROM brains;"))
+ session.execute(text("DELETE FROM knowledge;"))
+ session.commit()
+ # Cleanup storage
+ client = get_supabase_client()
+ client.storage.empty_bucket("quivr")
diff --git a/backend/benchmarks/serialization_dto.py b/backend/benchmarks/serialization_dto.py
new file mode 100644
index 000000000000..77623e25c77e
--- /dev/null
+++ b/backend/benchmarks/serialization_dto.py
@@ -0,0 +1,158 @@
+"""
+Small experiment debugging json serializer for KMS.
+Compare three serialization libs: pydantic, msgspec, orjson
+"""
+
+import statistics
+import timeit
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+from uuid import UUID
+
+import msgspec
+import orjson
+from pydantic import BaseModel
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+from quivr_core.models import KnowledgeStatus
+from rich.console import Console
+from rich.table import Table
+
+n_dto = 1000
+num_runs = 100
+
+
+class ListKM(BaseModel):
+ kms: List[KnowledgeDTO]
+
+
+def serialize_orjson(kms: list[KnowledgeDTO]):
+ return orjson.dumps([k.model_dump() for k in kms])
+
+
+def serialize_orjson_single(kms: ListKM):
+ return orjson.dumps(kms.model_dump())
+
+
+def serialize_pydantic(kms: list[KnowledgeDTO]):
+ return [km.model_dump_json() for km in kms]
+
+
+def serialize_pydantic_obj(kms: ListKM):
+ return kms.model_dump_json()
+
+
+def evaluate(name, func):
+ times = timeit.repeat(
+ lambda: func(), globals=globals(), repeat=num_runs, number=1
+ ) # Change repeat=5 for desired runs
+ average_time = sum(times) / len(times)
+ std_dev = statistics.stdev(times)
+ return name, average_time * 1000, std_dev * 1000
+
+
+class KnowledgeMsg(msgspec.Struct):
+ updated_at: datetime
+ created_at: datetime
+ user_id: UUID
+ brains: List[Dict[str, Any]]
+ id: Optional[UUID] = None
+ status: Optional[KnowledgeStatus] = None
+ file_size: int = 0
+ file_name: Optional[str] = None
+ url: Optional[str] = None
+ extension: str = ".txt"
+ is_folder: bool = False
+ source: Optional[str] = None
+ source_link: Optional[str] = None
+ file_sha1: Optional[str] = None
+ metadata: Optional[Dict[str, str]] = None
+ parent: Optional["KnowledgeDTO"] = None
+ children: List["KnowledgeDTO"] = []
+ sync_id: Optional[int] = None
+ sync_file_id: Optional[str] = None
+
+
+def print_table(results):
+ console = Console()
+ table = Table(title=f"Serialization Performance, n_obj={n_dto}", show_lines=True)
+
+ # Define table columns
+ table.add_column("Function Name", justify="left", style="cyan")
+ table.add_column("Average Time (ms)", justify="right", style="magenta")
+ table.add_column("Standard Deviation (ms)", justify="right", style="green")
+
+ # Add rows with evaluation results
+ for name, avg_time, std_dev in results:
+ table.add_row(name, f"{avg_time:.6f}", f"{std_dev:.6f}")
+
+ # Print the table to the console
+ console.print(table)
+
+
+def main():
+ data = {
+ "id": "24185498-9025-44ea-ae70-b5a1a342f97c",
+ "file_size": 57210,
+ "status": "UPLOADED",
+ "file_name": "0000993.pdf",
+ "url": None,
+ "extension": ".pdf",
+ "is_folder": False,
+ "updated_at": "2024-09-26T19:01:23.881842Z",
+ "created_at": "2024-09-26T19:00:57.110967Z",
+ "source": "local",
+ "source_link": None,
+ "file_sha1": "1488859a8d85a309b2bff4c669177e688997bfe9",
+ "metadata": None,
+ "user_id": "155b9ab3-e649-4f8a-b5cf-8150728a9202",
+ "brains": [
+ {
+ "name": "all_kms",
+ "description": "kms",
+ "temperature": 0,
+ "brain_type": "doc",
+ "brain_id": "a035b4e5-a385-468a-8f41-2d8344cc6a8f",
+ "status": "private",
+ "model": None,
+ "max_tokens": 2000,
+ "last_update": "2024-09-26T19:31:16.352708",
+ "prompt_id": None,
+ }
+ ],
+ "sync_id": None,
+ "sync_file_id": None,
+ "parent": None,
+ "children": [],
+ }
+
+ km = KnowledgeDTO.model_validate(data)
+ # print(isinstance([km]*N,BaseModel))
+ list_dto = [km] * n_dto
+ single_obj = ListKM(kms=list_dto)
+ km_msgspec = msgspec.json.decode(msgspec.json.encode(data), type=KnowledgeMsg)
+ list_msgspec = [km_msgspec] * n_dto
+
+ # Evaluation
+ results = []
+ results.append(evaluate("serialize_pydantic", lambda: serialize_pydantic(list_dto)))
+ results.append(
+ evaluate(
+ "serialize_pydantic_single_obj", lambda: serialize_pydantic_obj(single_obj)
+ )
+ )
+ results.append(evaluate("serialize_orjson", lambda: serialize_orjson(list_dto)))
+ results.append(
+ evaluate("serialize_orjson_single", lambda: serialize_orjson_single(single_obj))
+ )
+ results.append(
+ evaluate(
+ "serialize_msgspec",
+ lambda: [msgspec.json.encode(msg) for msg in list_msgspec],
+ )
+ )
+
+ print_table(results)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backend/core/quivr_core/files/file.py b/backend/core/quivr_core/files/file.py
index 9f4089b103fb..e7923d34df5b 100644
--- a/backend/core/quivr_core/files/file.py
+++ b/backend/core/quivr_core/files/file.py
@@ -112,9 +112,9 @@ def __init__(
id: UUID,
original_filename: str,
path: Path,
- brain_id: UUID,
file_sha1: str,
file_extension: FileExtension | str,
+ brain_id: UUID | None = None,
file_size: int | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
@@ -153,7 +153,6 @@ def metadata(self) -> dict[str, Any]:
def serialize(self) -> QuivrFileSerialized:
return QuivrFileSerialized(
id=self.id,
- brain_id=self.brain_id,
path=self.path.absolute(),
original_filename=self.original_filename,
file_size=self.file_size,
diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py
index 0dc304c67b77..101d2d2355e3 100644
--- a/backend/core/quivr_core/models.py
+++ b/backend/core/quivr_core/models.py
@@ -39,10 +39,11 @@ class ChatMessage(BaseModelV1):
class KnowledgeStatus(str, Enum):
- PROCESSING = "PROCESSING"
- UPLOADED = "UPLOADED"
ERROR = "ERROR"
RESERVED = "RESERVED"
+ PROCESSING = "PROCESSING"
+ PROCESSED = "PROCESSED"
+ UPLOADED = "UPLOADED"
class Source(BaseModel):
diff --git a/backend/core/quivr_core/processor/processor_base.py b/backend/core/quivr_core/processor/processor_base.py
index 1b8cbbe39423..a108dca85eed 100644
--- a/backend/core/quivr_core/processor/processor_base.py
+++ b/backend/core/quivr_core/processor/processor_base.py
@@ -2,7 +2,6 @@
from abc import ABC, abstractmethod
from importlib.metadata import PackageNotFoundError, version
from typing import Any
-from uuid import uuid4
from langchain_core.documents import Document
@@ -13,7 +12,6 @@
# TODO: processors should be cached somewhere ?
# The processor should be cached by processor type
-# The cache should use a single
class ProcessorBase(ABC):
supported_extensions: list[FileExtension | str]
@@ -43,7 +41,6 @@ async def process_file(self, file: QuivrFile) -> list[Document]:
"utf-8"
)
doc.metadata = {
- "id": uuid4(),
"chunk_index": idx,
"quivr_core_version": qvr_version,
**file.metadata,
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 80298fb942b8..66d3c13edccc 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -41,10 +41,21 @@ dev-dependencies = [
"pytest-cov>=5.0.0",
"tox>=4.0.0",
"chainlit>=1.1.306",
+ "pytest-profiling>=1.7.0",
+ "locust>=2.31.7",
]
[tool.rye.workspace]
-members = [".", "core", "worker", "api", "docs", "core/examples/chatbot", "core/MegaParse", "worker/diff-assistant"]
+members = [
+ ".",
+ "core",
+ "worker",
+ "api",
+ "docs",
+ "core/examples/chatbot",
+ "core/MegaParse",
+ "worker/diff-assistant",
+]
[tool.hatch.metadata]
allow-direct-references = true
diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock
index 498ff5d20dc1..98e7bcad88a4 100644
--- a/backend/requirements-dev.lock
+++ b/backend/requirements-dev.lock
@@ -87,11 +87,15 @@ black==24.8.0
# via flake8-black
bleach==6.1.0
# via nbconvert
+blinker==1.8.2
+ # via flask
boto3==1.35.2
# via cohere
botocore==1.35.2
# via boto3
# via s3transfer
+brotli==1.1.0
+ # via geventhttpclient
cachetools==5.5.0
# via google-auth
# via tox
@@ -100,6 +104,7 @@ celery==5.4.0
# via quivr-api
# via quivr-worker
certifi==2022.12.7
+ # via geventhttpclient
# via httpcore
# via httpx
# via requests
@@ -107,6 +112,7 @@ certifi==2022.12.7
# via unstructured-client
cffi==1.17.0 ; platform_python_implementation != 'PyPy' or implementation_name == 'pypy'
# via cryptography
+ # via gevent
# via pyzmq
cfgv==3.4.0
# via pre-commit
@@ -127,6 +133,7 @@ click==8.1.7
# via click-didyoumean
# via click-plugins
# via click-repl
+ # via flask
# via litellm
# via mkdocs
# via mkdocstrings
@@ -159,6 +166,8 @@ colorlog==6.8.2
# via quivr-api
comm==0.2.2
# via ipykernel
+configargparse==1.7
+ # via locust
contourpy==1.2.1
# via matplotlib
coverage==7.6.1
@@ -247,6 +256,14 @@ fire==0.6.0
flake8==7.1.1
# via flake8-black
flake8-black==0.3.6
+flask==3.0.3
+ # via flask-cors
+ # via flask-login
+ # via locust
+flask-cors==5.0.0
+ # via locust
+flask-login==0.6.3
+ # via locust
flatbuffers==24.3.25
# via onnxruntime
flower==2.0.1
@@ -265,6 +282,11 @@ fsspec==2024.2.0
# via llama-index-core
# via llama-index-legacy
# via torch
+gevent==24.2.1
+ # via geventhttpclient
+ # via locust
+geventhttpclient==2.3.1
+ # via locust
ghp-import==2.1.0
# via mkdocs
google-api-core==2.19.1
@@ -292,7 +314,10 @@ googleapis-common-protos==1.63.2
# via opentelemetry-exporter-otlp-proto-http
gotrue==2.7.0
# via supabase
+gprof2dot==2024.6.6
+ # via pytest-profiling
greenlet==3.0.3
+ # via gevent
# via playwright
# via sqlalchemy
griffe==1.2.0
@@ -369,9 +394,12 @@ ipykernel==6.29.5
# via mkdocs-jupyter
ipython==8.26.0
# via ipykernel
+itsdangerous==2.2.0
+ # via flask
jedi==0.19.1
# via ipython
jinja2==3.1.3
+ # via flask
# via litellm
# via mkdocs
# via mkdocs-material
@@ -525,6 +553,7 @@ llama-parse==0.5.6
# via quivr-api
llvmlite==0.43.0
# via numba
+locust==2.31.8
lxml==5.3.0
# via pikepdf
# via python-docx
@@ -551,6 +580,7 @@ markupsafe==2.1.5
# via mkdocs-autorefs
# via mkdocstrings
# via nbconvert
+ # via werkzeug
marshmallow==3.22.0
# via dataclasses-json
# via marshmallow-enum
@@ -607,6 +637,8 @@ mpmath==1.3.0
# via sympy
msal==1.30.0
# via quivr-api
+msgpack==1.1.0
+ # via locust
multidict==6.0.5
# via aiohttp
# via yarl
@@ -849,6 +881,7 @@ protobuf==4.25.4
# via transformers
psutil==6.0.0
# via ipykernel
+ # via locust
# via unstructured
psycopg2-binary==2.9.9
# via quivr-api
@@ -943,12 +976,14 @@ pytest==8.3.2
# via pytest-cov
# via pytest-dotenv
# via pytest-mock
+ # via pytest-profiling
# via pytest-xdist
pytest-asyncio==0.24.0
pytest-benchmark==4.0.0
pytest-cov==5.0.0
pytest-dotenv==0.5.2
pytest-mock==3.14.0
+pytest-profiling==1.7.0
pytest-xdist==3.6.1
python-dateutil==2.9.0.post0
# via botocore
@@ -1001,8 +1036,9 @@ python-socketio==5.11.3
pytz==2024.1
# via flower
# via pandas
-pywin32==306 ; (platform_python_implementation != 'PyPy' and sys_platform == 'win32') or platform_system == 'Windows'
+pywin32==306 ; sys_platform == 'win32' or platform_system == 'Windows'
# via jupyter-core
+ # via locust
# via portalocker
pyyaml==6.0.2
# via huggingface-hub
@@ -1025,6 +1061,7 @@ pyyaml-env-tag==0.1
pyzmq==26.1.1
# via ipykernel
# via jupyter-client
+ # via locust
rapidfuzz==3.9.6
# via python-doctr
# via unstructured
@@ -1053,6 +1090,7 @@ requests==2.32.3
# via litellm
# via llama-index-core
# via llama-index-legacy
+ # via locust
# via mkdocs-material
# via msal
# via opentelemetry-exporter-otlp-proto-http
@@ -1093,6 +1131,8 @@ sentry-sdk==2.13.0
# via quivr-api
setuptools==70.0.0
# via opentelemetry-instrumentation
+ # via zope-event
+ # via zope-interface
shapely==2.0.6
# via python-doctr
simple-websocket==1.0.0
@@ -1106,6 +1146,7 @@ six==1.16.0
# via langdetect
# via markdownify
# via posthog
+ # via pytest-profiling
# via python-dateutil
# via stone
# via unstructured-client
@@ -1287,6 +1328,7 @@ uritemplate==4.1.1
# via google-api-python-client
urllib3==1.26.13
# via botocore
+ # via geventhttpclient
# via requests
# via sentry-sdk
# via unstructured-client
@@ -1313,6 +1355,10 @@ webencodings==0.5.1
# via tinycss2
websockets==12.0
# via realtime
+werkzeug==3.0.4
+ # via flask
+ # via flask-login
+ # via locust
wrapt==1.16.0
# via deprecated
# via llama-index-core
@@ -1328,3 +1374,7 @@ yarl==1.9.4
# via aiohttp
zipp==3.20.0
# via importlib-metadata
+zope-event==5.0
+ # via gevent
+zope-interface==7.0.3
+ # via gevent
diff --git a/backend/supabase/migrations/20240905153004_knowledge-folders.sql b/backend/supabase/migrations/20240920153014_knowledge-folders.sql
similarity index 100%
rename from backend/supabase/migrations/20240905153004_knowledge-folders.sql
rename to backend/supabase/migrations/20240920153014_knowledge-folders.sql
diff --git a/backend/supabase/migrations/20240920180003_knowledge-sync.sql b/backend/supabase/migrations/20240920180003_knowledge-sync.sql
new file mode 100644
index 000000000000..ff0995141a96
--- /dev/null
+++ b/backend/supabase/migrations/20240920180003_knowledge-sync.sql
@@ -0,0 +1,23 @@
+-- Renamed syncs
+ALTER TABLE syncs_user
+ RENAME TO syncs;
+-- Add column foreign key sync in knowledge
+ALTER TABLE "public"."knowledge"
+ADD COLUMN "sync_id" INTEGER;
+ALTER TABLE "public"."knowledge"
+ADD CONSTRAINT "public_knowledge_sync_id_fkey" FOREIGN KEY (sync_id) REFERENCES syncs(id) ON DELETE CASCADE;
+-- Add column for sync_file_ids
+ALTER TABLE "public"."knowledge"
+ADD COLUMN "last_synced_at" timestamp with time zone;
+ALTER TABLE "public"."knowledge"
+ADD COLUMN "sync_file_id" TEXT;
+CREATE INDEX knowledge_sync_id_pkey ON public.knowledge USING btree (sync_id);
+CREATE INDEX knowledge_sync_file_id_pkey ON public.knowledge USING btree (sync_file_id);
+-- Add columns syncs
+alter table "public"."syncs"
+add column "created_at" timestamp with time zone default now();
+alter table "public"."syncs"
+add column "updated_at" timestamp with time zone default now();
+-- Drop files
+DROP TABLE IF EXISTS "public"."syncs_active" CASCADE;
+DROP TABLE IF EXISTS "public"."syncs_files" CASCADE;
diff --git a/backend/supabase/seed.sql b/backend/supabase/seed.sql
index b15761bf6421..0ccc144eb1e3 100644
--- a/backend/supabase/seed.sql
+++ b/backend/supabase/seed.sql
@@ -298,7 +298,7 @@ INSERT INTO "public"."user_daily_usage" ("user_id", "email", "date", "daily_requ
--
INSERT INTO "public"."user_identity" ("user_id", "openai_api_key", "company", "onboarded", "username", "company_size", "usage_purpose") VALUES
- ('39418e3b-0258-4452-af60-7acfcc1263ff', NULL, 'Stan', true, 'Stan', NULL, '');
+ ('39418e3b-0258-4452-af60-7acfcc1263ff', NULL, 'Quivr Local', true, 'Quivr Local', NULL, '');
--
@@ -330,20 +330,6 @@ SELECT pg_catalog.setval('"public"."integrations_user_id_seq"', 6, true);
SELECT pg_catalog.setval('"public"."product_to_features_id_seq"', 1, false);
---
--- Name: syncs_active_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres
---
-
-SELECT pg_catalog.setval('"public"."syncs_active_id_seq"', 1, false);
-
-
---
--- Name: syncs_files_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres
---
-
-SELECT pg_catalog.setval('"public"."syncs_files_id_seq"', 1, false);
-
-
--
-- Name: syncs_user_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres
--
diff --git a/backend/worker/quivr_worker/assistants/assistants.py b/backend/worker/quivr_worker/assistants/assistants.py
index 1571072bb0b7..17fcea1cad7e 100644
--- a/backend/worker/quivr_worker/assistants/assistants.py
+++ b/backend/worker/quivr_worker/assistants/assistants.py
@@ -1,13 +1,43 @@
import os
+from quivr_api.modules.assistant.repository.tasks import TasksRepository
from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.upload.service.upload_file import (
upload_file_storage,
)
+from sqlalchemy.ext.asyncio import AsyncEngine
from quivr_worker.assistants.cdp_use_case_2 import process_cdp_use_case_2
from quivr_worker.assistants.cdp_use_case_3 import process_cdp_use_case_3
from quivr_worker.utils.pdf_generator.pdf_generator import PDFGenerator, PDFModel
+from quivr_worker.utils.services import _start_session
+
+
+async def aprocess_assistant_task(
+ engine: AsyncEngine,
+ assistant_id: str,
+ notification_uuid: str,
+ task_id: int,
+ user_id: str,
+):
+ async with _start_session(engine) as async_session:
+ try:
+ tasks_repository = TasksRepository(async_session)
+ tasks_service = TasksService(tasks_repository)
+
+ await process_assistant(
+ assistant_id,
+ notification_uuid,
+ task_id,
+ tasks_service,
+ user_id,
+ )
+
+ except Exception as e:
+ await async_session.rollback()
+ raise e
+ finally:
+ await async_session.close()
async def process_assistant(
@@ -30,9 +60,8 @@ async def process_assistant(
assistant_id, notification_uuid, task_id, tasks_service, user_id
)
else:
- new_task = await tasks_service.update_task(task_id, {"status": "processing"})
+ await tasks_service.update_task(task_id, {"status": "processing"})
# Add a random delay of 10 to 20 seconds
-
task_result = {"status": "completed", "answer": output}
output_dir = f"{assistant_id}/{notification_uuid}"
diff --git a/backend/worker/quivr_worker/celery_monitor.py b/backend/worker/quivr_worker/celery_monitor.py
index 546e4b9d8206..f1483a3a302e 100644
--- a/backend/worker/quivr_worker/celery_monitor.py
+++ b/backend/worker/quivr_worker/celery_monitor.py
@@ -1,4 +1,5 @@
import asyncio
+import os
import threading
from enum import Enum
from queue import Queue
@@ -8,10 +9,8 @@
from celery.result import AsyncResult
from quivr_api.celery_config import celery
from quivr_api.logger import get_logger, setup_logger
-from quivr_api.modules.dependencies import async_engine
+from quivr_api.models.settings import settings
from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
-from quivr_api.modules.assistant.repository.tasks import TasksRepository
-from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
from quivr_api.modules.notification.dto.inputs import NotificationUpdatableProperties
from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum
@@ -19,8 +18,22 @@
NotificationService,
)
from quivr_core.models import KnowledgeStatus
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlmodel import text
from sqlmodel.ext.asyncio.session import AsyncSession
+async_engine = create_async_engine(
+ settings.pg_database_async_url,
+ connect_args={"server_settings": {"application_name": "quivr-monitor"}},
+ echo=True if os.getenv("ORM_DEBUG") else False,
+ future=True,
+ pool_pre_ping=True,
+ max_overflow=0,
+ pool_size=5, # NOTE: no bouncer for now, if 6 process workers => 6
+ pool_recycle=1800,
+ isolation_level="AUTOCOMMIT",
+)
+
setup_logger("notifier.log", send_log_server=False)
logger = get_logger("notifier_service")
notification_service = NotificationService()
@@ -41,80 +54,97 @@ class TaskIdentifier(str, Enum):
@dataclass
class TaskEvent:
task_id: str
- brain_id: UUID | None
task_name: TaskIdentifier
notification_id: str
knowledge_id: UUID | None
status: TaskStatus
-async def handler_loop():
- session = AsyncSession(async_engine, expire_on_commit=False, autoflush=False)
- knowledge_service = KnowledgeService(KnowledgeRepository(session))
- task_service = TasksService(TasksRepository(session))
+async def handle_error_task(
+ task: TaskEvent,
+ knowledge_service: KnowledgeService,
+ notification_service: NotificationService,
+):
+ logger.error(
+ f"task {task.task_id} process_file_task. Sending notifition {task.notification_id}"
+ )
+ notification_service.update_notification_by_id(
+ task.notification_id,
+ NotificationUpdatableProperties(
+ status=NotificationsStatusEnum.ERROR,
+ description=("An error occurred while processing the file"),
+ ),
+ )
+ logger.error(
+ f"task {task.task_id} process_file_task failed. Updating knowledge {task.knowledge_id} to Error"
+ )
+ if task.knowledge_id:
+ await knowledge_service.update_status_knowledge(
+ task.knowledge_id, KnowledgeStatus.ERROR
+ )
+ logger.error(
+ f"task {task.task_id} process_file_task . Updating knowledge {task.knowledge_id} status to Error"
+ )
+
+
+async def handle_success_task(
+ task: TaskEvent,
+ knowledge_service: KnowledgeService,
+ notification_service: NotificationService,
+):
+ logger.info(
+ f"task {task.task_id} process_file_task succeeded. Sending notification {task.notification_id}"
+ )
+ notification_service.update_notification_by_id(
+ task.notification_id,
+ NotificationUpdatableProperties(
+ status=NotificationsStatusEnum.SUCCESS,
+ description=(
+ "Your file has been properly uploaded!"
+ if task.task_name == TaskIdentifier.PROCESS_FILE_TASK
+ else "Your URL has been properly crawled!"
+ ),
+ ),
+ )
+ if task.knowledge_id:
+ await knowledge_service.update_status_knowledge(
+ knowledge_id=task.knowledge_id,
+ status=KnowledgeStatus.UPLOADED,
+ )
+ logger.info(
+ f"task {task.task_id} process_file_task failed. Updating knowledge {task.knowledge_id} to UPLOADED"
+ )
- logger.info("Initialized knowledge_service. Listening to task event...")
- while True:
- try:
- event: TaskEvent = queue.get()
- if event.status == TaskStatus.FAILED:
- if event.task_name == TaskIdentifier.PROCESS_ASSISTANT_TASK:
- # Update the task status to error
- logger.info(f"task {event.task_id} process_assistant_task failed. Updating task {event.notification_id} to error")
- await task_service.update_task(int(event.notification_id), {"status": "error"})
- else:
- logger.error(
- f"task {event.task_id} process_file_task. Sending notifition {event.notification_id}"
- )
- notification_service.update_notification_by_id(
- event.notification_id,
- NotificationUpdatableProperties(
- status=NotificationsStatusEnum.ERROR,
- description=(
- "An error occurred while processing the file"
- if event.task_name == TaskIdentifier.PROCESS_FILE_TASK
- else "An error occurred while processing the URL"
- ),
- ),
- )
- logger.error(
- f"task {event.task_id} process_file_task failed. Updating knowledge {event.knowledge_id} to Error"
- )
- if event.knowledge_id:
- await knowledge_service.update_status_knowledge(
- event.knowledge_id, KnowledgeStatus.ERROR
- )
- logger.error(
- f"task {event.task_id} process_file_task . Updating knowledge {event.knowledge_id} status to Error"
+
+async def handler_loop():
+ async with AsyncSession(
+ async_engine, expire_on_commit=False, autoflush=False
+ ) as session:
+ await session.execute(
+ text("SET SESSION idle_in_transaction_session_timeout = '1min';")
+ )
+ knowledge_service = KnowledgeService(KnowledgeRepository(session))
+ notification_service = NotificationService()
+ logger.info("Initialized knowledge_service. Listening to task event...")
+ while True:
+ try:
+ event: TaskEvent = queue.get()
+ if event.status == TaskStatus.FAILED:
+ await handle_success_task(
+ task=event,
+ knowledge_service=knowledge_service,
+ notification_service=notification_service,
)
- if event.status == TaskStatus.SUCCESS:
- logger.info(
- f"task {event.task_id} process_file_task succeeded. Sending notification {event.notification_id}"
- )
- notification_service.update_notification_by_id(
- event.notification_id,
- NotificationUpdatableProperties(
- status=NotificationsStatusEnum.SUCCESS,
- description=(
- "Your file has been properly uploaded!"
- if event.task_name == TaskIdentifier.PROCESS_FILE_TASK
- else "Your URL has been properly crawled!"
- ),
- ),
- )
- if event.knowledge_id:
- await knowledge_service.update_status_knowledge(
- knowledge_id=event.knowledge_id,
- status=KnowledgeStatus.UPLOADED,
- brain_id=event.brain_id,
+ if event.status == TaskStatus.SUCCESS:
+ await handle_error_task(
+ task=event,
+ knowledge_service=knowledge_service,
+ notification_service=notification_service,
)
- logger.info(
- f"task {event.task_id} process_file_task failed. Updating knowledge {event.knowledge_id} to UPLOADED"
- )
- except Exception as e:
- logger.error(f"Excpetion occured handling event {event}: {e}")
+ except Exception as e:
+ logger.error(f"Excpetion occured handling event {event}: {e}")
def notifier(app):
@@ -127,29 +157,25 @@ def handle_task_event(event):
task_result = AsyncResult(task.id, app=app)
task_name, task_kwargs = task_result.name, task_result.kwargs
- if task_name == "process_file_task" or task_name == "process_crawl_task":
+ if task_name == TaskIdentifier.PROCESS_FILE_TASK:
logger.debug(f"Received Event : {task} - {task_name} {task_kwargs} ")
- notification_id = task_kwargs["notification_id"]
knowledge_id = task_kwargs.get("knowledge_id", None)
- brain_id = task_kwargs.get("brain_id", None)
+ notification_id = task_kwargs.get("notification_id", None)
event = TaskEvent(
task_id=task,
task_name=TaskIdentifier(task_name),
knowledge_id=knowledge_id,
- brain_id=brain_id,
notification_id=notification_id,
status=TaskStatus(event["type"]),
)
queue.put(event)
elif task_name == "process_assistant_task":
logger.debug(f"Received Event : {task} - {task_name} {task_kwargs} ")
- notification_uuid = task_kwargs["notification_uuid"]
task_id = task_kwargs["task_id"]
event = TaskEvent(
task_id=task,
task_name=TaskIdentifier(task_name),
knowledge_id=None,
- brain_id=None,
notification_id=task_id,
status=TaskStatus(event["type"]),
)
@@ -169,23 +195,6 @@ def handle_task_event(event):
recv.capture(limit=None, timeout=None, wakeup=True)
-def is_being_executed(task_name: str) -> bool:
- """Returns whether the task with given task_name is already being executed.
-
- Args:
- task_name: Name of the task to check if it is running currently.
- Returns: A boolean indicating whether the task with the given task name is
- running currently.
- """
- active_tasks = celery.control.inspect().active()
- for worker, running_tasks in active_tasks.items():
- for task in running_tasks:
- if task["name"] == task_name: # type: ignore
- return True
-
- return False
-
-
if __name__ == "__main__":
logger.info("Started quivr-notifier service...")
diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py
index dce4a301d150..d48738bf4fed 100644
--- a/backend/worker/quivr_worker/celery_worker.py
+++ b/backend/worker/quivr_worker/celery_worker.py
@@ -5,6 +5,7 @@
import structlog
import torch
from celery import signals
+from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded
from celery.schedules import crontab
from celery.signals import worker_process_init
from celery.utils.log import get_task_logger
@@ -12,42 +13,15 @@
from quivr_api.celery_config import celery
from quivr_api.logger import setup_logger
from quivr_api.models.settings import settings
-from quivr_api.modules.assistant.repository.tasks import TasksRepository
-from quivr_api.modules.assistant.services.tasks_service import TasksService
from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector
-from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
-from quivr_api.modules.brain.service.brain_service import BrainService
from quivr_api.modules.dependencies import get_supabase_client
-from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
-from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
-from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
-from quivr_api.modules.notification.service.notification_service import (
- NotificationService,
-)
-from quivr_api.modules.sync.dto.inputs import SyncsUserStatus
-from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository
-from quivr_api.modules.sync.service.sync_notion import SyncNotionService
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
-from quivr_api.modules.vector.repository.vectors_repository import VectorRepository
-from quivr_api.modules.vector.service.vector_service import VectorService
from quivr_api.utils.telemetry import maybe_send_telemetry
-from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
-from sqlmodel import Session, text
-from sqlmodel.ext.asyncio.session import AsyncSession
-from quivr_worker.assistants.assistants import process_assistant
-from quivr_worker.celery_monitor import is_being_executed
+from quivr_worker.assistants.assistants import aprocess_assistant_task
from quivr_worker.check_premium import check_is_premium
-from quivr_worker.process.process_s3_file import process_uploaded_file
-from quivr_worker.process.process_url import process_url_func
-from quivr_worker.syncs.process_active_syncs import (
- SyncServices,
- process_all_active_syncs,
- process_notion_sync,
- process_sync,
-)
-from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async
+from quivr_worker.process import aprocess_file_task
+from quivr_worker.syncs.update_syncs import refresh_sync_files, refresh_sync_folders
from quivr_worker.utils.utils import _patch_json
torch.set_num_threads(1)
@@ -59,21 +33,8 @@
_patch_json()
-
-# FIXME: load at init time
-# Services
supabase_client = get_supabase_client()
-# document_vector_store = get_documents_vector_store()
-notification_service = NotificationService()
-sync_active_service = SyncService()
-sync_user_service = SyncUserService()
-sync_files_repo_service = SyncFilesRepository()
-brain_service = BrainService()
-brain_vectors = BrainsVectors()
-storage = SupabaseS3Storage()
-notion_service: SyncNotionService | None = None
async_engine: AsyncEngine | None = None
-engine: Engine | None = None
@signals.task_prerun.connect
@@ -86,26 +47,18 @@ def on_task_prerun(sender, task_id, task, args, kwargs, **_):
@worker_process_init.connect
def init_worker(**kwargs):
global async_engine
- global engine
if not async_engine:
async_engine = create_async_engine(
settings.pg_database_async_url,
+ connect_args={
+ "server_settings": {"application_name": f"quivr-worker-{os.getpid()}"}
+ },
echo=True if os.getenv("ORM_DEBUG") else False,
future=True,
- # NOTE: pessimistic bound on
+ # NOTE: pessimistic bound on reconnect
pool_pre_ping=True,
- pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
- pool_recycle=1800,
- )
-
- if not engine:
- engine = create_engine(
- settings.pg_database_url,
- echo=True if os.getenv("ORM_DEBUG") else False,
- future=True,
- # NOTE: pessimistic bound on
- pool_pre_ping=True,
- pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6
+ # NOTE: no bouncer for now
+ pool_size=1,
pool_recycle=1800,
)
@@ -113,187 +66,87 @@ def init_worker(**kwargs):
@celery.task(
retries=3,
default_retry_delay=1,
- name="process_assistant_task",
- autoretry_for=(Exception,),
+ name="process_file_task",
+ time_limit=600, # 10 min
+ soft_time_limit=300,
+ autoretry_for=(Exception,), # SoftTimeLimitExceeded should not included?
+ dont_autoretry_for=(SoftTimeLimitExceeded, TimeLimitExceeded),
)
-def process_assistant_task(
- assistant_id: str,
- notification_uuid: str,
- task_id: int,
- user_id: str,
+def process_file_task(
+ knowledge_id: UUID,
+ notification_id: UUID | None = None,
):
+ if async_engine is None:
+ init_worker()
+ assert async_engine
logger.info(
- f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}"
+ f"Task process_file started for knowledge_id={knowledge_id}, notification_id={notification_id}"
)
-
loop = asyncio.get_event_loop()
loop.run_until_complete(
- aprocess_assistant_task(
- assistant_id,
- notification_uuid,
- task_id,
- user_id,
- )
+ aprocess_file_task(async_engine=async_engine, knowledge_id=knowledge_id)
)
-async def aprocess_assistant_task(
- assistant_id: str,
- notification_uuid: str,
- task_id: int,
- user_id: str,
-):
- global async_engine
+@celery.task(
+ retries=3,
+ default_retry_delay=1,
+ name="refresh_sync_files_task",
+ soft_time_limit=3600,
+ autoretry_for=(Exception,),
+)
+def refresh_sync_files_task():
+ if async_engine is None:
+ init_worker()
assert async_engine
- async with AsyncSession(async_engine) as async_session:
- try:
- await async_session.execute(
- text("SET SESSION idle_in_transaction_session_timeout = '5min';")
- )
- tasks_repository = TasksRepository(async_session)
- tasks_service = TasksService(tasks_repository)
-
- await process_assistant(
- assistant_id,
- notification_uuid,
- task_id,
- tasks_service,
- user_id,
- )
-
- except Exception as e:
- await async_session.rollback()
- raise e
- finally:
- await async_session.close()
+ logger.info("Update sync task started")
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(refresh_sync_files(async_engine=async_engine))
@celery.task(
retries=3,
default_retry_delay=1,
- name="process_file_task",
+ name="refresh_sync_folders_task",
autoretry_for=(Exception,),
- dont_autoretry_for=(FileExistsError,),
)
-def process_file_task(
- file_name: str,
- file_original_name: str,
- brain_id: UUID,
- notification_id: UUID,
- knowledge_id: UUID,
- source: str | None = None,
- source_link: str | None = None,
- delete_file: bool = False,
-):
+def refresh_sync_folders_task():
if async_engine is None:
init_worker()
-
+ assert async_engine
+ logger.info("Update sync task started")
loop = asyncio.get_event_loop()
- loop.run_until_complete(
- aprocess_file_task(
- file_name=file_name,
- file_original_name=file_original_name,
- brain_id=brain_id,
- notification_id=notification_id,
- knowledge_id=knowledge_id,
- source=source,
- source_link=source_link,
- delete_file=delete_file,
- )
- )
-
-
-async def aprocess_file_task(
- file_name: str,
- file_original_name: str,
- brain_id: UUID,
- notification_id: UUID,
- knowledge_id: UUID,
- source: str | None = None,
- source_link: str | None = None,
- delete_file: bool = False,
-):
- global engine
- assert engine
- async with AsyncSession(async_engine) as async_session:
- try:
- await async_session.execute(
- text("SET SESSION idle_in_transaction_session_timeout = '5min';")
- )
- with Session(engine, expire_on_commit=False, autoflush=False) as session:
- session.execute(
- text("SET SESSION idle_in_transaction_session_timeout = '5min';")
- )
- vector_repository = VectorRepository(session)
- vector_service = VectorService(
- vector_repository
- ) # FIXME @amine: fix to need AsyncSession in vector Service
- knowledge_repository = KnowledgeRepository(async_session)
- knowledge_service = KnowledgeService(knowledge_repository)
- await process_uploaded_file(
- supabase_client=supabase_client,
- brain_service=brain_service,
- vector_service=vector_service,
- knowledge_service=knowledge_service,
- file_name=file_name,
- brain_id=brain_id,
- file_original_name=file_original_name,
- knowledge_id=knowledge_id,
- integration=source,
- integration_link=source_link,
- delete_file=delete_file,
- )
- session.commit()
- await async_session.commit()
- except Exception as e:
- session.rollback()
- await async_session.rollback()
- raise e
- finally:
- session.close()
- await async_session.close()
+ loop.run_until_complete(refresh_sync_folders(async_engine=async_engine))
@celery.task(
retries=3,
default_retry_delay=1,
- name="process_crawl_task",
+ name="process_assistant_task",
autoretry_for=(Exception,),
)
-def process_crawl_task(
- crawl_website_url: str,
- brain_id: UUID,
- knowledge_id: UUID,
- notification_id: UUID | None = None,
+def process_assistant_task(
+ assistant_id: str,
+ notification_uuid: str,
+ task_id: int,
+ user_id: str,
):
+ if async_engine is None:
+ init_worker()
+ assert async_engine
logger.info(
- f"Task process_crawl_task started for url={crawl_website_url}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}"
+ f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}"
+ )
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(
+ aprocess_assistant_task(
+ async_engine,
+ assistant_id,
+ notification_uuid,
+ task_id,
+ user_id,
+ )
)
- global engine
- assert engine
- try:
- with Session(engine, expire_on_commit=False, autoflush=False) as session:
- session.execute(
- text("SET SESSION idle_in_transaction_session_timeout = '5min';")
- )
- vector_repository = VectorRepository(session)
- vector_service = VectorService(vector_repository)
- loop = asyncio.get_event_loop()
- loop.run_until_complete(
- process_url_func(
- url=crawl_website_url,
- brain_id=brain_id,
- knowledge_id=knowledge_id,
- brain_service=brain_service,
- vector_service=vector_service,
- )
- )
- session.commit()
- except Exception as e:
- session.rollback()
- raise e
- finally:
- session.close()
@celery.task(name="NotionConnectorLoad")
@@ -319,117 +172,21 @@ def check_is_premium_task():
check_is_premium(supabase_client)
-@celery.task(name="process_sync_task")
-def process_sync_task(
- sync_id: int, user_id: str, files_ids: list[str], folder_ids: list[str]
-):
- global async_engine
- assert async_engine
- sync = next(
- filter(lambda s: s.id == sync_id, sync_active_service.get_syncs_active(user_id))
- )
- loop = asyncio.get_event_loop()
- loop.run_until_complete(
- process_sync(
- sync=sync,
- files_ids=files_ids,
- folder_ids=folder_ids,
- services=SyncServices(
- async_engine=async_engine,
- sync_active_service=sync_active_service,
- sync_user_service=sync_user_service,
- sync_files_repo_service=sync_files_repo_service,
- storage=storage,
- brain_vectors=brain_vectors,
- notification_service=notification_service,
- ),
- )
- )
-
-
-@celery.task(name="process_active_syncs_task")
-def process_active_syncs_task():
- sync_already_running = is_being_executed("process_sync_task")
-
- if sync_already_running:
- logger.info("Sync already running, skipping")
- return
- global async_engine
- assert async_engine
- loop = asyncio.get_event_loop()
- loop.run_until_complete(
- process_all_active_syncs(
- SyncServices(
- async_engine=async_engine,
- sync_active_service=sync_active_service,
- sync_user_service=sync_user_service,
- sync_files_repo_service=sync_files_repo_service,
- storage=storage,
- brain_vectors=brain_vectors,
- notification_service=notification_service,
- ),
- )
- )
-
-
-@celery.task(name="process_notion_sync_task")
-def process_notion_sync_task():
- global async_engine
- assert async_engine
- loop = asyncio.get_event_loop()
- loop.run_until_complete(process_notion_sync(async_engine))
-
-
-@celery.task(name="fetch_and_store_notion_files_task")
-def fetch_and_store_notion_files_task(
- access_token: str, user_id: UUID, sync_user_id: int
-):
- if async_engine is None:
- init_worker()
- assert async_engine
- try:
- logger.debug("Fetching and storing Notion files")
- loop = asyncio.get_event_loop()
- loop.run_until_complete(
- fetch_and_store_notion_files_async(
- async_engine, access_token, user_id, sync_user_id
- )
- )
- sync_user_service.update_sync_user_status(
- sync_user_id=sync_user_id, status=str(SyncsUserStatus.SYNCED)
- )
- except Exception:
- logger.error("Error fetching and storing Notion files")
- sync_user_service.update_sync_user_status(
- sync_user_id=sync_user_id, status=str(SyncsUserStatus.ERROR)
- )
-
-
-@celery.task(name="clean_notion_user_syncs")
-def clean_notion_user_syncs():
- logger.debug("Cleaning Notion user syncs")
- sync_user_service.clean_notion_user_syncs()
-
-
celery.conf.beat_schedule = {
"ping_telemetry": {
"task": f"{__name__}.ping_telemetry",
"schedule": crontab(minute="*/30", hour="*"),
},
- "process_active_syncs": {
- "task": "process_active_syncs_task",
- "schedule": crontab(minute="*/1", hour="*"),
- },
"process_premium_users": {
"task": "check_is_premium_task",
"schedule": crontab(minute="*/1", hour="*"),
},
- "process_notion_sync": {
- "task": "process_notion_sync_task",
- "schedule": crontab(minute="0", hour="*/6"),
+ "refresh_sync_files": {
+ "task": "refresh_sync_files_task",
+ "schedule": crontab(hour="*/8"),
},
- "clean_notion_user_syncs": {
- "task": "clean_notion_user_syncs",
- "schedule": crontab(minute="0", hour="0"),
+ "refresh_sync_folders": {
+ "task": "refresh_sync_folders_task",
+ "schedule": crontab(hour="*/8"),
},
}
diff --git a/backend/worker/quivr_worker/files.py b/backend/worker/quivr_worker/files.py
deleted file mode 100644
index 8648c7ba9c8e..000000000000
--- a/backend/worker/quivr_worker/files.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import hashlib
-import time
-from contextlib import contextmanager
-from pathlib import Path
-from tempfile import NamedTemporaryFile
-from typing import Any
-from uuid import UUID
-
-from quivr_api.logger import get_logger
-from quivr_core.files.file import FileExtension, QuivrFile
-
-from quivr_worker.utils.utils import get_tmp_name
-
-logger = get_logger("celery_worker")
-
-
-def compute_sha1(content: bytes) -> str:
- m = hashlib.sha1()
- m.update(content)
- return m.hexdigest()
-
-
-@contextmanager
-def build_file(
- file_data: bytes,
- knowledge_id: UUID,
- file_name: str,
- original_file_name: str | None = None,
-):
- try:
- # TODO(@aminediro) : Maybe use fsspec file to be agnostic to where files are stored :?
- # We are reading the whole file to memory, which doesn't scale
- tmp_name, base_file_name, file_extension = get_tmp_name(file_name)
- tmp_file = NamedTemporaryFile(
- suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none
- )
- tmp_file.write(file_data)
- tmp_file.flush()
- file_sha1 = compute_sha1(file_data)
-
- file_instance = File(
- knowledge_id=knowledge_id,
- file_name=base_file_name,
- original_file_name=(
- original_file_name if original_file_name else base_file_name
- ),
- tmp_file_path=Path(tmp_file.name),
- file_size=len(file_data),
- file_extension=file_extension,
- file_sha1=file_sha1,
- )
- yield file_instance
- finally:
- # Code to release resource, e.g.:
- tmp_file.close()
-
-
-class File:
- __slots__ = [
- "id",
- "file_name",
- "tmp_file_path",
- "file_size",
- "file_extension",
- "file_sha1",
- "original_file_name",
- ]
-
- def __init__(
- self,
- knowledge_id: UUID,
- file_name: str,
- tmp_file_path: Path,
- file_size: int,
- file_extension: str,
- file_sha1: str,
- original_file_name: str,
- ):
- self.id = knowledge_id
- self.file_name = file_name
- self.tmp_file_path = tmp_file_path
- self.file_size = file_size
- self.file_sha1 = file_sha1
- self.file_extension = FileExtension(file_extension)
- self.original_file_name = original_file_name
-
- def is_empty(self):
- return self.file_size < 1 # pyright: ignore reportPrivateUsage=none
-
- def to_qfile(self, brain_id: UUID, metadata: dict[str, Any] = {}) -> QuivrFile:
- return QuivrFile(
- id=self.id,
- original_filename=self.file_name,
- path=self.tmp_file_path,
- brain_id=brain_id,
- file_sha1=self.file_sha1,
- file_extension=self.file_extension,
- file_size=self.file_size,
- metadata={
- "date": time.strftime("%Y%m%d"),
- "file_name": self.file_name,
- "original_file_name": self.original_file_name,
- "knowledge_id": self.id,
- **metadata,
- },
- )
diff --git a/backend/worker/quivr_worker/parsers/audio.py b/backend/worker/quivr_worker/parsers/audio.py
index 533357e28080..f4c6890500e2 100644
--- a/backend/worker/quivr_worker/parsers/audio.py
+++ b/backend/worker/quivr_worker/parsers/audio.py
@@ -3,11 +3,12 @@
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from openai import OpenAI
+from quivr_core.files.file import QuivrFile
-from quivr_worker.files import File, compute_sha1
+from quivr_worker.process.utils import compute_sha1
-def process_audio(file: File, model: str = "whisper=1"):
+def process_audio(file: QuivrFile, model: str = "whisper=1"):
# TODO(@aminediro): These should apear in the class processor
# Should be instanciated once per Processor
chunk_size = 500
@@ -19,7 +20,8 @@ def process_audio(file: File, model: str = "whisper=1"):
dateshort = time.strftime("%Y%m%d-%H%M%S")
file_meta_name = f"audiotranscript_{dateshort}.txt"
- with open(file.tmp_file_path, "rb") as audio_file:
+ # TODO: This reopens the file adding an additional FD
+ with open(file.path, "rb") as audio_file:
transcript = client.audio.transcriptions.create(model=model, file=audio_file)
transcript_txt = transcript.text.encode("utf-8")
diff --git a/backend/worker/quivr_worker/parsers/crawler.py b/backend/worker/quivr_worker/parsers/crawler.py
index b6bec671c231..d60f3b36de5b 100644
--- a/backend/worker/quivr_worker/parsers/crawler.py
+++ b/backend/worker/quivr_worker/parsers/crawler.py
@@ -20,7 +20,6 @@ class URL(BaseModel):
async def extract_from_url(url: URL) -> str:
# Extract and combine content recursively
loader = PlaywrightURLLoader(urls=[url.url], remove_selectors=["header", "footer"])
-
data = await loader.aload()
# Now turn the data into a string
logger.info(f"Extracted content from {len(data)} pages")
diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md
new file mode 100644
index 000000000000..d83060bb172d
--- /dev/null
+++ b/backend/worker/quivr_worker/process/README.md
@@ -0,0 +1,116 @@
+# Knowledge Processing Task
+
+## Steps for Processing
+
+1. The task receives a `knowledge_id: UUID`.
+2. The `KnowledgeProcessor.process_knowledge` method processes the knowledge:
+ - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream:
+ - Retrieves the `KnowledgeDB` object from the database.
+ - Determines the processing steps based on the knowledge source:
+ - **Local**:
+ - Downloads the knowledge data from S3 storage and writes it to a temporary file.
+ - Yields the `[Knowledge, QuivrFile]`.
+ - **Web**: Processes similarly to the **Local** method.
+ - **[Syncs]**:
+ - Fetches the associated sync and verifies the credentials.
+ - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider).
+ - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received.
+ - For all children of this knowledge (i.e., those fetched from the sync):
+ - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`):
+ - This implies that the sync's child knowledge might have been processed earlier in another brain.
+ - If the knowledge has been PROCESSED, link it to the parent brains and continue.
+ - If not, reprocess the file.
+ - If the child does not exist:
+ - Create the knowledge associated with the sync file and set it to `Processing`.
+ - Download the sync file's data and yield the `[Knowledge, QuivrFile]`.
+ - Skip processing of the tuple if the knowledge is a folder.
+ - Parse the `QuivrFile` using `quivr-core`.
+ - Store the resulting chunks in the database.
+ - Update the knowledge status to `PROCESSED`.
+
+### Handling Exceptions During Parsing Loop
+
+#### Catchable Errors:
+
+If an exception occurs during the parsing loop, the following steps are taken:
+
+1. Roll back the current transaction (this only affects the vectors) if they were set. The processing loop performs the following stateful operations in this order:
+
+ - Creating knowledges (with `Processing` status).
+ - Updating knowledges: linking them to brains.
+ - Creating vectors.
+ - Updating knowledges.
+
+ **Transaction Safety for Each Operation:**
+
+ - **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry.
+ - **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry.
+ - **Creating vectors**:
+ - This operation should be rolled back if an error occurs afterward. Otherwise, the knowledge could remain in `Processing` or `ERROR` status with associated vectors.
+ - Reprocessing the knowledge would result in reinserting the vectors into the database, leading to duplicate vectors for the same knowledge.
+
+**Transaction Safety for Each Operation:**
+
+- **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry.
+- **Downloading sync files**: This operation is idempotent but is safe to retry. If a change has occured, we would download the last version of the file.
+- **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry.
+- **Creating vectors**:
+ - This operation should be rolled back if an error occurs afterward. Otherwise, the knowledge could remain in `Processing` or `ERROR` status with associated vectors.
+ - Reprocessing the knowledge would result in reinserting the vectors into the database, leading to duplicate vectors for the same knowledge.
+
+1. Set the knowledge status to `ERROR`.
+2. Continue processing.
+
+| Note: This means that some knowledges will remain in an errored state. Currently, they are not automatically rescheduled for processing.
+
+#### Uncatchable Errors (e.g., worker process fails):
+
+- The task will be automatically retried three times, handled by Celery.
+- The notifier will receive an event indicating the task has failed.
+- The notifier will set the knowledge status to `ERROR` for the task.
+
+---
+
+🔴 **NOTE: Sync Error Handling for Version v0.1:**
+
+For `process_knowledge` tasks involving the processing of a sync folder, the folder's status will be set to `ERROR`. If child knowledges associated with the sync have already been created, their status cannot be set to `ERROR`. This would leave them stuck in `PROCESSING` status while their parent has an `ERROR` status.
+
+Why can’t we set all children to `ERROR`? This introduces a potential race condition: Sync knowledge can be added to a brain independently from its parent, so it’s unclear if the `PROCESSING` status is tied to the failed task. Although keeping a `task_id` associated with `knowledge_id` could help, it’s error-prone and impacts the database schema, which would have significant consequences.
+
+However, sync knowledge added to a brain will be reprocessed after some time through the sync update task, ensuring that their status will eventually be set to the correct state.
+
+# Syncing Knowledge task
+
+1. **Syncing Knowledge Syncs of Type Files:**
+ - Outdated file syncs are fetched in batches.
+ - For each file, if the remote file's `last_modified_at` is newer than the local `last_synced_at`, the file is updated.
+ - If the file is missing remotely, the db knowledge is deleted.
+2. **Syncing Knowledge Folders:**
+ - Outdated folder syncs are retrieved in batches.
+ - For each folder, its children (files and subfolders) are fetched from both the database and the remote provider.
+ - Remote children missing from the local database are added and processed.
+ - **If a Folder is Not Found:**
+ - If a folder no longer exists remotely, it is deleted locally, along with all associated knowledge entries.
+
+🔴 **Key Considerations**
+
+- **Batch Processing:**
+
+ - Both file and folder syncs are handled in batches, ensuring the system can process large data efficiently.
+
+- **Error Handling:**
+
+ - The system logs errors such as missing credentials or files, allowing the sync process to continue or fail gracefully.
+
+- **Savepoints and Rollback:**
+
+ - During file and folder processing, savepoints are created. If an error occurs, the transaction can be rolled back, ensuring the original knowledge remains unmodified.
+
+- **Deleting Folders:**
+ - If a folder is missing remotely, it triggers the deletion of the folder and all associated knowledge entries from the local system.
+
+---
+
+## Notification Steps
+
+To discuss: @StanGirard @Zewed
diff --git a/backend/worker/quivr_worker/process/__init__.py b/backend/worker/quivr_worker/process/__init__.py
index e69de29bb2d1..d44d4bd6bff5 100644
--- a/backend/worker/quivr_worker/process/__init__.py
+++ b/backend/worker/quivr_worker/process/__init__.py
@@ -0,0 +1,12 @@
+from uuid import UUID
+
+from sqlalchemy.ext.asyncio import AsyncEngine
+
+from quivr_worker.process.processor import KnowledgeProcessor
+from quivr_worker.utils.services import build_processor_services
+
+
+async def aprocess_file_task(async_engine: AsyncEngine, knowledge_id: UUID):
+ async with build_processor_services(async_engine) as processor_services:
+ km_processor = KnowledgeProcessor(services=processor_services)
+ await km_processor.process_knowledge(knowledge_id)
diff --git a/backend/worker/quivr_worker/process/process_file.py b/backend/worker/quivr_worker/process/process_file.py
index a13eb833f8ac..01fb9513f632 100644
--- a/backend/worker/quivr_worker/process/process_file.py
+++ b/backend/worker/quivr_worker/process/process_file.py
@@ -1,14 +1,12 @@
from typing import Any
-from uuid import UUID
from langchain_core.documents import Document
from quivr_api.logger import get_logger
-from quivr_api.modules.brain.entity.brain_entity import BrainEntity
-from quivr_api.modules.brain.service.brain_service import BrainService
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
from quivr_api.modules.vector.service.vector_service import VectorService
+from quivr_core.files.file import QuivrFile
from quivr_core.processor.registry import get_processor_class
-from quivr_worker.files import File
from quivr_worker.parsers.audio import process_audio
logger = get_logger("celery_worker")
@@ -25,92 +23,56 @@
}
-async def process_file(
- file_instance: File,
- brain: BrainEntity,
- brain_service: BrainService,
- vector_service: VectorService,
- integration: str | None,
- integration_link: str | None,
-):
- chunks = await parse_file(
- file=file_instance,
- brain=brain,
- integration=integration,
- integration_link=integration_link,
- )
- store_chunks(
- file=file_instance,
- brain_id=brain.brain_id,
- chunks=chunks,
- brain_service=brain_service,
- vector_service=vector_service,
- )
-
-
-def store_chunks(
+async def store_chunks(
*,
- file: File,
- brain_id: UUID,
+ knowledge: KnowledgeDB,
chunks: list[Document],
- brain_service: BrainService,
vector_service: VectorService,
):
- # vector_ids = document_vector_store.add_documents(chunks)
- vector_ids = vector_service.create_vectors(chunks, file.id)
- logger.debug(f"Inserted {len(chunks)} chunks in vectors table for {file}")
-
+ assert knowledge.id
+ vector_ids = await vector_service.create_vectors(
+ chunks, knowledge.id, autocommit=False
+ )
+ logger.debug(
+ f"Inserted {len(chunks)} chunks in vectors table for knowledge: {knowledge.id}"
+ )
if vector_ids is None or len(vector_ids) == 0:
- raise Exception(f"Error inserting chunks for file {file.file_name}")
+ raise Exception(f"Error inserting chunks for knowledge {knowledge.id}")
- brain_service.update_brain_last_update_time(brain_id)
-
-async def parse_file(
- file: File,
- brain: BrainEntity,
- integration: str | None = None,
- integration_link: str | None = None,
+async def parse_qfile(
+ *,
+ qfile: QuivrFile,
**processor_kwargs: dict[str, Any],
) -> list[Document]:
try:
# TODO(@aminediro): add audio procesors to quivr-core
- if file.file_extension in audio_extensions:
- logger.debug(f"processing audio file {file}")
- audio_docs = process_audio_file(file, brain)
+ if qfile.file_extension in audio_extensions:
+ logger.debug(f"processing audio file {qfile}")
+ audio_docs = process_audio_file(qfile)
return audio_docs
else:
- qfile = file.to_qfile(
- brain.brain_id,
- {
- "integration": integration or "",
- "integration_link": integration_link or "",
- },
- )
- processor_cls = get_processor_class(file.file_extension)
+ processor_cls = get_processor_class(qfile.file_extension)
processor = processor_cls(**processor_kwargs)
docs = await processor.process_file(qfile)
logger.debug(f"Parsed {qfile} to : {docs}")
return docs
except KeyError as e:
- raise ValueError(f"Can't parse {file}. No available processor") from e
+ raise ValueError(f"Can't parse {qfile}. No available processor") from e
+# TODO: Move this to quivr-core
def process_audio_file(
- file: File,
- brain: BrainEntity,
+ qfile: QuivrFile,
):
try:
- result = process_audio(file=file)
+ result = process_audio(file=qfile)
if result is None or result == 0:
logger.info(
- f"{file.file_name} has been uploaded to brain. There might have been an error while reading it, please make sure the file is not illformed or just an image", # pyright: ignore reportPrivateUsage=none
+ f"{qfile.file_name} has been uploaded to brain. There might have been an error while reading it, please make sure the file is not illformed or just an image", # pyright: ignore reportPrivateUsage=none
)
return []
- logger.info(
- f"{file.file_name} has been uploaded to brain {brain.name} in {result} chunks", # pyright: ignore reportPrivateUsage=none
- )
return result
except Exception as e:
- logger.exception(f"Error processing audio file {file}: {e}")
+ logger.exception(f"Error processing audio file {qfile}: {e}")
raise e
diff --git a/backend/worker/quivr_worker/process/process_s3_file.py b/backend/worker/quivr_worker/process/process_s3_file.py
deleted file mode 100644
index 99bc4e7360d1..000000000000
--- a/backend/worker/quivr_worker/process/process_s3_file.py
+++ /dev/null
@@ -1,56 +0,0 @@
-from uuid import UUID
-
-from quivr_api.logger import get_logger
-from quivr_api.modules.brain.service.brain_service import BrainService
-from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate
-from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
-from quivr_api.modules.vector.service.vector_service import VectorService
-
-from quivr_worker.files import build_file
-from quivr_worker.process.process_file import process_file
-from supabase import Client
-
-logger = get_logger("celery_worker")
-
-
-async def process_uploaded_file(
- supabase_client: Client,
- brain_service: BrainService,
- vector_service: VectorService,
- knowledge_service: KnowledgeService,
- file_name: str,
- brain_id: UUID,
- file_original_name: str,
- knowledge_id: UUID,
- integration: str | None = None,
- integration_link: str | None = None,
- delete_file: bool = False,
- bucket_name: str = "quivr",
-):
- brain = brain_service.get_brain_by_id(brain_id)
- if brain is None:
- logger.exception(
- "It seems like you're uploading knowledge to an unknown brain."
- )
- raise ValueError("unknown brain")
- assert brain
- file_data = supabase_client.storage.from_(bucket_name).download(file_name)
- # TODO: Have the whole logic on do we process file or not
- # Don't process a file that already exists (file_sha1 in the table with STATUS=UPLOADED)
- #
- # - Check on file_sha1 and status
- # If we have some knowledge with error
- with build_file(file_data, knowledge_id, file_name) as file_instance:
- knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id)
- await knowledge_service.update_knowledge(
- knowledge,
- KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore
- )
- await process_file(
- file_instance=file_instance,
- brain=brain,
- brain_service=brain_service,
- vector_service=vector_service,
- integration=integration,
- integration_link=integration_link,
- )
diff --git a/backend/worker/quivr_worker/process/process_url.py b/backend/worker/quivr_worker/process/process_url.py
deleted file mode 100644
index a5dabecd361d..000000000000
--- a/backend/worker/quivr_worker/process/process_url.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from uuid import UUID
-
-from quivr_api.logger import get_logger
-from quivr_api.modules.brain.service.brain_service import BrainService
-from quivr_api.modules.vector.service.vector_service import VectorService
-
-from quivr_worker.files import build_file
-from quivr_worker.parsers.crawler import URL, extract_from_url, slugify
-from quivr_worker.process.process_file import process_file
-
-logger = get_logger("celery_worker")
-
-
-async def process_url_func(
- url: str,
- brain_id: UUID,
- knowledge_id: UUID,
- brain_service: BrainService,
- vector_service: VectorService,
-):
- crawl_website = URL(url=url)
- extracted_content = await extract_from_url(crawl_website)
- extracted_content_bytes = extracted_content.encode("utf-8")
- file_name = slugify(crawl_website.url) + ".txt"
-
- brain = brain_service.get_brain_by_id(brain_id)
- if brain is None:
- logger.error("It seems like you're uploading knowledge to an unknown brain.")
- return 1
-
- with build_file(extracted_content_bytes, knowledge_id, file_name) as file_instance:
- # TODO(@StanGirard): fix bug
- # NOTE (@aminediro): I think this might be related to knowledge delete timeouts ?
- await process_file(
- file_instance=file_instance,
- brain=brain,
- brain_service=brain_service,
- integration=None,
- integration_link=None,
- vector_service=vector_service,
- )
diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py
new file mode 100644
index 000000000000..83bee5f61d8b
--- /dev/null
+++ b/backend/worker/quivr_worker/process/processor.py
@@ -0,0 +1,438 @@
+import asyncio
+from datetime import datetime, timedelta, timezone
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, AsyncGenerator, List, Optional, Tuple
+from uuid import UUID
+
+from quivr_api.logger import get_logger
+from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile, SyncType
+from quivr_api.modules.sync.utils.sync import BaseSync
+from quivr_core.files.file import QuivrFile
+from quivr_core.models import KnowledgeStatus
+from sqlalchemy.ext.asyncio import AsyncSessionTransaction
+
+from quivr_worker.parsers.crawler import URL, extract_from_url
+from quivr_worker.process.process_file import parse_qfile, store_chunks
+from quivr_worker.process.utils import (
+ build_qfile,
+ build_sync_file,
+ compute_sha1,
+ skip_process,
+)
+from quivr_worker.utils.services import ProcessorServices
+
+logger = get_logger("celery_worker")
+
+
+class KnowledgeProcessor:
+ def __init__(self, services: ProcessorServices):
+ self.services = services
+
+ async def fetch_db_knowledges_and_syncprovider(
+ self,
+ sync_id: int,
+ user_id: UUID,
+ folder_id: str | None,
+ ) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]:
+ map_knowledges_task = self.services.knowledge_service.map_syncs_knowledge_user(
+ sync_id=sync_id, user_id=user_id
+ )
+ sync_files_task = self.services.sync_service.get_files_folder_user_sync(
+ sync_id,
+ user_id,
+ folder_id,
+ )
+ return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821
+
+ async def yield_processable_knowledge(
+ self, knowledge_id: UUID
+ ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]:
+ """Should only yield ready to process knowledges:
+ Knowledge ready to process:
+ - Is either Local or Sync
+ - Is in a status: PROCESSING | ERROR
+ - Has an associated QuivrFile that is parsable
+ """
+ knowledge = await self.services.knowledge_service.get_knowledge(knowledge_id)
+ if knowledge.source == KnowledgeSource.LOCAL:
+ async for to_process in self._yield_local(knowledge):
+ yield to_process
+ elif knowledge.source in (
+ KnowledgeSource.AZURE,
+ KnowledgeSource.GOOGLE,
+ KnowledgeSource.DROPBOX,
+ KnowledgeSource.GITHUB,
+ # KnowledgeSource.NOTION,
+ ):
+ async for to_process in self._yield_syncs(knowledge):
+ yield to_process
+ elif knowledge.source == KnowledgeSource.WEB:
+ async for to_process in self._yield_web(knowledge):
+ yield to_process
+ else:
+ logger.error(
+ f"received knowledge : {knowledge.id} with unknown source: {knowledge.source}"
+ )
+ raise ValueError(f"Unknown knowledge source : {knowledge.source}")
+
+ async def _yield_local(
+ self, knowledge: KnowledgeDB
+ ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]:
+ if knowledge.id is None or knowledge.file_name is None:
+ logger.error(f"received unprocessable local knowledge : {knowledge.id} ")
+ raise ValueError(
+ f"received unprocessable local knowledge : {knowledge.id} "
+ )
+ if knowledge.is_folder:
+ yield (
+ knowledge,
+ QuivrFile(
+ id=knowledge.id,
+ original_filename=knowledge.file_name,
+ file_extension=knowledge.extension,
+ file_sha1="",
+ path=Path(),
+ ),
+ )
+ else:
+ file_data = await self.services.knowledge_service.storage.download_file(
+ knowledge
+ )
+ knowledge.file_sha1 = compute_sha1(file_data)
+ with build_qfile(knowledge, file_data) as qfile:
+ yield (knowledge, qfile)
+
+ async def _yield_web(
+ self, knowledge_db: KnowledgeDB
+ ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]:
+ if knowledge_db.id is None or knowledge_db.url is None:
+ logger.error(f"received unprocessable web knowledge : {knowledge_db.id} ")
+ raise ValueError(
+ f"received unprocessable web knowledge : {knowledge_db.id} "
+ )
+ crawl_website = URL(url=knowledge_db.url)
+ extracted_content = await extract_from_url(crawl_website)
+ extracted_content_bytes = extracted_content.encode("utf-8")
+ knowledge_db.file_sha1 = compute_sha1(extracted_content_bytes)
+ knowledge_db.file_size = len(extracted_content_bytes)
+ with build_qfile(knowledge_db, extracted_content_bytes) as qfile:
+ yield (knowledge_db, qfile)
+
+ async def _yield_syncs(
+ self, parent_knowledge: KnowledgeDB
+ ) -> AsyncGenerator[Optional[Tuple[KnowledgeDB, QuivrFile]], None]:
+ if parent_knowledge.id is None:
+ logger.error(f"received unprocessable knowledge: {parent_knowledge.id} ")
+ raise ValueError
+
+ if parent_knowledge.file_name is None:
+ logger.error(f"received unprocessable knowledge : {parent_knowledge.id} ")
+ raise ValueError(
+ f"received unprocessable knowledge : {parent_knowledge.id} "
+ )
+ if (
+ parent_knowledge.sync_file_id is None
+ or parent_knowledge.sync_id is None
+ or parent_knowledge.source_link is None
+ ):
+ logger.error(
+ f"unprocessable sync knowledge : {parent_knowledge.id}. no sync_file_id"
+ )
+ raise ValueError(
+ f"received unprocessable knowledge : {parent_knowledge.id} "
+ )
+ # Get associated sync
+ sync = await self._get_sync(parent_knowledge.sync_id)
+ if sync.credentials is None:
+ logger.error(
+ f"can't process knowledge: {parent_knowledge.id}. sync {sync.id} has no credentials"
+ )
+ raise ValueError("no associated credentials")
+
+ provider_name = SyncProvider(sync.provider.lower())
+ sync_provider = self.services.syncprovider_mapping[provider_name]
+
+ # Yield parent_knowledge as the first knowledge to process
+ async with build_sync_file(
+ file_knowledge=parent_knowledge,
+ credentials=sync.credentials,
+ sync_provider=sync_provider,
+ sync_file=SyncFile(
+ id=parent_knowledge.sync_file_id,
+ name=parent_knowledge.file_name,
+ extension=parent_knowledge.extension,
+ web_view_link=parent_knowledge.source_link,
+ is_folder=parent_knowledge.is_folder,
+ last_modified_at=parent_knowledge.updated_at,
+ ),
+ ) as f:
+ yield f
+
+ # Fetch children
+ (
+ syncfile_to_knowledge,
+ sync_files,
+ ) = await self.fetch_db_knowledges_and_syncprovider(
+ sync_id=parent_knowledge.sync_id,
+ user_id=parent_knowledge.user_id,
+ folder_id=parent_knowledge.sync_file_id,
+ )
+ if not sync_files:
+ return
+
+ for sync_file in sync_files:
+ file_knowledge = (
+ await self.services.knowledge_service.create_or_link_sync_knowledge(
+ syncfile_id_to_knowledge=syncfile_to_knowledge,
+ parent_knowledge=parent_knowledge,
+ sync_file=sync_file,
+ )
+ )
+ if file_knowledge.status == KnowledgeStatus.PROCESSED:
+ continue
+ async with build_sync_file(
+ file_knowledge=file_knowledge,
+ credentials=sync.credentials,
+ sync_provider=sync_provider,
+ sync_file=sync_file,
+ ) as f:
+ yield f
+
+ async def create_savepoint(self) -> AsyncSessionTransaction:
+ savepoint = (
+ await self.services.knowledge_service.repository.session.begin_nested()
+ )
+ return savepoint
+
+ async def process_knowledge(self, knowledge_id: UUID):
+ async for knowledge_tuple in self.yield_processable_knowledge(knowledge_id):
+ # FIXME(@AmineDiro) : nested transaction for making
+ savepoint = await self.create_savepoint()
+ if knowledge_tuple is None:
+ continue
+ knowledge, qfile = knowledge_tuple
+ try:
+ await self._process_inner(knowledge=knowledge, qfile=qfile)
+ await savepoint.commit()
+ except Exception as e:
+ await savepoint.rollback()
+ logger.error(f"Error processing knowledge {knowledge_id} : {e}")
+ # FIXME: This one can also fail if knowledge was deleted
+ await self.services.knowledge_service.update_knowledge(
+ knowledge,
+ KnowledgeUpdate(
+ status=KnowledgeStatus.ERROR,
+ ),
+ )
+
+ async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile):
+ last_synced_at = datetime.now(timezone.utc)
+ if not skip_process(knowledge):
+ chunks = await parse_qfile(qfile=qfile)
+ await store_chunks(
+ knowledge=knowledge,
+ chunks=chunks,
+ vector_service=self.services.vector_service,
+ )
+ await self.services.knowledge_service.update_knowledge(
+ knowledge,
+ KnowledgeUpdate(
+ status=KnowledgeStatus.PROCESSED,
+ file_sha1=knowledge.file_sha1,
+ # Update sync
+ last_synced_at=last_synced_at if knowledge.sync_id else None,
+ ),
+ autocommit=False,
+ )
+
+ @lru_cache(maxsize=50) # noqa: B019
+ async def _get_sync(self, sync_id: int) -> Sync:
+ sync = await self.services.sync_service.get_sync_by_id(sync_id)
+ return sync
+
+ async def refresh_sync_folders(
+ self, timedelta_hour: int = 8, batch_size: int = 100
+ ):
+ last_time = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour)
+ km_sync_folders = await self.services.knowledge_service.get_outdated_syncs(
+ limit_time=last_time,
+ batch_size=batch_size,
+ km_sync_type=SyncType.FOLDER,
+ )
+ for sync_folder_km in km_sync_folders:
+ await self.refresh_sync_folder(sync_folder_km)
+
+ async def refresh_sync_folder(self, folder_km: KnowledgeDB) -> KnowledgeDB:
+ assert folder_km.sync_id, "can only update sync files with sync_id"
+ assert folder_km.sync_file_id, "can only update sync files with sync_file_id "
+ sync = await self._get_sync(folder_km.sync_id)
+ if sync.credentials is None:
+ logger.error(
+ f"can't process knowledge: {folder_km.id}. sync {sync.id} has no credentials"
+ )
+ raise ValueError(f"no associated credentials with knowledge {folder_km}")
+ provider_name = SyncProvider(sync.provider.lower())
+ sync_provider = self.services.syncprovider_mapping[provider_name]
+ km_children: List[KnowledgeDB] = await folder_km.awaitable_attrs.children
+ sync_children = {c.sync_file_id for c in km_children}
+ try:
+ sync_files = await sync_provider.aget_files(
+ credentials=sync.credentials,
+ folder_id=folder_km.sync_file_id,
+ recursive=False,
+ )
+ for sync_entry in filter(lambda s: s.id not in sync_children, sync_files):
+ await self.add_new_sync_entry(folder=folder_km, sync_entry=sync_entry)
+
+ except FileNotFoundError:
+ logger.info(
+ f"Knowledge {folder_km.id} not found in remote sync. Removing the folder"
+ )
+ await self.services.knowledge_service.remove_knowledge(
+ folder_km, autocommit=True
+ )
+ except Exception:
+ logger.exception(f"Exception occured processing folder: {folder_km.id}")
+ finally:
+ await self.services.knowledge_service.update_knowledge(
+ knowledge=folder_km,
+ payload=KnowledgeUpdate(last_synced_at=datetime.now(timezone.utc)),
+ )
+ return folder_km
+
+ async def add_new_sync_entry(self, folder: KnowledgeDB, sync_entry: SyncFile):
+ sync_km = await self.services.knowledge_service.create_knowledge(
+ user_id=folder.user_id,
+ knowledge_to_add=AddKnowledge(
+ file_name=sync_entry.name,
+ is_folder=sync_entry.is_folder,
+ extension=sync_entry.extension,
+ source=folder.source,
+ source_link=sync_entry.web_view_link,
+ parent_id=folder.id,
+ sync_id=folder.sync_id,
+ sync_file_id=sync_entry.id,
+ ),
+ status=KnowledgeStatus.PROCESSING,
+ upload_file=None,
+ autocommit=True,
+ process_async=False,
+ )
+ async for processable_tuple in self._yield_syncs(sync_km):
+ if processable_tuple is None:
+ continue
+ knowledge, qfile = processable_tuple
+ savepoint = await self.create_savepoint()
+ try:
+ await self._process_inner(knowledge=knowledge, qfile=qfile)
+ await savepoint.commit()
+ except Exception:
+ await savepoint.rollback()
+ logger.exception(f"Error occured processing :{knowledge.id}")
+
+ async def refresh_knowledge_sync_files(
+ self, timedelta_hour: int = 8, batch_size: int = 1000
+ ):
+ last_time = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour)
+ km_sync_files = await self.services.knowledge_service.get_outdated_syncs(
+ limit_time=last_time,
+ batch_size=batch_size,
+ km_sync_type=SyncType.FILE,
+ )
+ for old_km in km_sync_files:
+ try:
+ assert old_km.sync_id, "can only update sync files with sync_id"
+ assert (
+ old_km.sync_file_id
+ ), "can only update sync files with sync_file_id "
+ sync = await self._get_sync(old_km.sync_id)
+ if sync.credentials is None:
+ logger.error(
+ f"can't process knowledge: {old_km.id}. sync {sync.id} has no credentials"
+ )
+ raise ValueError(
+ f"no associated credentials with knowledge {old_km}"
+ )
+ provider_name = SyncProvider(sync.provider.lower())
+ sync_provider = self.services.syncprovider_mapping[provider_name]
+ new_sync_file = (
+ await sync_provider.aget_files_by_id(
+ credentials=sync.credentials, file_ids=[old_km.sync_file_id]
+ )
+ )[0]
+ await self.refresh_knowledge_entry(
+ old_km=old_km,
+ new_sync_file=new_sync_file,
+ sync_provider=sync_provider,
+ sync_credentials=sync.credentials,
+ )
+ except FileNotFoundError:
+ logger.info(
+ f"Knowledge {old_km.id} not found in remote sync. Removing the knowledge"
+ )
+ await self.services.knowledge_service.remove_knowledge(
+ old_km, autocommit=True
+ )
+ except Exception:
+ logger.exception(f"Exception occured processing km: {old_km.id}")
+
+ async def refresh_knowledge_entry(
+ self,
+ old_km: KnowledgeDB,
+ new_sync_file: SyncFile,
+ sync_provider: BaseSync,
+ sync_credentials: dict[str, Any],
+ ) -> KnowledgeDB | None:
+ assert (
+ old_km.last_synced_at
+ ), "can only update sync files without a last_synced_at"
+ if (
+ new_sync_file.last_modified_at
+ and new_sync_file.last_modified_at > old_km.last_synced_at
+ ) or new_sync_file.last_modified_at is None:
+ savepoint = await self.create_savepoint()
+ try:
+ new_km = await self.services.knowledge_service.create_knowledge(
+ user_id=old_km.user_id,
+ knowledge_to_add=AddKnowledge(
+ file_name=new_sync_file.name,
+ is_folder=new_sync_file.is_folder,
+ extension=new_sync_file.extension,
+ source=old_km.source,
+ source_link=new_sync_file.web_view_link,
+ parent_id=old_km.parent_id,
+ sync_id=old_km.sync_id,
+ sync_file_id=new_sync_file.id,
+ ),
+ status=KnowledgeStatus.PROCESSING,
+ link_brains=await old_km.awaitable_attrs.brains,
+ upload_file=None,
+ autocommit=False,
+ process_async=False,
+ )
+ async with build_sync_file(
+ new_km,
+ new_sync_file,
+ sync_provider=sync_provider,
+ credentials=sync_credentials,
+ ) as (
+ new_km,
+ qfile,
+ ):
+ await self._process_inner(new_km, qfile)
+ await self.services.knowledge_service.remove_knowledge(
+ old_km, autocommit=False
+ )
+ await savepoint.commit()
+ await savepoint.session.refresh(new_km)
+ return new_km
+
+ except Exception as e:
+ logger.exception(
+ f"Rolling back. Error occured updating sync {old_km.id}: {e}"
+ )
+ await savepoint.rollback()
diff --git a/backend/worker/quivr_worker/process/utils.py b/backend/worker/quivr_worker/process/utils.py
new file mode 100644
index 000000000000..f4b052a25a8a
--- /dev/null
+++ b/backend/worker/quivr_worker/process/utils.py
@@ -0,0 +1,149 @@
+import hashlib
+import os
+import time
+from contextlib import asynccontextmanager, contextmanager
+from io import BytesIO
+from pathlib import Path
+from tempfile import NamedTemporaryFile
+from typing import Any, AsyncGenerator, Generator, Tuple
+
+from quivr_api.celery_config import celery
+from quivr_api.logger import get_logger
+from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import SyncFile
+from quivr_api.modules.sync.utils.sync import (
+ AzureDriveSync,
+ BaseSync,
+ DropboxSync,
+ GitHubSync,
+ GoogleDriveSync,
+)
+from quivr_core.files.file import FileExtension, QuivrFile
+
+from quivr_worker.parsers.crawler import slugify
+
+celery_inspector = celery.control.inspect()
+
+logger = get_logger("celery_worker")
+
+
+def skip_process(knowledge: KnowledgeDTO | KnowledgeDB) -> bool:
+ return knowledge.is_folder and knowledge.source != KnowledgeSource.NOTION
+
+
+def build_syncprovider_mapping() -> dict[SyncProvider, BaseSync]:
+ mapping_sync_utils = {
+ SyncProvider.GOOGLE: GoogleDriveSync(),
+ SyncProvider.AZURE: AzureDriveSync(),
+ SyncProvider.DROPBOX: DropboxSync(),
+ SyncProvider.GITHUB: GitHubSync(),
+ # SyncProvider.NOTION: NotionSync(notion_service=notion_service),
+ }
+ return mapping_sync_utils
+
+
+def compute_sha1(content: bytes) -> str:
+ m = hashlib.sha1()
+ m.update(content)
+ return m.hexdigest()
+
+
+def get_tmp_name(file_name: str) -> Tuple[str, str, str]:
+ # Filepath is S3 based
+ tmp_name = file_name.replace("/", "_")
+ base_file_name = os.path.basename(file_name)
+ _, file_extension = os.path.splitext(base_file_name)
+ return tmp_name, base_file_name, file_extension
+
+
+@contextmanager
+def create_temp_file(
+ file_data: bytes,
+ file_name_ext: str,
+):
+ # TODO(@aminediro) :
+ # Maybe use fsspec file to be agnostic to where files are stored
+ # We are reading the whole file to memory, which doesn't scale
+ try:
+ tmp_name, _, _ = get_tmp_name(file_name_ext)
+ tmp_file = NamedTemporaryFile(
+ suffix="_" + tmp_name,
+ )
+ tmp_file.write(file_data)
+ tmp_file.flush()
+ yield Path(tmp_file.name)
+ finally:
+ tmp_file.close()
+
+
+async def download_sync_file(
+ sync_provider: BaseSync, file: SyncFile, credentials: dict[str, Any]
+) -> bytes:
+ logger.info(f"Downloading {file} using {sync_provider}")
+ file_response = await sync_provider.adownload_file(credentials, file)
+ logger.debug(f"Fetch sync file response: {file_response}")
+ raw_data = file_response["content"]
+ if isinstance(raw_data, BytesIO):
+ file_data = raw_data.read()
+ else:
+ file_data = raw_data.encode("utf-8")
+ logger.debug(f"Successfully downloaded sync file : {file}")
+ return file_data
+
+
+@asynccontextmanager
+async def build_sync_file(
+ file_knowledge: KnowledgeDB,
+ sync_file: SyncFile,
+ sync_provider: BaseSync,
+ credentials: dict[str, Any],
+) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile], None]:
+ file_data = await download_sync_file(
+ sync_provider=sync_provider,
+ file=sync_file,
+ credentials=credentials,
+ )
+ file_knowledge.file_sha1 = compute_sha1(file_data)
+ file_knowledge.file_size = len(file_data)
+ with build_qfile(file_knowledge, file_data) as qfile:
+ yield (file_knowledge, qfile)
+
+
+@contextmanager
+def build_qfile(
+ knowledge: KnowledgeDB, file_data: bytes
+) -> Generator[QuivrFile, None, None]:
+ assert knowledge.id
+ assert knowledge.file_sha1
+ if knowledge.source == KnowledgeSource.WEB:
+ file_name = slugify(knowledge.url) + ".txt"
+ extension = FileExtension.txt
+ else:
+ assert knowledge.file_name
+ file_name = knowledge.file_name
+ extension = FileExtension(knowledge.extension)
+
+ with create_temp_file(
+ file_data=file_data, file_name_ext=file_name
+ ) as tmp_file_path:
+ qfile = QuivrFile(
+ id=knowledge.id,
+ original_filename=file_name,
+ path=tmp_file_path,
+ file_sha1=knowledge.file_sha1,
+ file_extension=extension,
+ file_size=len(file_data),
+ metadata={
+ "date": time.strftime("%Y%m%d"),
+ "file_name": knowledge.file_name,
+ "knowledge_id": knowledge.id,
+ },
+ )
+ if knowledge.metadata_:
+ qfile.additional_metadata = {
+ **qfile.metadata,
+ **knowledge.metadata_,
+ }
+ yield qfile
diff --git a/backend/worker/quivr_worker/syncs/__init__.py b/backend/worker/quivr_worker/syncs/__init__.py
index 0ca8a21db9fe..e69de29bb2d1 100644
--- a/backend/worker/quivr_worker/syncs/__init__.py
+++ b/backend/worker/quivr_worker/syncs/__init__.py
@@ -1,3 +0,0 @@
-from .process_active_syncs import process_all_active_syncs
-
-__all__ = ["process_all_active_syncs"]
diff --git a/backend/worker/quivr_worker/syncs/process_active_syncs.py b/backend/worker/quivr_worker/syncs/process_active_syncs.py
index 196e54773909..4d884163ccba 100644
--- a/backend/worker/quivr_worker/syncs/process_active_syncs.py
+++ b/backend/worker/quivr_worker/syncs/process_active_syncs.py
@@ -7,14 +7,19 @@
from quivr_api.modules.notification.service.notification_service import (
NotificationService,
)
-from quivr_api.modules.sync.entity.sync_models import SyncsActive
-from quivr_api.modules.sync.repository.sync_repository import NotionRepository
+from quivr_api.modules.sync.repository.notion_repository import NotionRepository
from quivr_api.modules.sync.service.sync_notion import (
SyncNotionService,
fetch_limit_notion_pages,
update_notion_pages,
)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.sync import (
+ AzureDriveSync,
+ DropboxSync,
+ GitHubSync,
+ GoogleDriveSync,
+)
from quivr_api.modules.sync.utils.syncutils import SyncUtils
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlmodel import text
@@ -27,6 +32,17 @@
logger = get_logger("celery_worker")
+async def build_syncprovider_mapping():
+ mapping_sync_utils = {
+ "google": GoogleDriveSync(),
+ "azure": AzureDriveSync(),
+ "dropbox": DropboxSync(),
+ "github": GitHubSync(),
+ # "notion", NotionSync(notion_service=notion_service),
+ }
+ return mapping_sync_utils
+
+
async def process_sync(
sync: SyncsActive,
files_ids: list[str],
@@ -35,9 +51,7 @@ async def process_sync(
):
async with build_syncs_utils(services) as mapping_syncs_utils:
try:
- user_sync = services.sync_user_service.get_sync_user_by_id(
- sync.syncs_user_id
- )
+ user_sync = services.sync_user_service.get_sync_by_id(sync.syncs_user_id)
services.notification_service.remove_notification_by_id(
sync.notification_id
)
@@ -45,7 +59,7 @@ async def process_sync(
sync_util = mapping_syncs_utils[user_sync.provider.lower()]
await sync_util.direct_sync(
sync_active=sync,
- user_sync=user_sync,
+ sync_user=user_sync,
files_ids=files_ids,
folder_ids=folder_ids,
)
@@ -69,8 +83,8 @@ async def process_all_active_syncs(sync_services: SyncServices):
async def _process_all_active_syncs(
- sync_active_service: SyncService,
- sync_user_service: SyncUserService,
+ sync_active_service: SyncsService,
+ sync_user_service: SyncsService,
mapping_syncs_utils: dict[str, SyncUtils],
notification_service: NotificationService,
):
@@ -78,7 +92,7 @@ async def _process_all_active_syncs(
logger.debug(f"Found active syncs: {active_syncs}")
for sync in active_syncs:
try:
- user_sync = sync_user_service.get_sync_user_by_id(sync.syncs_user_id)
+ user_sync = sync_user_service.get_sync_by_id(sync.syncs_user_id)
# TODO: this should be global
# NOTE: Remove the global notification
notification_service.remove_notification_by_id(sync.notification_id)
@@ -104,7 +118,7 @@ async def process_notion_sync(
await session.execute(
text("SET SESSION idle_in_transaction_session_timeout = '5min';")
)
- sync_user_service = SyncUserService()
+ sync_user_service = SyncsService()
notion_repository = NotionRepository(session)
notion_service = SyncNotionService(notion_repository)
diff --git a/backend/worker/quivr_worker/syncs/store_notion.py b/backend/worker/quivr_worker/syncs/store_notion.py
index 2e44524de36c..482e986caf2a 100644
--- a/backend/worker/quivr_worker/syncs/store_notion.py
+++ b/backend/worker/quivr_worker/syncs/store_notion.py
@@ -3,7 +3,7 @@
from notion_client import Client
from quivr_api.logger import get_logger
-from quivr_api.modules.sync.repository.sync_repository import NotionRepository
+from quivr_api.modules.sync.repository.notion_repository import NotionRepository
from quivr_api.modules.sync.service.sync_notion import (
SyncNotionService,
fetch_limit_notion_pages,
diff --git a/backend/worker/quivr_worker/syncs/update_syncs.py b/backend/worker/quivr_worker/syncs/update_syncs.py
new file mode 100644
index 000000000000..960f38bd1997
--- /dev/null
+++ b/backend/worker/quivr_worker/syncs/update_syncs.py
@@ -0,0 +1,16 @@
+from sqlalchemy.ext.asyncio import AsyncEngine
+
+from quivr_worker.process.processor import KnowledgeProcessor
+from quivr_worker.utils.services import build_processor_services
+
+
+async def refresh_sync_files(async_engine: AsyncEngine):
+ async with build_processor_services(async_engine) as processor_services:
+ km_processor = KnowledgeProcessor(services=processor_services)
+ await km_processor.refresh_knowledge_sync_files()
+
+
+async def refresh_sync_folders(async_engine: AsyncEngine):
+ async with build_processor_services(async_engine) as processor_services:
+ km_processor = KnowledgeProcessor(services=processor_services)
+ await km_processor.refresh_knowledge_sync_files()
diff --git a/backend/worker/quivr_worker/syncs/utils.py b/backend/worker/quivr_worker/syncs/utils.py
deleted file mode 100644
index bbc3c75f8588..000000000000
--- a/backend/worker/quivr_worker/syncs/utils.py
+++ /dev/null
@@ -1,91 +0,0 @@
-from contextlib import asynccontextmanager
-from dataclasses import dataclass
-from typing import AsyncGenerator
-
-from quivr_api.celery_config import celery
-from quivr_api.logger import get_logger
-from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
-from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
-from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
-from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
-from quivr_api.modules.notification.service.notification_service import (
- NotificationService,
-)
-from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository
-from quivr_api.modules.sync.repository.sync_repository import NotionRepository
-from quivr_api.modules.sync.service.sync_notion import (
- SyncNotionService,
-)
-from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
-from quivr_api.modules.sync.utils.sync import (
- AzureDriveSync,
- DropboxSync,
- GitHubSync,
- GoogleDriveSync,
- NotionSync,
-)
-from quivr_api.modules.sync.utils.syncutils import SyncUtils
-from sqlalchemy.ext.asyncio import AsyncEngine
-from sqlmodel import text
-from sqlmodel.ext.asyncio.session import AsyncSession
-
-celery_inspector = celery.control.inspect()
-
-logger = get_logger("celery_worker")
-
-
-@dataclass
-class SyncServices:
- async_engine: AsyncEngine
- sync_active_service: SyncService
- sync_user_service: SyncUserService
- sync_files_repo_service: SyncFilesRepository
- notification_service: NotificationService
- brain_vectors: BrainsVectors
- storage: SupabaseS3Storage
-
-
-@asynccontextmanager
-async def build_syncs_utils(
- deps: SyncServices,
-) -> AsyncGenerator[dict[str, SyncUtils], None]:
- try:
- async with AsyncSession(
- deps.async_engine, expire_on_commit=False, autoflush=False
- ) as session:
- await session.execute(
- text("SET SESSION idle_in_transaction_session_timeout = '5min';")
- )
- notion_repository = NotionRepository(session)
- notion_service = SyncNotionService(notion_repository)
- knowledge_service = KnowledgeService(KnowledgeRepository(session))
-
- mapping_sync_utils = {}
- for provider_name, sync_cloud in [
- ("google", GoogleDriveSync()),
- ("azure", AzureDriveSync()),
- ("dropbox", DropboxSync()),
- ("github", GitHubSync()),
- (
- "notion",
- NotionSync(notion_service=notion_service),
- ), # Fixed duplicate "github" key
- ]:
- provider_sync_util = SyncUtils(
- sync_user_service=deps.sync_user_service,
- sync_active_service=deps.sync_active_service,
- sync_files_repo=deps.sync_files_repo_service,
- sync_cloud=sync_cloud,
- notification_service=deps.notification_service,
- brain_vectors=deps.brain_vectors,
- knowledge_service=knowledge_service,
- )
- mapping_sync_utils[provider_name] = provider_sync_util
-
- yield mapping_sync_utils
- await session.commit()
- except Exception as e:
- await session.rollback()
- raise e
- finally:
- await session.close()
diff --git a/backend/worker/quivr_worker/utils/services.py b/backend/worker/quivr_worker/utils/services.py
new file mode 100644
index 000000000000..5bfbda44c156
--- /dev/null
+++ b/backend/worker/quivr_worker/utils/services.py
@@ -0,0 +1,85 @@
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from typing import AsyncGenerator
+
+from quivr_api.logger import get_logger
+from quivr_api.modules.dependencies import get_supabase_async_client
+from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
+from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage
+from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.repository.sync_repository import SyncsRepository
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.utils.sync import (
+ AzureDriveSync,
+ BaseSync,
+ DropboxSync,
+ GitHubSync,
+ GoogleDriveSync,
+)
+from quivr_api.modules.vector.repository.vectors_repository import VectorRepository
+from quivr_api.modules.vector.service.vector_service import VectorService
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlmodel import text
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+logger = get_logger("celery_worker")
+
+
+def build_syncprovider_mapping() -> dict[SyncProvider, BaseSync]:
+ mapping_sync_utils = {
+ SyncProvider.GOOGLE: GoogleDriveSync(),
+ SyncProvider.AZURE: AzureDriveSync(),
+ SyncProvider.DROPBOX: DropboxSync(),
+ SyncProvider.GITHUB: GitHubSync(),
+ # SyncProvider.NOTION: NotionSync(notion_service=notion_service),
+ }
+ return mapping_sync_utils
+
+
+@dataclass
+class ProcessorServices:
+ sync_service: SyncsService
+ vector_service: VectorService
+ knowledge_service: KnowledgeService
+ syncprovider_mapping: dict[SyncProvider, BaseSync]
+
+
+@asynccontextmanager
+async def _start_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
+ async with AsyncSession(engine) as session:
+ try:
+ await session.execute(
+ text("SET SESSION idle_in_transaction_session_timeout = '5min';")
+ )
+ yield session
+ await session.commit()
+ except Exception as e:
+ await session.rollback()
+ raise e
+ finally:
+ await session.close()
+
+
+@asynccontextmanager
+async def build_processor_services(
+ engine: AsyncEngine,
+) -> AsyncGenerator[ProcessorServices, None]:
+ async_client = await get_supabase_async_client()
+ storage = SupabaseS3Storage(async_client)
+ try:
+ async with _start_session(engine) as session:
+ vector_repository = VectorRepository(session)
+ vector_service = VectorService(vector_repository)
+ knowledge_repository = KnowledgeRepository(session)
+ knowledge_service = KnowledgeService(knowledge_repository, storage=storage)
+ sync_repository = SyncsRepository(session)
+ sync_service = SyncsService(sync_repository)
+ yield ProcessorServices(
+ knowledge_service=knowledge_service,
+ vector_service=vector_service,
+ sync_service=sync_service,
+ syncprovider_mapping=build_syncprovider_mapping(),
+ )
+ finally:
+ logger.info("Closing processor services")
diff --git a/backend/worker/quivr_worker/utils/utils.py b/backend/worker/quivr_worker/utils/utils.py
index 75978b27cd9c..80dfb034a917 100644
--- a/backend/worker/quivr_worker/utils/utils.py
+++ b/backend/worker/quivr_worker/utils/utils.py
@@ -1,16 +1,6 @@
-import os
import uuid
from json import JSONEncoder
from pathlib import PosixPath
-from typing import Tuple
-
-
-def get_tmp_name(file_name: str) -> Tuple[str, str, str]:
- # Filepath is S3 based
- tmp_name = file_name.replace("/", "_")
- base_file_name = os.path.basename(file_name)
- _, file_extension = os.path.splitext(base_file_name)
- return tmp_name, base_file_name, file_extension
# TODO: This is a hack for making uuid work with supabase clients
diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py
index 7d9828a365d0..d5f386f8d444 100644
--- a/backend/worker/tests/conftest.py
+++ b/backend/worker/tests/conftest.py
@@ -1,39 +1,615 @@
import os
+from datetime import datetime, timedelta, timezone
+from io import BytesIO
+from pathlib import Path
from uuid import uuid4
import pytest
-from quivr_worker.files import File
+import pytest_asyncio
+import sqlalchemy
+from fastapi import UploadFile
+from langchain_core.embeddings import DeterministicFakeEmbedding
+from quivr_api.models.settings import settings
+from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
+from quivr_api.modules.brain.entity.brain_user import BrainUserDB
+from quivr_api.modules.dependencies import get_supabase_client
+from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource
+from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository
+from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService
+from quivr_api.modules.knowledge.tests.conftest import FakeStorage
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import Sync
+from quivr_api.modules.sync.repository.sync_repository import SyncsRepository
+from quivr_api.modules.sync.service.sync_service import SyncsService
+from quivr_api.modules.sync.tests.test_sync_controller import FakeSync
+from quivr_api.modules.sync.utils.sync import BaseSync
+from quivr_api.modules.user.entity.user_identity import User
+from quivr_api.modules.vector.entity.vector import Vector
+from quivr_api.modules.vector.repository.vectors_repository import VectorRepository
+from quivr_api.modules.vector.service.vector_service import VectorService
+from quivr_core.files.file import QuivrFile
+from quivr_core.models import KnowledgeStatus
+from quivr_worker.utils.services import ProcessorServices
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlmodel import select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
+
+
+async_engine = create_async_engine(
+ "postgresql+asyncpg://" + pg_database_base_url,
+ echo=True if os.getenv("ORM_DEBUG") else False,
+ future=True,
+)
+
+
+@pytest_asyncio.fixture(scope="function")
+async def session():
+ async with async_engine.connect() as conn:
+ trans = await conn.begin()
+ nested = await conn.begin_nested()
+ async_session = AsyncSession(
+ conn,
+ expire_on_commit=False,
+ autoflush=False,
+ autocommit=False,
+ )
+
+ @sqlalchemy.event.listens_for(
+ async_session.sync_session, "after_transaction_end"
+ )
+ def end_savepoint(session, transaction):
+ nonlocal nested
+ if not nested.is_active:
+ nested = conn.sync_connection.begin_nested() # type: ignore
+
+ yield async_session
+ await trans.rollback()
+ await async_session.close()
+
+
+@pytest.fixture(scope="session")
+def supabase_client():
+ return get_supabase_client()
+
+
+@pytest_asyncio.fixture(scope="function")
+async def user(session: AsyncSession) -> User:
+ user_1 = (
+ await session.exec(select(User).where(User.email == "admin@quivr.app"))
+ ).one()
+ assert user_1.id
+ return user_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain_user(session, user: User) -> Brain:
+ assert user.id
+ brain_1 = Brain(
+ name="test_brain",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain_1)
+ await session.commit()
+ await session.refresh(brain_1)
+ assert brain_1.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain_1
+
+
+@pytest_asyncio.fixture(scope="function")
+async def brain_user2(session, user: User) -> Brain:
+ assert user.id
+ brain = Brain(
+ name="test_brain2",
+ description="this is a test brain",
+ brain_type=BrainType.integration,
+ )
+ session.add(brain)
+ await session.commit()
+ await session.refresh(brain)
+ assert brain.brain_id
+ brain_user = BrainUserDB(
+ brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner"
+ )
+ session.add(brain_user)
+ await session.commit()
+ return brain
+
+
+# NOTE: param sets the number of sync file the provider returns
+@pytest_asyncio.fixture(scope="function")
+async def proc_services(session: AsyncSession, request) -> ProcessorServices:
+ n_get_files = getattr(request, "param", 0)
+
+ storage = FakeStorage()
+ embedder = DeterministicFakeEmbedding(size=settings.embedding_dim)
+ vector_repository = VectorRepository(session)
+ vector_service = VectorService(vector_repository, embedder=embedder)
+ knowledge_repository = KnowledgeRepository(session)
+ knowledge_service = KnowledgeService(knowledge_repository, storage=storage)
+ sync_provider_mapping: dict[SyncProvider, BaseSync] = {
+ provider: FakeSync(provider_name=str(provider), n_get_files=n_get_files)
+ for provider in list(SyncProvider)
+ }
+ sync_repository = SyncsRepository(
+ session, sync_provider_mapping=sync_provider_mapping
+ )
+ sync_service = SyncsService(sync_repository)
+
+ return ProcessorServices(
+ knowledge_service=knowledge_service,
+ vector_service=vector_service,
+ sync_service=sync_service,
+ syncprovider_mapping=sync_provider_mapping,
+ )
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync(session: AsyncSession, user: User) -> Sync:
+ assert user.id
+ sync = Sync(
+ name="test_sync",
+ email="test@test.com",
+ user_id=user.id,
+ credentials={"test": "test"},
+ provider=SyncProvider.GOOGLE,
+ )
+
+ session.add(sync)
+ await session.commit()
+ await session.refresh(sync)
+ return sync
+
+
+@pytest_asyncio.fixture(scope="function")
+async def local_knowledge_folder(
+ proc_services: ProcessorServices, user: User, brain_user: Brain
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ service = proc_services.knowledge_service
+ km_to_add = AddKnowledge(
+ file_name="test",
+ source="local",
+ is_folder=True,
+ parent_id=None,
+ )
+ km = await service.create_knowledge(
+ user_id=user.id, knowledge_to_add=km_to_add, upload_file=None
+ )
+ # Link it to the brain
+ await service.link_knowledge_tree_brains(
+ km, brains_ids=[brain_user.brain_id], user_id=user.id
+ )
+ km = await service.update_knowledge(
+ knowledge=km,
+ payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING),
+ )
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def local_knowledge_folder_with_file(
+ proc_services: ProcessorServices, user: User, brain_user: Brain
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ service = proc_services.knowledge_service
+ km_to_add = AddKnowledge(
+ file_name="test",
+ source="local",
+ is_folder=True,
+ parent_id=None,
+ )
+ folder_km = await service.create_knowledge(
+ user_id=user.id, knowledge_to_add=km_to_add, upload_file=None
+ )
+ km_to_add = AddKnowledge(
+ file_name="test_file",
+ source=KnowledgeSource.LOCAL,
+ is_folder=False,
+ parent_id=folder_km.id,
+ )
+ km_data = BytesIO(os.urandom(24))
+ _ = await service.create_knowledge(
+ user_id=user.id,
+ knowledge_to_add=km_to_add,
+ upload_file=UploadFile(file=km_data, size=24, filename=km_to_add.file_name),
+ )
+ # Link it to the brain
+ await service.link_knowledge_tree_brains(
+ folder_km, brains_ids=[brain_user.brain_id], user_id=user.id
+ )
+ await service.update_knowledge(
+ knowledge=folder_km,
+ payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING),
+ )
+ return folder_km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def local_knowledge_file(
+ proc_services: ProcessorServices, user: User, brain_user: Brain
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ service = proc_services.knowledge_service
+ km_to_add = AddKnowledge(
+ file_name="test",
+ source="local",
+ is_folder=False,
+ parent_id=None,
+ )
+ km_data = BytesIO(os.urandom(24))
+ km = await service.create_knowledge(
+ user_id=user.id,
+ knowledge_to_add=km_to_add,
+ upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name),
+ )
+ # Link it to the brain
+ await service.link_knowledge_tree_brains(
+ km, brains_ids=[brain_user.brain_id], user_id=user.id
+ )
+ km = await service.update_knowledge(
+ knowledge=km,
+ payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING),
+ )
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_file(
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ user: User,
+ brain_user: Brain,
+ sync: Sync,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+
+ km = KnowledgeDB(
+ file_name="test_file_1.txt",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSING,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/test",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ sync_file_id="id1",
+ sync=sync,
+ last_synced_at=datetime.now(timezone.utc) - timedelta(days=2),
+ )
+
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ return km
+
+
+@pytest.fixture(scope="module")
+def embedder():
+ return DeterministicFakeEmbedding(size=settings.embedding_dim)
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_file_processed(
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ user: User,
+ brain_user: Brain,
+ sync: Sync,
+ embedder: DeterministicFakeEmbedding,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+
+ km = KnowledgeDB(
+ file_name="test_file_1.txt",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/test",
+ file_size=1233,
+ file_sha1="1234kj",
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ sync_file_id="id1",
+ sync=sync,
+ last_synced_at=datetime.now(timezone.utc) - timedelta(days=2),
+ )
+
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ assert km.id
+
+ vec = Vector(
+ content="test",
+ metadata_={},
+ embedding=embedder.embed_query("test"), # type: ignore
+ knowledge_id=km.id,
+ )
+ session.add(vec)
+ await session.commit()
+
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_folder(
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ user: User,
+ brain_user: Brain,
+ sync: Sync,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+
+ km = KnowledgeDB(
+ file_name="folder1",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSING,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/folder1",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ is_folder=True,
+ sync_file_id="id1",
+ sync=sync,
+ )
+
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_folder_with_file_in_other_brain(
+ session: AsyncSession,
+ user: User,
+ brain_user: Brain,
+ brain_user2: Brain,
+ sync: Sync,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ file = KnowledgeDB(
+ file_name="file",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/file1",
+ file_size=10,
+ file_sha1="test",
+ user_id=user.id,
+ brains=[brain_user2],
+ parent=None,
+ is_folder=False,
+ # NOTE: See FakeSync Implementation
+ sync_file_id="file-0",
+ sync=sync,
+ )
+
+ km = KnowledgeDB(
+ file_name="folder1",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSING,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/folder1",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ is_folder=True,
+ sync_file_id="id1",
+ sync=sync,
+ )
+
+ session.add(file)
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_folder_with_file_in_brain(
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ user: User,
+ brain_user: Brain,
+ sync: Sync,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ file = KnowledgeDB(
+ file_name="file",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/file1",
+ file_size=10,
+ file_sha1="test",
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ is_folder=False,
+ # NOTE: See FakeSync Implementation
+ sync_file_id="file-0",
+ sync=sync,
+ )
+
+ km = KnowledgeDB(
+ file_name="folder1",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSING,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/folder1",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ is_folder=True,
+ sync_file_id="id1",
+ sync=sync,
+ )
+
+ session.add(file)
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ return km
+
+
+@pytest_asyncio.fixture(scope="function")
+async def web_knowledge(
+ session: AsyncSession,
+ user: User,
+ brain_user: Brain,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+
+ km = KnowledgeDB(
+ file_name=None,
+ url="www.quivr.app",
+ extension=".html",
+ status=KnowledgeStatus.PROCESSING,
+ source=KnowledgeSource.WEB,
+ source_link="www.quivr.app",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ is_folder=False,
+ )
+
+ session.add(km)
+ await session.commit()
+ await session.refresh(km)
+
+ return km
@pytest.fixture
-def file_instance(tmp_path) -> File:
+def qfile_instance(tmp_path) -> QuivrFile:
data = "This is some test data."
temp_file = tmp_path / "data.txt"
temp_file.write_text(data)
knowledge_id = uuid4()
- return File(
- knowledge_id=knowledge_id,
+ return QuivrFile(
+ id=knowledge_id,
file_sha1="124",
file_extension=".txt",
- file_name=temp_file.name,
- original_file_name=temp_file.name,
+ original_filename=temp_file.name,
+ path=temp_file.absolute(),
file_size=len(data),
- tmp_file_path=temp_file.absolute(),
)
@pytest.fixture
-def audio_file(tmp_path) -> File:
+def audio_qfile(tmp_path) -> QuivrFile:
data = os.urandom(128)
temp_file = tmp_path / "data.mp4"
temp_file.write_bytes(data)
knowledge_id = uuid4()
- return File(
- knowledge_id=knowledge_id,
+ return QuivrFile(
+ id=knowledge_id,
file_sha1="124",
file_extension=".mp4",
- file_name=temp_file.name,
- original_file_name="data.mp4",
+ original_filename="data.mp4",
+ path=temp_file.absolute(),
file_size=len(data),
- tmp_file_path=temp_file.absolute(),
)
+
+
+@pytest.fixture
+def pdf_qfile(tmp_path) -> QuivrFile:
+ data = "This is some test data."
+ temp_file = tmp_path / "data.txt"
+ temp_file.write_text(data)
+ return QuivrFile(
+ id=uuid4(),
+ file_extension=".pdf",
+ original_filename="sample.pdf",
+ file_sha1="124",
+ file_size=1000,
+ path=Path("./tests/sample.pdf"),
+ )
+
+
+@pytest_asyncio.fixture(scope="function")
+async def sync_knowledge_folder_processed(
+ session: AsyncSession,
+ user: User,
+ brain_user: Brain,
+ sync: Sync,
+) -> KnowledgeDB:
+ assert user.id
+ assert brain_user.brain_id
+ folder = KnowledgeDB(
+ file_name="folder",
+ extension="",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/file1",
+ file_size=10,
+ file_sha1="test",
+ user_id=user.id,
+ brains=[brain_user],
+ parent=None,
+ is_folder=True,
+ # NOTE: See FakeSync Implementation
+ sync_file_id="folder-1",
+ sync=sync,
+ last_synced_at=datetime.now(timezone.utc) - timedelta(days=2),
+ )
+
+ km = KnowledgeDB(
+ file_name="file",
+ extension=".txt",
+ status=KnowledgeStatus.PROCESSED,
+ source=SyncProvider.GOOGLE,
+ source_link="drive://test/folder1",
+ file_size=0,
+ file_sha1=None,
+ user_id=user.id,
+ brains=[brain_user],
+ parent=folder,
+ is_folder=False,
+ sync_file_id="file-1",
+ sync=sync,
+ last_synced_at=datetime.now(timezone.utc) - timedelta(days=2),
+ )
+
+ session.add(folder)
+ session.add(km)
+ await session.commit()
+ await session.refresh(folder)
+
+ return folder
diff --git a/backend/worker/tests/test_process_file.py b/backend/worker/tests/test_process_file.py
index 9ca8a49b7153..2e41acd99fa8 100644
--- a/backend/worker/tests/test_process_file.py
+++ b/backend/worker/tests/test_process_file.py
@@ -1,48 +1,58 @@
-import datetime
-import os
-from pathlib import Path
+from random import randbytes
from uuid import uuid4
import pytest
-from quivr_api.modules.brain.entity.brain_entity import BrainEntity, BrainType
-from quivr_core.files.file import FileExtension
-from quivr_worker.files import File, build_file
-from quivr_worker.parsers.crawler import URL, slugify
-from quivr_worker.process.process_file import parse_file
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_worker.parsers.crawler import slugify
+from quivr_worker.process.process_file import parse_qfile
+from quivr_worker.process.utils import build_qfile
-def test_build_file():
- random_bytes = os.urandom(128)
- brain_id = uuid4()
- file_name = f"{brain_id}/test_file.txt"
- knowledge_id = uuid4()
+def test_build_qfile_fail(local_knowledge_file: KnowledgeDB):
+ random_bytes = randbytes(128)
+ local_knowledge_file.file_sha1 = None
+ with pytest.raises(AssertionError):
+ with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _:
+ pass
+ local_knowledge_file.file_sha1 = "sha1"
+ local_knowledge_file.id = None
+ with pytest.raises(AssertionError):
+ with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _:
+ pass
- with build_file(random_bytes, knowledge_id, file_name) as file:
+ local_knowledge_file.id = uuid4()
+
+
+def test_build_qfile_web(web_knowledge: KnowledgeDB):
+ random_bytes = randbytes(128)
+ web_knowledge.file_sha1 = "sha1"
+
+ with build_qfile(knowledge=web_knowledge, file_data=random_bytes) as file:
+ assert file.id == web_knowledge.id
+ assert file.file_size == 128
+ assert file.original_filename == slugify(web_knowledge.url) + ".txt"
+ assert file.file_extension == ".txt"
+ if web_knowledge.metadata_:
+ assert web_knowledge.metadata_.items() <= file.metadata.items()
+ assert file.brain_id is None
+
+
+def test_build_qfile(local_knowledge_file: KnowledgeDB):
+ random_bytes = randbytes(128)
+ local_knowledge_file.file_sha1 = "sha1"
+
+ with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as file:
+ assert file.id == local_knowledge_file.id
assert file.file_size == 128
- assert file.file_name == "test_file.txt"
- assert file.id == knowledge_id
- assert file.file_extension == FileExtension.txt
-
-
-def test_build_url():
- random_bytes = os.urandom(128)
- crawl_website = URL(url="http://url.url")
- file_name = slugify(crawl_website.url) + ".txt"
- knowledge_id = uuid4()
-
- with build_file(
- random_bytes,
- knowledge_id,
- file_name=file_name,
- original_file_name=crawl_website.url,
- ) as file:
- qfile = file.to_qfile(brain_id=uuid4())
- assert qfile.metadata["original_file_name"] == crawl_website.url
- assert qfile.metadata["file_name"] == file_name
-
-
-@pytest.mark.asyncio
-async def test_parse_audio(monkeypatch, audio_file):
+ assert file.original_filename == local_knowledge_file.file_name
+ assert file.file_extension == local_knowledge_file.extension
+ if local_knowledge_file.metadata_:
+ assert local_knowledge_file.metadata_.items() <= file.metadata.items()
+ assert file.brain_id is None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+async def test_parse_audio(monkeypatch, audio_qfile):
from openai.resources.audio.transcriptions import Transcriptions
from openai.types.audio.transcription import Transcription
@@ -50,56 +60,25 @@ def transcribe(*args, **kwargs):
return Transcription(text="audio data")
monkeypatch.setattr(Transcriptions, "create", transcribe)
- brain = BrainEntity(
- brain_id=uuid4(),
- name="test",
- brain_type=BrainType.doc,
- last_update=datetime.datetime.now(),
- )
- chunks = await parse_file(
- file=audio_file,
- brain=brain,
+ chunks = await parse_qfile(
+ qfile=audio_qfile,
)
assert len(chunks) > 0
assert chunks[0].page_content == "audio data"
-@pytest.mark.asyncio
-async def test_parse_file(file_instance):
- brain = BrainEntity(
- brain_id=uuid4(),
- name="test",
- brain_type=BrainType.doc,
- last_update=datetime.datetime.now(),
- )
- chunks = await parse_file(
- file=file_instance,
- brain=brain,
+@pytest.mark.asyncio(loop_scope="session")
+async def test_parse_file(qfile_instance):
+ chunks = await parse_qfile(
+ qfile=qfile_instance,
)
assert len(chunks) > 0
-@pytest.mark.asyncio
-async def test_parse_file_pdf():
- file_instance = File(
- knowledge_id=uuid4(),
- file_sha1="124",
- file_extension=".pdf",
- file_name="test",
- original_file_name="test",
- file_size=1000,
- tmp_file_path=Path("./tests/sample.pdf"),
+@pytest.mark.asyncio(loop_scope="session")
+async def test_parse_file_pdf(pdf_qfile):
+ chunks = await parse_qfile(
+ qfile=pdf_qfile,
)
- brain = BrainEntity(
- brain_id=uuid4(),
- name="test",
- brain_type=BrainType.doc,
- last_update=datetime.datetime.now(),
- )
- chunks = await parse_file(
- file=file_instance,
- brain=brain,
- )
-
assert len(chunks[0].page_content) > 0
assert len(chunks) > 0
diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py
index e69de29bb2d1..d98050a0a499 100644
--- a/backend/worker/tests/test_process_file_task.py
+++ b/backend/worker/tests/test_process_file_task.py
@@ -0,0 +1,349 @@
+from typing import Any
+
+import pytest
+from langchain_core.documents import Document
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_api.modules.vector.entity.vector import Vector
+from quivr_core.files.file import QuivrFile
+from quivr_core.models import KnowledgeStatus
+from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices
+from sqlmodel import col, select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+async def _parse_file_mock(
+ qfile: QuivrFile,
+ **processor_kwargs: dict[str, Any],
+) -> list[Document]:
+ with open(qfile.path, "rb") as f:
+ return [Document(page_content=str(f.read()), metadata={})]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_local_file(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ local_knowledge_file: KnowledgeDB,
+):
+ input_km = local_knowledge_file
+
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ assert input_km.id
+ assert input_km.brains
+ km_processor = KnowledgeProcessor(proc_services)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0].brain_id == input_km.brains[0].brain_id
+ assert km.file_sha1 is not None
+
+ # Check vectors where added
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) > 0
+ assert vecs[0].metadata_ is not None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_local_folder(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ local_knowledge_folder: KnowledgeDB,
+):
+ input_km = local_knowledge_folder
+
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ assert input_km.id
+ assert input_km.brains
+ km_processor = KnowledgeProcessor(proc_services)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0].brain_id == input_km.brains[0].brain_id
+ assert km.file_sha1 is None
+
+ # Check vectors where added
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) == 0
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_local_folder_with_file(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ local_knowledge_folder_with_file: KnowledgeDB,
+):
+ input_km = local_knowledge_folder_with_file
+
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ assert input_km.id
+ assert input_km.brains
+ km_processor = KnowledgeProcessor(proc_services)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0].brain_id == input_km.brains[0].brain_id
+ assert km.file_sha1 is None
+
+ # Check vectors where added
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) == 0
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_web_file(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ web_knowledge: KnowledgeDB,
+):
+ input_km = web_knowledge
+
+ async def _extract_url(url: str) -> str:
+ return "quivr has the best rag"
+
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ monkeypatch.setattr("quivr_worker.process.processor.extract_from_url", _extract_url)
+ assert input_km.id
+ assert input_km.brains
+ km_processor = KnowledgeProcessor(proc_services)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0].brain_id == input_km.brains[0].brain_id
+ assert km.file_sha1 is not None
+
+ # Check vectors where added
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) > 0
+ assert vecs[0].metadata_ is not None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_sync_file(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_file: KnowledgeDB,
+):
+ input_km = sync_knowledge_file
+ assert input_km.id
+ assert input_km.brains
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge set to processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0].brain_id == input_km.brains[0].brain_id
+ assert km.file_sha1 is not None
+ assert km.last_synced_at is not None
+
+ # Check vectors where added
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) > 0
+ assert vecs[0].metadata_ is not None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [4], indirect=True)
+async def test_process_sync_folder(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_folder: KnowledgeDB,
+):
+ input_km = sync_knowledge_folder
+ assert input_km.id
+ assert input_km.brains
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge set to processed
+ assert input_km.id
+ assert input_km.brains
+ assert input_km.brains[0]
+ knowledge_service = km_processor.services.knowledge_service
+ # FIXME (@AmineDiro): brain dto!!
+ kms = await knowledge_service.get_all_knowledge_in_brain(
+ input_km.brains[0].brain_id
+ )
+
+ # NOTE : this knowledge + 2 remote sync files
+ assert len(kms) == 5
+ for km in kms:
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert km.brains[0]["brain_id"]
+ assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id
+ assert km.file_sha1 is not None
+ assert km.last_synced_at is not None
+
+ # Check vectors where added
+ vecs = list((await session.exec(select(Vector))).all())
+ # Fake sync return a folder half the time, we skip folders
+ assert len(vecs) >= 2
+ assert vecs[0].metadata_ is not None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [1], indirect=True)
+async def test_process_sync_folder_with_file_in_brain(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_folder_with_file_in_brain: KnowledgeDB,
+):
+ input_km = sync_knowledge_folder_with_file_in_brain
+ assert input_km.id
+ assert input_km.brains
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge set to processed
+ assert input_km.id
+ assert input_km.brains
+ assert input_km.brains[0]
+ knowledge_service = km_processor.services.knowledge_service
+ # FIXME (@AmineDiro): brain dto!!
+ kms = await knowledge_service.get_all_knowledge_in_brain(
+ input_km.brains[0].brain_id
+ )
+
+ # NOTE : this knowledge + 2 remote sync files
+ assert len(kms) == 2
+ for km in kms:
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert len(km.brains) == 1, "File added to the same brain multiple times"
+ assert km.brains[0]["brain_id"]
+ assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id
+
+ # Check vectors
+ vecs = list((await session.exec(select(Vector))).all())
+ assert len(vecs) == 0, "File reprocessed, or folder processed "
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [1], indirect=True)
+async def test_process_sync_folder_with_file_in_other_brain(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_folder_with_file_in_other_brain: KnowledgeDB,
+):
+ input_km = sync_knowledge_folder_with_file_in_other_brain
+ assert input_km.id
+ assert input_km.brains
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge set to processed
+ assert input_km.id
+ assert input_km.brains
+ assert input_km.brains[0]
+ knowledge_service = km_processor.services.knowledge_service
+ # FIXME (@AmineDiro): brain dto!!
+ kms = await knowledge_service.get_all_knowledge_in_brain(
+ input_km.brains[0].brain_id
+ )
+
+ assert len(kms) == 2
+ for km in kms:
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert len(km.brains) >= 1, "File added to the same brain multiple times"
+ assert km.brains[0]["brain_id"]
+ assert input_km.brains[0].brain_id in {b["brain_id"] for b in km.brains}
+ if len(km.brains) > 1:
+ assert len({b["brain_id"] for b in km.brains}) == 2
+
+ # Check vectors
+ vecs = list((await session.exec(select(Vector))).all())
+ assert len(vecs) == 0, "File reprocessed, or folder processed "
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_process_km_rollback(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ local_knowledge_file: KnowledgeDB,
+):
+ input_km = local_knowledge_file
+ assert input_km.id
+ assert input_km.brains
+
+ async def _store_chunks_error(*args, **kwargs):
+ raise Exception("mock error")
+
+ monkeypatch.setattr(
+ "quivr_worker.process.processor.store_chunks", _store_chunks_error
+ )
+
+ km_processor = KnowledgeProcessor(proc_services)
+ await km_processor.process_knowledge(input_km.id)
+
+ # Check knowledge set to processed
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km.id)
+ assert km.status == KnowledgeStatus.ERROR
+ vecs = list((await session.exec(select(Vector))).all())
+ # Check we remove the vectors
+ assert len(vecs) == 0
diff --git a/backend/worker/tests/test_process_url_task.py b/backend/worker/tests/test_process_url_task.py
deleted file mode 100644
index a34501b52001..000000000000
--- a/backend/worker/tests/test_process_url_task.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import asyncio
-import os
-from typing import List, Tuple
-from uuid import uuid4
-
-import pytest
-import pytest_asyncio
-import sqlalchemy
-from quivr_api.celery_config import celery
-from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType
-from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
-from quivr_api.modules.user.entity.user_identity import User
-from quivr_worker.parsers.crawler import URL, extract_from_url
-from sqlalchemy.ext.asyncio import create_async_engine
-from sqlmodel import select
-from sqlmodel.ext.asyncio.session import AsyncSession
-
-pg_database_base_url = "postgres:postgres@localhost:54322/postgres"
-
-async_engine = create_async_engine(
- "postgresql+asyncpg://" + pg_database_base_url,
- echo=True if os.getenv("ORM_DEBUG") else False,
- future=True,
- pool_pre_ping=True,
- pool_size=10,
- pool_recycle=0.1,
-)
-
-
-TestData = Tuple[Brain, List[KnowledgeDB]]
-
-
-@pytest_asyncio.fixture(scope="function")
-async def session():
- print("\nSESSION_EVEN_LOOP", id(asyncio.get_event_loop()))
- async with async_engine.connect() as conn:
- trans = await conn.begin()
- nested = await conn.begin_nested()
- async_session = AsyncSession(
- conn,
- expire_on_commit=False,
- autoflush=False,
- autocommit=False,
- )
-
- @sqlalchemy.event.listens_for(
- async_session.sync_session, "after_transaction_end"
- )
- def end_savepoint(session, transaction):
- nonlocal nested
- if not nested.is_active:
- nested = conn.sync_connection.begin_nested()
-
- yield async_session
- await trans.rollback()
- await async_session.close()
-
-
-@pytest_asyncio.fixture()
-async def test_data(session: AsyncSession) -> TestData:
- user_1 = (
- await session.exec(select(User).where(User.email == "admin@quivr.app"))
- ).one()
- assert user_1.id
- # Brain data
- brain_1 = Brain(
- name="test_brain",
- description="this is a test brain",
- brain_type=BrainType.integration,
- )
-
- knowledge_brain_1 = KnowledgeDB(
- file_name="test_file_1",
- extension="txt",
- status="UPLOADED",
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1="test_sha1",
- brains=[brain_1],
- user_id=user_1.id,
- )
-
- knowledge_brain_2 = KnowledgeDB(
- file_name="test_file_2",
- extension="txt",
- status="UPLOADED",
- source="test_source",
- source_link="test_source_link",
- file_size=100,
- file_sha1="test_sha2",
- brains=[],
- user_id=user_1.id,
- )
-
- session.add(brain_1)
- session.add(knowledge_brain_1)
- session.add(knowledge_brain_2)
- await session.commit()
- return brain_1, [knowledge_brain_1, knowledge_brain_2]
-
-
-@pytest.mark.skip
-def test_crawl():
- url = "https://en.wikipedia.org/wiki/Python_(programming_language)"
- crawl_website = URL(url=url)
- extracted_content = extract_from_url(crawl_website)
-
- assert len(extracted_content) > 1
-
-
-@pytest.mark.skip
-def test_process_crawl_task(test_data: TestData):
- brain, [knowledge, _] = test_data
- url = "https://en.wikipedia.org/wiki/Python_(programming_language)"
- task = celery.send_task(
- "process_crawl_task",
- kwargs={
- "crawl_website_url": url,
- "brain_id": brain.brain_id,
- "knowledge_id": knowledge.id,
- "notification_id": uuid4(),
- },
- )
- result = task.wait() # noqa: F841
diff --git a/backend/worker/tests/test_update_syncs.py b/backend/worker/tests/test_update_syncs.py
new file mode 100644
index 000000000000..574ab30fb8e7
--- /dev/null
+++ b/backend/worker/tests/test_update_syncs.py
@@ -0,0 +1,220 @@
+import os
+from datetime import datetime, timedelta, timezone
+from io import BytesIO
+from typing import Any, Dict, List, Union
+
+import pytest
+from langchain_core.documents import Document
+from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB
+from quivr_api.modules.sync.dto.outputs import SyncProvider
+from quivr_api.modules.sync.entity.sync_models import SyncFile
+from quivr_api.modules.sync.tests.test_sync_controller import FakeSync
+from quivr_api.modules.sync.utils.sync import BaseSync
+from quivr_api.modules.vector.entity.vector import Vector
+from quivr_core.files.file import QuivrFile
+from quivr_core.models import KnowledgeStatus
+from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices
+from sqlmodel import col, select
+from sqlmodel.ext.asyncio.session import AsyncSession
+
+
+async def _parse_file_mock(
+ qfile: QuivrFile,
+ **processor_kwargs: dict[str, Any],
+) -> list[Document]:
+ with open(qfile.path, "rb") as f:
+ return [Document(page_content=str(f.read()), metadata={})]
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_refresh_sync_file(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_file: KnowledgeDB,
+):
+ input_km = sync_knowledge_file
+ assert input_km.id
+ assert input_km.brains
+ assert input_km.sync_file_id
+ assert input_km.file_name
+ assert input_km.source_link
+ assert input_km.last_synced_at
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ new_sync_file = SyncFile(
+ id=input_km.sync_file_id,
+ name=input_km.file_name,
+ extension=input_km.extension,
+ is_folder=False,
+ web_view_link=input_km.source_link,
+ last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1),
+ )
+ sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0)
+ new_km = await km_processor.refresh_knowledge_entry(
+ old_km=sync_knowledge_file,
+ new_sync_file=new_sync_file,
+ sync_provider=sync_provider,
+ sync_credentials={},
+ )
+
+ # Check knowledge was updated
+ assert new_km
+ assert new_km.id
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(new_km.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert {b.brain_id for b in km.brains} == {b.brain_id for b in input_km.brains}
+ assert km.parent_id == input_km.parent_id
+ assert km.file_sha1 is not None
+ assert km.last_synced_at
+ assert km.last_synced_at > input_km.last_synced_at
+
+ # Check vectors where removed
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ assert len(vecs) == 0
+
+ # Check vectors where added for the new km
+ vecs = list(
+ (
+ await session.exec(select(Vector).where(col(Vector.knowledge_id) == km.id))
+ ).all()
+ )
+ assert len(vecs) > 0
+ assert vecs[0].metadata_ is not None
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [0], indirect=True)
+async def test_refresh_sync_file_rollback(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_file_processed: KnowledgeDB,
+):
+ input_km = sync_knowledge_file_processed
+ assert input_km.id
+ assert input_km.brains
+ assert input_km.sync_file_id
+ assert input_km.file_name
+ assert input_km.source_link
+ assert input_km.last_synced_at
+
+ async def _parse_file_mock_error(
+ qfile: QuivrFile,
+ **processor_kwargs: dict[str, Any],
+ ) -> list[Document]:
+ raise Exception("error")
+
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr(
+ "quivr_worker.process.processor.parse_qfile", _parse_file_mock_error
+ )
+ new_sync_file = SyncFile(
+ id=input_km.sync_file_id,
+ name=input_km.file_name,
+ extension=input_km.extension,
+ is_folder=False,
+ web_view_link=input_km.source_link,
+ last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1),
+ )
+ sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0)
+ new_km = await km_processor.refresh_knowledge_entry(
+ old_km=input_km,
+ new_sync_file=new_sync_file,
+ sync_provider=sync_provider,
+ sync_credentials={},
+ )
+
+ # Check knowledge was not removed
+ assert new_km is None
+
+ # Check vectors where not removed
+ vecs = list(
+ (
+ await session.exec(
+ select(Vector).where(col(Vector.knowledge_id) == input_km.id)
+ )
+ ).all()
+ )
+ # Check nothing was added
+ assert len(vecs) == 1
+
+ # Check kms statyed correct
+ all_kms = list((await session.exec(select(KnowledgeDB))).unique().all())
+ assert len(all_kms) == 1
+ assert all_kms[0].id == input_km.id
+ assert all_kms[0].last_synced_at == input_km.last_synced_at
+
+
+@pytest.mark.asyncio(loop_scope="session")
+@pytest.mark.parametrize("proc_services", [2], indirect=True)
+async def test_refresh_sync_folder(
+ monkeypatch,
+ session: AsyncSession,
+ proc_services: ProcessorServices,
+ sync_knowledge_folder_processed: KnowledgeDB,
+):
+ input_km_folder = sync_knowledge_folder_processed
+ assert input_km_folder.id
+ assert input_km_folder.brains
+ assert input_km_folder.sync_file_id
+ assert input_km_folder.file_name
+ assert input_km_folder.source_link
+ assert input_km_folder.last_synced_at
+
+ class _MockSync:
+ name = "FakeProvider"
+ lower_name = "google"
+ datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ"
+
+ async def aget_files(
+ self, credentials: Dict, file_ids: List[str]
+ ) -> List[SyncFile]:
+ return self.get_files(credentials, file_ids)
+
+ def get_files(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]:
+ return [
+ SyncFile(
+ id="file_id_1",
+ name="new_file",
+ extension=".txt",
+ web_view_link="fake://test.com",
+ is_folder=False,
+ last_modified_at=datetime.now(),
+ )
+ ]
+
+ async def adownload_file(
+ self, credentials: Dict, file: SyncFile
+ ) -> Dict[str, Union[str, BytesIO]]:
+ return {"content": str(os.urandom(24))}
+
+ sync_provider_mapping: dict[SyncProvider, BaseSync] = {
+ provider: _MockSync() # type: ignore
+ for provider in list(SyncProvider)
+ }
+
+ input_km_children = await input_km_folder.awaitable_attrs.children
+ proc_services.syncprovider_mapping = sync_provider_mapping
+ km_processor = KnowledgeProcessor(proc_services)
+ monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock)
+ await km_processor.refresh_sync_folder(folder_km=input_km_folder)
+
+ # Check knowledge was updated
+ assert input_km_folder
+ assert input_km_folder.id
+ knowledge_service = km_processor.services.knowledge_service
+ km = await knowledge_service.get_knowledge(input_km_folder.id)
+ assert km.status == KnowledgeStatus.PROCESSED
+ assert {k.id for k in await km.awaitable_attrs.children}.issuperset(
+ {k.id for k in input_km_children}
+ )
diff --git a/frontend/app/App.tsx b/frontend/app/App.tsx
index 401b3258aa1b..f5679effe066 100644
--- a/frontend/app/App.tsx
+++ b/frontend/app/App.tsx
@@ -10,11 +10,7 @@ import { HelpWindow } from "@/lib/components/HelpWindow/HelpWindow";
import { Menu } from "@/lib/components/Menu/Menu";
import { useOutsideClickListener } from "@/lib/components/Menu/hooks/useOutsideClickListener";
import { SearchModal } from "@/lib/components/SearchModal/SearchModal";
-import {
- BrainProvider,
- ChatProvider,
- KnowledgeToFeedProvider,
-} from "@/lib/context";
+import { BrainProvider, ChatProvider } from "@/lib/context";
import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext";
import { ChatsProvider } from "@/lib/context/ChatsProvider";
import { HelpProvider } from "@/lib/context/HelpProvider/help-provider";
@@ -32,7 +28,6 @@ import { usePageTracking } from "@/services/analytics/june/usePageTracking";
import "../lib/config/LocaleConfig/i18n";
import styles from "./App.module.scss";
-import { FromConnectionsProvider } from "./chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider";
if (
process.env.NEXT_PUBLIC_POSTHOG_KEY != null &&
@@ -106,23 +101,19 @@ const AppWithQueryClient = ({ children }: PropsWithChildren): JSX.Element => {
- {isUrl ? enhanceUrlDisplay(title) : removeFileExtension(title)} -
-{t("drop", { ns: "upload" })}
- )} -