Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/db/models/chunk_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime

from pgvector.sqlalchemy import Vector
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy import DateTime, ForeignKey, Integer, Text, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship

Expand All @@ -27,9 +27,12 @@ class ChunkEmbedding(Base):
embedding: Mapped[list[float]] = mapped_column(
Vector(settings.vector_dimensions), nullable=False
)
source_document: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)

document = relationship("Document", lazy="joined")
document = relationship(
"Document",
back_populates="chunk_embeddings",
lazy="joined",
)
14 changes: 13 additions & 1 deletion src/db/models/document.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import uuid
from datetime import datetime
from typing import TYPE_CHECKING

from sqlalchemy import DateTime, Integer, String, Text, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, relationship

from src.db.base import Base

if TYPE_CHECKING:
from src.db.models.chunk_embedding import ChunkEmbedding


class Document(Base):
__tablename__ = "documents"
Expand All @@ -27,3 +33,9 @@ class Document(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
chunk_embeddings: Mapped[list["ChunkEmbedding"]] = relationship(
"ChunkEmbedding",
back_populates="document",
cascade="all, delete-orphan",
passive_deletes=True,
)
10 changes: 8 additions & 2 deletions src/routers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
router = APIRouter(tags=["ask"])


def get_question_answer_service(
db: AsyncSession = Depends(get_db),
) -> QuestionAnswerService:
return QuestionAnswerService(db)


@router.post(
"/ask",
response_model=AskResponse,
Expand All @@ -18,6 +24,6 @@
)
async def ask_question(
payload: AskRequest,
db: AsyncSession = Depends(get_db),
service: QuestionAnswerService = Depends(get_question_answer_service),
) -> AskResponse:
return await QuestionAnswerService(db).ask(payload)
return await service.ask(payload)
24 changes: 15 additions & 9 deletions src/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,51 @@
from src.db.session import get_db
from src.schemas.document import (
DocumentCreateResponse,
DocumentErrorResponse,
DocumentListItem,
ErrorResponse,
)
from src.services.document import DocumentService

router = APIRouter(prefix="/documents", tags=["documents"])


def get_document_service(
db: AsyncSession = Depends(get_db),
) -> DocumentService:
return DocumentService(db)


@router.post(
"",
response_model=DocumentCreateResponse,
status_code=status.HTTP_202_ACCEPTED,
responses={status.HTTP_400_BAD_REQUEST: {"model": DocumentErrorResponse}},
responses={status.HTTP_400_BAD_REQUEST: {"model": ErrorResponse}},
)
async def upload_document(
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
service: DocumentService = Depends(get_document_service),
) -> DocumentCreateResponse:
document = await DocumentService(db).upload_document(file)
document = await service.upload_document(file)
return DocumentCreateResponse(
document_id=document.id, status=document.status
)


@router.get("", response_model=list[DocumentListItem])
async def list_documents(
db: AsyncSession = Depends(get_db),
service: DocumentService = Depends(get_document_service),
) -> list[DocumentListItem]:
return await DocumentService(db).list_documents()
return await service.list_documents()


@router.delete(
"/{document_id}",
status_code=status.HTTP_204_NO_CONTENT,
responses={status.HTTP_404_NOT_FOUND: {"model": DocumentErrorResponse}},
responses={status.HTTP_404_NOT_FOUND: {"model": ErrorResponse}},
)
async def delete_document(
document_id: UUID,
db: AsyncSession = Depends(get_db),
service: DocumentService = Depends(get_document_service),
) -> Response:
await DocumentService(db).delete_document(document_id)
await service.delete_document(document_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
6 changes: 2 additions & 4 deletions src/schemas/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from pydantic import BaseModel, Field

from src.schemas.document import ErrorResponse as AskErrorResponse


class AskRequest(BaseModel):
question: str = Field(min_length=3)
Expand All @@ -16,7 +18,3 @@ class AskSource(BaseModel):
class AskResponse(BaseModel):
answer: str
sources: list[AskSource]


class AskErrorResponse(BaseModel):
detail: str
4 changes: 2 additions & 2 deletions src/schemas/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DocumentCreateResponse(BaseModel):


class DocumentListItem(BaseModel):
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True, populate_by_name=True)

id: UUID
name: str
Expand All @@ -26,5 +26,5 @@ class DocumentListItem(BaseModel):
created_at: datetime


class DocumentErrorResponse(BaseModel):
class ErrorResponse(BaseModel):
detail: str
10 changes: 10 additions & 0 deletions src/services/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def _generate(self, messages, stop=None, run_manager=None, **kwargs):
generations=[ChatGeneration(message=AIMessage(content=response))]
)

async def _agenerate(
self, messages, stop=None, run_manager=None, **kwargs
):
return self._generate(
messages=messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)


def get_embeddings() -> Embeddings:
if settings.embedding_provider == "openai" and settings.openai_api_key:
Expand Down
26 changes: 19 additions & 7 deletions src/services/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from uuid import UUID
from pathlib import Path

import structlog
from fastapi import HTTPException, UploadFile, status
Expand All @@ -7,6 +8,7 @@
from src.repositories.chunk_embedding import ChunkEmbeddingRepository
from src.repositories.document import DocumentRepository
from src.schemas.document import DocumentListItem
from src.schemas.document import DocumentStatus
from src.services.storage import FileStorageService
from src.tasks.document import process_document_task

Expand All @@ -28,11 +30,7 @@ async def upload_document(self, file: UploadFile):
detail="File must have a filename",
)

suffix = (
"." + file.filename.rsplit(".", maxsplit=1)[-1].lower()
if "." in file.filename
else ""
)
suffix = Path(file.filename).suffix.lower()
if suffix not in ALLOWED_SUFFIXES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -48,7 +46,21 @@ async def upload_document(self, file: UploadFile):
document_id=str(document.id),
filename=document.name,
)
process_document_task.delay(str(document.id))
try:
process_document_task.delay(str(document.id))
except Exception as exc:
document.status = DocumentStatus.FAILED.value
document.error_message = "Unable to enqueue document processing"
await self.session.commit()
logger.exception(
"document_processing_enqueue_failed",
document_id=str(document.id),
error=str(exc),
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to enqueue document processing task",
) from exc
return document

async def list_documents(self) -> list[DocumentListItem]:
Expand All @@ -68,7 +80,7 @@ async def delete_document(self, document_id: UUID) -> None:
await self.embeddings.delete_by_document_id(document.id)
await self.repository.delete(document)
await self.session.commit()
self.storage.remove(document.file_path)
await self.storage.remove(document.file_path)
logger.info(
"document_deleted",
document_id=str(document.id),
Expand Down
16 changes: 9 additions & 7 deletions src/services/document_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import asyncio
import re
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -43,8 +44,9 @@ async def process_document(self, document_id: UUID) -> None:

try:
chunks = await self._load_and_split(document.file_path)
vectors = self.embedding_model.embed_documents(
[chunk.page_content for chunk in chunks]
vectors = await asyncio.to_thread(
self.embedding_model.embed_documents,
[chunk.page_content for chunk in chunks],
)
await self._validate_vector_dimensions(vectors)

Expand All @@ -54,7 +56,6 @@ async def process_document(self, document_id: UUID) -> None:
chunk_id=index,
content=chunk.page_content,
embedding=vector,
source_document=document.name,
)
for index, (chunk, vector) in enumerate(
zip(chunks, vectors, strict=True), start=1
Expand Down Expand Up @@ -164,13 +165,13 @@ async def generate_summary(self, document_id: UUID) -> None:

async def _load_and_split(self, file_path: str):
loader = self._build_loader(file_path)
docs = loader.load()
docs = await asyncio.to_thread(loader.load)
splitter = RecursiveCharacterTextSplitter(
chunk_size=settings.chunk_size,
chunk_overlap=settings.chunk_overlap,
)

return splitter.split_documents(docs)
return await asyncio.to_thread(splitter.split_documents, docs)

def _build_loader(self, file_path: str) -> Any:
suffix = Path(file_path).suffix.lower()
Expand All @@ -180,9 +181,10 @@ def _build_loader(self, file_path: str) -> Any:

async def _generate_summary(self, chunks) -> str:
context = "\n\n".join(chunk.page_content for chunk in chunks[:3])
response = self.llm.invoke(
response = await asyncio.to_thread(
self.llm.invoke,
"Summarize the document briefly for cataloging purposes.\n\n"
f"{context}"
f"{context}",
)
content = response.content
return (
Expand Down
13 changes: 9 additions & 4 deletions src/services/qa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import asyncio
import structlog
from fastapi import HTTPException, status
from langchain_core.documents import Document as LangChainDocument
Expand Down Expand Up @@ -62,7 +63,9 @@ async def ask(self, payload: AskRequest) -> AskResponse:
},
)

question_embedding = self.embedding_model.embed_query(payload.question)
question_embedding = await asyncio.to_thread(
self.embedding_model.embed_query, payload.question
)
question_embedding_dimensions = len(question_embedding)
chunks = await self.embedding_repository.similarity_search(
embedding=question_embedding,
Expand All @@ -84,7 +87,7 @@ async def ask(self, payload: AskRequest) -> AskResponse:
LangChainDocument(
page_content=chunk.content,
metadata={
"document": chunk.source_document,
"document": chunk.document.name,
"chunk_id": chunk.chunk_id,
},
)
Expand All @@ -100,7 +103,8 @@ async def ask(self, payload: AskRequest) -> AskResponse:
prompt_message = self.prompt.invoke(
{"question": payload.question, "context": context}
)
raw_answer = self.llm.invoke(prompt_message).content
llm_response = await asyncio.to_thread(self.llm.invoke, prompt_message)
raw_answer = llm_response.content
answer = (
raw_answer
if isinstance(raw_answer, str)
Expand All @@ -119,7 +123,8 @@ async def ask(self, payload: AskRequest) -> AskResponse:
answer=answer,
sources=[
AskSource(
document=chunk.source_document, chunk_id=chunk.chunk_id
document=chunk.document.name,
chunk_id=chunk.chunk_id,
)
for chunk in chunks
],
Expand Down
17 changes: 13 additions & 4 deletions src/services/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from pathlib import Path
from uuid import uuid4

Expand All @@ -11,12 +12,20 @@ async def save(self, file: UploadFile) -> tuple[str, str]:
safe_name = file.filename or f"document-{uuid4()}.txt"
file_id = uuid4()
destination = settings.upload_dir / f"{file_id}_{safe_name}"
content = await file.read()
destination.write_bytes(content)
await asyncio.to_thread(
destination.parent.mkdir, parents=True, exist_ok=True
)
await asyncio.to_thread(destination.write_bytes, b"")
with destination.open("ab") as stream:
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
await asyncio.to_thread(stream.write, chunk)
await file.close()
return safe_name, str(destination)

def remove(self, file_path: str) -> None:
async def remove(self, file_path: str) -> None:
path = Path(file_path)
if path.exists():
path.unlink()
await asyncio.to_thread(path.unlink)
Loading
Loading