From 207c33188fd854bc8cad46e126cb45ce3bdc66e1 Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Tue, 28 Jan 2025 15:50:42 +0000 Subject: [PATCH 1/6] Refactor part 1 --- .github/workflows/python-tests.yaml | 2 - Makefile | 48 +-- backend/app/__init__.py | 21 -- backend/app/auth/config.py | 3 - backend/app/auth/dependencies.py | 2 +- backend/app/config.py | 9 +- .../{utils => fetch_utils}/airtable_utils.py | 7 +- .../google_drive_utils.py | 63 +--- .../embedding_utils.py | 1 - .../ingestion/process_utils/openai_utils.py | 223 ++++++++++++ backend/app/ingestion/routers.py | 64 +--- .../gcp_storage_utils.py | 5 +- .../ingestion/utils/file_processing_utils.py | 132 ++++--- backend/app/ingestion/utils/openai_utils.py | 96 +---- .../app/ingestion/utils/record_processing.py | 343 +++++++++++------- backend/app/search/utils.py | 2 +- backend/app/utils.py | 4 +- backend/tests/api/test_ingestion.py | 2 +- 18 files changed, 563 insertions(+), 464 deletions(-) delete mode 100644 backend/app/auth/config.py rename backend/app/ingestion/{utils => fetch_utils}/airtable_utils.py (98%) rename backend/app/ingestion/{utils => fetch_utils}/google_drive_utils.py (57%) rename backend/app/ingestion/{utils => process_utils}/embedding_utils.py (99%) create mode 100644 backend/app/ingestion/process_utils/openai_utils.py rename backend/app/ingestion/{utils => storage_utils}/gcp_storage_utils.py (99%) diff --git a/.github/workflows/python-tests.yaml b/.github/workflows/python-tests.yaml index 69eb671..6823bd9 100644 --- a/.github/workflows/python-tests.yaml +++ b/.github/workflows/python-tests.yaml @@ -11,7 +11,6 @@ env: POSTGRES_PASSWORD: postgres-test-pw POSTGRES_USER: postgres-test-user POSTGRES_DB: postgres-test-db - REDIS_HOST: redis://redis:6379 jobs: container-job: runs-on: ubuntu-20.04 @@ -54,7 +53,6 @@ jobs: - name: Run Unit Tests env: PROMETHEUS_MULTIPROC_DIR: /tmp - REDIS_HOST: ${{ env.REDIS_HOST }} run: | cd backend export POSTGRES_HOST=postgres POSTGRES_USER=$POSTGRES_USER \ diff --git a/Makefile b/Makefile index b9e01a3..26195e7 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,4 @@ include ./deployment/docker-compose/.backend.env -include ./deployment/docker-compose/.base.env PROJECT_NAME=hew-ai CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate @@ -8,54 +7,25 @@ ENDPOINT_URL = localhost:8000 guard-%: @if [ -z '${${*}}' ]; then echo 'ERROR: environment variable $* not set' && exit 1; fi -# Note: Run `make fresh-env psycopg2-binary=true` to manually replace psycopg with psycopg2-binary -fresh-env : - conda remove --name $(PROJECT_NAME) --all -y - conda create --name $(PROJECT_NAME) python==3.12 -y - - $(CONDA_ACTIVATE) $(PROJECT_NAME); \ - pip install -r backend/requirements.txt --ignore-installed; \ - pip install -r requirements-dev.txt --ignore-installed; \ - pre-commit install - - if [ "$(psycopg2-binary)" = "true" ]; then \ - $(CONDA_ACTIVATE) $(PROJECT_NAME); \ - pip uninstall -y psycopg2==2.9.9; \ - pip install psycopg2-binary==2.9.9; \ - fi setup-db: guard-POSTGRES_USER guard-POSTGRES_PASSWORD guard-POSTGRES_DB - -@docker stop pg-hew-ai-local - -@docker rm pg-hew-ai-local + -@docker stop survey-accelerator + -@docker rm survey-accelerator @docker system prune -f @sleep 2 - @docker run --name pg-hew-ai-local \ - -e POSTGRES_USER=$(POSTGRES_USER) \ - -e POSTGRES_PASSWORD=$(POSTGRES_PASSWORD) \ - -e POSTGRES_DB=$(POSTGRES_DB) \ - -p 5432:5432 \ + @docker run --name survey-accelerator \ + -e POSTGRES_USER=${POSTGRES_USER} \ + -e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \ + -e POSTGRES_DB=${POSTGRES_DB} \ + -p ${POSTGRES_PORT}:5432 \ -d pgvector/pgvector:pg16 @sleep 5 set -a && \ - source "$(CURDIR)/deployment/docker-compose/.base.env" && \ source "$(CURDIR)/deployment/docker-compose/.backend.env" && \ set +a && \ cd backend && \ python -m alembic upgrade head teardown-db: - @docker stop pg-hew-ai-local - @docker rm pg-hew-ai-local - -setup-redis: - -@docker stop redis-hew-ai-local - -@docker rm redis-hew-ai-local - @docker system prune -f - @sleep 2 - @docker run --name redis-hew-ai-local \ - -p 6379:6379 \ - -d redis:6.0-alpine - -make teardown-redis: - @docker stop redis-hew-ai-local - @docker rm redis-hew-ai-local + @docker stop survey-accelerator + @docker rm survey-accelerator diff --git a/backend/app/__init__.py b/backend/app/__init__.py index f363e53..cf0e874 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -1,12 +1,7 @@ -from contextlib import asynccontextmanager -from typing import AsyncIterator - from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from redis import asyncio as aioredis from . import ingestion, search -from .config import REDIS_HOST from .utils import setup_logger logger = setup_logger() @@ -14,21 +9,6 @@ tags_metadata = [ingestion.TAG_METADATA] -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncIterator[None]: - """ - Lifespan events for the FastAPI application. - """ - - logger.info("Application started") - app.state.redis = await aioredis.from_url(REDIS_HOST) - - yield - - await app.state.redis.close() - logger.info("Application finished") - - def create_app() -> FastAPI: """ Create a FastAPI application with the appropriate routers. @@ -36,7 +16,6 @@ def create_app() -> FastAPI: app = FastAPI( title="Survey Accelerator", openapi_tags=tags_metadata, - lifespan=lifespan, debug=True, ) diff --git a/backend/app/auth/config.py b/backend/app/auth/config.py deleted file mode 100644 index 2a50280..0000000 --- a/backend/app/auth/config.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -API_SECRET_KEY = os.getenv("API_SECRET_KEY", "kk") diff --git a/backend/app/auth/dependencies.py b/backend/app/auth/dependencies.py index e5abec8..7676853 100644 --- a/backend/app/auth/dependencies.py +++ b/backend/app/auth/dependencies.py @@ -4,7 +4,7 @@ HTTPBearer, ) -from .config import API_SECRET_KEY +from ..config import API_SECRET_KEY bearer = HTTPBearer() diff --git a/backend/app/config.py b/backend/app/config.py index e589410..8f35bd6 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,5 +1,9 @@ import os +# Auth +API_SECRET_KEY = os.getenv("API_SECRET_KEY", "kk") + + # PostgreSQL Configurations POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres") @@ -8,9 +12,6 @@ POSTGRES_DB = os.environ.get("POSTGRES_DB", "postgres") DB_POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", 20)) -# Redis Configuration -REDIS_HOST = os.environ.get("REDIS_HOST", "redis://localhost:6379") - # Backend Configuration BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING") @@ -50,6 +51,4 @@ ) # Other Configurations -MAX_PAGES = int(os.environ.get("MAX_PAGES", 3)) MAIN_DOWNLOAD_DIR = "downloaded_gdrives_sa" -XLSX_SUBDIR = os.path.join(MAIN_DOWNLOAD_DIR, "xlsx") diff --git a/backend/app/ingestion/utils/airtable_utils.py b/backend/app/ingestion/fetch_utils/airtable_utils.py similarity index 98% rename from backend/app/ingestion/utils/airtable_utils.py rename to backend/app/ingestion/fetch_utils/airtable_utils.py index a729304..0181949 100644 --- a/backend/app/ingestion/utils/airtable_utils.py +++ b/backend/app/ingestion/fetch_utils/airtable_utils.py @@ -2,13 +2,12 @@ from typing import Any, Dict -from pyairtable import Api -from sqlalchemy import select - from app.config import AIRTABLE_API_KEY, AIRTABLE_CONFIGS from app.database import get_async_session from app.ingestion.models import DocumentDB from app.utils import setup_logger +from pyairtable import Api +from sqlalchemy import select logger = setup_logger() @@ -17,7 +16,7 @@ raise EnvironmentError("Airtable API key not found in environment variables.") -def get_airtable_records() -> list: +async def get_airtable_records() -> list: """ Fetch records from Airtable and return a list of records. Raises exceptions if there are issues fetching the records. diff --git a/backend/app/ingestion/utils/google_drive_utils.py b/backend/app/ingestion/fetch_utils/google_drive_utils.py similarity index 57% rename from backend/app/ingestion/utils/google_drive_utils.py rename to backend/app/ingestion/fetch_utils/google_drive_utils.py index 44d8874..724bfc6 100644 --- a/backend/app/ingestion/utils/google_drive_utils.py +++ b/backend/app/ingestion/fetch_utils/google_drive_utils.py @@ -4,14 +4,13 @@ import os from typing import Optional +from app.config import SCOPES, SERVICE_ACCOUNT_FILE_PATH +from app.utils import setup_logger from google.oauth2 import service_account from googleapiclient.discovery import Resource as DriveResource from googleapiclient.discovery import build from googleapiclient.http import MediaIoBaseDownload -from app.config import SCOPES, SERVICE_ACCOUNT_FILE_PATH, XLSX_SUBDIR -from app.utils import setup_logger - logger = setup_logger() @@ -26,7 +25,7 @@ def get_drive_service() -> DriveResource: return drive_service -def extract_file_id(gdrive_url: str) -> str: +async def extract_file_id(gdrive_url: str) -> str: """ Extract the file ID from a Google Drive URL. @@ -68,7 +67,7 @@ def extract_file_id(gdrive_url: str) -> str: ) -def determine_file_type(file_name: str) -> str: +async def determine_file_type(file_name: str) -> str: """ Determine the file type based on the file extension. """ @@ -81,56 +80,22 @@ def determine_file_type(file_name: str) -> str: return "other" -def download_file( - file_id: str, file_name: str, file_type: str, drive_service: DriveResource -) -> Optional[io.BytesIO]: +def download_file(file_id: str, drive_service: DriveResource) -> Optional[io.BytesIO]: """ Download a file from Google Drive using its file ID and handle it based on file type For PDFs, download into memory and return the BytesIO object. For XLSX, download and save to disk. """ try: - # Get file metadata to determine MIME type - file_metadata = ( - drive_service.files().get(fileId=file_id, fields="mimeType, name").execute() - ) - mime_type = file_metadata.get("mimeType") - - if ( - file_type == "xlsx" - and mime_type == "application/vnd.google-apps.spreadsheet" - ): - # Export Google Sheets to Excel format - request = drive_service.files().export_media( - fileId=file_id, - mimeType="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - ) - save_path = os.path.join( - XLSX_SUBDIR, - ( - f"{file_name}.xlsx" - if not file_name.lower().endswith(".xlsx") - else file_name - ), - ) - with open(save_path, "wb") as xlsx_file: - downloader = MediaIoBaseDownload(xlsx_file, request) - done = False - while not done: - status, done = downloader.next_chunk() - return None # No need to return anything for XLSX - elif file_type == "pdf": - # Download PDF files into memory - request = drive_service.files().get_media(fileId=file_id) - pdf_buffer = io.BytesIO() - downloader = MediaIoBaseDownload(pdf_buffer, request) - done = False - while not done: - status, done = downloader.next_chunk() - pdf_buffer.seek(0) # Reset buffer position to the beginning - return pdf_buffer # Return the in-memory file for PDFs - else: - return None + logger.warning("Downloading PDF file...") + request = drive_service.files().get_media(fileId=file_id) + pdf_buffer = io.BytesIO() + downloader = MediaIoBaseDownload(pdf_buffer, request) + done = False + while not done: + status, done = downloader.next_chunk() + pdf_buffer.seek(0) # Reset buffer position to the beginning + return pdf_buffer # Return the in-memory file for PDFs except Exception as e: logger.error(f"Error downloading file: {e}") diff --git a/backend/app/ingestion/utils/embedding_utils.py b/backend/app/ingestion/process_utils/embedding_utils.py similarity index 99% rename from backend/app/ingestion/utils/embedding_utils.py rename to backend/app/ingestion/process_utils/embedding_utils.py index 758886a..3733b5c 100644 --- a/backend/app/ingestion/utils/embedding_utils.py +++ b/backend/app/ingestion/process_utils/embedding_utils.py @@ -1,7 +1,6 @@ # utils/embedding_utils.py import cohere - from app.config import COHERE_API_KEY from app.utils import setup_logger diff --git a/backend/app/ingestion/process_utils/openai_utils.py b/backend/app/ingestion/process_utils/openai_utils.py new file mode 100644 index 0000000..1e31e2d --- /dev/null +++ b/backend/app/ingestion/process_utils/openai_utils.py @@ -0,0 +1,223 @@ +# utils / openai_utils.py + +import json +import os + +import openai +from app.utils import setup_logger +from openai import AsyncOpenAI + +logger = setup_logger() + +# Instantiate the OpenAI client (you can keep this or comment it out) +client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) +async_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + +async def generate_contextual_summary(document_content: str, chunk_content: str) -> str: + """ + Generate a concise contextual summary for a chunk. + """ + # Construct the prompt (optional, you can comment this out if not needed) + prompt = f""" + + {document_content} + + + Here is a specific page from the document: + + {chunk_content} + + + Please provide a concise, contextually accurate summary for the above page based + strictly on its visible content. + DO NOT include generic survey topics (e.g., contraception) unless clearly + mentioned on the page. + This is to improve precise search relevance and must reflect what is + explicitly covered on this page AND HOW IT SITUATES WITHIN THE LARGER DOCUMENT. + Answer only with the context, avoiding any inferred topics. + """ + try: + response = await async_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=150, + temperature=0, + ) + summary = response.choices[0].message.content.strip() + + return summary + + except Exception as e: + logger.error(f"Unexpected error generating contextual summary: {e}") + return "" + + +async def generate_brief_summary(document_content: str) -> str: + """ + Generate a concise summary of the entire document in 10-15 words. + """ + # Construct the prompt (optional) + prompt = f"""Summarize the following document in 10 to 15 in a sentence that + starts with the word 'Covers' + e.g. 'Covers womens health and contraception awareness.' or + 'Covers the impact of climate change on agriculture.' + Respond only with the summary and nothing else. + ENSURE YOUR ANSWER NEVER EXCEEDS 15 WORDS. + Below is the content to summarize: + \n\n{document_content}""" + + try: + response = await async_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=20, + temperature=0.2, + ) + summary = response.choices[0].message.content.strip() + return summary + except Exception as e: + logger.error(f"Error generating brief summary: {e}") + return "" + + +async def generate_smart_filename(file_name: str, document_content: str) -> str: + """ + Generate a specific and elegant filename devoid of year numbers. + + Args: + file_name: Original file name. + document_content: Content of the document. + + Returns: + A descriptive filename as a plain text string. + """ + # Construct the prompt with precise instructions + prompt = f""" + Given the original file name "{file_name}" and the content excerpt below, generate + a specific and descriptive filename that includes relevant information, such as the + organization name (e.g., 'USAID', 'UNICEF') and the main topic of the document. + + The filename should: + - Avoid any dates or year numbers + - Exclude prefixes like 'Filename:', asterisks, or other symbols + - Be concise, specific, and relevant to the document content + + Examples of appropriate filenames might include: + - 'USAID Family Health Survey' + - 'UNICEF Maternal Health Questionnaire' + - 'DHS Reproductive Health Analysis' + + Content excerpt: + {document_content[:2000]} + + Plain text filename: + """ + + try: + response = await async_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=10, + temperature=0.2, + ) + smart_name = response.choices[0].message.content.strip() + + # Optional: Ensure there's no unexpected text or symbols + smart_name = smart_name.split(":", 1)[ + -1 + ].strip() # Remove anything before a colon + + return smart_name + except Exception as e: + logger.error(f"Error generating elegant filename: {e}") + return "" + + +async def extract_question_answer_from_page(chunk_content: str) -> list[dict]: + """ + Extract questions and answers from a chunk of text. + """ + prompt = f""" + You are to extract questions and their possible answers from the following survey + questionnaire chunk. The survey may be messy, but take time to reason through your + response. + + **Provide the output as a JSON list of dictionaries in the following format without + any code block or markdown formatting**: + [ + {{"question": "Question text", "answers": ["Answer 1", "Answer 2"]}}, + ... + ] + + **Do not include any additional text outside the JSON array. Do not include any code + block notation, such as triple backticks or language identifiers like +json.** + + If no questions or answers are found, return an empty list: [] + + Chunk: + {chunk_content} + """ + + try: + response = await async_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=1000, + temperature=0, + ) + qa_pairs_str = response.choices[0].message.content.strip() + qa_pairs = json.loads(qa_pairs_str) + return qa_pairs + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + logger.error(f"Response was: {qa_pairs_str}") + return [] + except Exception as e: + logger.error(f"Error extracting questions and answers: {e}") + return [] + + +async def generate_query_match_explanation(query: str, chunk_content: str) -> str: + """ + Generate a short explanation of how the query matches the contextualized chunk. + """ + prompt = f""" + Given the following query: + "{query}" + + And the following chunk from a document: + "{chunk_content}" + + Provide a one-sentence, 12 word maximum explanation starting with "Mentions ..." + to explain why the + chunk matches the query. + + Be extremely specific to the document at hand and avoid generalizations + or inferences. Do not mention the query in the explanation. + Do not include any additional text outside the explanation. + """ + + try: + response = await async_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "user", "content": prompt}, + ], + max_tokens=250, + temperature=0, + ) + explanation = response.choices[0].message.content.strip() + return explanation + except Exception as e: + logger.error(f"Error generating match explanation: {e}") + return "Unable to generate explanation." diff --git a/backend/app/ingestion/routers.py b/backend/app/ingestion/routers.py index e9da208..03aa0e8 100644 --- a/backend/app/ingestion/routers.py +++ b/backend/app/ingestion/routers.py @@ -1,13 +1,13 @@ # routers.py -from fastapi import APIRouter, Body, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status from app.auth.dependencies import authenticate_key -from app.ingestion.schemas import AirtableIngestionResponse -from app.ingestion.utils.airtable_utils import ( +from app.ingestion.fetch_utils.airtable_utils import ( get_airtable_records, get_missing_document_ids, ) +from app.ingestion.schemas import AirtableIngestionResponse from app.ingestion.utils.record_processing import ingest_records from app.utils import setup_logger @@ -25,58 +25,6 @@ } -@router.post("/airtable", response_model=AirtableIngestionResponse) -async def ingest_airtable( - ids: list[int] = Body(..., embed=True), -) -> AirtableIngestionResponse: - """ - Ingest documents from Airtable records with page-level progress logging. - Accepts a list of document IDs to process. - """ - try: - records = get_airtable_records() - - except KeyError as e: - logger.error(f"Airtable configuration error: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Airtable configuration error.", - ) from e - - except Exception as e: - logger.error(f"Error fetching records from Airtable: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error fetching records from Airtable: {e}", - ) from e - - if not records: - return AirtableIngestionResponse( - total_records_processed=0, - total_chunks_created=0, - message="No records found in Airtable.", - ) - # Map IDs to records - airtable_id_to_record = {} - for record in records: - fields = record.get("fields", {}) - id_value = fields.get("ID") - if id_value is not None: - airtable_id_to_record[id_value] = record - - records_to_process = [ - airtable_id_to_record[id_] for id_ in ids if id_ in airtable_id_to_record - ] - if not records_to_process: - return AirtableIngestionResponse( - total_records_processed=0, - total_chunks_created=0, - message="No matching records found for the provided IDs.", - ) - - return await ingest_records(records_to_process) - - @router.post("/airtable/refresh", response_model=AirtableIngestionResponse) async def airtable_refresh_and_ingest() -> AirtableIngestionResponse: """ @@ -84,7 +32,8 @@ async def airtable_refresh_and_ingest() -> AirtableIngestionResponse: 'document_id's. Automatically ingest the missing documents. """ try: - records = get_airtable_records() + records = await get_airtable_records() + records = records[:3] except KeyError as e: logger.error(f"Airtable configuration error: {e}") raise HTTPException( @@ -117,8 +66,7 @@ async def airtable_refresh_and_ingest() -> AirtableIngestionResponse: # Map IDs to records airtable_id_to_record = {} for record in records: - fields = record.get("fields", {}) - id_value = fields.get("ID") + id_value = record.get("fields").get("ID") if id_value is not None: airtable_id_to_record[id_value] = record diff --git a/backend/app/ingestion/utils/gcp_storage_utils.py b/backend/app/ingestion/storage_utils/gcp_storage_utils.py similarity index 99% rename from backend/app/ingestion/utils/gcp_storage_utils.py rename to backend/app/ingestion/storage_utils/gcp_storage_utils.py index b1b6742..75a0f62 100644 --- a/backend/app/ingestion/utils/gcp_storage_utils.py +++ b/backend/app/ingestion/storage_utils/gcp_storage_utils.py @@ -1,10 +1,9 @@ import io -from google.cloud import storage -from google.oauth2 import service_account - from app.config import SERVICE_ACCOUNT_FILE_PATH from app.utils import setup_logger +from google.cloud import storage +from google.oauth2 import service_account logger = setup_logger() diff --git a/backend/app/ingestion/utils/file_processing_utils.py b/backend/app/ingestion/utils/file_processing_utils.py index 9340fa2..e23983f 100644 --- a/backend/app/ingestion/utils/file_processing_utils.py +++ b/backend/app/ingestion/utils/file_processing_utils.py @@ -1,14 +1,11 @@ -# utils/file_processing_utils.py - import asyncio from typing import Any, BinaryIO, Dict import PyPDF2 import tiktoken -import tqdm -from app.ingestion.utils.embedding_utils import create_embedding -from app.ingestion.utils.openai_utils import ( +from app.ingestion.process_utils.embedding_utils import create_embedding +from app.ingestion.process_utils.openai_utils import ( extract_question_answer_from_page, generate_contextual_summary, ) @@ -52,15 +49,14 @@ async def process_file( file_buffer: BinaryIO, file_name: str, file_type: str, - progress_bar: tqdm.tqdm, - progress_lock: asyncio.Lock, - metadata: Dict[str, Any], # Pass metadata here + metadata: Dict[str, Any], ) -> list[Dict[str, Any]]: """ Process the file by parsing, generating summaries, and creating embeddings. Updates the progress bar after processing each page. + + The heavy computations are performed at the document level to improve efficiency. """ - # Ensure file_buffer is not None if file_buffer is None: logger.warning(f"No file buffer provided for file '{file_name}'. Skipping.") return [] @@ -71,16 +67,11 @@ async def process_file( logger.warning(f"No text extracted from file '{file_name}'. Skipping.") return [] - # Update the progress bar total with the number of pages - async with progress_lock: - progress_bar.total += len(chunks) - progress_bar.refresh() - # Combine all chunks into a single document_content document_content = "\n\n".join(chunks) # Initialize tiktoken encoder for the model you're using - encoding = tiktoken.encoding_for_model("gpt-4") # Adjust the model name if needed + encoding = tiktoken.encoding_for_model("gpt-4") # Define model's max context length MAX_CONTEXT_LENGTH = 8192 # Adjust according to your model's context length @@ -101,18 +92,21 @@ async def process_file( retrieval of the chunk. Answer only with the succinct context and nothing else. """ - - # Tokenize the prompt template without variables prompt_tokens = len( encoding.encode(prompt_template.format(document_content="", chunk_content="")) ) - # Create metadata string + # Tokenize document_content once + document_tokens = encoding.encode(document_content) + + # Create metadata string once metadata_string = create_metadata_string(metadata) metadata_section = f"Information about this document:\n{metadata_string}\n\n" processed_pages: list[Dict[str, Any]] = [] + # Prepare tasks for contextual summaries + contextual_summary_tasks = [] for page_num, page_text in enumerate(chunks): # Tokenize the chunk content chunk_tokens = len(encoding.encode(page_text)) @@ -120,71 +114,109 @@ async def process_file( # Calculate the available tokens for document_content available_tokens = MAX_CONTEXT_LENGTH - prompt_tokens - chunk_tokens - # Ensure available_tokens is positive if available_tokens <= 0: logger.warning( f"Chunk content on page {page_num + 1} is too long. Skipping." ) + contextual_summary_tasks.append(None) # Placeholder for skipped page continue - # Tokenize document_content - document_tokens = encoding.encode(document_content) - # Truncate document_content if necessary if len(document_tokens) > available_tokens: - # Truncate document_content to fit available tokens truncated_document_tokens = document_tokens[:available_tokens] truncated_document_content = encoding.decode(truncated_document_tokens) else: truncated_document_content = document_content - # Generate contextual summary asynchronously - chunk_summary = await asyncio.to_thread( - generate_contextual_summary, truncated_document_content, page_text + # Create the task for generating contextual summary + task = generate_contextual_summary( + truncated_document_content, + page_text, ) + contextual_summary_tasks.append(task) + + # Run contextual summary tasks concurrently + chunk_summaries = await asyncio.gather( + *[task for task in contextual_summary_tasks if task is not None], + return_exceptions=True, + ) + + # Prepare tasks for embedding generation and QA extraction + embedding_tasks = [] + qa_extraction_tasks = [] + + summary_index = 0 # Index to keep track of successful summaries + for page_num, page_text in enumerate(chunks): + if contextual_summary_tasks[page_num] is None: + # Skip pages that couldn't generate a summary + continue + + chunk_summary = chunk_summaries[summary_index] + summary_index += 1 if not chunk_summary: logger.warning( - f"""Contextual summary generation failed for page {page_num + 1} - in file '{file_name}'. Skipping.""" + f"Contextual summary generation failed for page {page_num + 1} " + f"in file '{file_name}'. Skipping." ) continue # Combine metadata, context summary, and chunk text contextualized_chunk = f"{metadata_section}{chunk_summary}\n\n{page_text}" - # Create embedding asynchronously - embedding = await asyncio.to_thread(create_embedding, contextualized_chunk) + # Create the task for embedding generation + embedding_task = asyncio.to_thread(create_embedding, contextualized_chunk) + embedding_tasks.append((page_num, embedding_task)) - if embedding is None: - logger.warning( - f"""Embedding generation failed for page {page_num + 1} in - file '{file_name}'. Skipping.""" - ) - continue - - # Extract questions and answers - extracted_question_answers = extract_question_answer_from_page(page_text) - if not isinstance(extracted_question_answers, list): - logger.error( - f"""Extracted QA pairs on page {page_num + 1} is not a list. - Skipping QA pairs for this page.""" - ) - extracted_question_answers = [] + # Create the task for question-answer extraction + qa_task = extract_question_answer_from_page(page_text) + qa_extraction_tasks.append((page_num, qa_task)) + # Initialize processed page entry processed_pages.append( { "page_number": page_num + 1, "contextualized_chunk": contextualized_chunk, "chunk_summary": chunk_summary, - "embedding": embedding, - "extracted_question_answers": extracted_question_answers, + # Embedding and QA pairs will be added later } ) - # Update progress bar after processing each page - async with progress_lock: - progress_bar.update(1) + # Run embedding tasks concurrently + embeddings_results = await asyncio.gather( + *[task for _, task in embedding_tasks], return_exceptions=True + ) + + # Run QA extraction tasks concurrently + qa_results = await asyncio.gather( + *[task for _, task in qa_extraction_tasks], return_exceptions=True + ) + + # Associate embeddings and QA pairs with processed pages + for i, page in enumerate(processed_pages): + page_num = page["page_number"] - 1 # Zero-based index + + # Handle embedding + embedding_result = embeddings_results[i] + if isinstance(embedding_result, Exception) or embedding_result is None: + logger.warning( + f"Embedding generation failed for page {page_num + 1} in " + f"file '{file_name}'. Skipping embedding." + ) + page["embedding"] = None + else: + page["embedding"] = embedding_result + + # Handle QA pairs + qa_result = qa_results[i] + if not isinstance(qa_result, list): + logger.error( + f"Extracted QA pairs on page {page_num + 1} is not a list. " + f"Skipping QA pairs for this page." + ) + page["extracted_question_answers"] = [] + else: + page["extracted_question_answers"] = qa_result return processed_pages diff --git a/backend/app/ingestion/utils/openai_utils.py b/backend/app/ingestion/utils/openai_utils.py index 5c17f18..8d0c073 100644 --- a/backend/app/ingestion/utils/openai_utils.py +++ b/backend/app/ingestion/utils/openai_utils.py @@ -4,9 +4,8 @@ import os import openai -from openai import AsyncOpenAI - from app.utils import setup_logger +from openai import AsyncOpenAI logger = setup_logger() @@ -15,7 +14,7 @@ async_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) -def generate_contextual_summary(document_content: str, chunk_content: str) -> str: +async def generate_contextual_summary(document_content: str, chunk_content: str) -> str: """ Generate a concise contextual summary for a chunk. """ @@ -39,7 +38,7 @@ def generate_contextual_summary(document_content: str, chunk_content: str) -> st Answer only with the context, avoiding any inferred topics. """ try: - response = client.chat.completions.create( + response = await async_client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "user", "content": prompt}, @@ -56,92 +55,7 @@ def generate_contextual_summary(document_content: str, chunk_content: str) -> st return "" -def generate_brief_summary(document_content: str) -> str: - """ - Generate a concise summary of the entire document in 10-15 words. - """ - # Construct the prompt (optional) - prompt = f"""Summarize the following document in 10 to 15 in a sentence that - starts with the word 'Covers' - e.g. 'Covers womens health and contraception awareness.' or - 'Covers the impact of climate change on agriculture.' - Respond only with the summary and nothing else. - ENSURE YOUR ANSWER NEVER EXCEEDS 15 WORDS. - Below is the content to summarize: - \n\n{document_content}""" - - try: - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=20, - temperature=0.2, - ) - summary = response.choices[0].message.content.strip() - return summary - except Exception as e: - logger.error(f"Error generating brief summary: {e}") - return "" - - -def generate_smart_filename(file_name: str, document_content: str) -> str: - """ - Generate a specific and elegant filename devoid of year numbers. - - Args: - file_name: Original file name. - document_content: Content of the document. - - Returns: - A descriptive filename as a plain text string. - """ - # Construct the prompt with precise instructions - prompt = f""" - Given the original file name "{file_name}" and the content excerpt below, generate - a specific and descriptive filename that includes relevant information, such as the - organization name (e.g., 'USAID', 'UNICEF') and the main topic of the document. - - The filename should: - - Avoid any dates or year numbers - - Exclude prefixes like 'Filename:', asterisks, or other symbols - - Be concise, specific, and relevant to the document content - - Examples of appropriate filenames might include: - - 'USAID Family Health Survey' - - 'UNICEF Maternal Health Questionnaire' - - 'DHS Reproductive Health Analysis' - - Content excerpt: - {document_content[:2000]} - - Plain text filename: - """ - - try: - response = client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=10, - temperature=0.2, - ) - smart_name = response.choices[0].message.content.strip() - - # Optional: Ensure there's no unexpected text or symbols - smart_name = smart_name.split(":", 1)[ - -1 - ].strip() # Remove anything before a colon - - return smart_name - except Exception as e: - logger.error(f"Error generating elegant filename: {e}") - return "" - - -def extract_question_answer_from_page(chunk_content: str) -> list[dict]: +async def extract_question_answer_from_page(chunk_content: str) -> list[dict]: """ Extract questions and answers from a chunk of text. """ @@ -168,7 +82,7 @@ def extract_question_answer_from_page(chunk_content: str) -> list[dict]: """ try: - response = client.chat.completions.create( + response = await async_client.chat.completions.create( model="gpt-4o", messages=[ {"role": "user", "content": prompt}, diff --git a/backend/app/ingestion/utils/record_processing.py b/backend/app/ingestion/utils/record_processing.py index 0e00cfc..5f28066 100644 --- a/backend/app/ingestion/utils/record_processing.py +++ b/backend/app/ingestion/utils/record_processing.py @@ -1,198 +1,277 @@ -# utils/record_processing.py - import asyncio import io import os -from typing import Any, Dict, Tuple - -import tqdm +from typing import Any, Dict, List from app.database import get_async_session -from app.ingestion.models import save_document_to_db -from app.ingestion.schemas import AirtableIngestionResponse -from app.ingestion.utils.file_processing_utils import process_file -from app.ingestion.utils.gcp_storage_utils import upload_file_buffer_to_gcp_bucket -from app.ingestion.utils.google_drive_utils import ( +from app.ingestion.fetch_utils.google_drive_utils import ( determine_file_type, download_file, extract_file_id, get_drive_service, ) -from app.ingestion.utils.openai_utils import ( - generate_brief_summary, - generate_smart_filename, +from app.ingestion.models import save_document_to_db +from app.ingestion.schemas import AirtableIngestionResponse +from app.ingestion.storage_utils.gcp_storage_utils import ( + upload_file_buffer_to_gcp_bucket, ) +from app.ingestion.utils.file_processing_utils import process_file from app.utils import setup_logger logger = setup_logger() +MAX_CONCURRENT_TASKS = 5 # Adjust as needed + -async def process_record( - record: Dict[str, Any], - progress_bar: tqdm.tqdm, - progress_lock: asyncio.Lock, -) -> Tuple[int, int]: +async def extract_metadata(record: Dict[str, Any]) -> Dict[str, Any] | None: """ - Process a single record from Airtable. + Extracts and validates metadata from a record. """ fields = record.get("fields", {}) file_name = fields.get("File name") + survey_name = fields.get("Survey name") gdrive_link = fields.get("Drive link") document_id = fields.get("ID") if not file_name or not gdrive_link or document_id is None: logger.warning( - f"""Record {record.get('id')} is missing 'File name', - 'Drive link', or 'ID'. Skipping.""" + f"""Record {record.get('id')} is missing 'File name', 'Drive link', or 'ID'. + Skipping.""" ) - return (0, 0) + return None # Determine file type - file_type = determine_file_type(file_name) + file_type = await determine_file_type(file_name) if file_type == "other": logger.warning(f"File '{file_name}' has an unsupported extension. Skipping.") - return (0, 0) + return None # Extract file ID from Drive link try: - file_id = extract_file_id(gdrive_link) + file_id = await extract_file_id(gdrive_link) except ValueError as ve: logger.error(f"Error processing file '{file_name}': {ve}. Skipping.") - return (0, 0) - - # Get a Google Drive service instance - drive_service = get_drive_service() + return None + + # Collect metadata + metadata = { + "record_id": record.get("id"), + "fields": fields, + "file_name": file_name, + "file_type": file_type, + "file_id": file_id, + "document_id": document_id, + "survey_name": survey_name, + } + return metadata + + +async def download_files( + metadata_list: List[Dict[str, Any]], + semaphore: asyncio.Semaphore, + drive_service, +) -> List[Dict[str, Any]]: + """ + Downloads files concurrently and updates metadata with file buffers. + """ - # Download the file asynchronously - try: - file_buffer = await asyncio.to_thread( - download_file, file_id, file_name, file_type, drive_service - ) - except Exception as e: - logger.error(f"Error downloading file '{file_name}': {e}") - return (0, 0) - - if not file_buffer and file_type != "xlsx": - logger.error(f"Failed to download file '{file_name}'. Skipping.") - return (0, 0) - - # If the file is an Excel file, handle accordingly (skipped in this context) - if file_type == "xlsx": - logger.info(f"Excel file '{file_name}' processing is not implemented.") - return (0, 0) - - # Read the file content into bytes - file_buffer.seek(0) - file_bytes = file_buffer.read() - - # Create a new BytesIO object for processing - processing_file_buffer = io.BytesIO(file_bytes) - - # Process the file asynchronously - processed_pages = await process_file( - processing_file_buffer, - file_name, - file_type, - progress_bar, - progress_lock, - metadata=fields, - ) + async def download(metadata): + """ + Downloads a single file and updates metadata with file buffer. + """ + async with semaphore: + file_name = metadata["file_name"] + file_id = metadata["file_id"] - if not processed_pages: - logger.warning(f"No processed pages for file '{file_name}'. Skipping.") - return (0, 0) + try: + file_buffer = await asyncio.to_thread( + download_file, file_id, drive_service + ) - # Combine all page texts to create document content - document_content = "\n\n".join( - [page.get("contextualized_chunk", "") for page in processed_pages] - ) - generated_title = generate_smart_filename( - file_name=file_name, document_content=document_content - ) + metadata["file_buffer"] = file_buffer + return metadata - # Generate brief summary - brief_summary = generate_brief_summary(document_content) - if not brief_summary: - logger.warning(f"Failed to generate summary for document '{file_name}'") + except Exception as e: + logger.error(f"Error downloading file '{file_name}': {e}") + return None - # Create another BytesIO object for uploading - uploading_file_buffer = io.BytesIO(file_bytes) + tasks = [download(metadata) for metadata in metadata_list if metadata] + results = await asyncio.gather(*tasks) + return [res for res in results if res] - # Upload the file to GCP bucket - bucket_name = os.getenv("GCP_BUCKET_NAME", "survey_accelerator_files") - pdf_url = upload_file_buffer_to_gcp_bucket( - uploading_file_buffer, bucket_name, file_name - ) +async def process_files( + metadata_list: List[Dict[str, Any]], + semaphore: asyncio.Semaphore, +) -> List[Dict[str, Any]]: + """ + Processes files concurrently and updates metadata with processed pages. + """ - if not pdf_url: - logger.error(f"Failed to upload document '{file_name}' to GCP bucket") - return (0, 0) - - # Save to database with metadata - async for asession in get_async_session(): - try: - await save_document_to_db( - file_name=file_name, - processed_pages=processed_pages, - file_id=file_id, - asession=asession, + async def process(metadata): + """ + Processes a single file and updates metadata with processed pages. + """ + async with semaphore: + file_name = metadata["file_name"] + file_type = metadata["file_type"] + fields = metadata["fields"] + file_buffer = metadata["file_buffer"] + + # Read the file content into bytes + file_buffer.seek(0) + file_bytes = file_buffer.read() + + # Create a new BytesIO object for processing + processing_file_buffer = io.BytesIO(file_bytes) + + # Process the file asynchronously + processed_pages = await process_file( + processing_file_buffer, + file_name, + file_type, metadata=fields, - pdf_url=pdf_url, - summary=brief_summary, - title=generated_title, ) - break - except Exception as e: - logger.error(f"Error saving document '{file_name}' to database: {e}") - return (0, 0) - finally: - await asession.close() - logger.info(f"File '{file_name}' processed successfully.") + if not processed_pages: + logger.warning(f"No processed pages for file '{file_name}'. Skipping.") + return None + + metadata["processed_pages"] = processed_pages + metadata["file_bytes"] = file_bytes + return metadata + + tasks = [process(metadata) for metadata in metadata_list] + results = await asyncio.gather(*tasks) + return [res for res in results if res] + + +async def upload_files_to_gcp( + metadata_list: List[Dict[str, Any]], + semaphore: asyncio.Semaphore, +) -> List[Dict[str, Any]]: + """ + Uploads files to GCP concurrently and updates metadata with PDF URLs. + """ + + async def upload(metadata): + async with semaphore: + file_name = metadata["file_name"] + file_bytes = metadata["file_bytes"] + uploading_file_buffer = io.BytesIO(file_bytes) + + # Upload the file to GCP bucket + bucket_name = os.getenv("GCP_BUCKET_NAME", "survey_accelerator_files") + + pdf_url = await asyncio.to_thread( + upload_file_buffer_to_gcp_bucket, + uploading_file_buffer, + bucket_name, + file_name, + ) + + if not pdf_url: + logger.error(f"Failed to upload document '{file_name}' to GCP bucket") + return None + + metadata["pdf_url"] = pdf_url + return metadata - return (1, len(processed_pages)) + tasks = [upload(metadata) for metadata in metadata_list] + results = await asyncio.gather(*tasks) + return [res for res in results if res] -async def ingest_records(records: list[Dict[str, Any]]) -> AirtableIngestionResponse: +async def save_single_document(metadata: Dict[str, Any], asession): + """ + Saves a single document to the database and prints a success message. """ - Ingest a list of Airtable records. + await save_document_to_db( + file_name=metadata["file_name"], + processed_pages=metadata["processed_pages"], + file_id=metadata["file_id"], + asession=asession, + metadata=metadata["fields"], + pdf_url=metadata["pdf_url"], + summary=metadata.get("brief_summary"), + title=metadata.get("generated_title"), + ) + print(f"File '{metadata['file_name']}' processed and saved successfully.") - Args: - records (list[Dict[str, Any]]): The list of Airtable records to process. - Returns: - AirtableIngestionResponse: The response containing ingestion results. +async def save_all_to_db(metadata_list: List[Dict[str, Any]]): + """ + Saves all documents to the database concurrently. + """ + try: + async with get_async_session() as asession: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) + + async def save_with_semaphore(metadata, session): + async with semaphore: + await save_single_document(metadata, session) + + save_tasks = [ + save_with_semaphore(metadata, asession) for metadata in metadata_list + ] + await asyncio.gather(*save_tasks) + except Exception as e: + logger.error(f"Error saving documents to database: {e}") + + +async def ingest_records(records: List[Dict[str, Any]]) -> AirtableIngestionResponse: + """ + Ingests a list of Airtable records in parallel stages. """ total_records_processed = 0 total_chunks_created = 0 + # Step 1: Extract metadata + metadata_list = [await extract_metadata(record) for record in records] + metadata_list = [m for m in metadata_list if m] + + if not metadata_list: + return AirtableIngestionResponse( + total_records_processed=0, + total_chunks_created=0, + message="No valid records to process.", + ) - # Create an asynchronous lock for progress bar updates - progress_lock = asyncio.Lock() + # Step 2: Download files + semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) + drive_service = get_drive_service() # Get Google Drive service once - # Initialize the progress bar without a total - progress_bar = tqdm.tqdm(desc="Processing pages", unit="page", total=0) + files = await download_files(metadata_list, semaphore, drive_service) + if not metadata_list: + return AirtableIngestionResponse( + total_records_processed=0, + total_chunks_created=0, + message="Failed to download any files.", + ) - # Limit the number of concurrent tasks to prevent resource exhaustion - semaphore = asyncio.Semaphore(2) # Adjust the number as needed + # Step 3: Process files + semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - async def process_with_semaphore(record: Dict[str, Any]) -> Tuple[int, int]: - """Semaphore-protected function to process a record.""" - async with semaphore: - return await process_record(record, progress_bar, progress_lock) + updated_metadata_list = await process_files(files, semaphore) + if not metadata_list: + return AirtableIngestionResponse( + total_records_processed=0, + total_chunks_created=0, + message="No files were processed.", + ) - # Create a list of tasks to process records concurrently - tasks = [process_with_semaphore(record) for record in records] + # Step 4: Upload files to GCP + semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - # Run tasks concurrently - results = await asyncio.gather(*tasks) + metadata_list = await upload_files_to_gcp(updated_metadata_list, semaphore) - # Sum up the totals from each task - for records_processed, chunks_created in results: - total_records_processed += records_processed - total_chunks_created += chunks_created + # Step 5: Save all to database + await save_all_to_db(metadata_list) - progress_bar.close() + # Update totals + total_records_processed = len(metadata_list) + total_chunks_created = sum( + len(metadata["processed_pages"]) for metadata in metadata_list + ) return AirtableIngestionResponse( total_records_processed=total_records_processed, diff --git a/backend/app/search/utils.py b/backend/app/search/utils.py index 60149d1..902173b 100644 --- a/backend/app/search/utils.py +++ b/backend/app/search/utils.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.ingestion.models import DocumentDB -from app.ingestion.utils.openai_utils import generate_query_match_explanation +from app.ingestion.process_utils.openai_utils import generate_query_match_explanation from app.search.schemas import ( DocumentMetadata, DocumentSearchResult, diff --git a/backend/app/utils.py b/backend/app/utils.py index b538a7a..03abe5a 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -8,9 +8,7 @@ from logging import Logger from uuid import uuid4 -from .config import ( - LOG_LEVEL, -) +from .config import LOG_LEVEL # To make 32-byte API keys (results in 43 characters) SECRET_KEY_N_BYTES = 32 diff --git a/backend/tests/api/test_ingestion.py b/backend/tests/api/test_ingestion.py index a8679d7..d375853 100644 --- a/backend/tests/api/test_ingestion.py +++ b/backend/tests/api/test_ingestion.py @@ -3,7 +3,7 @@ import pytest from fastapi.testclient import TestClient -from backend.app.auth.config import API_SECRET_KEY +from backend.app.config import API_SECRET_KEY @pytest.mark.parametrize( From 5451a5569d2a2788f3098a3013ac02c035d807a1 Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Tue, 28 Jan 2025 15:50:50 +0000 Subject: [PATCH 2/6] Refactor part 1 --- backend/app/ingestion/utils/openai_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/app/ingestion/utils/openai_utils.py b/backend/app/ingestion/utils/openai_utils.py index 8d0c073..f629007 100644 --- a/backend/app/ingestion/utils/openai_utils.py +++ b/backend/app/ingestion/utils/openai_utils.py @@ -4,9 +4,10 @@ import os import openai -from app.utils import setup_logger from openai import AsyncOpenAI +from app.utils import setup_logger + logger = setup_logger() # Instantiate the OpenAI client (you can keep this or comment it out) From 1a74829863745ad047dd9830d7573fa5c3c1a7f5 Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Tue, 28 Jan 2025 15:54:07 +0000 Subject: [PATCH 3/6] Refactor part 2 --- .../ingestion/process_utils/openai_utils.py | 85 ------------------- 1 file changed, 85 deletions(-) diff --git a/backend/app/ingestion/process_utils/openai_utils.py b/backend/app/ingestion/process_utils/openai_utils.py index 1e31e2d..8d0c073 100644 --- a/backend/app/ingestion/process_utils/openai_utils.py +++ b/backend/app/ingestion/process_utils/openai_utils.py @@ -55,91 +55,6 @@ async def generate_contextual_summary(document_content: str, chunk_content: str) return "" -async def generate_brief_summary(document_content: str) -> str: - """ - Generate a concise summary of the entire document in 10-15 words. - """ - # Construct the prompt (optional) - prompt = f"""Summarize the following document in 10 to 15 in a sentence that - starts with the word 'Covers' - e.g. 'Covers womens health and contraception awareness.' or - 'Covers the impact of climate change on agriculture.' - Respond only with the summary and nothing else. - ENSURE YOUR ANSWER NEVER EXCEEDS 15 WORDS. - Below is the content to summarize: - \n\n{document_content}""" - - try: - response = await async_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=20, - temperature=0.2, - ) - summary = response.choices[0].message.content.strip() - return summary - except Exception as e: - logger.error(f"Error generating brief summary: {e}") - return "" - - -async def generate_smart_filename(file_name: str, document_content: str) -> str: - """ - Generate a specific and elegant filename devoid of year numbers. - - Args: - file_name: Original file name. - document_content: Content of the document. - - Returns: - A descriptive filename as a plain text string. - """ - # Construct the prompt with precise instructions - prompt = f""" - Given the original file name "{file_name}" and the content excerpt below, generate - a specific and descriptive filename that includes relevant information, such as the - organization name (e.g., 'USAID', 'UNICEF') and the main topic of the document. - - The filename should: - - Avoid any dates or year numbers - - Exclude prefixes like 'Filename:', asterisks, or other symbols - - Be concise, specific, and relevant to the document content - - Examples of appropriate filenames might include: - - 'USAID Family Health Survey' - - 'UNICEF Maternal Health Questionnaire' - - 'DHS Reproductive Health Analysis' - - Content excerpt: - {document_content[:2000]} - - Plain text filename: - """ - - try: - response = await async_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=10, - temperature=0.2, - ) - smart_name = response.choices[0].message.content.strip() - - # Optional: Ensure there's no unexpected text or symbols - smart_name = smart_name.split(":", 1)[ - -1 - ].strip() # Remove anything before a colon - - return smart_name - except Exception as e: - logger.error(f"Error generating elegant filename: {e}") - return "" - - async def extract_question_answer_from_page(chunk_content: str) -> list[dict]: """ Extract questions and answers from a chunk of text. From bb6b82f263cdccf94d13db0a3311ac47cb3b31ef Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Mon, 3 Mar 2025 10:49:40 +0300 Subject: [PATCH 4/6] Follow tutorial for unified logger --- backend/main.py | 130 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 112 insertions(+), 18 deletions(-) diff --git a/backend/main.py b/backend/main.py index f2f5b82..043e839 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,36 +1,130 @@ import logging +import os +import sys +from types import FrameType +from typing import Any, Dict, Optional -import uvicorn from app import create_app from app.config import BACKEND_ROOT_PATH +from gunicorn.app.base import BaseApplication +from gunicorn.glogging import Logger +from loguru import logger from uvicorn.workers import UvicornWorker -# Configure root logger -logging.basicConfig( - level=logging.INFO, # Set to INFO to reduce log verbosity - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) +LOG_LEVEL = logging.getLevelName(os.environ.get("LOG_LEVEL", "DEBUG")) +JSON_LOGS = True if os.environ.get("JSON_LOGS", "0") == "1" else False +WORKERS = int(os.environ.get("GUNICORN_WORKERS", "5")) -# Set your application's logger to DEBUG -app_logger = logging.getLogger("app") -app_logger.setLevel(logging.DEBUG) -# Suppress debug logs from third-party libraries -logging.getLogger("openai").setLevel(logging.WARNING) -logging.getLogger("httpx").setLevel(logging.WARNING) -logging.getLogger("uvicorn").setLevel(logging.WARNING) -logging.getLogger("httpcore").setLevel(logging.WARNING) -logging.getLogger("httpcore").setLevel(logging.WARNING) +class InterceptHandler(logging.Handler): + """ + Intercept standard logging messages and redirect them to Loguru. + """ + def emit(self, record: logging.LogRecord) -> None: + """ + Emit a log record. + """ + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno -app = create_app() + frame: Optional[FrameType] = sys._getframe(6) + depth = 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +class StubbedGunicornLogger(Logger): + """ + This class is used to stub out the Gunicorn logger. + """ + + def setup(self, cfg: Any) -> None: + """ + Setup the logger. + """ + handler = logging.NullHandler() + self.error_logger = logging.getLogger("gunicorn.error") + self.error_logger.addHandler(handler) + self.access_logger = logging.getLogger("gunicorn.access") + self.access_logger.addHandler(handler) + self.error_logger.setLevel(LOG_LEVEL) + self.access_logger.setLevel(LOG_LEVEL) class Worker(UvicornWorker): - """Custom worker class to allow root_path to be passed to Uvicorn""" + """Custom worker class to allow root_path to be passed to Uvicorn.""" CONFIG_KWARGS = {"root_path": BACKEND_ROOT_PATH} +class StandaloneApplication(BaseApplication): + """Our Gunicorn application.""" + + def __init__(self, app: Any, options: Optional[Dict[str, Any]] = None) -> None: + """ + Initialize the application with the given app and options. + """ + self.options = options or {} + self.application = app + super().__init__() + + def load_config(self) -> None: + """ + Load the configuration. + """ + config = { + key: value + for key, value in self.options.items() + if key in self.cfg.settings and value is not None + } + for key, value in config.items(): + self.cfg.set(key.lower(), value) + + def load(self) -> Any: + """ + Load the application. + """ + return self.application + + if __name__ == "__main__": - uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, log_level="info") + intercept_handler = InterceptHandler() + logging.root.setLevel(LOG_LEVEL) + + seen = set() + for name in [ + *logging.root.manager.loggerDict.keys(), + "gunicorn", + "gunicorn.access", + "gunicorn.error", + "uvicorn", + "uvicorn.access", + "uvicorn.error", + ]: + if name not in seen: + seen.add(name.split(".")[0]) + logging.getLogger(name).handlers = [intercept_handler] + + logger.configure(handlers=[{"sink": sys.stdout, "serialize": JSON_LOGS}]) + + # Instantiate your FastAPI app + app = create_app() + + options = { + "bind": "0.0.0.0", + "workers": WORKERS, + "accesslog": "-", + "errorlog": "-", + "worker_class": "__main__.Worker", # use import path as string + "logger_class": StubbedGunicornLogger, + } + + StandaloneApplication(app, options).run() From 88854bdd9171175b232098d29797b3b546619693 Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Mon, 3 Mar 2025 11:35:12 +0300 Subject: [PATCH 5/6] Change default configs + allow FastAPI to run from correct dir --- backend/app/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/app/config.py b/backend/app/config.py index 8f35bd6..ddf5999 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -13,8 +13,8 @@ DB_POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", 20)) # Backend Configuration -BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") -LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING") +BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "./") +LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") # PGVector Configuration PGVECTOR_VECTOR_SIZE = int(os.environ.get("PGVECTOR_VECTOR_SIZE", 1024)) From 5269e5f09a0cb159f792665084cefc44316e9fcc Mon Sep 17 00:00:00 2001 From: Mark Botterill Date: Mon, 3 Mar 2025 16:59:01 +0300 Subject: [PATCH 6/6] Refactor phase one done --- .../fetch_utils/google_drive_utils.py | 131 +++++--- backend/app/ingestion/models.py | 112 ++++--- .../process_utils/embedding_utils.py | 60 ++++ .../ingestion/process_utils/openai_utils.py | 2 +- backend/app/ingestion/routers.py | 112 ++++--- .../storage_utils/gcp_storage_utils.py | 41 +++ .../ingestion/utils/file_processing_utils.py | 30 +- .../app/ingestion/utils/record_processing.py | 280 ------------------ backend/app/search/schemas.py | 1 - backend/app/search/utils.py | 1 - backend/main.py | 3 +- ...3_03_aaa104d6c3ed_remove_file_id_column.py | 35 +++ frontend/src/interfaces.ts | 1 - 13 files changed, 389 insertions(+), 420 deletions(-) delete mode 100644 backend/app/ingestion/utils/record_processing.py create mode 100644 backend/migrations/versions/2025_03_03_aaa104d6c3ed_remove_file_id_column.py diff --git a/backend/app/ingestion/fetch_utils/google_drive_utils.py b/backend/app/ingestion/fetch_utils/google_drive_utils.py index 724bfc6..4d0ab66 100644 --- a/backend/app/ingestion/fetch_utils/google_drive_utils.py +++ b/backend/app/ingestion/fetch_utils/google_drive_utils.py @@ -1,41 +1,19 @@ -# utils/google_drive_utils.py - import io import os -from typing import Optional +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional from app.config import SCOPES, SERVICE_ACCOUNT_FILE_PATH from app.utils import setup_logger from google.oauth2 import service_account -from googleapiclient.discovery import Resource as DriveResource from googleapiclient.discovery import build from googleapiclient.http import MediaIoBaseDownload logger = setup_logger() -def get_drive_service() -> DriveResource: - """ - Authenticate using a service account and return the Drive service. - """ - creds = service_account.Credentials.from_service_account_file( - SERVICE_ACCOUNT_FILE_PATH, scopes=SCOPES - ) - drive_service = build("drive", "v3", credentials=creds) - return drive_service - - -async def extract_file_id(gdrive_url: str) -> str: - """ - Extract the file ID from a Google Drive URL. - - Supports URLs of the form: - - https://drive.google.com/file/d/FILE_ID/view?usp=sharing - - https://drive.google.com/open?id=FILE_ID - - Any URL containing 'id=FILE_ID' - """ +def extract_file_id(gdrive_url: str) -> str: if "id=" in gdrive_url: - # URL contains 'id=FILE_ID' file_id = gdrive_url.split("id=")[1].split("&")[0] if file_id: return file_id @@ -44,7 +22,6 @@ async def extract_file_id(gdrive_url: str) -> str: "Invalid Google Drive URL format: missing file ID after 'id='." ) else: - # Handle URLs of the form 'https://drive.google.com/file/d/FILE_ID/...' parts = gdrive_url.strip("/").split("/") if "d" in parts: d_index = parts.index("d") @@ -62,15 +39,12 @@ async def extract_file_id(gdrive_url: str) -> str: ) from e else: raise ValueError( - """URL format not recognized. Ensure it contains 'id=' or - follows the standard Drive URL format.""" + """URL format not recognized. Ensure it contains 'id=' or follows the + standard Drive URL format.""" ) -async def determine_file_type(file_name: str) -> str: - """ - Determine the file type based on the file extension. - """ +def determine_file_type(file_name: str) -> str: _, ext = os.path.splitext(file_name.lower()) if ext == ".pdf": return "pdf" @@ -80,23 +54,96 @@ async def determine_file_type(file_name: str) -> str: return "other" -def download_file(file_id: str, drive_service: DriveResource) -> Optional[io.BytesIO]: +def download_file(file_id: str) -> Optional[io.BytesIO]: """ - Download a file from Google Drive using its file ID and handle it based on file type - For PDFs, download into memory and return the BytesIO object. - For XLSX, download and save to disk. + Download a file from Google Drive using the Drive API. """ + creds = service_account.Credentials.from_service_account_file( + SERVICE_ACCOUNT_FILE_PATH, scopes=SCOPES + ) + drive_service = build("drive", "v3", credentials=creds) try: - logger.warning("Downloading PDF file...") request = drive_service.files().get_media(fileId=file_id) - pdf_buffer = io.BytesIO() - downloader = MediaIoBaseDownload(pdf_buffer, request) + file_buffer = io.BytesIO() + downloader = MediaIoBaseDownload(file_buffer, request) done = False while not done: status, done = downloader.next_chunk() - pdf_buffer.seek(0) # Reset buffer position to the beginning - return pdf_buffer # Return the in-memory file for PDFs + if file_buffer.getbuffer().nbytes == 0: + raise RuntimeError("No content was downloaded from the file.") + file_buffer.seek(0) + return file_buffer + except Exception as e: + logger.error(f"Error downloading file with ID '{file_id}': {e}") + return None + + +def download_file_wrapper(record: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """A wrapper function to process each record and download its file.""" + try: + fields = record.get("fields", {}) + gdrive_link = fields.get("Drive link") + file_name = fields.get("File name") + document_id = fields.get("ID") + survey_name = fields.get("Survey name") + description = fields.get("Description") + + if not gdrive_link or not file_name: + logger.error("Record is missing 'Drive link' or 'File name'") + return None + + # Check if the file is a PDF (we only want to process PDFs) + file_type = determine_file_type(file_name) + if file_type != "pdf": + logger.info(f"Skipping non-PDF file: '{file_name}'") + return None + file_id = extract_file_id(gdrive_link) + if not file_id: + logger.error(f"Could not extract file ID from link '{gdrive_link}'") + return None + + logger.info(f"Starting download of file '{file_name}'") + file_buffer = download_file(file_id) + if file_buffer is None: + logger.error(f"Failed to download file '{file_name}'") + return None + + logger.info(f"Completed download of file '{file_name}'") + return { + "file_name": file_name, + "file_buffer": file_buffer, + "file_type": file_type, + "document_id": document_id, + "survey_name": survey_name, + "summary": description, + "fields": fields, + } except Exception as e: - logger.error(f"Error downloading file: {e}") + logger.error(f"Error downloading file '{file_name}': {e}") return None + + +def download_all_files( + records: List[Dict[str, Any]], n_max_workers: int +) -> List[Dict[str, Any]]: + """Download all files concurrently using ThreadPoolExecutor.""" + downloaded_files = [] + + with ThreadPoolExecutor(max_workers=n_max_workers) as executor: + # Map each record to a future + future_to_record = { + executor.submit(download_file_wrapper, record): record for record in records + } + + for future in as_completed(future_to_record): + record = future_to_record[future] + file_name = record.get("fields", {}).get("File name", "Unknown") + try: + result = future.result() + if result is not None: + downloaded_files.append(result) + except Exception as e: + logger.error(f"Error downloading file '{file_name}': {e}") + + return downloaded_files diff --git a/backend/app/ingestion/models.py b/backend/app/ingestion/models.py index a80e2bf..847045f 100644 --- a/backend/app/ingestion/models.py +++ b/backend/app/ingestion/models.py @@ -3,7 +3,7 @@ """ from datetime import datetime, timezone -from typing import Optional +from typing import Any, Dict, List, Optional from pgvector.sqlalchemy import Vector from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text @@ -17,12 +17,16 @@ PGVECTOR_M, PGVECTOR_VECTOR_SIZE, ) +from app.database import get_async_session from app.utils import setup_logger from ..models import Base logger = setup_logger() +# Maximum number of concurrent database operations +MAX_CONCURRENT_DB_TASKS = 5 + class QAPairDB(Base): """ORM for managing question-answer pairs associated with each document.""" @@ -65,7 +69,6 @@ class DocumentDB(Base): ) id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - file_id: Mapped[str] = mapped_column(String(length=36), nullable=False) file_name: Mapped[str] = mapped_column(String(length=150), nullable=False) title: Mapped[Optional[str]] = mapped_column(String(length=150), nullable=True) summary: Mapped[Optional[str]] = mapped_column(Text, nullable=True) @@ -101,13 +104,13 @@ class DocumentDB(Base): async def save_document_to_db( *, processed_pages: list[dict], - file_id: str, file_name: str, asession: AsyncSession, metadata: dict, pdf_url: str, - summary: str, - title: str, + summary: str = None, + title: str = None, + document_id: int, ) -> None: """ Save documents and their associated QA pairs to the database. @@ -132,20 +135,13 @@ async def save_document_to_db( logger.error(f"Error processing metadata for file '{file_name}': {e}") raise - documents = [] - - logger.debug(f"Processing {len(processed_pages)} pages for file '{file_name}'.") - + # Process pages one by one to avoid transaction issues for idx, page in enumerate(processed_pages): logger.debug(f"Processing page {idx + 1}/{len(processed_pages)}.") try: - # Log the page content keys - logger.debug(f"Page keys: {list(page.keys())}") - # Create DocumentDB instance for each page document = DocumentDB( - file_id=file_id, file_name=file_name, page_number=page["page_number"], contextualized_chunk=page["contextualized_chunk"], @@ -158,17 +154,14 @@ async def save_document_to_db( regions=regions, notes=notes, drive_link=drive_link, - year=year, + year=metadata.get("Year"), date_added=date_added, - document_id=document_id, + document_id=document_id, # Using the document_id passed to the function pdf_url=pdf_url, summary=summary, title=title, ) - # Log the created DocumentDB instance - logger.debug(f"Created DocumentDB instance: {document}") - # Create QAPairDB instances and associate them with the document extracted_qa_pairs = page.get("extracted_question_answers", []) if not isinstance(extracted_qa_pairs, list): @@ -179,10 +172,6 @@ async def save_document_to_db( qa_pairs = [] for qa_idx, qa_pair in enumerate(extracted_qa_pairs): - logger.debug( - f"""Processing QA pair {qa_idx + 1}/{len(extracted_qa_pairs)} - on page {idx + 1}.""" - ) try: question = qa_pair.get("question", "") answers = qa_pair.get("answers", []) @@ -199,31 +188,72 @@ async def save_document_to_db( answer=answer_text, ) - # Log the created QAPairDB instance - logger.debug(f"Created QAPairDB instance: {qa_pair_instance}") - qa_pairs.append(qa_pair_instance) except Exception as e: logger.error(f"Error processing QA pair on page {idx + 1}: {e}") document.qa_pairs = qa_pairs - documents.append(document) - except Exception as e: - logger.error(f"Error processing page {idx + 1} for file '{file_name}': {e}") + # Add document to session and commit immediately + asession.add(document) + await asession.commit() + logger.debug(f"Saved page {idx + 1} for file '{file_name}'") - try: - # Add all documents (and their QA pairs) to the session - logger.debug( - f"Adding {len(documents)} documents to the session for file '{file_name}'." - ) - asession.add_all(documents) - logger.debug("Committing the session.") - await asession.commit() # Commit all at once - logger.debug("Session committed successfully.") - except Exception as e: - logger.error(f"Error committing to database for file '{file_name}': {e}") - await asession.rollback() - raise + except Exception as e: + logger.error(f"Error saving page {idx + 1} for file '{file_name}': {e}") + await asession.rollback() + # Continue with next page logger.debug(f"Finished save_document_to_db for file '{file_name}'.") + + +async def save_single_document(metadata: Dict[str, Any]) -> bool: + """ + Saves a single document to the database using the iterator-style session. + """ + success = False + + # Use the iterator style from the original code + async for asession in get_async_session(): + try: + await save_document_to_db( + file_name=metadata["file_name"], + processed_pages=metadata["processed_pages"], + asession=asession, + metadata=metadata.get("fields", {}), + pdf_url=metadata.get("pdf_url", ""), # Use empty string as fallback + title=metadata.get("survey_name"), + summary=metadata.get("summary"), + document_id=metadata.get( + "document_id" + ), # Pass document_id from metadata + ) + logger.info( + f"File '{metadata['file_name']}' processed and saved successfully." + ) + success = True + break # Success, exit the loop + except Exception as e: + logger.error( + f"Error saving document '{metadata['file_name']}' to database: {e}" + ) + finally: + # Explicitly close the session + await asession.close() + + return success + + +async def save_all_to_db(metadata_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Saves all documents to the database. + """ + saved_documents = [] + + # Process each document sequentially + for metadata in metadata_list: + success = await save_single_document(metadata) + if success: + saved_documents.append(metadata) + + return saved_documents diff --git a/backend/app/ingestion/process_utils/embedding_utils.py b/backend/app/ingestion/process_utils/embedding_utils.py index 3733b5c..79957cd 100644 --- a/backend/app/ingestion/process_utils/embedding_utils.py +++ b/backend/app/ingestion/process_utils/embedding_utils.py @@ -1,5 +1,9 @@ # utils/embedding_utils.py +import asyncio +import io +from typing import Any, Dict, List + import cohere from app.config import COHERE_API_KEY from app.utils import setup_logger @@ -28,3 +32,59 @@ def create_embedding(text: str) -> list: except Exception as e: logger.error(f"Error generating embedding: {e}") raise e + + +async def process_files( + metadata_list: List[Dict[str, Any]], + semaphore: asyncio.Semaphore, +) -> List[Dict[str, Any]]: + """ + Processes files concurrently and updates metadata with processed pages. + """ + # Import here to avoid circular imports + from app.ingestion.utils.file_processing_utils import process_file + + async def process(metadata): + """ + Processes a single file and updates metadata with processed pages. + """ + async with semaphore: + file_name = metadata["file_name"] + file_type = metadata.get( + "file_type", "pdf" + ) # Default to PDF if not specified + fields = metadata.get("fields", {}) + file_buffer = metadata["file_buffer"] + + # Read the file content into bytes + file_buffer.seek(0) + file_bytes = file_buffer.read() + + # Create a new BytesIO object for processing + processing_file_buffer = io.BytesIO(file_bytes) + + try: + # Process the file asynchronously + processed_pages = await process_file( + processing_file_buffer, + file_name, + file_type, + metadata=fields, + ) + + if not processed_pages: + logger.warning( + f"No processed pages for file '{file_name}'. Skipping." + ) + return None + except Exception as e: + logger.error(f"Error processing file '{file_name}': {e}") + return None + + metadata["processed_pages"] = processed_pages + metadata["file_bytes"] = file_bytes + return metadata + + tasks = [process(metadata) for metadata in metadata_list] + results = await asyncio.gather(*tasks) + return [res for res in results if res] diff --git a/backend/app/ingestion/process_utils/openai_utils.py b/backend/app/ingestion/process_utils/openai_utils.py index 8d0c073..85e3549 100644 --- a/backend/app/ingestion/process_utils/openai_utils.py +++ b/backend/app/ingestion/process_utils/openai_utils.py @@ -87,7 +87,7 @@ async def extract_question_answer_from_page(chunk_content: str) -> list[dict]: messages=[ {"role": "user", "content": prompt}, ], - max_tokens=1000, + max_tokens=2500, temperature=0, ) qa_pairs_str = response.choices[0].message.content.strip() diff --git a/backend/app/ingestion/routers.py b/backend/app/ingestion/routers.py index 03aa0e8..42eee0c 100644 --- a/backend/app/ingestion/routers.py +++ b/backend/app/ingestion/routers.py @@ -1,17 +1,20 @@ -# routers.py +import asyncio +import logging -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends from app.auth.dependencies import authenticate_key from app.ingestion.fetch_utils.airtable_utils import ( get_airtable_records, get_missing_document_ids, ) +from app.ingestion.fetch_utils.google_drive_utils import download_all_files +from app.ingestion.models import save_all_to_db +from app.ingestion.process_utils.embedding_utils import process_files from app.ingestion.schemas import AirtableIngestionResponse -from app.ingestion.utils.record_processing import ingest_records -from app.utils import setup_logger +from app.ingestion.storage_utils.gcp_storage_utils import upload_files_to_gcp -logger = setup_logger() +logger = logging.getLogger(__name__) router = APIRouter( dependencies=[Depends(authenticate_key)], @@ -23,57 +26,84 @@ "name": "Ingestion", "description": "Endpoints for ingesting documents from Airtable records", } +MAX_CONCURRENT_DOWNLOADS = 10 +MAX_CONCURRENT_UPLOADS = 10 +MAX_CONCURRENT_PROCESSING = 5 # Processing is CPU intensive @router.post("/airtable/refresh", response_model=AirtableIngestionResponse) -async def airtable_refresh_and_ingest() -> AirtableIngestionResponse: +async def airtable_refresh_and_ingest(): """ Refresh the list of documents by comparing Airtable 'ID' fields with database 'document_id's. Automatically ingest the missing documents. """ - try: - records = await get_airtable_records() - records = records[:3] - except KeyError as e: - logger.error(f"Airtable configuration error: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Airtable configuration error.", - ) from e - except Exception as e: - logger.error(f"Error fetching records from Airtable: {e}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error fetching records from Airtable: {e}", - ) from e - - if not records: + logger.info("Starting Airtable refresh and ingestion process.") + # Get all Airtable records + online_records = await get_airtable_records() + # Compare IDs with SQL DB IDs to get missing IDS + missing_ids = await get_missing_document_ids(online_records) + + # Map IDs to records + missing_records = [ + record for record in online_records if record["fields"]["ID"] in missing_ids + ] + + # Limit to 15 records for development + missing_records = ( + missing_records[:20] if len(missing_records) > 15 else missing_records + ) + + logger.info(f"Found {len(missing_records)} records to process") + + # Download the files concurrently + downloaded_files = download_all_files(missing_records, MAX_CONCURRENT_DOWNLOADS) + + if not downloaded_files: return AirtableIngestionResponse( total_records_processed=0, total_chunks_created=0, - message="No records found in Airtable.", + message="No files were downloaded.", ) - # Get missing IDs - missing_ids = await get_missing_document_ids(records) - if not missing_ids: + logger.info(f"Downloaded {len(downloaded_files)} files") + + # Upload files to GCP + upload_semaphore = asyncio.Semaphore(MAX_CONCURRENT_UPLOADS) + uploaded_files = await upload_files_to_gcp(downloaded_files, upload_semaphore) + + if not uploaded_files: return AirtableIngestionResponse( total_records_processed=0, total_chunks_created=0, - message="No new documents to ingest.", + message="No files were uploaded to GCP.", ) - # Map IDs to records - airtable_id_to_record = {} - for record in records: - id_value = record.get("fields").get("ID") - if id_value is not None: - airtable_id_to_record[id_value] = record - - records_to_process = [ - airtable_id_to_record[id_] - for id_ in missing_ids - if id_ in airtable_id_to_record - ] + logger.info(f"Uploaded {len(uploaded_files)} files to GCP") + + # Process files (text extraction and embedding) + process_semaphore = asyncio.Semaphore(MAX_CONCURRENT_PROCESSING) + processed_files = await process_files(uploaded_files, process_semaphore) + + if not processed_files: + return AirtableIngestionResponse( + total_records_processed=0, + total_chunks_created=0, + message="No files were processed successfully.", + ) + + logger.info(f"Processed {len(processed_files)} files") + + # Save processed files to database + saved_files = await save_all_to_db(processed_files) + + logger.info(f"Saved {len(saved_files)} files to database") + + # Count total pages processed + total_chunks_created = sum(len(x.get("processed_pages", [])) for x in saved_files) - return await ingest_records(records_to_process) + # Return information about processed files + return AirtableIngestionResponse( + total_records_processed=len(saved_files), + total_chunks_created=total_chunks_created, + message=f"Ingestion completed. Processed {len(saved_files)} documents.", + ) diff --git a/backend/app/ingestion/storage_utils/gcp_storage_utils.py b/backend/app/ingestion/storage_utils/gcp_storage_utils.py index 75a0f62..4a4ae1c 100644 --- a/backend/app/ingestion/storage_utils/gcp_storage_utils.py +++ b/backend/app/ingestion/storage_utils/gcp_storage_utils.py @@ -1,4 +1,7 @@ +import asyncio import io +import os +from typing import Any, Dict, List from app.config import SERVICE_ACCOUNT_FILE_PATH from app.utils import setup_logger @@ -56,3 +59,41 @@ def upload_file_buffer_to_gcp_bucket( except Exception as e: logger.error(f"Error uploading file to GCP bucket: {e}") return "" + + +async def upload_files_to_gcp( + metadata_list: List[Dict[str, Any]], + semaphore: asyncio.Semaphore, +) -> List[Dict[str, Any]]: + """ + Uploads files to GCP concurrently and updates metadata with PDF URLs. + """ + + async def upload(metadata): + async with semaphore: + file_name = metadata["file_name"] + file_buffer = metadata["file_buffer"] + + # Rewind buffer to ensure we're at the start + file_buffer.seek(0) + + # Upload the file to GCP bucket + bucket_name = os.getenv("GCP_BUCKET_NAME", "survey_accelerator_files") + + pdf_url = await asyncio.to_thread( + upload_file_buffer_to_gcp_bucket, + file_buffer, + bucket_name, + file_name, + ) + + if not pdf_url: + logger.error(f"Failed to upload document '{file_name}' to GCP bucket") + return None + + metadata["pdf_url"] = pdf_url + return metadata + + tasks = [upload(metadata) for metadata in metadata_list] + results = await asyncio.gather(*tasks) + return [res for res in results if res] diff --git a/backend/app/ingestion/utils/file_processing_utils.py b/backend/app/ingestion/utils/file_processing_utils.py index e23983f..ae220c2 100644 --- a/backend/app/ingestion/utils/file_processing_utils.py +++ b/backend/app/ingestion/utils/file_processing_utils.py @@ -32,17 +32,25 @@ def parse_pdf_file(file_buffer: BinaryIO) -> list[str]: """ Synchronously parse a PDF file into a list of page texts. """ - pdf_reader = PyPDF2.PdfReader(file_buffer) - chunks: list[str] = [] - num_pages = len(pdf_reader.pages) - for page_num in range(num_pages): - page = pdf_reader.pages[page_num] - page_text = page.extract_text() - if page_text and page_text.strip(): - chunks.append(page_text.strip()) - if not chunks: - raise RuntimeError("No text could be extracted from the uploaded PDF file.") - return chunks + try: + pdf_reader = PyPDF2.PdfReader(file_buffer) + chunks: list[str] = [] + num_pages = len(pdf_reader.pages) + for page_num in range(num_pages): + page = pdf_reader.pages[page_num] + page_text = page.extract_text() + if page_text and page_text.strip(): + chunks.append(page_text.strip()) + if not chunks: + logger.warning("No text could be extracted from the uploaded PDF file.") + return [] + return chunks + except PyPDF2.errors.PdfReadError as e: + logger.error(f"Error reading PDF file: {e}") + return [] + except Exception as e: + logger.error(f"Unexpected error parsing PDF file: {e}") + return [] async def process_file( diff --git a/backend/app/ingestion/utils/record_processing.py b/backend/app/ingestion/utils/record_processing.py deleted file mode 100644 index 5f28066..0000000 --- a/backend/app/ingestion/utils/record_processing.py +++ /dev/null @@ -1,280 +0,0 @@ -import asyncio -import io -import os -from typing import Any, Dict, List - -from app.database import get_async_session -from app.ingestion.fetch_utils.google_drive_utils import ( - determine_file_type, - download_file, - extract_file_id, - get_drive_service, -) -from app.ingestion.models import save_document_to_db -from app.ingestion.schemas import AirtableIngestionResponse -from app.ingestion.storage_utils.gcp_storage_utils import ( - upload_file_buffer_to_gcp_bucket, -) -from app.ingestion.utils.file_processing_utils import process_file -from app.utils import setup_logger - -logger = setup_logger() - -MAX_CONCURRENT_TASKS = 5 # Adjust as needed - - -async def extract_metadata(record: Dict[str, Any]) -> Dict[str, Any] | None: - """ - Extracts and validates metadata from a record. - """ - fields = record.get("fields", {}) - file_name = fields.get("File name") - survey_name = fields.get("Survey name") - gdrive_link = fields.get("Drive link") - document_id = fields.get("ID") - - if not file_name or not gdrive_link or document_id is None: - logger.warning( - f"""Record {record.get('id')} is missing 'File name', 'Drive link', or 'ID'. - Skipping.""" - ) - return None - - # Determine file type - file_type = await determine_file_type(file_name) - if file_type == "other": - logger.warning(f"File '{file_name}' has an unsupported extension. Skipping.") - return None - - # Extract file ID from Drive link - try: - file_id = await extract_file_id(gdrive_link) - except ValueError as ve: - logger.error(f"Error processing file '{file_name}': {ve}. Skipping.") - return None - - # Collect metadata - metadata = { - "record_id": record.get("id"), - "fields": fields, - "file_name": file_name, - "file_type": file_type, - "file_id": file_id, - "document_id": document_id, - "survey_name": survey_name, - } - return metadata - - -async def download_files( - metadata_list: List[Dict[str, Any]], - semaphore: asyncio.Semaphore, - drive_service, -) -> List[Dict[str, Any]]: - """ - Downloads files concurrently and updates metadata with file buffers. - """ - - async def download(metadata): - """ - Downloads a single file and updates metadata with file buffer. - """ - async with semaphore: - file_name = metadata["file_name"] - file_id = metadata["file_id"] - - try: - file_buffer = await asyncio.to_thread( - download_file, file_id, drive_service - ) - - metadata["file_buffer"] = file_buffer - return metadata - - except Exception as e: - logger.error(f"Error downloading file '{file_name}': {e}") - return None - - tasks = [download(metadata) for metadata in metadata_list if metadata] - results = await asyncio.gather(*tasks) - return [res for res in results if res] - - -async def process_files( - metadata_list: List[Dict[str, Any]], - semaphore: asyncio.Semaphore, -) -> List[Dict[str, Any]]: - """ - Processes files concurrently and updates metadata with processed pages. - """ - - async def process(metadata): - """ - Processes a single file and updates metadata with processed pages. - """ - async with semaphore: - file_name = metadata["file_name"] - file_type = metadata["file_type"] - fields = metadata["fields"] - file_buffer = metadata["file_buffer"] - - # Read the file content into bytes - file_buffer.seek(0) - file_bytes = file_buffer.read() - - # Create a new BytesIO object for processing - processing_file_buffer = io.BytesIO(file_bytes) - - # Process the file asynchronously - processed_pages = await process_file( - processing_file_buffer, - file_name, - file_type, - metadata=fields, - ) - - if not processed_pages: - logger.warning(f"No processed pages for file '{file_name}'. Skipping.") - return None - - metadata["processed_pages"] = processed_pages - metadata["file_bytes"] = file_bytes - return metadata - - tasks = [process(metadata) for metadata in metadata_list] - results = await asyncio.gather(*tasks) - return [res for res in results if res] - - -async def upload_files_to_gcp( - metadata_list: List[Dict[str, Any]], - semaphore: asyncio.Semaphore, -) -> List[Dict[str, Any]]: - """ - Uploads files to GCP concurrently and updates metadata with PDF URLs. - """ - - async def upload(metadata): - async with semaphore: - file_name = metadata["file_name"] - file_bytes = metadata["file_bytes"] - uploading_file_buffer = io.BytesIO(file_bytes) - - # Upload the file to GCP bucket - bucket_name = os.getenv("GCP_BUCKET_NAME", "survey_accelerator_files") - - pdf_url = await asyncio.to_thread( - upload_file_buffer_to_gcp_bucket, - uploading_file_buffer, - bucket_name, - file_name, - ) - - if not pdf_url: - logger.error(f"Failed to upload document '{file_name}' to GCP bucket") - return None - - metadata["pdf_url"] = pdf_url - return metadata - - tasks = [upload(metadata) for metadata in metadata_list] - results = await asyncio.gather(*tasks) - return [res for res in results if res] - - -async def save_single_document(metadata: Dict[str, Any], asession): - """ - Saves a single document to the database and prints a success message. - """ - await save_document_to_db( - file_name=metadata["file_name"], - processed_pages=metadata["processed_pages"], - file_id=metadata["file_id"], - asession=asession, - metadata=metadata["fields"], - pdf_url=metadata["pdf_url"], - summary=metadata.get("brief_summary"), - title=metadata.get("generated_title"), - ) - print(f"File '{metadata['file_name']}' processed and saved successfully.") - - -async def save_all_to_db(metadata_list: List[Dict[str, Any]]): - """ - Saves all documents to the database concurrently. - """ - try: - async with get_async_session() as asession: - semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - - async def save_with_semaphore(metadata, session): - async with semaphore: - await save_single_document(metadata, session) - - save_tasks = [ - save_with_semaphore(metadata, asession) for metadata in metadata_list - ] - await asyncio.gather(*save_tasks) - except Exception as e: - logger.error(f"Error saving documents to database: {e}") - - -async def ingest_records(records: List[Dict[str, Any]]) -> AirtableIngestionResponse: - """ - Ingests a list of Airtable records in parallel stages. - """ - total_records_processed = 0 - total_chunks_created = 0 - # Step 1: Extract metadata - metadata_list = [await extract_metadata(record) for record in records] - metadata_list = [m for m in metadata_list if m] - - if not metadata_list: - return AirtableIngestionResponse( - total_records_processed=0, - total_chunks_created=0, - message="No valid records to process.", - ) - - # Step 2: Download files - semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - drive_service = get_drive_service() # Get Google Drive service once - - files = await download_files(metadata_list, semaphore, drive_service) - if not metadata_list: - return AirtableIngestionResponse( - total_records_processed=0, - total_chunks_created=0, - message="Failed to download any files.", - ) - - # Step 3: Process files - semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - - updated_metadata_list = await process_files(files, semaphore) - if not metadata_list: - return AirtableIngestionResponse( - total_records_processed=0, - total_chunks_created=0, - message="No files were processed.", - ) - - # Step 4: Upload files to GCP - semaphore = asyncio.Semaphore(MAX_CONCURRENT_TASKS) - - metadata_list = await upload_files_to_gcp(updated_metadata_list, semaphore) - - # Step 5: Save all to database - await save_all_to_db(metadata_list) - - # Update totals - total_records_processed = len(metadata_list) - total_chunks_created = sum( - len(metadata["processed_pages"]) for metadata in metadata_list - ) - - return AirtableIngestionResponse( - total_records_processed=total_records_processed, - total_chunks_created=total_chunks_created, - message=f"Ingestion completed. Processed {total_records_processed} documents.", - ) diff --git a/backend/app/search/schemas.py b/backend/app/search/schemas.py index 3f5303b..44a403a 100644 --- a/backend/app/search/schemas.py +++ b/backend/app/search/schemas.py @@ -19,7 +19,6 @@ class DocumentMetadata(BaseModel): """Schema for the document metadata.""" id: int - file_id: str file_name: str title: str summary: str diff --git a/backend/app/search/utils.py b/backend/app/search/utils.py index 902173b..917de6c 100644 --- a/backend/app/search/utils.py +++ b/backend/app/search/utils.py @@ -54,7 +54,6 @@ def create_metadata(doc: DocumentDB) -> DocumentMetadata: """ return DocumentMetadata( id=doc.id, - file_id=doc.file_id, file_name=doc.file_name, title=doc.title, summary=doc.summary, diff --git a/backend/main.py b/backend/main.py index 043e839..df36404 100644 --- a/backend/main.py +++ b/backend/main.py @@ -11,7 +11,7 @@ from loguru import logger from uvicorn.workers import UvicornWorker -LOG_LEVEL = logging.getLevelName(os.environ.get("LOG_LEVEL", "DEBUG")) +LOG_LEVEL = logging.getLevelName(os.environ.get("LOG_LEVEL", "INFO")) JSON_LOGS = True if os.environ.get("JSON_LOGS", "0") == "1" else False WORKERS = int(os.environ.get("GUNICORN_WORKERS", "5")) @@ -121,6 +121,7 @@ def load(self) -> Any: options = { "bind": "0.0.0.0", "workers": WORKERS, + "timeout": 120, "accesslog": "-", "errorlog": "-", "worker_class": "__main__.Worker", # use import path as string diff --git a/backend/migrations/versions/2025_03_03_aaa104d6c3ed_remove_file_id_column.py b/backend/migrations/versions/2025_03_03_aaa104d6c3ed_remove_file_id_column.py new file mode 100644 index 0000000..8f62c02 --- /dev/null +++ b/backend/migrations/versions/2025_03_03_aaa104d6c3ed_remove_file_id_column.py @@ -0,0 +1,35 @@ +"""remove_file_id_column + +Revision ID: aaa104d6c3ed +Revises: ddc3a53cac8d +Create Date: 2025-03-03 16:50:59.814301 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "aaa104d6c3ed" +down_revision: Union[str, None] = "ddc3a53cac8d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("documents", "file_id") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "documents", + sa.Column( + "file_id", sa.VARCHAR(length=36), autoincrement=False, nullable=False + ), + ) + # ### end Alembic commands ### diff --git a/frontend/src/interfaces.ts b/frontend/src/interfaces.ts index d9caf35..153407d 100644 --- a/frontend/src/interfaces.ts +++ b/frontend/src/interfaces.ts @@ -2,7 +2,6 @@ export interface Metadata { id: number; - file_id: string; file_name: string; title: string; summary: string;