Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ brevia/extensions/*

# docs site
site/

# output folder
./files/
1 change: 1 addition & 0 deletions brevia/async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def run_job_service(
try:
lock_job_service(job_store)
service = create_service(job_store.service)
job_store.payload['job_id'] = str(job_store.uuid)
result = service.run(job_store.payload)

except Exception as exc: # pylint: disable=broad-exception-caught
Expand Down
1 change: 1 addition & 0 deletions brevia/routers/analyze_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def upload_analyze(

payload = json.loads(payload)
payload['file_path'] = tmp_path
payload.setdefault('file_name', file.filename)
job = async_jobs.create_job(
service=service,
payload=payload,
Expand Down
2 changes: 2 additions & 0 deletions brevia/routers/app_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
completion_router,
config_router,
providers_router,
download_router,
)


Expand All @@ -28,3 +29,4 @@ def add_routers(app: FastAPI) -> None:
app.include_router(completion_router.router)
app.include_router(config_router.router)
app.include_router(providers_router.router)
app.include_router(download_router.router)
35 changes: 35 additions & 0 deletions brevia/routers/download_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Download files endpoint."""
import os
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import FileResponse
from brevia.settings import get_settings


router = APIRouter()


@router.api_route(
'/download/{file_path:path}',
methods=['GET', 'HEAD'],
tags=['Download'],
)
async def download_file(file_path: str):
"""
Endpoint to download a file from the internal file system.
"""
base_path = get_settings().file_output_base_path
# Check if the base path is an S3 file path
if base_path.startswith('s3://'):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
'File download is not supported',
)

full_path = os.path.join(base_path, file_path)
if not os.path.isfile(full_path):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f'File not found: {file_path}',
)

return FileResponse(full_path, filename=os.path.basename(full_path))
23 changes: 22 additions & 1 deletion brevia/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
from brevia.tasks.text_analysis import RefineTextAnalysisTask, SummarizeTextAnalysisTask

from brevia.load_file import read
from brevia.utilities.output import PublicFileOutput


class BaseService(ABC):
"""Base class for services"""
job_id = None

def run(self, payload: dict) -> dict:
"""Run a service using a payload"""
if not self.validate(payload):
raise ValueError(f'Invalid service payload - {payload}')
self.job_id = payload.get('job_id')
return self.execute(payload)

@abstractmethod
Expand Down Expand Up @@ -108,7 +111,7 @@ def summarize_from_file(


class RefineTextAnalysisService(BaseService):
"""Service to perform summarization from text input"""
"""Service to perform text analysis from text input"""

def execute(self, payload: dict):
"""Service logic"""
Expand Down Expand Up @@ -139,6 +142,24 @@ def validate(self, payload: dict):
return True


class RefineTextAnalysisToTxtService(RefineTextAnalysisService):
"""Service to perform text analysis from file, creating a txt file as output"""

def execute(self, payload: dict):
"""Service logic"""
result = super().execute(payload)
file_out = PublicFileOutput(job_id=self.job_id)
file_name = 'summary.txt'
if payload.get('file_name'):
file_name = payload['file_name'].rsplit('.', 1)[0] + '.txt'
url = file_out.write(result['output'], file_name)
result['artifacts'] = [{
'name': file_name,
'url': url,
}]
return result


class FakeService(BaseService):
"""Fake class for services testing"""

Expand Down
12 changes: 11 additions & 1 deletion brevia/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from functools import lru_cache
from typing import Annotated, Any
from os import environ, path
from os import environ, path, getcwd
from urllib import parse
from sqlalchemy import NullPool, create_engine, Column, String, func, inspect
from sqlalchemy.engine import Connection
Expand Down Expand Up @@ -128,6 +128,16 @@ class Settings(BaseSettings):
# App metadata
block_openapi_urls: bool = False

# File output
file_output_base_path: str = Field(
default=f'{path.abspath(getcwd())}/files',
exclude=True
)
file_output_base_url: str = Field(
default='/download',
exclude=True
)

def update(
self,
other: dict[str, Any] | BaseSettings,
Expand Down
91 changes: 91 additions & 0 deletions brevia/utilities/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import tempfile
import os
from brevia.settings import get_settings


class PublicFileOutput:
"""
A class to handle file output operations in Brevia that need a public link.
"""
job_id = None

def __init__(self, job_id: str = None):
"""
Initialize the FileOutput object with a job ID.
:param job_id: The job ID to associate with this output.
"""
self.job_id = job_id

def file_path(self, filename: str):
"""
Generate the file path for the output file.
:param filename: The name of the file.
:return: The full path to the file.
"""
base_path = get_settings().file_output_base_path
if base_path.startswith('s3://'):
out_dir = tempfile.mkdtemp()
if self.job_id:
out_dir = f"{base_path}/{self.job_id}"
os.makedirs(out_dir, exist_ok=True)
else:
out_dir = base_path

return f"{out_dir}/{filename}"

def file_url(self, filename: str):
"""
Generate the URL for the output file.
:param filename: The name of the file.
:return: The URL to access the file.
"""
# Generate the output URL
base_url = get_settings().file_output_base_url
return f'{base_url}/{filename}'

def _s3_upload(self, file_path: str, bucket_name: str, object_name: str):
"""
Upload the file to S3.
:param file_path: The path to the file to upload.
:param bucket_name: The name of the S3 bucket.
:param object_name: The S3 object name.
"""
try:
import boto3 # pylint: disable=import-outside-toplevel
s3 = boto3.client('s3')
return s3.upload_file(file_path, bucket_name, object_name)
except ImportError as exc:
raise ImportError('Boto3 is not installed!') from exc

def write(self, content: str, filename: str):
"""
Write content of the file to the specified filename.
Returns the URL of the file.
:param content: The content to write to the file.
:param filename: The name of the file to write to.
"""
output_path = self.file_path(filename)
with open(output_path, 'w', encoding='utf-8') as file:
file.write(content)

if self.job_id:
filename = f"{self.job_id}/{filename}"
base_path = get_settings().file_output_base_path
if base_path.startswith('s3://'):
# Extract bucket name and object name from S3 path
bucket_name = base_path.split('/')[2]
object_name = '/'.join(base_path.split('/')[3:]).lstrip('/')
object_name += f"/{filename}"
self._s3_upload(output_path, bucket_name, object_name.lstrip('/'))
# Remove the local file and its parent tmp directory
os.remove(output_path)
parent_dir = os.path.dirname(output_path)
if not os.listdir(parent_dir):
os.rmdir(parent_dir)

return self.file_url(filename)
Empty file added files/.gitkeep
Empty file.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def update_settings():
settings.tokens_users = ''
settings.status_token = ''
settings.use_test_models = True
settings.file_output_base_path = f'{Path(__file__).parent}/files'


@pytest.fixture(autouse=True)
Expand Down
1 change: 1 addition & 0 deletions tests/files/1234/test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test
32 changes: 32 additions & 0 deletions tests/routers/test_download_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Download Router module tests."""
from unittest.mock import patch
from fastapi.testclient import TestClient
from fastapi import FastAPI, status
from brevia.routers.download_router import router

app = FastAPI()
app.include_router(router)
client = TestClient(app)


def test_download_file_success():
"""Test download file success."""
response = client.get('/download/silence.mp3')
assert response.status_code == status.HTTP_200_OK


def test_download_file_not_found():
"""Test download file not found."""
response = client.get('/download/nonexistent_file.txt')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {'detail': 'File not found: nonexistent_file.txt'}


@patch('brevia.routers.download_router.get_settings')
def test_download_file_s3_path(mock_get_settings):
"""Test download file from S3 path."""
mock_get_settings.return_value.file_output_base_path = 's3://mock-bucket'

response = client.get('/download/test_file.txt')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {'detail': 'File download is not supported'}
28 changes: 28 additions & 0 deletions tests/test_services.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Services module tests"""
import os
from pathlib import Path
import pytest
from brevia.services import (
SummarizeFileService,
SummarizeTextService,
RefineTextAnalysisService,
RefineTextAnalysisToTxtService,
)
from brevia.settings import get_settings

Expand Down Expand Up @@ -62,3 +64,29 @@ def test_refine_text_analysis_fail():
'prompts': {'a': 'b'}
})
assert str(exc.value).startswith('Invalid service payload - ')


def test_summarize_file_service_txt():
"""Test SummarizeFileServiceTxt service"""
files_path = f'{Path(__file__).parent}/files'
settings = get_settings()
current_path = settings.prompts_base_path
settings.prompts_base_path = f'{files_path}/prompts'
service = RefineTextAnalysisToTxtService()
payload = {
'file_path': f'{files_path}/docs/test.txt',
'file_name': 'example.txt',
'job_id': '1234',
'prompts': {
'initial_prompt': 'initial_prompt.yml',
'refine_prompt': 'refine_prompt.yml'
}
}
result = service.run(payload)

assert 'artifacts' in result
assert result['artifacts'][0]['name'] == 'example.txt'
assert result["artifacts"][0]['url'] == '/download/1234/example.txt'

settings.prompts_base_path = current_path
os.remove(f'{files_path}/1234/example.txt')
Loading
Loading