diff --git a/app/main.py b/app/main.py index dfd28ac..35c925e 100644 --- a/app/main.py +++ b/app/main.py @@ -2,14 +2,13 @@ import sentry_sdk import uvicorn -from fastapi import Depends, FastAPI +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pyinstrument import Profiler from starlette import status from starlette.requests import Request from app.config import LIMIT_PERIOD, Settings -from app.services.auth import validate_api_key from app.services.detect.router import detect_router from app.services.identify.router import identify_router from app.services.limiter import request_limiter diff --git a/app/services/saver/manager.py b/app/services/saver/manager.py index b18e2c5..a610a00 100644 --- a/app/services/saver/manager.py +++ b/app/services/saver/manager.py @@ -1,3 +1,4 @@ +import asyncio import uuid from typing import TYPE_CHECKING @@ -12,6 +13,33 @@ import numpy as np +async def save_image_progress( + file: UploadFile, + name: str, + user_id: str, + progress_callback: asyncio.Event, + vector: list[float] | None = None, +): + """Asynchronously saves an image and triggers a progress callback once completed. + + Args: + ---- + file (UploadFile): The image file to be uploaded. + name (str): The name for the image. + user_id (str): Unique user identifier. + progress_callback (asyncio.Event): Event to signal when the upload is complete. + vector (list[float] | None, optional): Optional vector data associated with the image. + + Returns: + ------- + str: URL of the uploaded image. + + """ + url_upload = await save_image(file=file, name=name, user_id=user_id, vector=vector) + progress_callback.set() + return url_upload + + async def save_image( file: UploadFile, name: str, diff --git a/app/services/saver/router.py b/app/services/saver/router.py index 3cbe044..1c7acbf 100644 --- a/app/services/saver/router.py +++ b/app/services/saver/router.py @@ -1,12 +1,14 @@ import asyncio +import json from fastapi import APIRouter, Depends, UploadFile from starlette.requests import Request +from starlette.responses import StreamingResponse from app.config import LIMIT_PERIOD from app.services.auth import validate_api_key from app.services.limiter import request_limiter -from app.services.saver.manager import remove_image, save_image +from app.services.saver.manager import remove_image, save_image, save_image_progress saver_router: APIRouter = APIRouter(dependencies=[Depends(validate_api_key)], tags=["Saver"]) @@ -39,29 +41,37 @@ async def post_save_image( @saver_router.post("/saver/bulk") -@request_limiter.limit(LIMIT_PERIOD) async def post_save_images( files: list[UploadFile], user_id: str, request: Request, -) -> None: - """Save multiple images. +) -> StreamingResponse: + """Save multiple images and send progress updates via SSE.""" - Args: - ---- - files (list[UploadFile]): The files you want to update. - user_id (str): The user id that is uploading the files. - request (Request): Needed for the limiter + async def event_generator(progress: list[asyncio.Event], total_images: int) -> str: + for processed, event in enumerate(progress): + await event.wait() + event.clear() + yield f"data: {json.dumps({'processed': processed, + 'total': total_images, + 'percentage': processed / total_images})}\n\n" - """ - tasks = [] + progress_events = [asyncio.Event() for _ in files] - for file in files: + tasks = [] + for idx, file in enumerate(files): name = file.filename - task = save_image(file, name, user_id) + task = save_image_progress( + file=file, name=name, user_id=user_id, progress_callback=progress_events[idx] + ) tasks.append(task) + await asyncio.gather(*tasks) + return StreamingResponse( + event_generator(progress_events, len(files)), media_type="text/event-stream" + ) + @saver_router.delete("/delete") @request_limiter.limit(LIMIT_PERIOD) diff --git a/tests/services/saver/__init__.py b/tests/services/saver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/services/saver/test_multiple_images/1-crown.jpg b/tests/services/saver/test_multiple_images/1-crown.jpg new file mode 100644 index 0000000..ec4afd7 Binary files /dev/null and b/tests/services/saver/test_multiple_images/1-crown.jpg differ diff --git a/tests/services/saver/test_multiple_images/2-galicia.jpg b/tests/services/saver/test_multiple_images/2-galicia.jpg new file mode 100644 index 0000000..d04ed67 Binary files /dev/null and b/tests/services/saver/test_multiple_images/2-galicia.jpg differ diff --git a/tests/services/saver/test_multiple_images/3-salitos.jpg b/tests/services/saver/test_multiple_images/3-salitos.jpg new file mode 100644 index 0000000..53903cb Binary files /dev/null and b/tests/services/saver/test_multiple_images/3-salitos.jpg differ diff --git a/tests/services/saver/test_multiple_images/4-crown-red.jpg b/tests/services/saver/test_multiple_images/4-crown-red.jpg new file mode 100644 index 0000000..7b4f315 Binary files /dev/null and b/tests/services/saver/test_multiple_images/4-crown-red.jpg differ diff --git a/tests/services/saver/test_multiple_images/5-star.jpg b/tests/services/saver/test_multiple_images/5-star.jpg new file mode 100644 index 0000000..cd9cbf0 Binary files /dev/null and b/tests/services/saver/test_multiple_images/5-star.jpg differ diff --git a/tests/services/saver/test_saver_bulk.py b/tests/services/saver/test_saver_bulk.py new file mode 100644 index 0000000..e15ee6b --- /dev/null +++ b/tests/services/saver/test_saver_bulk.py @@ -0,0 +1,28 @@ +from pathlib import Path +from unittest.mock import MagicMock + +import pytest +from dotenv import load_dotenv +from starlette.requests import Request + +from app.services.saver.router import post_save_images +from app.shared.utils import upload_file + +load_dotenv() + +TEST_USER: str = "test_user" + + +@pytest.mark.asyncio +async def test_saver_bulk_ok(): + """Manual test to check if the API endpoint is working. + + This test will return always true, but it's useful for me to see if + the files are being uploaded. + """ + directory: Path = Path("tests/services/saver/test_multiple_images") + paths: list = list(directory.glob("*.jpg")) + uploaded_files = [await upload_file(path) for path in paths] + + fake_request = MagicMock(spec=Request) + await post_save_images(files=uploaded_files, user_id="test_user", request=fake_request)