diff --git a/src/db/models/chunk_embedding.py b/src/db/models/chunk_embedding.py index 617a94f..4621d4a 100644 --- a/src/db/models/chunk_embedding.py +++ b/src/db/models/chunk_embedding.py @@ -2,7 +2,7 @@ from datetime import datetime from pgvector.sqlalchemy import Vector -from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func +from sqlalchemy import DateTime, ForeignKey, Integer, Text, func from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -27,9 +27,12 @@ class ChunkEmbedding(Base): embedding: Mapped[list[float]] = mapped_column( Vector(settings.vector_dimensions), nullable=False ) - source_document: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) - document = relationship("Document", lazy="joined") + document = relationship( + "Document", + back_populates="chunk_embeddings", + lazy="joined", + ) diff --git a/src/db/models/document.py b/src/db/models/document.py index 9be7257..952dc1f 100644 --- a/src/db/models/document.py +++ b/src/db/models/document.py @@ -1,12 +1,18 @@ +from __future__ import annotations + import uuid from datetime import datetime +from typing import TYPE_CHECKING from sqlalchemy import DateTime, Integer, String, Text, func from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from src.db.base import Base +if TYPE_CHECKING: + from src.db.models.chunk_embedding import ChunkEmbedding + class Document(Base): __tablename__ = "documents" @@ -27,3 +33,9 @@ class Document(Base): created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) + chunk_embeddings: Mapped[list["ChunkEmbedding"]] = relationship( + "ChunkEmbedding", + back_populates="document", + cascade="all, delete-orphan", + passive_deletes=True, + ) diff --git a/src/routers/ask.py b/src/routers/ask.py index eccf2bf..80bae50 100644 --- a/src/routers/ask.py +++ b/src/routers/ask.py @@ -8,6 +8,12 @@ router = APIRouter(tags=["ask"]) +def get_question_answer_service( + db: AsyncSession = Depends(get_db), +) -> QuestionAnswerService: + return QuestionAnswerService(db) + + @router.post( "/ask", response_model=AskResponse, @@ -18,6 +24,6 @@ ) async def ask_question( payload: AskRequest, - db: AsyncSession = Depends(get_db), + service: QuestionAnswerService = Depends(get_question_answer_service), ) -> AskResponse: - return await QuestionAnswerService(db).ask(payload) + return await service.ask(payload) diff --git a/src/routers/document.py b/src/routers/document.py index 0790a77..69db5c0 100644 --- a/src/routers/document.py +++ b/src/routers/document.py @@ -6,25 +6,31 @@ from src.db.session import get_db from src.schemas.document import ( DocumentCreateResponse, - DocumentErrorResponse, DocumentListItem, + ErrorResponse, ) from src.services.document import DocumentService router = APIRouter(prefix="/documents", tags=["documents"]) +def get_document_service( + db: AsyncSession = Depends(get_db), +) -> DocumentService: + return DocumentService(db) + + @router.post( "", response_model=DocumentCreateResponse, status_code=status.HTTP_202_ACCEPTED, - responses={status.HTTP_400_BAD_REQUEST: {"model": DocumentErrorResponse}}, + responses={status.HTTP_400_BAD_REQUEST: {"model": ErrorResponse}}, ) async def upload_document( file: UploadFile = File(...), - db: AsyncSession = Depends(get_db), + service: DocumentService = Depends(get_document_service), ) -> DocumentCreateResponse: - document = await DocumentService(db).upload_document(file) + document = await service.upload_document(file) return DocumentCreateResponse( document_id=document.id, status=document.status ) @@ -32,19 +38,19 @@ async def upload_document( @router.get("", response_model=list[DocumentListItem]) async def list_documents( - db: AsyncSession = Depends(get_db), + service: DocumentService = Depends(get_document_service), ) -> list[DocumentListItem]: - return await DocumentService(db).list_documents() + return await service.list_documents() @router.delete( "/{document_id}", status_code=status.HTTP_204_NO_CONTENT, - responses={status.HTTP_404_NOT_FOUND: {"model": DocumentErrorResponse}}, + responses={status.HTTP_404_NOT_FOUND: {"model": ErrorResponse}}, ) async def delete_document( document_id: UUID, - db: AsyncSession = Depends(get_db), + service: DocumentService = Depends(get_document_service), ) -> Response: - await DocumentService(db).delete_document(document_id) + await service.delete_document(document_id) return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/src/schemas/ask.py b/src/schemas/ask.py index 83400b3..1b96614 100644 --- a/src/schemas/ask.py +++ b/src/schemas/ask.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, Field +from src.schemas.document import ErrorResponse as AskErrorResponse + class AskRequest(BaseModel): question: str = Field(min_length=3) @@ -16,7 +18,3 @@ class AskSource(BaseModel): class AskResponse(BaseModel): answer: str sources: list[AskSource] - - -class AskErrorResponse(BaseModel): - detail: str diff --git a/src/schemas/document.py b/src/schemas/document.py index f83f709..4786b3f 100644 --- a/src/schemas/document.py +++ b/src/schemas/document.py @@ -17,7 +17,7 @@ class DocumentCreateResponse(BaseModel): class DocumentListItem(BaseModel): - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True, populate_by_name=True) id: UUID name: str @@ -26,5 +26,5 @@ class DocumentListItem(BaseModel): created_at: datetime -class DocumentErrorResponse(BaseModel): +class ErrorResponse(BaseModel): detail: str diff --git a/src/services/ai.py b/src/services/ai.py index 2e5e74c..c92ed21 100644 --- a/src/services/ai.py +++ b/src/services/ai.py @@ -32,6 +32,16 @@ def _generate(self, messages, stop=None, run_manager=None, **kwargs): generations=[ChatGeneration(message=AIMessage(content=response))] ) + async def _agenerate( + self, messages, stop=None, run_manager=None, **kwargs + ): + return self._generate( + messages=messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + def get_embeddings() -> Embeddings: if settings.embedding_provider == "openai" and settings.openai_api_key: diff --git a/src/services/document.py b/src/services/document.py index c24c3a4..98f1e9a 100644 --- a/src/services/document.py +++ b/src/services/document.py @@ -1,4 +1,5 @@ from uuid import UUID +from pathlib import Path import structlog from fastapi import HTTPException, UploadFile, status @@ -7,6 +8,7 @@ from src.repositories.chunk_embedding import ChunkEmbeddingRepository from src.repositories.document import DocumentRepository from src.schemas.document import DocumentListItem +from src.schemas.document import DocumentStatus from src.services.storage import FileStorageService from src.tasks.document import process_document_task @@ -28,11 +30,7 @@ async def upload_document(self, file: UploadFile): detail="File must have a filename", ) - suffix = ( - "." + file.filename.rsplit(".", maxsplit=1)[-1].lower() - if "." in file.filename - else "" - ) + suffix = Path(file.filename).suffix.lower() if suffix not in ALLOWED_SUFFIXES: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -48,7 +46,21 @@ async def upload_document(self, file: UploadFile): document_id=str(document.id), filename=document.name, ) - process_document_task.delay(str(document.id)) + try: + process_document_task.delay(str(document.id)) + except Exception as exc: + document.status = DocumentStatus.FAILED.value + document.error_message = "Unable to enqueue document processing" + await self.session.commit() + logger.exception( + "document_processing_enqueue_failed", + document_id=str(document.id), + error=str(exc), + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Unable to enqueue document processing task", + ) from exc return document async def list_documents(self) -> list[DocumentListItem]: @@ -68,7 +80,7 @@ async def delete_document(self, document_id: UUID) -> None: await self.embeddings.delete_by_document_id(document.id) await self.repository.delete(document) await self.session.commit() - self.storage.remove(document.file_path) + await self.storage.remove(document.file_path) logger.info( "document_deleted", document_id=str(document.id), diff --git a/src/services/document_processor.py b/src/services/document_processor.py index 3c36982..bc88b6c 100644 --- a/src/services/document_processor.py +++ b/src/services/document_processor.py @@ -1,4 +1,5 @@ from __future__ import annotations +import asyncio import re from pathlib import Path from typing import Any @@ -43,8 +44,9 @@ async def process_document(self, document_id: UUID) -> None: try: chunks = await self._load_and_split(document.file_path) - vectors = self.embedding_model.embed_documents( - [chunk.page_content for chunk in chunks] + vectors = await asyncio.to_thread( + self.embedding_model.embed_documents, + [chunk.page_content for chunk in chunks], ) await self._validate_vector_dimensions(vectors) @@ -54,7 +56,6 @@ async def process_document(self, document_id: UUID) -> None: chunk_id=index, content=chunk.page_content, embedding=vector, - source_document=document.name, ) for index, (chunk, vector) in enumerate( zip(chunks, vectors, strict=True), start=1 @@ -164,13 +165,13 @@ async def generate_summary(self, document_id: UUID) -> None: async def _load_and_split(self, file_path: str): loader = self._build_loader(file_path) - docs = loader.load() + docs = await asyncio.to_thread(loader.load) splitter = RecursiveCharacterTextSplitter( chunk_size=settings.chunk_size, chunk_overlap=settings.chunk_overlap, ) - return splitter.split_documents(docs) + return await asyncio.to_thread(splitter.split_documents, docs) def _build_loader(self, file_path: str) -> Any: suffix = Path(file_path).suffix.lower() @@ -180,9 +181,10 @@ def _build_loader(self, file_path: str) -> Any: async def _generate_summary(self, chunks) -> str: context = "\n\n".join(chunk.page_content for chunk in chunks[:3]) - response = self.llm.invoke( + response = await asyncio.to_thread( + self.llm.invoke, "Summarize the document briefly for cataloging purposes.\n\n" - f"{context}" + f"{context}", ) content = response.content return ( diff --git a/src/services/qa.py b/src/services/qa.py index 6f71ce5..efc5fa9 100644 --- a/src/services/qa.py +++ b/src/services/qa.py @@ -1,4 +1,5 @@ from __future__ import annotations +import asyncio import structlog from fastapi import HTTPException, status from langchain_core.documents import Document as LangChainDocument @@ -62,7 +63,9 @@ async def ask(self, payload: AskRequest) -> AskResponse: }, ) - question_embedding = self.embedding_model.embed_query(payload.question) + question_embedding = await asyncio.to_thread( + self.embedding_model.embed_query, payload.question + ) question_embedding_dimensions = len(question_embedding) chunks = await self.embedding_repository.similarity_search( embedding=question_embedding, @@ -84,7 +87,7 @@ async def ask(self, payload: AskRequest) -> AskResponse: LangChainDocument( page_content=chunk.content, metadata={ - "document": chunk.source_document, + "document": chunk.document.name, "chunk_id": chunk.chunk_id, }, ) @@ -100,7 +103,8 @@ async def ask(self, payload: AskRequest) -> AskResponse: prompt_message = self.prompt.invoke( {"question": payload.question, "context": context} ) - raw_answer = self.llm.invoke(prompt_message).content + llm_response = await asyncio.to_thread(self.llm.invoke, prompt_message) + raw_answer = llm_response.content answer = ( raw_answer if isinstance(raw_answer, str) @@ -119,7 +123,8 @@ async def ask(self, payload: AskRequest) -> AskResponse: answer=answer, sources=[ AskSource( - document=chunk.source_document, chunk_id=chunk.chunk_id + document=chunk.document.name, + chunk_id=chunk.chunk_id, ) for chunk in chunks ], diff --git a/src/services/storage.py b/src/services/storage.py index 534885c..ba01721 100644 --- a/src/services/storage.py +++ b/src/services/storage.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from uuid import uuid4 @@ -11,12 +12,20 @@ async def save(self, file: UploadFile) -> tuple[str, str]: safe_name = file.filename or f"document-{uuid4()}.txt" file_id = uuid4() destination = settings.upload_dir / f"{file_id}_{safe_name}" - content = await file.read() - destination.write_bytes(content) + await asyncio.to_thread( + destination.parent.mkdir, parents=True, exist_ok=True + ) + await asyncio.to_thread(destination.write_bytes, b"") + with destination.open("ab") as stream: + while True: + chunk = await file.read(1024 * 1024) + if not chunk: + break + await asyncio.to_thread(stream.write, chunk) await file.close() return safe_name, str(destination) - def remove(self, file_path: str) -> None: + async def remove(self, file_path: str) -> None: path = Path(file_path) if path.exists(): - path.unlink() + await asyncio.to_thread(path.unlink) diff --git a/src/tasks/document.py b/src/tasks/document.py index 7b4916a..1aab1db 100644 --- a/src/tasks/document.py +++ b/src/tasks/document.py @@ -1,4 +1,5 @@ import asyncio +import threading from uuid import UUID from collections.abc import Coroutine from typing import Any @@ -12,13 +13,15 @@ logger = structlog.get_logger(__name__) _worker_loop: asyncio.AbstractEventLoop | None = None +_worker_loop_lock = threading.Lock() def _get_worker_loop() -> asyncio.AbstractEventLoop: global _worker_loop - if _worker_loop is None or _worker_loop.is_closed(): - _worker_loop = asyncio.new_event_loop() - return _worker_loop + with _worker_loop_lock: + if _worker_loop is None or _worker_loop.is_closed(): + _worker_loop = asyncio.new_event_loop() + return _worker_loop def _run_in_worker_loop(coro: Coroutine[Any, Any, None]) -> None: @@ -53,6 +56,7 @@ def generate_summary_task(document_id: str) -> None: @worker_process_shutdown.connect def _close_worker_loop(**_: object) -> None: global _worker_loop - if _worker_loop is not None and not _worker_loop.is_closed(): - _worker_loop.close() - _worker_loop = None + with _worker_loop_lock: + if _worker_loop is not None and not _worker_loop.is_closed(): + _worker_loop.close() + _worker_loop = None