Skip to content

Commit 0742cd6

Browse files
Merge pull request #1 from IDinsight/cleanup
Cleanup
2 parents e015e49 + 5269e5f commit 0742cd6

25 files changed

+833
-755
lines changed

.github/workflows/python-tests.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ env:
1111
POSTGRES_PASSWORD: postgres-test-pw
1212
POSTGRES_USER: postgres-test-user
1313
POSTGRES_DB: postgres-test-db
14-
REDIS_HOST: redis://redis:6379
1514
jobs:
1615
container-job:
1716
runs-on: ubuntu-20.04
@@ -54,7 +53,6 @@ jobs:
5453
- name: Run Unit Tests
5554
env:
5655
PROMETHEUS_MULTIPROC_DIR: /tmp
57-
REDIS_HOST: ${{ env.REDIS_HOST }}
5856
run: |
5957
cd backend
6058
export POSTGRES_HOST=postgres POSTGRES_USER=$POSTGRES_USER \

Makefile

+9-39
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
include ./deployment/docker-compose/.backend.env
2-
include ./deployment/docker-compose/.base.env
32

43
PROJECT_NAME=hew-ai
54
CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate
@@ -8,54 +7,25 @@ ENDPOINT_URL = localhost:8000
87
guard-%:
98
@if [ -z '${${*}}' ]; then echo 'ERROR: environment variable $* not set' && exit 1; fi
109

11-
# Note: Run `make fresh-env psycopg2-binary=true` to manually replace psycopg with psycopg2-binary
12-
fresh-env :
13-
conda remove --name $(PROJECT_NAME) --all -y
14-
conda create --name $(PROJECT_NAME) python==3.12 -y
15-
16-
$(CONDA_ACTIVATE) $(PROJECT_NAME); \
17-
pip install -r backend/requirements.txt --ignore-installed; \
18-
pip install -r requirements-dev.txt --ignore-installed; \
19-
pre-commit install
20-
21-
if [ "$(psycopg2-binary)" = "true" ]; then \
22-
$(CONDA_ACTIVATE) $(PROJECT_NAME); \
23-
pip uninstall -y psycopg2==2.9.9; \
24-
pip install psycopg2-binary==2.9.9; \
25-
fi
2610

2711
setup-db: guard-POSTGRES_USER guard-POSTGRES_PASSWORD guard-POSTGRES_DB
28-
-@docker stop pg-hew-ai-local
29-
-@docker rm pg-hew-ai-local
12+
-@docker stop survey-accelerator
13+
-@docker rm survey-accelerator
3014
@docker system prune -f
3115
@sleep 2
32-
@docker run --name pg-hew-ai-local \
33-
-e POSTGRES_USER=$(POSTGRES_USER) \
34-
-e POSTGRES_PASSWORD=$(POSTGRES_PASSWORD) \
35-
-e POSTGRES_DB=$(POSTGRES_DB) \
36-
-p 5432:5432 \
16+
@docker run --name survey-accelerator \
17+
-e POSTGRES_USER=${POSTGRES_USER} \
18+
-e POSTGRES_PASSWORD=${POSTGRES_PASSWORD} \
19+
-e POSTGRES_DB=${POSTGRES_DB} \
20+
-p ${POSTGRES_PORT}:5432 \
3721
-d pgvector/pgvector:pg16
3822
@sleep 5
3923
set -a && \
40-
source "$(CURDIR)/deployment/docker-compose/.base.env" && \
4124
source "$(CURDIR)/deployment/docker-compose/.backend.env" && \
4225
set +a && \
4326
cd backend && \
4427
python -m alembic upgrade head
4528

4629
teardown-db:
47-
@docker stop pg-hew-ai-local
48-
@docker rm pg-hew-ai-local
49-
50-
setup-redis:
51-
-@docker stop redis-hew-ai-local
52-
-@docker rm redis-hew-ai-local
53-
@docker system prune -f
54-
@sleep 2
55-
@docker run --name redis-hew-ai-local \
56-
-p 6379:6379 \
57-
-d redis:6.0-alpine
58-
59-
make teardown-redis:
60-
@docker stop redis-hew-ai-local
61-
@docker rm redis-hew-ai-local
30+
@docker stop survey-accelerator
31+
@docker rm survey-accelerator

backend/app/__init__.py

-21
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,21 @@
1-
from contextlib import asynccontextmanager
2-
from typing import AsyncIterator
3-
41
from fastapi import FastAPI
52
from fastapi.middleware.cors import CORSMiddleware
6-
from redis import asyncio as aioredis
73

84
from . import ingestion, search
9-
from .config import REDIS_HOST
105
from .utils import setup_logger
116

127
logger = setup_logger()
138

149
tags_metadata = [ingestion.TAG_METADATA]
1510

1611

17-
@asynccontextmanager
18-
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
19-
"""
20-
Lifespan events for the FastAPI application.
21-
"""
22-
23-
logger.info("Application started")
24-
app.state.redis = await aioredis.from_url(REDIS_HOST)
25-
26-
yield
27-
28-
await app.state.redis.close()
29-
logger.info("Application finished")
30-
31-
3212
def create_app() -> FastAPI:
3313
"""
3414
Create a FastAPI application with the appropriate routers.
3515
"""
3616
app = FastAPI(
3717
title="Survey Accelerator",
3818
openapi_tags=tags_metadata,
39-
lifespan=lifespan,
4019
debug=True,
4120
)
4221

backend/app/auth/config.py

-3
This file was deleted.

backend/app/auth/dependencies.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
HTTPBearer,
55
)
66

7-
from .config import API_SECRET_KEY
7+
from ..config import API_SECRET_KEY
88

99
bearer = HTTPBearer()
1010

backend/app/config.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import os
22

3+
# Auth
4+
API_SECRET_KEY = os.getenv("API_SECRET_KEY", "kk")
5+
6+
37
# PostgreSQL Configurations
48
POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres")
59
POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres")
@@ -8,12 +12,9 @@
812
POSTGRES_DB = os.environ.get("POSTGRES_DB", "postgres")
913
DB_POOL_SIZE = int(os.environ.get("DB_POOL_SIZE", 20))
1014

11-
# Redis Configuration
12-
REDIS_HOST = os.environ.get("REDIS_HOST", "redis://localhost:6379")
13-
1415
# Backend Configuration
15-
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "")
16-
LOG_LEVEL = os.environ.get("LOG_LEVEL", "WARNING")
16+
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "./")
17+
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
1718

1819
# PGVector Configuration
1920
PGVECTOR_VECTOR_SIZE = int(os.environ.get("PGVECTOR_VECTOR_SIZE", 1024))
@@ -50,6 +51,4 @@
5051
)
5152

5253
# Other Configurations
53-
MAX_PAGES = int(os.environ.get("MAX_PAGES", 3))
5454
MAIN_DOWNLOAD_DIR = "downloaded_gdrives_sa"
55-
XLSX_SUBDIR = os.path.join(MAIN_DOWNLOAD_DIR, "xlsx")

backend/app/ingestion/utils/airtable_utils.py backend/app/ingestion/fetch_utils/airtable_utils.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
from typing import Any, Dict
44

5-
from pyairtable import Api
6-
from sqlalchemy import select
7-
85
from app.config import AIRTABLE_API_KEY, AIRTABLE_CONFIGS
96
from app.database import get_async_session
107
from app.ingestion.models import DocumentDB
118
from app.utils import setup_logger
9+
from pyairtable import Api
10+
from sqlalchemy import select
1211

1312
logger = setup_logger()
1413

@@ -17,7 +16,7 @@
1716
raise EnvironmentError("Airtable API key not found in environment variables.")
1817

1918

20-
def get_airtable_records() -> list:
19+
async def get_airtable_records() -> list:
2120
"""
2221
Fetch records from Airtable and return a list of records.
2322
Raises exceptions if there are issues fetching the records.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import io
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
from typing import Any, Dict, List, Optional
5+
6+
from app.config import SCOPES, SERVICE_ACCOUNT_FILE_PATH
7+
from app.utils import setup_logger
8+
from google.oauth2 import service_account
9+
from googleapiclient.discovery import build
10+
from googleapiclient.http import MediaIoBaseDownload
11+
12+
logger = setup_logger()
13+
14+
15+
def extract_file_id(gdrive_url: str) -> str:
16+
if "id=" in gdrive_url:
17+
file_id = gdrive_url.split("id=")[1].split("&")[0]
18+
if file_id:
19+
return file_id
20+
else:
21+
raise ValueError(
22+
"Invalid Google Drive URL format: missing file ID after 'id='."
23+
)
24+
else:
25+
parts = gdrive_url.strip("/").split("/")
26+
if "d" in parts:
27+
d_index = parts.index("d")
28+
try:
29+
file_id = parts[d_index + 1]
30+
if file_id:
31+
return file_id
32+
else:
33+
raise ValueError(
34+
"Invalid Google Drive URL format: missing file ID after '/d/'."
35+
)
36+
except IndexError as e:
37+
raise ValueError(
38+
"Invalid Google Drive URL format: incomplete URL."
39+
) from e
40+
else:
41+
raise ValueError(
42+
"""URL format not recognized. Ensure it contains 'id=' or follows the
43+
standard Drive URL format."""
44+
)
45+
46+
47+
def determine_file_type(file_name: str) -> str:
48+
_, ext = os.path.splitext(file_name.lower())
49+
if ext == ".pdf":
50+
return "pdf"
51+
elif ext == ".xlsx":
52+
return "xlsx"
53+
else:
54+
return "other"
55+
56+
57+
def download_file(file_id: str) -> Optional[io.BytesIO]:
58+
"""
59+
Download a file from Google Drive using the Drive API.
60+
"""
61+
creds = service_account.Credentials.from_service_account_file(
62+
SERVICE_ACCOUNT_FILE_PATH, scopes=SCOPES
63+
)
64+
drive_service = build("drive", "v3", credentials=creds)
65+
try:
66+
request = drive_service.files().get_media(fileId=file_id)
67+
file_buffer = io.BytesIO()
68+
downloader = MediaIoBaseDownload(file_buffer, request)
69+
done = False
70+
while not done:
71+
status, done = downloader.next_chunk()
72+
if file_buffer.getbuffer().nbytes == 0:
73+
raise RuntimeError("No content was downloaded from the file.")
74+
file_buffer.seek(0)
75+
return file_buffer
76+
except Exception as e:
77+
logger.error(f"Error downloading file with ID '{file_id}': {e}")
78+
return None
79+
80+
81+
def download_file_wrapper(record: Dict[str, Any]) -> Optional[Dict[str, Any]]:
82+
"""A wrapper function to process each record and download its file."""
83+
try:
84+
fields = record.get("fields", {})
85+
gdrive_link = fields.get("Drive link")
86+
file_name = fields.get("File name")
87+
document_id = fields.get("ID")
88+
survey_name = fields.get("Survey name")
89+
description = fields.get("Description")
90+
91+
if not gdrive_link or not file_name:
92+
logger.error("Record is missing 'Drive link' or 'File name'")
93+
return None
94+
95+
# Check if the file is a PDF (we only want to process PDFs)
96+
file_type = determine_file_type(file_name)
97+
if file_type != "pdf":
98+
logger.info(f"Skipping non-PDF file: '{file_name}'")
99+
return None
100+
101+
file_id = extract_file_id(gdrive_link)
102+
if not file_id:
103+
logger.error(f"Could not extract file ID from link '{gdrive_link}'")
104+
return None
105+
106+
logger.info(f"Starting download of file '{file_name}'")
107+
file_buffer = download_file(file_id)
108+
if file_buffer is None:
109+
logger.error(f"Failed to download file '{file_name}'")
110+
return None
111+
112+
logger.info(f"Completed download of file '{file_name}'")
113+
return {
114+
"file_name": file_name,
115+
"file_buffer": file_buffer,
116+
"file_type": file_type,
117+
"document_id": document_id,
118+
"survey_name": survey_name,
119+
"summary": description,
120+
"fields": fields,
121+
}
122+
except Exception as e:
123+
logger.error(f"Error downloading file '{file_name}': {e}")
124+
return None
125+
126+
127+
def download_all_files(
128+
records: List[Dict[str, Any]], n_max_workers: int
129+
) -> List[Dict[str, Any]]:
130+
"""Download all files concurrently using ThreadPoolExecutor."""
131+
downloaded_files = []
132+
133+
with ThreadPoolExecutor(max_workers=n_max_workers) as executor:
134+
# Map each record to a future
135+
future_to_record = {
136+
executor.submit(download_file_wrapper, record): record for record in records
137+
}
138+
139+
for future in as_completed(future_to_record):
140+
record = future_to_record[future]
141+
file_name = record.get("fields", {}).get("File name", "Unknown")
142+
try:
143+
result = future.result()
144+
if result is not None:
145+
downloaded_files.append(result)
146+
except Exception as e:
147+
logger.error(f"Error downloading file '{file_name}': {e}")
148+
149+
return downloaded_files

0 commit comments

Comments
 (0)