Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,9 @@
# BREVIA_APP_PORT=3000
# BREVIA_APP_VOLUME_NAME=breviaappdata

## File output options - uncomment to enable
## `FILE_OUTPUT_BASE_PATH` can be a local filesystem path or an S3 path like s3://my-bucket/path
# FILE_OUTPUT_BASE_PATH=/var/www/brevia/files
## `FILE_OUTPUT_BASE_URL` is the base URL used to access files, if omitted the files will be served via the `/download` endpoint,
## when using S3, the files can be served via the S3 URL, something like: https://my-bucket.s3.{region}.amazonaws.com/path
# FILE_OUTPUT_BASE_URL=https://example.com/download
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ brevia/extensions/*

# docs site
site/

# output folder
files/
1 change: 1 addition & 0 deletions brevia/async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def run_job_service(
try:
lock_job_service(job_store)
service = create_service(job_store.service)
job_store.payload['job_id'] = str(job_store.uuid)
result = service.run(job_store.payload)

except Exception as exc: # pylint: disable=broad-exception-caught
Expand Down
8 changes: 4 additions & 4 deletions brevia/base_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
from langchain_core.callbacks import CallbackManagerForRetrieverRun
"""Base retriever module"""
from langchain_core.documents import Document
from langchain_core.callbacks.manager import AsyncCallbackManagerForRetrieverRun
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever


Expand All @@ -14,8 +14,8 @@ class BreviaBaseRetriever(VectorStoreRetriever):
"""Configuration containing settings for the search from the application"""

async def _aget_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, **kwargs
) -> list[Document]:
"""
Asynchronous implementation for retrieving relevant documents with score
Merges results from multiple custom searches using different filters.
Expand Down
1 change: 1 addition & 0 deletions brevia/routers/analyze_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def upload_analyze(

payload = json.loads(payload)
payload['file_path'] = tmp_path
payload.setdefault('file_name', file.filename)
job = async_jobs.create_job(
service=service,
payload=payload,
Expand Down
2 changes: 2 additions & 0 deletions brevia/routers/app_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
completion_router,
config_router,
providers_router,
download_router,
)


Expand All @@ -28,3 +29,4 @@ def add_routers(app: FastAPI) -> None:
app.include_router(completion_router.router)
app.include_router(config_router.router)
app.include_router(providers_router.router)
app.include_router(download_router.router)
35 changes: 35 additions & 0 deletions brevia/routers/download_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Download files endpoint."""
import os
from fastapi import APIRouter, HTTPException, status
from fastapi.responses import FileResponse
from brevia.settings import get_settings


router = APIRouter()


@router.api_route(
'/download/{file_path:path}',
methods=['GET', 'HEAD'],
tags=['Download'],
)
async def download_file(file_path: str):
"""
Endpoint to download a file from the internal file system.
"""
base_path = get_settings().file_output_base_path
# Check if the base path is an S3 file path
if base_path.startswith('s3://'):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
'File download is not supported',
)

full_path = os.path.join(base_path, file_path)
if not os.path.isfile(full_path):
raise HTTPException(
status.HTTP_404_NOT_FOUND,
f'File not found: {file_path}',
)

return FileResponse(full_path, filename=os.path.basename(full_path))
23 changes: 22 additions & 1 deletion brevia/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
from brevia.tasks.text_analysis import RefineTextAnalysisTask, SummarizeTextAnalysisTask

from brevia.load_file import read
from brevia.utilities.output import LinkedFileOutput


class BaseService(ABC):
"""Base class for services"""
job_id = None

def run(self, payload: dict) -> dict:
"""Run a service using a payload"""
if not self.validate(payload):
raise ValueError(f'Invalid service payload - {payload}')
self.job_id = payload.get('job_id')
return self.execute(payload)

@abstractmethod
Expand Down Expand Up @@ -108,7 +111,7 @@ def summarize_from_file(


class RefineTextAnalysisService(BaseService):
"""Service to perform summarization from text input"""
"""Service to perform text analysis from text input"""

def execute(self, payload: dict):
"""Service logic"""
Expand Down Expand Up @@ -139,6 +142,24 @@ def validate(self, payload: dict):
return True


class RefineTextAnalysisToTxtService(RefineTextAnalysisService):
"""Service to perform text analysis from file, creating a txt file as output"""

def execute(self, payload: dict):
"""Service logic"""
result = super().execute(payload)
file_out = LinkedFileOutput(job_id=self.job_id)
file_name = 'summary.txt'
if payload.get('file_name'):
file_name = payload['file_name'].rsplit('.', 1)[0] + '.txt'
url = file_out.write(result['output'], file_name)
result['artifacts'] = [{
'name': file_name,
'url': url,
}]
return result


class FakeService(BaseService):
"""Fake class for services testing"""

Expand Down
12 changes: 11 additions & 1 deletion brevia/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from functools import lru_cache
from typing import Annotated, Any
from os import environ, path
from os import environ, path, getcwd
from urllib import parse
from sqlalchemy import NullPool, create_engine, Column, String, func, inspect
from sqlalchemy.engine import Connection
Expand Down Expand Up @@ -128,6 +128,16 @@ class Settings(BaseSettings):
# App metadata
block_openapi_urls: bool = False

# File output
file_output_base_path: str = Field(
default=f'{path.abspath(getcwd())}/files',
exclude=True
)
file_output_base_url: str = Field(
default='/download',
exclude=True
)

def update(
self,
other: dict[str, Any] | BaseSettings,
Expand Down
88 changes: 88 additions & 0 deletions brevia/utilities/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""File output utilities"""
import tempfile
import os
from brevia.settings import get_settings


class LinkedFileOutput:
"""
A class to handle file output operations in Brevia that need a download link.
"""
job_id = None

def __init__(self, job_id: str = None):
"""
Initialize the FileOutput object with a job ID.
:param job_id: The job ID to associate with this output.
"""
self.job_id = job_id

def file_path(self, filename: str):
"""
Generate the file path for the output file.
:param filename: The name of the file.
:return: The full path to the file.
"""
out_dir = get_settings().file_output_base_path
if out_dir.startswith('s3://'):
with tempfile.NamedTemporaryFile(prefix="brevia_", delete=False) as t_file:
return t_file.name
if self.job_id:
out_dir = f"{out_dir}/{self.job_id}"
os.makedirs(out_dir, exist_ok=True)

return f"{out_dir}/{filename}"

def file_url(self, filename: str):
"""
Generate the URL for the output file.
:param filename: The name of the file.
:return: The URL to access the file.
"""
# Generate the output URL
base_url = get_settings().file_output_base_url
return f'{base_url}/{filename}'

def _s3_upload(self, file_path: str, bucket_name: str, object_name: str):
"""
Upload the file to S3.
:param file_path: The path to the file to upload.
:param bucket_name: The name of the S3 bucket.
:param object_name: The S3 object name.
"""
try:
import boto3 # pylint: disable=import-outside-toplevel
s3 = boto3.client('s3')
return s3.upload_file(file_path, bucket_name, object_name)
except ModuleNotFoundError:
raise ImportError('Boto3 is not installed!')

def write(self, content: str, filename: str):
"""
Write content of the file to the specified filename.
Returns the URL of the file.
:param content: The content to write to the file.
:param filename: The name of the file to write to.
"""
output_path = self.file_path(filename)
with open(output_path, 'w', encoding='utf-8') as file:
file.write(content)

if self.job_id:
filename = f"{self.job_id}/{filename}"
base_path = get_settings().file_output_base_path
if base_path.startswith('s3://'):
# Extract bucket name and object name from S3 path
bucket_name = base_path.split('/')[2]
object_name = '/'.join(base_path.split('/')[3:]).lstrip('/')
object_name += f"/{filename}"
self._s3_upload(output_path, bucket_name, object_name.lstrip('/'))
# Remove the local temp file
os.unlink(output_path)

return self.file_url(filename)
Empty file added files/.gitkeep
Empty file.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def update_settings():
settings.tokens_users = ''
settings.status_token = ''
settings.use_test_models = True
settings.file_output_base_path = f'{Path(__file__).parent}/files'


@pytest.fixture(autouse=True)
Expand Down
1 change: 1 addition & 0 deletions tests/files/1234/test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test
32 changes: 32 additions & 0 deletions tests/routers/test_download_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Download Router module tests."""
from unittest.mock import patch
from fastapi.testclient import TestClient
from fastapi import FastAPI, status
from brevia.routers.download_router import router

app = FastAPI()
app.include_router(router)
client = TestClient(app)


def test_download_file_success():
"""Test download file success."""
response = client.get('/download/silence.mp3')
assert response.status_code == status.HTTP_200_OK


def test_download_file_not_found():
"""Test download file not found."""
response = client.get('/download/nonexistent_file.txt')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {'detail': 'File not found: nonexistent_file.txt'}


@patch('brevia.routers.download_router.get_settings')
def test_download_file_s3_path(mock_get_settings):
"""Test download file from S3 path."""
mock_get_settings.return_value.file_output_base_path = 's3://mock-bucket'

response = client.get('/download/test_file.txt')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.json() == {'detail': 'File download is not supported'}
28 changes: 28 additions & 0 deletions tests/test_services.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Services module tests"""
import os
from pathlib import Path
import pytest
from brevia.services import (
SummarizeFileService,
SummarizeTextService,
RefineTextAnalysisService,
RefineTextAnalysisToTxtService,
)
from brevia.settings import get_settings

Expand Down Expand Up @@ -62,3 +64,29 @@ def test_refine_text_analysis_fail():
'prompts': {'a': 'b'}
})
assert str(exc.value).startswith('Invalid service payload - ')


def test_summarize_file_service_txt():
"""Test SummarizeFileServiceTxt service"""
files_path = f'{Path(__file__).parent}/files'
settings = get_settings()
current_path = settings.prompts_base_path
settings.prompts_base_path = f'{files_path}/prompts'
service = RefineTextAnalysisToTxtService()
payload = {
'file_path': f'{files_path}/docs/test.txt',
'file_name': 'example.txt',
'job_id': '1234',
'prompts': {
'initial_prompt': 'initial_prompt.yml',
'refine_prompt': 'refine_prompt.yml'
}
}
result = service.run(payload)

assert 'artifacts' in result
assert result['artifacts'][0]['name'] == 'example.txt'
assert result["artifacts"][0]['url'] == '/download/1234/example.txt'

settings.prompts_base_path = current_path
os.remove(f'{files_path}/1234/example.txt')
Loading