Skip to content

Cleanup #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/python-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ env:
POSTGRES_PASSWORD: postgres-test-pw
POSTGRES_USER: postgres-test-user
POSTGRES_DB: postgres-test-db
REDIS_HOST: redis://redis:6379
jobs:
container-job:
runs-on: ubuntu-20.04
Expand Down Expand Up @@ -54,7 +53,6 @@ jobs:
- name: Run Unit Tests
env:
PROMETHEUS_MULTIPROC_DIR: /tmp
REDIS_HOST: ${{ env.REDIS_HOST }}
run: |
cd backend
export POSTGRES_HOST=postgres POSTGRES_USER=$POSTGRES_USER \
Expand Down
48 changes: 9 additions & 39 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
include ./deployment/docker-compose/.backend.env
include ./deployment/docker-compose/.base.env

PROJECT_NAME=hew-ai
CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate
Expand All @@ -8,54 +7,25 @@ ENDPOINT_URL = localhost:8000
guard-%:
@if [ -z '${${*}}' ]; then echo 'ERROR: environment variable $* not set' && exit 1; fi

# Note: Run `make fresh-env psycopg2-binary=true` to manually replace psycopg with psycopg2-binary
fresh-env :
conda remove --name $(PROJECT_NAME) --all -y
conda create --name $(PROJECT_NAME) python==3.12 -y

$(CONDA_ACTIVATE) $(PROJECT_NAME); \
pip install -r backend/requirements.txt --ignore-installed; \
pip install -r requirements-dev.txt --ignore-installed; \
pre-commit install

if [ "$(psycopg2-binary)" = "true" ]; then \
$(CONDA_ACTIVATE) $(PROJECT_NAME); \
pip uninstall -y psycopg2==2.9.9; \
pip install psycopg2-binary==2.9.9; \
fi

setup-db: guard-POSTGRES_USER guard-POSTGRES_PASSWORD guard-POSTGRES_DB
-@docker stop pg-hew-ai-local
-@docker rm pg-hew-ai-local
-@docker stop survey-accelerator
-@docker rm survey-accelerator
@docker system prune -f
@sleep 2
@docker run --name pg-hew-ai-local \
-e POSTGRES_USER=$(POSTGRES_USER) \
-e POSTGRES_PASSWORD=$(POSTGRES_PASSWORD) \
-e POSTGRES_DB=$(POSTGRES_DB) \
-p 5432:5432 \
@docker run --name survey-accelerator \
-e POSTGRES_USER=${POSTGRES_USER} \
-e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \
-e POSTGRES_DB=${POSTGRES_DB} \
-p ${POSTGRES_PORT}:5432 \
-d pgvector/pgvector:pg16
@sleep 5
set -a && \
source "$(CURDIR)/deployment/docker-compose/.base.env" && \
source "$(CURDIR)/deployment/docker-compose/.backend.env" && \
set +a && \
cd backend && \
python -m alembic upgrade head

teardown-db:
@docker stop pg-hew-ai-local
@docker rm pg-hew-ai-local

setup-redis:
-@docker stop redis-hew-ai-local
-@docker rm redis-hew-ai-local
@docker system prune -f
@sleep 2
@docker run --name redis-hew-ai-local \
-p 6379:6379 \
-d redis:6.0-alpine

make teardown-redis:
@docker stop redis-hew-ai-local
@docker rm redis-hew-ai-local
@docker stop survey-accelerator
@docker rm survey-accelerator
21 changes: 0 additions & 21 deletions backend/app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,21 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from redis import asyncio as aioredis

from . import ingestion, search
from .config import REDIS_HOST
from .utils import setup_logger

logger = setup_logger()

tags_metadata = [ingestion.TAG_METADATA]


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""
Lifespan events for the FastAPI application.
"""

logger.info("Application started")
app.state.redis = await aioredis.from_url(REDIS_HOST)

yield

await app.state.redis.close()
logger.info("Application finished")


def create_app() -> FastAPI:
"""
Create a FastAPI application with the appropriate routers.
"""
app = FastAPI(
title="Survey Accelerator",
openapi_tags=tags_metadata,
lifespan=lifespan,
debug=True,
)

Expand Down
3 changes: 0 additions & 3 deletions backend/app/auth/config.py

This file was deleted.

2 changes: 1 addition & 1 deletion backend/app/auth/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
HTTPBearer,
)

from .config import API_SECRET_KEY
from ..config import API_SECRET_KEY

bearer = HTTPBearer()

Expand Down
13 changes: 6 additions & 7 deletions backend/app/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os

# Auth
API_SECRET_KEY = os.getenv("API_SECRET_KEY", "kk")


# PostgreSQL Configurations
POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres")
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres")
Expand All @@ -8,12 +12,9 @@
POSTGRES_DB = os.environ.get("POSTGRES_DB", "postgres")
DB_POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", 20))

# Redis Configuration
REDIS_HOST = os.environ.get("REDIS_HOST", "redis://localhost:6379")

# Backend Configuration
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "")
LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING")
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "./")
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")

# PGVector Configuration
PGVECTOR_VECTOR_SIZE = int(os.environ.get("PGVECTOR_VECTOR_SIZE", 1024))
Expand Down Expand Up @@ -50,6 +51,4 @@
)

# Other Configurations
MAX_PAGES = int(os.environ.get("MAX_PAGES", 3))
MAIN_DOWNLOAD_DIR = "downloaded_gdrives_sa"
XLSX_SUBDIR = os.path.join(MAIN_DOWNLOAD_DIR, "xlsx")
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import Any, Dict

from pyairtable import Api
from sqlalchemy import select

from app.config import AIRTABLE_API_KEY, AIRTABLE_CONFIGS
from app.database import get_async_session
from app.ingestion.models import DocumentDB
from app.utils import setup_logger
from pyairtable import Api
from sqlalchemy import select

logger = setup_logger()

Expand All @@ -17,7 +16,7 @@
raise EnvironmentError("Airtable API key not found in environment variables.")


def get_airtable_records() -> list:
async def get_airtable_records() -> list:
"""
Fetch records from Airtable and return a list of records.
Raises exceptions if there are issues fetching the records.
Expand Down
149 changes: 149 additions & 0 deletions backend/app/ingestion/fetch_utils/google_drive_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import io
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional

from app.config import SCOPES, SERVICE_ACCOUNT_FILE_PATH
from app.utils import setup_logger
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload

logger = setup_logger()


def extract_file_id(gdrive_url: str) -> str:
if "id=" in gdrive_url:
file_id = gdrive_url.split("id=")[1].split("&")[0]
if file_id:
return file_id
else:
raise ValueError(
"Invalid Google Drive URL format: missing file ID after 'id='."
)
else:
parts = gdrive_url.strip("/").split("/")
if "d" in parts:
d_index = parts.index("d")
try:
file_id = parts[d_index + 1]
if file_id:
return file_id
else:
raise ValueError(
"Invalid Google Drive URL format: missing file ID after '/d/'."
)
except IndexError as e:
raise ValueError(
"Invalid Google Drive URL format: incomplete URL."
) from e
else:
raise ValueError(
"""URL format not recognized. Ensure it contains 'id=' or follows the
standard Drive URL format."""
)


def determine_file_type(file_name: str) -> str:
_, ext = os.path.splitext(file_name.lower())
if ext == ".pdf":
return "pdf"
elif ext == ".xlsx":
return "xlsx"
else:
return "other"


def download_file(file_id: str) -> Optional[io.BytesIO]:
"""
Download a file from Google Drive using the Drive API.
"""
creds = service_account.Credentials.from_service_account_file(
SERVICE_ACCOUNT_FILE_PATH, scopes=SCOPES
)
drive_service = build("drive", "v3", credentials=creds)
try:
request = drive_service.files().get_media(fileId=file_id)
file_buffer = io.BytesIO()
downloader = MediaIoBaseDownload(file_buffer, request)
done = False
while not done:
status, done = downloader.next_chunk()
if file_buffer.getbuffer().nbytes == 0:
raise RuntimeError("No content was downloaded from the file.")
file_buffer.seek(0)
return file_buffer
except Exception as e:
logger.error(f"Error downloading file with ID '{file_id}': {e}")
return None


def download_file_wrapper(record: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""A wrapper function to process each record and download its file."""
try:
fields = record.get("fields", {})
gdrive_link = fields.get("Drive link")
file_name = fields.get("File name")
document_id = fields.get("ID")
survey_name = fields.get("Survey name")
description = fields.get("Description")

if not gdrive_link or not file_name:
logger.error("Record is missing 'Drive link' or 'File name'")
return None

# Check if the file is a PDF (we only want to process PDFs)
file_type = determine_file_type(file_name)
if file_type != "pdf":
logger.info(f"Skipping non-PDF file: '{file_name}'")
return None

file_id = extract_file_id(gdrive_link)
if not file_id:
logger.error(f"Could not extract file ID from link '{gdrive_link}'")
return None

logger.info(f"Starting download of file '{file_name}'")
file_buffer = download_file(file_id)
if file_buffer is None:
logger.error(f"Failed to download file '{file_name}'")
return None

logger.info(f"Completed download of file '{file_name}'")
return {
"file_name": file_name,
"file_buffer": file_buffer,
"file_type": file_type,
"document_id": document_id,
"survey_name": survey_name,
"summary": description,
"fields": fields,
}
except Exception as e:
logger.error(f"Error downloading file '{file_name}': {e}")
return None


def download_all_files(
records: List[Dict[str, Any]], n_max_workers: int
) -> List[Dict[str, Any]]:
"""Download all files concurrently using ThreadPoolExecutor."""
downloaded_files = []

with ThreadPoolExecutor(max_workers=n_max_workers) as executor:
# Map each record to a future
future_to_record = {
executor.submit(download_file_wrapper, record): record for record in records
}

for future in as_completed(future_to_record):
record = future_to_record[future]
file_name = record.get("fields", {}).get("File name", "Unknown")
try:
result = future.result()
if result is not None:
downloaded_files.append(result)
except Exception as e:
logger.error(f"Error downloading file '{file_name}': {e}")

return downloaded_files
Loading
Loading