diff --git a/alembic/env.py b/alembic/env.py index 54f54ca3..4f8e48f2 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,11 +4,11 @@ from logging.config import fileConfig from pathlib import Path -from alembic import context from alembic.autogenerate import rewriter from alembic.operations import ops -from sqlalchemy import engine_from_config -from sqlalchemy import pool +from sqlalchemy import engine_from_config, pool + +from alembic import context from src.config.settings import settings from src.infrastructure.repositories.sql.connector import Base diff --git a/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py b/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py index 6f6f27ea..68b7e182 100644 --- a/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py +++ b/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "04e0f5f5f0af" diff --git a/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py b/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py index e670d092..9e70cd48 100644 --- a/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py +++ b/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "0ce7f69147eb" diff --git a/alembic/versions/1062951c28c7_add_ingestion_type_to_jobs.py b/alembic/versions/1062951c28c7_add_ingestion_type_to_jobs.py index 2a4fc739..1fe1cf16 100644 --- a/alembic/versions/1062951c28c7_add_ingestion_type_to_jobs.py +++ b/alembic/versions/1062951c28c7_add_ingestion_type_to_jobs.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "1062951c28c7" diff --git a/alembic/versions/25bc49e441f1_add_subject_id_to_ingestion_jobs.py b/alembic/versions/25bc49e441f1_add_subject_id_to_ingestion_jobs.py index 85ec6c3f..86b1af91 100644 --- a/alembic/versions/25bc49e441f1_add_subject_id_to_ingestion_jobs.py +++ b/alembic/versions/25bc49e441f1_add_subject_id_to_ingestion_jobs.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "25bc49e441f1" diff --git a/alembic/versions/31fc75e62cf8_add_status_message_to_ingestion_jobs.py b/alembic/versions/31fc75e62cf8_add_status_message_to_ingestion_jobs.py index ca834d9e..4e2df639 100644 --- a/alembic/versions/31fc75e62cf8_add_status_message_to_ingestion_jobs.py +++ b/alembic/versions/31fc75e62cf8_add_status_message_to_ingestion_jobs.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "31fc75e62cf8" diff --git a/alembic/versions/40f70fb7bb1c_add_index_to_chunk_index_and_source_.py b/alembic/versions/40f70fb7bb1c_add_index_to_chunk_index_and_source_.py index ba7943c9..77adad03 100644 --- a/alembic/versions/40f70fb7bb1c_add_index_to_chunk_index_and_source_.py +++ b/alembic/versions/40f70fb7bb1c_add_index_to_chunk_index_and_source_.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "40f70fb7bb1c" diff --git a/alembic/versions/4c007a7c27e8_add_status_and_error_message_to_.py b/alembic/versions/4c007a7c27e8_add_status_and_error_message_to_.py index 01a7c45b..b0e54c4f 100644 --- a/alembic/versions/4c007a7c27e8_add_status_and_error_message_to_.py +++ b/alembic/versions/4c007a7c27e8_add_status_and_error_message_to_.py @@ -9,6 +9,7 @@ from typing import Sequence, Union import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py b/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py index b34fca90..08373507 100644 --- a/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py +++ b/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "4e8d4e04a288" diff --git a/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py b/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py index d7c10bcd..53527a7e 100644 --- a/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py +++ b/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "50420d500c2e" diff --git a/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py b/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py index b82b602e..d77ae0d3 100644 --- a/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py +++ b/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "5736075a22d0" diff --git a/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py b/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py index 51389e63..3929a5f7 100644 --- a/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py +++ b/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "5ff7984a3bcc" diff --git a/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py b/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py index f248b63e..a9066038 100644 --- a/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py +++ b/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "6e53bc32edfe" diff --git a/alembic/versions/72f69987a221_rename_diarization_title_to_name.py b/alembic/versions/72f69987a221_rename_diarization_title_to_name.py index 7e6a9858..f01fa161 100644 --- a/alembic/versions/72f69987a221_rename_diarization_title_to_name.py +++ b/alembic/versions/72f69987a221_rename_diarization_title_to_name.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "72f69987a221" diff --git a/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py b/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py index 88cdee3e..f57e532a 100644 --- a/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py +++ b/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "73f13c5ff10a" diff --git a/alembic/versions/790acc587e00_add_content_column_to_chunk_index.py b/alembic/versions/790acc587e00_add_content_column_to_chunk_index.py index 987d43f7..8e76a859 100644 --- a/alembic/versions/790acc587e00_add_content_column_to_chunk_index.py +++ b/alembic/versions/790acc587e00_add_content_column_to_chunk_index.py @@ -9,6 +9,7 @@ from typing import Sequence, Union import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py b/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py index ad9b1068..d38b6bc9 100644 --- a/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py +++ b/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "946d88fe08b1" diff --git a/alembic/versions/a1b2c3d4e5f6_add_status_message_to_diarizations.py b/alembic/versions/a1b2c3d4e5f6_add_status_message_to_diarizations.py index b86fcff3..f641e1bb 100644 --- a/alembic/versions/a1b2c3d4e5f6_add_status_message_to_diarizations.py +++ b/alembic/versions/a1b2c3d4e5f6_add_status_message_to_diarizations.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "a1b2c3d4e5f6" diff --git a/alembic/versions/a4e3eb3d951c_add_tokens_count_to_chunks.py b/alembic/versions/a4e3eb3d951c_add_tokens_count_to_chunks.py index 94a0f458..1c7c2227 100644 --- a/alembic/versions/a4e3eb3d951c_add_tokens_count_to_chunks.py +++ b/alembic/versions/a4e3eb3d951c_add_tokens_count_to_chunks.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "a4e3eb3d951c" diff --git a/alembic/versions/bd01964d9b26_created_tables.py b/alembic/versions/bd01964d9b26_created_tables.py index 47b511ec..8b2190b6 100644 --- a/alembic/versions/bd01964d9b26_created_tables.py +++ b/alembic/versions/bd01964d9b26_created_tables.py @@ -9,6 +9,7 @@ from typing import Sequence, Union import sqlalchemy as sa + from alembic import op # revision identifiers, used by Alembic. diff --git a/alembic/versions/c16fab000f02_add_user_table.py b/alembic/versions/c16fab000f02_add_user_table.py index 2335eed3..46bb1530 100644 --- a/alembic/versions/c16fab000f02_add_user_table.py +++ b/alembic/versions/c16fab000f02_add_user_table.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "c16fab000f02" diff --git a/alembic/versions/c48798b08031_add_voice_samples_table.py b/alembic/versions/c48798b08031_add_voice_samples_table.py index 7acc3793..57192030 100644 --- a/alembic/versions/c48798b08031_add_voice_samples_table.py +++ b/alembic/versions/c48798b08031_add_voice_samples_table.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "c48798b08031" diff --git a/alembic/versions/d3b1a2c3d4e5_add_chunks_count_to_ingestion_jobs.py b/alembic/versions/d3b1a2c3d4e5_add_chunks_count_to_ingestion_jobs.py index a7c6c2b9..ccdee1ae 100644 --- a/alembic/versions/d3b1a2c3d4e5_add_chunks_count_to_ingestion_jobs.py +++ b/alembic/versions/d3b1a2c3d4e5_add_chunks_count_to_ingestion_jobs.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "d3b1a2c3d4e5" diff --git a/alembic/versions/d6d36d89a425_add_source_metadata_to_diarizations.py b/alembic/versions/d6d36d89a425_add_source_metadata_to_diarizations.py index 31c2815f..e38d20d4 100644 --- a/alembic/versions/d6d36d89a425_add_source_metadata_to_diarizations.py +++ b/alembic/versions/d6d36d89a425_add_source_metadata_to_diarizations.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "d6d36d89a425" diff --git a/alembic/versions/e7509b75b129_add_icon_column_to_knowledgesubject.py b/alembic/versions/e7509b75b129_add_icon_column_to_knowledgesubject.py index e7009745..a4bdf387 100644 --- a/alembic/versions/e7509b75b129_add_icon_column_to_knowledgesubject.py +++ b/alembic/versions/e7509b75b129_add_icon_column_to_knowledgesubject.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "e7509b75b129" diff --git a/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py b/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py index 7848ecf3..887bc5df 100644 --- a/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py +++ b/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py @@ -8,9 +8,9 @@ from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = "f120b614600a" diff --git a/main.py b/main.py index e0bcf4ca..6ae2b529 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,9 @@ -from contextlib import asynccontextmanager -import warnings import os +# ruff: noqa: E402 +import warnings +from contextlib import asynccontextmanager + # Suppress NNPACK warnings (Unsupported hardware) os.environ["NNPACK_CPU_FAST_8x8_CONV"] = "0" @@ -9,12 +11,14 @@ # causes DLL discovery issues on Windows that don't affect the core app. warnings.filterwarnings("ignore", category=UserWarning, module="torchcodec") -from fastapi import FastAPI, Depends # noqa: E402 +from fastapi import Depends, FastAPI # noqa: E402 from fastapi.middleware.cors import CORSMiddleware # noqa: E402 from src.config.logger import setup_logging # noqa: E402 from src.presentation.api.dependencies import get_current_user # noqa: E402 -from src.presentation.api.middleware.trace_middleware import TraceMiddleware # noqa: E402 +from src.presentation.api.middleware.trace_middleware import ( + TraceMiddleware, # noqa: E402 +) from src.presentation.api.routes import ( # noqa: E402 audio_diarization_and_recognition_router as audio_router, ) @@ -29,7 +33,9 @@ source_router, subject_router, ) -from src.presentation.api.routes import voice_profile_management_router as voice_router # noqa: E402 +from src.presentation.api.routes import ( + voice_profile_management_router as voice_router, # noqa: E402 +) logger = setup_logging() @@ -46,10 +52,10 @@ async def lifespan(app: FastAPI): from src.config.settings import Settings from src.infrastructure.services.model_loader_service import ModelLoaderService from src.infrastructure.services.re_rank_service import ReRankService + from src.infrastructure.services.redis_event_bus import RedisEventBus from src.infrastructure.services.redis_task_queue_service import ( RedisTaskQueueService, ) - from src.infrastructure.services.redis_event_bus import RedisEventBus logger.info("Initializing Settings...") _settings = Settings() @@ -81,16 +87,16 @@ async def lifespan(app: FastAPI): logger.info("Re-rank model pre-loaded successfully.") # Register worker tasks and initialize Redis Task Queue - from src.infrastructure.services.redis_task_queue_service import register_task from src.application.workers import ( - run_file_ingestion_worker, - run_youtube_ingestion_worker, - run_web_ingestion_worker, + run_audio_diarization_dispatcher_worker, run_audio_diarization_worker, run_diarization_ingestion_worker, + run_file_ingestion_worker, + run_web_ingestion_worker, run_youtube_dispatcher_worker, - run_audio_diarization_dispatcher_worker, + run_youtube_ingestion_worker, ) + from src.infrastructure.services.redis_task_queue_service import register_task register_task("run_file_ingestion_worker", run_file_ingestion_worker) register_task("run_youtube_ingestion_worker", run_youtube_ingestion_worker) @@ -216,6 +222,7 @@ def health_check(): if __name__ == "__main__": import uvicorn + from src.config.settings import settings uvicorn.run( diff --git a/mypy.ini b/mypy.ini index c0ddf542..4f6fb6bd 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,3 +5,6 @@ namespace_packages = True [mypy-tests.*] ignore_errors = True + +[mypy-scripts.*] +ignore_errors = True diff --git a/pyproject.toml b/pyproject.toml index 41fceb49..07a7551c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,3 +95,29 @@ torchaudio = { index = "pytorch-cu126" } [tool.setuptools.packages.find] where = ["."] include = ["src*"] + +[tool.ruff] +exclude = [ + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pypackages__", + "dist", + "node_modules", + "venv", + "scripts", +] +line-length = 120 +indent-width = 4 + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" diff --git a/pytest.ini b/pytest.ini index fa5c97b2..c7ca6063 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,6 +3,7 @@ addopts = -v --cov=src --cov-report=term-missing --cov-report=html python_files = test_*.py *_test.py testpaths = tests pythonpath = . +norecursedirs = scripts filterwarnings = ignore::DeprecationWarning:pydub.utils diff --git a/scripts/migrate_vector_db.py b/scripts/migrate_vector_db.py new file mode 100644 index 00000000..00e75a35 --- /dev/null +++ b/scripts/migrate_vector_db.py @@ -0,0 +1,166 @@ +""" +Script to migrate/re-ingest chunks into the configured vector database. +It reads the `chunk_index` records from the SQL database and pushes +them to the Vector Store, using the current embedding model. +This is useful when changing embedding models or vector databases. + +Usage: + python scripts/migrate_vector_db.py [--clear] + +Options: + --clear Clear the target vector collection before migrating. + Required for backends that do not upsert on insert + (e.g. ChromaDB) to avoid duplicate-ID errors. + Equivalent to running `scripts/clear_vector_db.py` first. +""" + +import argparse +import os +import sys +import traceback +from datetime import datetime +from typing import cast, Dict, Any, Optional +from uuid import UUID + +# Add project root and scripts directory to sys.path +_SCRIPTS_DIR = os.path.abspath(os.path.dirname(__file__)) +_PROJECT_ROOT = os.path.abspath(os.path.join(_SCRIPTS_DIR, "..")) +sys.path.insert(0, _PROJECT_ROOT) +sys.path.insert(0, _SCRIPTS_DIR) + +from sqlalchemy.orm import Session +from sqlalchemy import or_, and_ +from clear_vector_db import clear_vector_db +from src.infrastructure.repositories.sql.connector import Session as DBSessionFactory +from src.infrastructure.repositories.sql.models.chunk_index import ChunkIndexModel +from src.infrastructure.services.model_loader_service import ModelLoaderService +from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel +from src.config.settings import Settings +from src.config.logger import Logger +from src.presentation.api.dependencies import get_vector_repository + +logger = Logger() + + +def migrate_vector_db(batch_size: int = 100, clear: bool = False) -> None: + settings = Settings() + + embedding_model_name = settings.model_embedding.name + logger.info(f"Initializing Model Loader Service with model: {embedding_model_name}") + model_loader = ModelLoaderService(model_name=embedding_model_name) + model_loader.load_model() + + logger.info("Initializing Vector Repository...") + # This automatically instantiates the vector repo for the correct type defined in .env + vector_repo = get_vector_repository(settings=settings, model_loader=model_loader) + + if not vector_repo.is_ready(): + logger.error("Vector Repository is not ready. Aborting.") + sys.exit(1) + + if clear: + logger.info("--clear flag set: clearing the vector collection before migration...") + clear_vector_db() + + db: Session = DBSessionFactory() + try: + total_migrated = 0 + last_created_at: Optional[datetime] = None + last_id: Optional[UUID] = None + + while True: + query = db.query(ChunkIndexModel).order_by( + ChunkIndexModel.created_at, ChunkIndexModel.id + ) + # Keyset pagination: skip rows already processed in previous batches + if last_created_at is not None: + query = query.filter( + or_( + ChunkIndexModel.created_at > last_created_at, + and_( + ChunkIndexModel.created_at == last_created_at, + ChunkIndexModel.id > last_id, + ), + ) + ) + + chunk_models_sql = query.limit(batch_size).all() + + if not chunk_models_sql: + break + + # Capture keyset cursors from the last row of this batch + last_created_at = cast(datetime, chunk_models_sql[-1].created_at) + last_id = cast(UUID, chunk_models_sql[-1].id) + + documents = _process_batch(chunk_models_sql, embedding_model_name) + + total_migrated += len(documents) + logger.info( + f"Uploading batch of {len(documents)} chunks to vector db... (Total migrated so far: {total_migrated})" + ) + + # create_documents will internally call the EmbeddingService for the texts and save them + vector_repo.create_documents(documents) + + # Expunge all ORM objects to prevent unbounded memory growth + db.expunge_all() + + logger.info( + f"Vector DB migration finished successfully! Total chunks migrated: {total_migrated}" + ) + + except Exception as e: + logger.error(f"Migration failed: {e}\n{traceback.format_exc()}") + sys.exit(1) + finally: + db.close() + + +def _process_batch( + chunks_sql: list[ChunkIndexModel], embedding_model_name: str +) -> list[ChunkModel]: + """Helper to convert SQL chunks to vector domain models.""" + documents = [] + for chunk_sql in chunks_sql: + extra_data: Dict[str, Any] = ( + dict(chunk_sql.extra) if isinstance(chunk_sql.extra, dict) else {} + ) + if chunk_sql.vector_store_type: + extra_data["original_vector_store_type"] = chunk_sql.vector_store_type + + doc = ChunkModel( + id=cast(UUID, chunk_sql.id), + job_id=cast(UUID, chunk_sql.job_id), + content_source_id=cast(UUID, chunk_sql.content_source_id), + source_type=str(chunk_sql.source_type or "UNKNOWN"), + external_source=cast(str, chunk_sql.external_source), + subject_id=cast(UUID, chunk_sql.subject_id), + index=cast(int, chunk_sql.index), + content=cast(str, chunk_sql.content), + tokens_count=cast(int, chunk_sql.tokens_count), + language=cast(str, chunk_sql.language), + embedding_model=embedding_model_name, + created_at=cast(datetime, chunk_sql.created_at), + version_number=cast(int, chunk_sql.version_number), + extra=extra_data, + ) + documents.append(doc) + return documents + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Migrate chunk_index records from SQL into the configured vector store." + ) + parser.add_argument( + "--clear", + action="store_true", + help=( + "Clear the target vector collection before migrating. " + "Use this for backends that do not upsert on insert (e.g. ChromaDB) " + "to avoid duplicate-ID errors." + ), + ) + args = parser.parse_args() + migrate_vector_db(clear=args.clear) diff --git a/src/application/dtos/commands/ingest_diarization_command.py b/src/application/dtos/commands/ingest_diarization_command.py index 27fd9db8..2ecef3fe 100644 --- a/src/application/dtos/commands/ingest_diarization_command.py +++ b/src/application/dtos/commands/ingest_diarization_command.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import UUID diff --git a/src/application/dtos/commands/ingest_youtube_command.py b/src/application/dtos/commands/ingest_youtube_command.py index 85f9e20a..21244bd2 100644 --- a/src/application/dtos/commands/ingest_youtube_command.py +++ b/src/application/dtos/commands/ingest_youtube_command.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional, List +from typing import List, Optional from src.application.dtos.enums.youtube_data_type import YoutubeDataType diff --git a/src/application/dtos/commands/train_voice_command.py b/src/application/dtos/commands/train_voice_command.py new file mode 100644 index 00000000..209faa6f --- /dev/null +++ b/src/application/dtos/commands/train_voice_command.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TrainVoiceCommand: + """Command to train a voice profile from a speaker segment in a diarization.""" + + diarization_id: str + speaker_label: str + name: str + job_id: Optional[str] = None diff --git a/src/application/dtos/results/ingest_youtube_result.py b/src/application/dtos/results/ingest_youtube_result.py index 24d5e09f..1b4cbd9f 100644 --- a/src/application/dtos/results/ingest_youtube_result.py +++ b/src/application/dtos/results/ingest_youtube_result.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, List, Dict +from typing import Dict, List, Optional from uuid import UUID diff --git a/src/application/dtos/results/search_chunks_result.py b/src/application/dtos/results/search_chunks_result.py index c34b8e49..7ae4ae12 100644 --- a/src/application/dtos/results/search_chunks_result.py +++ b/src/application/dtos/results/search_chunks_result.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import List, Optional + from src.domain.entities.chunk_entity import ChunkEntity diff --git a/src/application/use_cases/auth_use_case.py b/src/application/use_cases/auth_use_case.py index 832b0864..0f391e79 100644 --- a/src/application/use_cases/auth_use_case.py +++ b/src/application/use_cases/auth_use_case.py @@ -1,14 +1,14 @@ import secrets -from typing import Optional, Dict, Any, Tuple from datetime import datetime, timezone +from typing import Any, Dict, Optional, Tuple from src.domain.entities.user import User as UserEntity -from src.domain.interfaces.repository.user_repository import IUserRepository from src.domain.exception.auth_exceptions import ( - InvalidStateError, GoogleAuthError, + InvalidStateError, UserNotCreatedError, ) +from src.domain.interfaces.repository.user_repository import IUserRepository from src.infrastructure.services.auth_service import AuthService diff --git a/src/application/use_cases/content_source_use_case.py b/src/application/use_cases/content_source_use_case.py index ed93d25b..ac43266d 100644 --- a/src/application/use_cases/content_source_use_case.py +++ b/src/application/use_cases/content_source_use_case.py @@ -1,4 +1,5 @@ from uuid import UUID + from src.config.logger import Logger from src.domain.interfaces.repository.retriver_repository import IVectorRepository from src.infrastructure.services.chunk_index_service import ChunkIndexService diff --git a/src/application/use_cases/delete_diarization_use_case.py b/src/application/use_cases/delete_diarization_use_case.py index e083b38b..eeba2592 100644 --- a/src/application/use_cases/delete_diarization_use_case.py +++ b/src/application/use_cases/delete_diarization_use_case.py @@ -1,13 +1,14 @@ import logging import os import shutil + from sqlalchemy.orm import Session from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) -from src.infrastructure.services.content_source_service import ContentSourceService from src.infrastructure.repositories.storage.storage import StorageService +from src.infrastructure.services.content_source_service import ContentSourceService logger = logging.getLogger(__name__) diff --git a/src/application/use_cases/diarization_ingestion_use_case.py b/src/application/use_cases/diarization_ingestion_use_case.py index 9c887fcc..41173854 100644 --- a/src/application/use_cases/diarization_ingestion_use_case.py +++ b/src/application/use_cases/diarization_ingestion_use_case.py @@ -9,14 +9,16 @@ from src.config.logger import Logger from src.domain.entities.chunk_entity import ChunkEntity from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus +from src.domain.entities.enums.diarization_status_enum import DiarizationStatus from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.domain.interfaces.services.i_event_bus import IEventBus +from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) -from src.domain.entities.enums.diarization_status_enum import DiarizationStatus -from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor from src.infrastructure.services.chunk_index_service import ChunkIndexService +from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.content_source_service import ContentSourceService from src.infrastructure.services.embedding_service import EmbeddingService from src.infrastructure.services.ingestion_job_service import IngestionJobService @@ -24,9 +26,7 @@ KnowledgeSubjectService, ) from src.infrastructure.services.model_loader_service import ModelLoaderService -from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.text_splitter_service import TextSplitterService -from src.domain.interfaces.services.i_event_bus import IEventBus logger = Logger() @@ -368,7 +368,9 @@ def _format_transcript( curr_texts.append(text) else: if curr_speaker is not None: - ts = f"[{self._format_seconds(cast(float, curr_start))} - {self._format_seconds(cast(float, curr_end))}]" + start_str = self._format_seconds(cast(float, curr_start)) + end_str = self._format_seconds(cast(float, curr_end)) + ts = f"[{start_str} - {end_str}]" merged_lines.append(f"{ts} {curr_speaker}: {' '.join(curr_texts)}") curr_speaker, curr_start, curr_end, curr_texts = ( diff --git a/src/application/use_cases/file_ingestion_use_case.py b/src/application/use_cases/file_ingestion_use_case.py index 30c03d0a..45dfa42f 100644 --- a/src/application/use_cases/file_ingestion_use_case.py +++ b/src/application/use_cases/file_ingestion_use_case.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional from uuid import UUID - from langchain_core.documents import Document from src.application.dtos.commands.ingest_file_command import IngestFileCommand @@ -14,9 +13,11 @@ from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.extractors.docling_extractor import DoclingExtractor from src.infrastructure.extractors.plain_text_extractor import PlainTextExtractor from src.infrastructure.services.chunk_index_service import ChunkIndexService +from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.content_source_service import ContentSourceService from src.infrastructure.services.embedding_service import EmbeddingService from src.infrastructure.services.ingestion_job_service import IngestionJobService @@ -24,9 +25,7 @@ KnowledgeSubjectService, ) from src.infrastructure.services.model_loader_service import ModelLoaderService -from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.text_splitter_service import TextSplitterService -from src.domain.interfaces.services.i_event_bus import IEventBus logger = Logger() diff --git a/src/application/use_cases/knowledge_subject_use_case.py b/src/application/use_cases/knowledge_subject_use_case.py index 3759f979..4cd9e05d 100644 --- a/src/application/use_cases/knowledge_subject_use_case.py +++ b/src/application/use_cases/knowledge_subject_use_case.py @@ -1,10 +1,11 @@ from uuid import UUID + +from src.application.use_cases.content_source_use_case import ContentSourceUseCase from src.config.logger import Logger from src.domain.interfaces.repository.retriver_repository import IVectorRepository from src.infrastructure.services.knowledge_subject_service import ( KnowledgeSubjectService, ) -from src.application.use_cases.content_source_use_case import ContentSourceUseCase logger = Logger() diff --git a/src/application/use_cases/manage_voice_profiles.py b/src/application/use_cases/manage_voice_profiles.py index d235f559..08bde59b 100644 --- a/src/application/use_cases/manage_voice_profiles.py +++ b/src/application/use_cases/manage_voice_profiles.py @@ -1,10 +1,12 @@ -from contextlib import suppress import os +from contextlib import suppress from typing import cast from sqlalchemy.orm import Session from src.config.settings import settings +from src.domain.entities.enums.diarization_status_enum import DiarizationStatus +from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) @@ -89,10 +91,11 @@ def execute(self, name: str) -> None: class TrainVoiceProfileFromSpeakerSegmentUseCase: - def __init__(self, db: Session): + def __init__(self, db: Session, event_bus: IEventBus | None = None): self.db = db self.repo = DiarizationRepository(db) self.storage = StorageService() + self.event_bus = event_bus def execute( self, @@ -107,6 +110,23 @@ def execute( if not record.storage_path: raise ValueError("No storage path found for this diarization.") + # Ensure we update status if we have an event bus (background context) + if self.event_bus: + self.repo.update_status( + diarization_id, + DiarizationStatus.TRAINING.value, + status_message=f"Treinando voz: {name}", + ) + self.event_bus.publish( + "ingestion_status", + { + "type": "diarization", + "id": diarization_id, + "status": DiarizationStatus.TRAINING.value, + "message": f"Treinando perfil de voz '{name}'...", + }, + ) + s3_key = f"{record.storage_path}/{speaker_label}.wav" audio_cfg = settings.audio @@ -117,14 +137,52 @@ def execute( try: self.storage.download_file(s3_key, local_path) - except Exception: - raise ValueError(f"Speaker audio not found in storage: {speaker_label}") - try: hf_token = settings.auth.hf_token or "" voice_db = VoiceDB(db=self.db, hf_token=hf_token) voice_id, _ = voice_db.add(name=name, audio_path=local_path) + + if self.event_bus: + self.repo.update_status( + diarization_id, + DiarizationStatus.COMPLETED.value, + status_message=f"Voz '{name}' treinada com sucesso", + ) + self.event_bus.publish( + "ingestion_status", + { + "type": "diarization", + "id": diarization_id, + "status": DiarizationStatus.COMPLETED.value, + "message": f"Perfil de voz '{name}' registrado!", + }, + ) + # Also notify specifically about the voice + self.event_bus.publish( + "ingestion_status", + {"type": "voice", "action": "train", "name": name}, + ) + return voice_id + + except Exception as e: + if self.event_bus: + self.repo.update_status( + diarization_id, + DiarizationStatus.FAILED.value, + error_message=str(e), + status_message="Falha no treinamento de voz", + ) + self.event_bus.publish( + "ingestion_status", + { + "type": "diarization", + "id": diarization_id, + "status": DiarizationStatus.FAILED.value, + "message": f"Erro no treinamento: {str(e)}", + }, + ) + raise finally: if os.path.exists(local_path): os.remove(local_path) diff --git a/src/application/use_cases/process_audio_diarization_pipeline.py b/src/application/use_cases/process_audio_diarization_pipeline.py index a9b9aa5e..2cb00a9a 100644 --- a/src/application/use_cases/process_audio_diarization_pipeline.py +++ b/src/application/use_cases/process_audio_diarization_pipeline.py @@ -4,28 +4,25 @@ import shutil import uuid from pathlib import Path -from typing import Any, cast, Optional +from typing import Any, Optional, cast from urllib.parse import unquote from sqlalchemy.orm import Session +from src.config.settings import settings +from src.domain.entities.enums.diarization_status_enum import ( + DiarizationStatus, + DiarizationStep, +) +from src.domain.interfaces.services.i_event_bus import IEventBus +from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) -from src.config.settings import settings - from src.infrastructure.repositories.storage.storage import StorageService -from src.infrastructure.services.whisperx_audio_diarizer import AudioDiarizer from src.infrastructure.services.pyannote_voice_recognizer import VoiceRecognizer from src.infrastructure.services.voice_profile_service import VoiceDB - -from src.domain.interfaces.services.i_event_bus import IEventBus - -from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor -from src.domain.entities.enums.diarization_status_enum import ( - DiarizationStatus, - DiarizationStep, -) +from src.infrastructure.services.whisperx_audio_diarizer import AudioDiarizer logger = logging.getLogger(__name__) @@ -245,10 +242,10 @@ def _finalize_pipeline( # Create/Update ContentSource in AWAITING_VERIFICATION status if self.cs_service: try: - from src.domain.entities.enums.source_type_enum_entity import SourceType from src.domain.entities.enums.content_source_status_enum import ( ContentSourceStatus, ) + from src.domain.entities.enums.source_type_enum_entity import SourceType cs_source_type = ( SourceType.YOUTUBE if source_type == "youtube" else SourceType.AUDIO diff --git a/src/application/use_cases/search_use_case.py b/src/application/use_cases/search_use_case.py index 71cb5b0e..21f2893b 100644 --- a/src/application/use_cases/search_use_case.py +++ b/src/application/use_cases/search_use_case.py @@ -1,7 +1,6 @@ from typing import Any, List, Optional, Union from uuid import UUID - from src.application.dtos.results.search_chunks_result import SearchChunksResult from src.config.logger import Logger from src.domain.entities.enums.search_mode_enum import SearchMode diff --git a/src/application/use_cases/web_scraping_use_case.py b/src/application/use_cases/web_scraping_use_case.py index 36900da6..01a40de8 100644 --- a/src/application/use_cases/web_scraping_use_case.py +++ b/src/application/use_cases/web_scraping_use_case.py @@ -11,8 +11,10 @@ from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.extractors.crawl4ai_extractor import Crawl4AIExtractor from src.infrastructure.services.chunk_index_service import ChunkIndexService +from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.content_source_service import ContentSourceService from src.infrastructure.services.embedding_service import EmbeddingService from src.infrastructure.services.ingestion_job_service import IngestionJobService @@ -20,9 +22,7 @@ KnowledgeSubjectService, ) from src.infrastructure.services.model_loader_service import ModelLoaderService -from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.services.text_splitter_service import TextSplitterService -from src.domain.interfaces.services.i_event_bus import IEventBus logger = Logger() diff --git a/src/application/use_cases/youtube_ingestion_use_case.py b/src/application/use_cases/youtube_ingestion_use_case.py index 19c2c4be..0d3c021b 100644 --- a/src/application/use_cases/youtube_ingestion_use_case.py +++ b/src/application/use_cases/youtube_ingestion_use_case.py @@ -1,9 +1,9 @@ import concurrent.futures -from contextlib import suppress import random import threading import time import uuid +from contextlib import suppress from datetime import datetime, timezone from typing import Any, Dict, List, Optional from uuid import UUID @@ -20,12 +20,12 @@ from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus from src.domain.entities.enums.source_type_enum_entity import SourceType from src.domain.exception.youtube_exceptions import ( - YoutubeVideoPrivateException, - YoutubeVideoUnplayableException, + YoutubeIPBlockedException, + YoutubeNetworkException, YoutubeTranscriptNotFoundException, YoutubeTranscriptsDisabledException, - YoutubeNetworkException, - YoutubeIPBlockedException, + YoutubeVideoPrivateException, + YoutubeVideoUnplayableException, ) from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor diff --git a/src/application/workers.py b/src/application/workers.py index b0bceeaf..88900f3c 100644 --- a/src/application/workers.py +++ b/src/application/workers.py @@ -8,10 +8,11 @@ from src.application.dtos.commands.ingest_web_command import IngestWebCommand from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand from src.application.dtos.commands.process_audio_command import ProcessAudioCommand +from src.application.dtos.commands.train_voice_command import TrainVoiceCommand from src.application.service_registry import registry from src.infrastructure.loggers.std_logger import ( - set_global_context, clear_global_context, + set_global_context, ) logger = logging.getLogger(__name__) @@ -44,14 +45,14 @@ def run_file_ingestion_worker(cmd: IngestFileCommand): return try: + from src.application.use_cases.file_ingestion_use_case import ( + FileIngestionUseCase, + ) + from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.presentation.api.dependencies import ( resolve_ingestion_context, - resolve_vector_repository, resolve_rerank_service, - ) - from src.infrastructure.services.chunk_vector_service import ChunkVectorService - from src.application.use_cases.file_ingestion_use_case import ( - FileIngestionUseCase, + resolve_vector_repository, ) ctx = resolve_ingestion_context(app) @@ -93,15 +94,15 @@ def run_youtube_ingestion_worker(cmd: IngestYoutubeCommand): return try: - from src.presentation.api.dependencies import ( - resolve_ingestion_context, - resolve_vector_repository, + from src.application.use_cases.youtube_ingestion_use_case import ( + YoutubeIngestionUseCase, ) from src.infrastructure.services.youtube_vector_service import ( YouTubeVectorService, ) - from src.application.use_cases.youtube_ingestion_use_case import ( - YoutubeIngestionUseCase, + from src.presentation.api.dependencies import ( + resolve_ingestion_context, + resolve_vector_repository, ) ctx = resolve_ingestion_context(app) @@ -238,17 +239,17 @@ def run_diarization_ingestion_worker(cmd: IngestDiarizationCommand): return try: - from src.presentation.api.dependencies import ( - resolve_ingestion_context, - resolve_vector_repository, - resolve_rerank_service, + from src.application.use_cases.diarization_ingestion_use_case import ( + DiarizationIngestionUseCase, ) - from src.infrastructure.services.chunk_vector_service import ChunkVectorService from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) - from src.application.use_cases.diarization_ingestion_use_case import ( - DiarizationIngestionUseCase, + from src.infrastructure.services.chunk_vector_service import ChunkVectorService + from src.presentation.api.dependencies import ( + resolve_ingestion_context, + resolve_rerank_service, + resolve_vector_repository, ) ctx = resolve_ingestion_context(app) @@ -302,17 +303,17 @@ def run_web_ingestion_worker(cmd: Any): async def _run(): try: - from src.presentation.api.dependencies import ( - resolve_ingestion_context, - resolve_vector_repository, - resolve_rerank_service, - get_web_extractor, + from src.application.use_cases.web_scraping_use_case import ( + WebScrapingUseCase, ) from src.infrastructure.services.chunk_vector_service import ( ChunkVectorService, ) - from src.application.use_cases.web_scraping_use_case import ( - WebScrapingUseCase, + from src.presentation.api.dependencies import ( + get_web_extractor, + resolve_ingestion_context, + resolve_rerank_service, + resolve_vector_repository, ) ctx = resolve_ingestion_context(app) @@ -347,21 +348,20 @@ async def _run(): def _audio_diarization_subprocess(cmd_dict: dict): """Run audio diarization in a separate process to avoid torch/CUDA thread deadlocks.""" + from src.application.use_cases.process_audio_diarization_pipeline import ( + ProcessAudioDiarizationPipelineUseCase, + ) from src.infrastructure.repositories.sql.connector import ( Session as DBSessionFactory, ) - from src.application.use_cases.process_audio_diarization_pipeline import ( - ProcessAudioDiarizationPipelineUseCase, + from src.infrastructure.repositories.sql.content_source_repository import ( + ContentSourceSQLRepository, ) from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) - from src.infrastructure.services.redis_event_bus import RedisEventBus - from src.infrastructure.services.content_source_service import ContentSourceService - from src.infrastructure.repositories.sql.content_source_repository import ( - ContentSourceSQLRepository, - ) + from src.infrastructure.services.redis_event_bus import RedisEventBus db = DBSessionFactory() event_bus = RedisEventBus() @@ -597,3 +597,43 @@ def run_audio_diarization_worker(cmd: ProcessAudioCommand): ) finally: clear_global_context() + + +def run_voice_training_worker(cmd: TrainVoiceCommand): + """Background worker function for voice profile training from speaker segment.""" + set_global_context({"correlation_id": f"worker-voice-train-{cmd.name}"}) + + if isinstance(cmd, dict): + cmd = TrainVoiceCommand(**cmd) + + app = _get_app() + if not app: + clear_global_context() + return + + try: + from src.application.use_cases.manage_voice_profiles import ( + TrainVoiceProfileFromSpeakerSegmentUseCase, + ) + from src.infrastructure.repositories.sql.connector import Session as DBSession + from src.presentation.api.dependencies import resolve_ingestion_context + + ctx = resolve_ingestion_context(app) + db = DBSession() + try: + use_case = TrainVoiceProfileFromSpeakerSegmentUseCase( + db, event_bus=ctx.event_bus + ) + use_case.execute( + diarization_id=cmd.diarization_id, + speaker_label=cmd.speaker_label, + name=cmd.name, + ) + finally: + db.close() + except Exception as e: + logger.error( + f"Worker Error: Failed to execute voice training: {e}", exc_info=True + ) + finally: + clear_global_context() diff --git a/src/config/logger.py b/src/config/logger.py index 545d6a77..bd070e8e 100644 --- a/src/config/logger.py +++ b/src/config/logger.py @@ -1,5 +1,6 @@ import logging -from src.infrastructure.loggers.std_logger import StdLogger, InterceptHandler + +from src.infrastructure.loggers.std_logger import InterceptHandler, StdLogger LOG_FORMAT = "{asctime} | {levelname:<8} | {filepath}:{funcName}:{lineno} | {message} | {context}" diff --git a/src/config/settings.py b/src/config/settings.py index ab22773b..bf6c2041 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -1,8 +1,8 @@ import logging import warnings -from typing import List, Annotated, Optional +from typing import Annotated, List, Optional -from pydantic import field_validator, Field, BaseModel +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from src.config.validators import docker_host_fallback, docker_host_fallback_optional diff --git a/src/domain/entities/chunk_entity.py b/src/domain/entities/chunk_entity.py index c958bec0..91d6d5d5 100644 --- a/src/domain/entities/chunk_entity.py +++ b/src/domain/entities/chunk_entity.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import UUID, uuid4 from pydantic import BaseModel, Field diff --git a/src/domain/entities/diarization.py b/src/domain/entities/diarization.py index 5fa85dd7..cac78849 100644 --- a/src/domain/entities/diarization.py +++ b/src/domain/entities/diarization.py @@ -1,6 +1,6 @@ import os from datetime import datetime, timezone -from uuid import uuid4, UUID +from uuid import UUID, uuid4 import soundfile as sf from pydantic import BaseModel, Field diff --git a/src/domain/entities/enums/diarization_status_enum.py b/src/domain/entities/enums/diarization_status_enum.py index 68853b51..1a9565f8 100644 --- a/src/domain/entities/enums/diarization_status_enum.py +++ b/src/domain/entities/enums/diarization_status_enum.py @@ -7,6 +7,7 @@ class DiarizationStatus(Enum): COMPLETED = "completed" FAILED = "failed" AWAITING_VERIFICATION = "awaiting_verification" + TRAINING = "training" class DiarizationStep(Enum): diff --git a/src/domain/entities/ingestion_job_entity.py b/src/domain/entities/ingestion_job_entity.py index 2dddb64f..2bc98d29 100644 --- a/src/domain/entities/ingestion_job_entity.py +++ b/src/domain/entities/ingestion_job_entity.py @@ -3,6 +3,7 @@ from uuid import UUID, uuid4 from pydantic import BaseModel, Field + from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus diff --git a/src/domain/exception/youtube_exceptions.py b/src/domain/exception/youtube_exceptions.py index 8e6fd220..df257d00 100644 --- a/src/domain/exception/youtube_exceptions.py +++ b/src/domain/exception/youtube_exceptions.py @@ -57,7 +57,10 @@ class YoutubeNetworkException(YoutubeException): """Raised when there is a network-related error (DNS, connection).""" def __init__(self, video_id: str, error_msg: str): - message = f"Network error while accessing video {video_id}. Please check your connection. Details: {error_msg}" + message = ( + f"Network error while accessing video {video_id}. " + f"Please check your connection. Details: {error_msg}" + ) super().__init__(message, video_id=video_id) @@ -65,5 +68,8 @@ class YoutubeIPBlockedException(YoutubeException): """Raised when YouTube is blocking requests from the server's IP.""" def __init__(self, video_id: str, error_msg: str): - message = f"YouTube is blocking our requests for video {video_id}. Likely IP ban or temporary block. Details: {error_msg}" + message = ( + f"YouTube is blocking our requests for video {video_id}. " + f"Likely IP ban or temporary block. Details: {error_msg}" + ) super().__init__(message, video_id=video_id) diff --git a/src/domain/interfaces/logger/logger.py b/src/domain/interfaces/logger/logger.py index 99cfa535..1115c31f 100644 --- a/src/domain/interfaces/logger/logger.py +++ b/src/domain/interfaces/logger/logger.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Dict +from typing import Dict, Optional class ILogger(ABC): diff --git a/src/domain/interfaces/repository/retriver_repository.py b/src/domain/interfaces/repository/retriver_repository.py index 5d74867c..bb50fe29 100644 --- a/src/domain/interfaces/repository/retriver_repository.py +++ b/src/domain/interfaces/repository/retriver_repository.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Any +from typing import Any, List, Optional from src.domain.entities.enums.search_mode_enum import SearchMode from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel diff --git a/src/domain/interfaces/repository/user_repository.py b/src/domain/interfaces/repository/user_repository.py index 6fbcfb59..86f45ed2 100644 --- a/src/domain/interfaces/repository/user_repository.py +++ b/src/domain/interfaces/repository/user_repository.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional + from src.domain.entities.user import User diff --git a/src/domain/mappers/chunk_index_mapper.py b/src/domain/mappers/chunk_index_mapper.py index 5f3b278f..3857f78b 100644 --- a/src/domain/mappers/chunk_index_mapper.py +++ b/src/domain/mappers/chunk_index_mapper.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, cast +from typing import List, Optional, cast from uuid import UUID from src.domain.entities.chunk_entity import ChunkEntity diff --git a/src/domain/mappers/chunk_mapper.py b/src/domain/mappers/chunk_mapper.py index 56b63761..e07f7e94 100644 --- a/src/domain/mappers/chunk_mapper.py +++ b/src/domain/mappers/chunk_mapper.py @@ -9,6 +9,7 @@ from uuid import uuid4 from langchain_core.documents import Document + from src.domain.entities.chunk_entity import ChunkEntity from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel @@ -128,7 +129,8 @@ def _convert_to_uuid(value: Any) -> Any: def _normalize_source_type(source: str) -> str: """Normalize the source type to the canonical string value used for persistence. - Tries to map by enum member name (case-insensitive) or by enum value. If mapping fails, returns the input string. + Tries to map by enum member name (case-insensitive) or by enum value. + If mapping fails, returns the input string. """ s = source.strip() # try by name (case-insensitive) diff --git a/src/domain/mappers/content_source_mapper.py b/src/domain/mappers/content_source_mapper.py index 97769367..40d173cb 100644 --- a/src/domain/mappers/content_source_mapper.py +++ b/src/domain/mappers/content_source_mapper.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, Dict, Any, cast +from typing import Any, Dict, List, Optional, cast from uuid import UUID from src.domain.entities.content_source_entity import ContentSourceEntity diff --git a/src/domain/mappers/ingestion_job_mapper.py b/src/domain/mappers/ingestion_job_mapper.py index e1ec876b..3f638d73 100644 --- a/src/domain/mappers/ingestion_job_mapper.py +++ b/src/domain/mappers/ingestion_job_mapper.py @@ -1,9 +1,9 @@ -from typing import Optional, List, cast -from uuid import UUID from datetime import datetime +from typing import List, Optional, cast +from uuid import UUID -from src.domain.entities.ingestion_job_entity import IngestionJobEntity from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus +from src.domain.entities.ingestion_job_entity import IngestionJobEntity from src.infrastructure.repositories.sql.models.ingestion_job import IngestionJobModel diff --git a/src/domain/mappers/knowledge_subject_mapper.py b/src/domain/mappers/knowledge_subject_mapper.py index 086618ed..46da7e1c 100644 --- a/src/domain/mappers/knowledge_subject_mapper.py +++ b/src/domain/mappers/knowledge_subject_mapper.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional, List, cast +from typing import List, Optional, cast from uuid import UUID from src.domain.entities.knowledge_subject_entity import KnowledgeSubjectEntity diff --git a/src/infrastructure/extractors/crawl4ai_extractor.py b/src/infrastructure/extractors/crawl4ai_extractor.py index fc7a6f95..12114c4b 100644 --- a/src/infrastructure/extractors/crawl4ai_extractor.py +++ b/src/infrastructure/extractors/crawl4ai_extractor.py @@ -1,10 +1,11 @@ import re import tempfile -import httpx +from typing import Any, Dict, List + import anyio -from typing import Any, List, Dict +import httpx +from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig from langchain_core.documents import Document -from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode from src.config.logger import Logger from src.domain.interfaces.extractors.base_extractor_interface import IBaseExtractor @@ -20,7 +21,11 @@ class Crawl4AIExtractor(IBaseExtractor): def __init__(self, headless: bool = True): # Using a realistic user agent to bypass anti-bot protection - self.user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36" + self.user_agent = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/122.0.0.0 Safari/537.36" + ) self.browser_config = BrowserConfig( headless=headless, viewport_width=1280, diff --git a/src/infrastructure/extractors/docling_extractor.py b/src/infrastructure/extractors/docling_extractor.py index 6a5d37af..4132656f 100644 --- a/src/infrastructure/extractors/docling_extractor.py +++ b/src/infrastructure/extractors/docling_extractor.py @@ -1,6 +1,6 @@ import os import threading -from typing import List, Any +from typing import Any, List from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import PdfPipelineOptions diff --git a/src/infrastructure/extractors/models/youtube_metadata_dto.py b/src/infrastructure/extractors/models/youtube_metadata_dto.py index c6fbc0cb..a8499d55 100644 --- a/src/infrastructure/extractors/models/youtube_metadata_dto.py +++ b/src/infrastructure/extractors/models/youtube_metadata_dto.py @@ -1,6 +1,6 @@ -from typing import Optional, List +from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field class YoutubeMetadataDTO(BaseModel): diff --git a/src/infrastructure/extractors/plain_text_extractor.py b/src/infrastructure/extractors/plain_text_extractor.py index f071ab27..251c149f 100644 --- a/src/infrastructure/extractors/plain_text_extractor.py +++ b/src/infrastructure/extractors/plain_text_extractor.py @@ -15,7 +15,11 @@ class PlainTextExtractor: def __init__(self): self.timeout = 30.0 self.headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/91.0.4472.124 Safari/537.36" + ) } def extract(self, file_path_or_url: str) -> List[Document]: diff --git a/src/infrastructure/extractors/youtube_extractor.py b/src/infrastructure/extractors/youtube_extractor.py index 0b5518e7..529f3f60 100644 --- a/src/infrastructure/extractors/youtube_extractor.py +++ b/src/infrastructure/extractors/youtube_extractor.py @@ -6,10 +6,10 @@ import imageio_ffmpeg from youtube_transcript_api import ( - YouTubeTranscriptApi, FetchedTranscript, - TranscriptsDisabled, NoTranscriptFound, + TranscriptsDisabled, + YouTubeTranscriptApi, ) from youtube_transcript_api.proxies import GenericProxyConfig, WebshareProxyConfig from yt_dlp import YoutubeDL @@ -17,12 +17,12 @@ from src.config.logger import Logger from src.config.settings import settings from src.domain.exception.youtube_exceptions import ( + YoutubeIPBlockedException, + YoutubeNetworkException, YoutubeTranscriptNotFoundException, YoutubeTranscriptsDisabledException, YoutubeVideoPrivateException, YoutubeVideoUnplayableException, - YoutubeNetworkException, - YoutubeIPBlockedException, ) from src.domain.interfaces.extractors.youtube_extractor_interface import ( IYoutubeExtractor, @@ -54,7 +54,11 @@ def _get_common_ydl_opts(self, quiet: bool = True) -> dict: "source_address": "0.0.0.0", # nosec # Mimic a modern browser to avoid blocks "http_headers": { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ), "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8", "Accept-Language": "en-US,en;q=0.5", "Referer": "https://www.google.com/", diff --git a/src/infrastructure/loggers/std_logger.py b/src/infrastructure/loggers/std_logger.py index 4dd9f622..54975ca7 100644 --- a/src/infrastructure/loggers/std_logger.py +++ b/src/infrastructure/loggers/std_logger.py @@ -1,11 +1,11 @@ import inspect -from contextlib import suppress import logging import os import sys +from contextlib import suppress from contextvars import ContextVar from datetime import datetime -from typing import Optional, Any, Dict +from typing import Any, Dict, Optional from src.config.settings import settings from src.domain.interfaces.logger.logger import ILogger diff --git a/src/infrastructure/repositories/sql/chunk_index_repository.py b/src/infrastructure/repositories/sql/chunk_index_repository.py index d9dcec77..864ca304 100644 --- a/src/infrastructure/repositories/sql/chunk_index_repository.py +++ b/src/infrastructure/repositories/sql/chunk_index_repository.py @@ -1,15 +1,14 @@ -from typing import List, Optional, Any -from typing import cast +from typing import Any, List, Optional, cast from uuid import UUID import sqlalchemy as sa from sqlalchemy.orm import joinedload -from src.infrastructure.repositories.sql.utils.utils import ensure_uuid from src.config.logger import Logger from src.infrastructure.repositories.sql.connector import Connector from src.infrastructure.repositories.sql.models.chunk_index import ChunkIndexModel from src.infrastructure.repositories.sql.models.content_source import ContentSourceModel +from src.infrastructure.repositories.sql.utils.utils import ensure_uuid logger = Logger() diff --git a/src/infrastructure/repositories/sql/connector.py b/src/infrastructure/repositories/sql/connector.py index aa5e002f..439fdddf 100644 --- a/src/infrastructure/repositories/sql/connector.py +++ b/src/infrastructure/repositories/sql/connector.py @@ -1,5 +1,5 @@ from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.orm import declarative_base, sessionmaker from src.config.settings import settings diff --git a/src/infrastructure/repositories/sql/content_source_repository.py b/src/infrastructure/repositories/sql/content_source_repository.py index 4621d0fa..726e432b 100644 --- a/src/infrastructure/repositories/sql/content_source_repository.py +++ b/src/infrastructure/repositories/sql/content_source_repository.py @@ -1,6 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, List, Any -from typing import cast +from typing import Any, List, Optional, cast from uuid import UUID from src.config.logger import Logger @@ -114,7 +113,7 @@ def get_by_diarization_id( ) # Search within JSON field (syntax works for SQLite and Postgres) # Using string comparison for the diarization_id in the JSON object - from sqlalchemy import cast, String + from sqlalchemy import String, cast # Search within JSON field in a way that works for SQLite and Postgres # Using cast to String to ensure we can compare with the diarization_id diff --git a/src/infrastructure/repositories/sql/ingestion_job_repository.py b/src/infrastructure/repositories/sql/ingestion_job_repository.py index 516ce068..4e0dcc18 100644 --- a/src/infrastructure/repositories/sql/ingestion_job_repository.py +++ b/src/infrastructure/repositories/sql/ingestion_job_repository.py @@ -1,6 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, List, Any, Union -from typing import cast +from typing import Any, List, Optional, Union, cast from uuid import UUID from sqlalchemy.orm import joinedload diff --git a/src/infrastructure/repositories/sql/knowledge_subject_repository.py b/src/infrastructure/repositories/sql/knowledge_subject_repository.py index 4e494469..49579505 100644 --- a/src/infrastructure/repositories/sql/knowledge_subject_repository.py +++ b/src/infrastructure/repositories/sql/knowledge_subject_repository.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any, cast +from typing import Any, List, Optional, cast from uuid import UUID from sqlalchemy.orm import selectinload diff --git a/src/infrastructure/repositories/sql/models/__init__.py b/src/infrastructure/repositories/sql/models/__init__.py index d1ae8fdf..392d1d5d 100644 --- a/src/infrastructure/repositories/sql/models/__init__.py +++ b/src/infrastructure/repositories/sql/models/__init__.py @@ -2,13 +2,15 @@ # Import each model module so SQLAlchemy's declarative registry sees all mappers # and string-based relationships can be resolved. -from . import chunk_index # noqa: F401 -from . import content_source # noqa: F401 -from . import diarization_record # noqa: F401 -from . import ingestion_job # noqa: F401 -from . import knowledge_subject # noqa: F401 -from . import user # noqa: F401 -from . import voice_record +from . import ( + chunk_index, # noqa: F401 + content_source, # noqa: F401 + diarization_record, # noqa: F401 + ingestion_job, # noqa: F401 + knowledge_subject, # noqa: F401 + user, # noqa: F401 + voice_record, +) __all__ = [ "knowledge_subject", diff --git a/src/infrastructure/repositories/sql/models/chunk_index.py b/src/infrastructure/repositories/sql/models/chunk_index.py index d1bed98c..734bda10 100644 --- a/src/infrastructure/repositories/sql/models/chunk_index.py +++ b/src/infrastructure/repositories/sql/models/chunk_index.py @@ -5,16 +5,16 @@ import uuid from sqlalchemy import ( + JSON, + UUID, Column, - Text, DateTime, + ForeignKey, + Index, Integer, + Text, func, - ForeignKey, text, - UUID, - Index, - JSON, ) from sqlalchemy.orm import relationship diff --git a/src/infrastructure/repositories/sql/models/content_source.py b/src/infrastructure/repositories/sql/models/content_source.py index 1a419b96..3a2cf4db 100644 --- a/src/infrastructure/repositories/sql/models/content_source.py +++ b/src/infrastructure/repositories/sql/models/content_source.py @@ -5,17 +5,17 @@ import uuid from sqlalchemy import ( + JSON, + UUID, Column, - Text, DateTime, + ForeignKey, + Index, Integer, + Text, + UniqueConstraint, func, - ForeignKey, - UUID, text, - UniqueConstraint, - Index, - JSON, ) from sqlalchemy.orm import relationship diff --git a/src/infrastructure/repositories/sql/models/diarization_record.py b/src/infrastructure/repositories/sql/models/diarization_record.py index 00c8d8f9..a3ed626d 100644 --- a/src/infrastructure/repositories/sql/models/diarization_record.py +++ b/src/infrastructure/repositories/sql/models/diarization_record.py @@ -5,7 +5,7 @@ import datetime import uuid -from sqlalchemy import Column, String, Float, DateTime, JSON, ForeignKey, UUID +from sqlalchemy import JSON, UUID, Column, DateTime, Float, ForeignKey, String from src.domain.entities.enums.diarization_status_enum import DiarizationStatus from src.infrastructure.repositories.sql.connector import Base diff --git a/src/infrastructure/repositories/sql/models/ingestion_job.py b/src/infrastructure/repositories/sql/models/ingestion_job.py index b82b67da..08c3e5eb 100644 --- a/src/infrastructure/repositories/sql/models/ingestion_job.py +++ b/src/infrastructure/repositories/sql/models/ingestion_job.py @@ -4,7 +4,7 @@ import uuid -from sqlalchemy import Column, Text, DateTime, func, ForeignKey, UUID, Integer, Index +from sqlalchemy import UUID, Column, DateTime, ForeignKey, Index, Integer, Text, func from sqlalchemy.orm import relationship, synonym from src.infrastructure.repositories.sql.connector import Base diff --git a/src/infrastructure/repositories/sql/models/knowledge_subject.py b/src/infrastructure/repositories/sql/models/knowledge_subject.py index 227f8131..80fea806 100644 --- a/src/infrastructure/repositories/sql/models/knowledge_subject.py +++ b/src/infrastructure/repositories/sql/models/knowledge_subject.py @@ -4,7 +4,7 @@ import uuid -from sqlalchemy import Column, Text, DateTime, func, UUID +from sqlalchemy import UUID, Column, DateTime, Text, func from sqlalchemy.orm import relationship from src.infrastructure.repositories.sql.connector import Base diff --git a/src/infrastructure/repositories/sql/models/user.py b/src/infrastructure/repositories/sql/models/user.py index cd3ce4d9..c5831ce4 100644 --- a/src/infrastructure/repositories/sql/models/user.py +++ b/src/infrastructure/repositories/sql/models/user.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import Optional -from sqlalchemy import String, DateTime +from sqlalchemy import DateTime, String from sqlalchemy.orm import Mapped, mapped_column from src.infrastructure.repositories.sql.connector import Base diff --git a/src/infrastructure/repositories/sql/models/voice_record.py b/src/infrastructure/repositories/sql/models/voice_record.py index 85d91234..9f83d105 100644 --- a/src/infrastructure/repositories/sql/models/voice_record.py +++ b/src/infrastructure/repositories/sql/models/voice_record.py @@ -5,7 +5,7 @@ import datetime import uuid -from sqlalchemy import Column, String, DateTime, JSON +from sqlalchemy import JSON, Column, DateTime, String from src.infrastructure.repositories.sql.connector import Base diff --git a/src/infrastructure/repositories/sql/user_repository.py b/src/infrastructure/repositories/sql/user_repository.py index 310e699f..a8faf9c2 100644 --- a/src/infrastructure/repositories/sql/user_repository.py +++ b/src/infrastructure/repositories/sql/user_repository.py @@ -1,13 +1,13 @@ -from typing import Optional, Any from datetime import datetime, timezone +from typing import Any, Optional from sqlalchemy import select -from src.infrastructure.repositories.sql.utils.utils import ensure_uuid +from src.domain.entities.user import User as UserEntity from src.domain.interfaces.repository.user_repository import IUserRepository from src.infrastructure.repositories.sql.connector import Connector from src.infrastructure.repositories.sql.models.user import User as UserModel -from src.domain.entities.user import User as UserEntity +from src.infrastructure.repositories.sql.utils.utils import ensure_uuid class UserSQLRepository(IUserRepository): diff --git a/src/infrastructure/repositories/sql/utils/utils.py b/src/infrastructure/repositories/sql/utils/utils.py index 51ff597a..800c575f 100644 --- a/src/infrastructure/repositories/sql/utils/utils.py +++ b/src/infrastructure/repositories/sql/utils/utils.py @@ -1,5 +1,6 @@ from typing import Any, Optional from uuid import UUID + from src.config.logger import Logger logger = Logger() diff --git a/src/infrastructure/repositories/vector/chroma/chunk_repository.py b/src/infrastructure/repositories/vector/chroma/chunk_repository.py index a8ba6745..03bdd94d 100644 --- a/src/infrastructure/repositories/vector/chroma/chunk_repository.py +++ b/src/infrastructure/repositories/vector/chroma/chunk_repository.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/repositories/vector/faiss/chunk_repository.py b/src/infrastructure/repositories/vector/faiss/chunk_repository.py index 5f8efc58..7383f6e7 100644 --- a/src/infrastructure/repositories/vector/faiss/chunk_repository.py +++ b/src/infrastructure/repositories/vector/faiss/chunk_repository.py @@ -1,7 +1,7 @@ import json import os from datetime import datetime -from typing import List, Optional, Any +from typing import Any, List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/repositories/vector/models/chunk_model.py b/src/infrastructure/repositories/vector/models/chunk_model.py index db0eda83..25dd8d59 100644 --- a/src/infrastructure/repositories/vector/models/chunk_model.py +++ b/src/infrastructure/repositories/vector/models/chunk_model.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional from uuid import UUID, uuid4 from pydantic import BaseModel, Field diff --git a/src/infrastructure/repositories/vector/qdrant/chunk_repository.py b/src/infrastructure/repositories/vector/qdrant/chunk_repository.py index 48432e6c..5aa0b874 100644 --- a/src/infrastructure/repositories/vector/qdrant/chunk_repository.py +++ b/src/infrastructure/repositories/vector/qdrant/chunk_repository.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional, Any, Dict, cast, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID from qdrant_client.http import models as rest diff --git a/src/infrastructure/repositories/vector/qdrant/connector.py b/src/infrastructure/repositories/vector/qdrant/connector.py index 348bcadb..5cefa5a4 100644 --- a/src/infrastructure/repositories/vector/qdrant/connector.py +++ b/src/infrastructure/repositories/vector/qdrant/connector.py @@ -2,6 +2,7 @@ from typing import Optional from qdrant_client import QdrantClient + from src.config.logger import Logger warnings.filterwarnings( diff --git a/src/infrastructure/repositories/vector/weaviate/chunk_repository.py b/src/infrastructure/repositories/vector/weaviate/chunk_repository.py index e7c4d314..fee3cd75 100644 --- a/src/infrastructure/repositories/vector/weaviate/chunk_repository.py +++ b/src/infrastructure/repositories/vector/weaviate/chunk_repository.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/repositories/vector/weaviate/weaviate_client.py b/src/infrastructure/repositories/vector/weaviate/weaviate_client.py index 587aa134..7757c837 100644 --- a/src/infrastructure/repositories/vector/weaviate/weaviate_client.py +++ b/src/infrastructure/repositories/vector/weaviate/weaviate_client.py @@ -1,4 +1,5 @@ from typing import Optional + from src.config.logger import Logger from src.config.settings import VectorConfig diff --git a/src/infrastructure/repositories/vector/weaviate/weaviate_vector.py b/src/infrastructure/repositories/vector/weaviate/weaviate_vector.py index 116af407..5ec5b332 100644 --- a/src/infrastructure/repositories/vector/weaviate/weaviate_vector.py +++ b/src/infrastructure/repositories/vector/weaviate/weaviate_vector.py @@ -2,7 +2,6 @@ from src.infrastructure.repositories.vector.weaviate.weaviate_client import ( WeaviateClient, ) - from src.infrastructure.services.embedding_service import EmbeddingService logger = Logger() diff --git a/src/infrastructure/services/auth_service.py b/src/infrastructure/services/auth_service.py index 7b634bea..d022f56f 100644 --- a/src/infrastructure/services/auth_service.py +++ b/src/infrastructure/services/auth_service.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional import httpx -from jose import jwt, JWTError +from jose import JWTError, jwt from src.config.settings import settings from src.domain.entities.user import User as UserEntity diff --git a/src/infrastructure/services/chunk_vector_service.py b/src/infrastructure/services/chunk_vector_service.py index d2613f08..01f301e3 100644 --- a/src/infrastructure/services/chunk_vector_service.py +++ b/src/infrastructure/services/chunk_vector_service.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Any +from typing import Any, List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/services/content_source_service.py b/src/infrastructure/services/content_source_service.py index 527e839e..5fb47d16 100644 --- a/src/infrastructure/services/content_source_service.py +++ b/src/infrastructure/services/content_source_service.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/services/ingestion_job_service.py b/src/infrastructure/services/ingestion_job_service.py index acefd161..769cfc6f 100644 --- a/src/infrastructure/services/ingestion_job_service.py +++ b/src/infrastructure/services/ingestion_job_service.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/services/knowledge_subject_service.py b/src/infrastructure/services/knowledge_subject_service.py index c5dba81d..16c374a5 100644 --- a/src/infrastructure/services/knowledge_subject_service.py +++ b/src/infrastructure/services/knowledge_subject_service.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional from uuid import UUID from src.config.logger import Logger diff --git a/src/infrastructure/services/model_loader_service.py b/src/infrastructure/services/model_loader_service.py index 1ef75368..4b1e35c1 100644 --- a/src/infrastructure/services/model_loader_service.py +++ b/src/infrastructure/services/model_loader_service.py @@ -1,10 +1,10 @@ import logging import threading -from typing import Any, Dict, ClassVar, Optional +from typing import Any, ClassVar, Dict, Optional import torch import whisperx -from pyannote.audio import Model, Inference +from pyannote.audio import Inference, Model from sentence_transformers import SentenceTransformer from src.domain.interfaces.services.mode_loader_service import IModelLoaderService diff --git a/src/infrastructure/services/pyannote_voice_recognizer.py b/src/infrastructure/services/pyannote_voice_recognizer.py index 9086bd74..b2e8d458 100644 --- a/src/infrastructure/services/pyannote_voice_recognizer.py +++ b/src/infrastructure/services/pyannote_voice_recognizer.py @@ -1,14 +1,16 @@ import os import time from pathlib import Path + import numpy as np -from src.domain.entities.voice import MatchResult, BatchResult + +from src.domain.entities.voice import BatchResult, MatchResult +from src.infrastructure.services.model_loader_service import model_loader from src.infrastructure.utils.audio_utils import ( - load_audio_tensor, cosine_similarity, get_best_device, + load_audio_tensor, ) -from src.infrastructure.services.model_loader_service import model_loader class VoiceRecognizer: diff --git a/src/infrastructure/services/re_rank_service.py b/src/infrastructure/services/re_rank_service.py index 959ec906..429948e1 100644 --- a/src/infrastructure/services/re_rank_service.py +++ b/src/infrastructure/services/re_rank_service.py @@ -1,7 +1,9 @@ import os import tempfile from typing import List + from flashrank import Ranker, RerankRequest + from src.config.logger import Logger from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel diff --git a/src/infrastructure/services/redis_task_queue_service.py b/src/infrastructure/services/redis_task_queue_service.py index 19c215f0..d4a8a04d 100644 --- a/src/infrastructure/services/redis_task_queue_service.py +++ b/src/infrastructure/services/redis_task_queue_service.py @@ -176,17 +176,17 @@ def remove_task_by_index(self, index: int) -> Optional[dict]: def _deserialize_args(self, raw_args: list) -> list: """Reconstruct dataclass command objects from serialized dicts.""" + from src.application.dtos.commands.ingest_diarization_command import ( + IngestDiarizationCommand, + ) from src.application.dtos.commands.ingest_file_command import IngestFileCommand + from src.application.dtos.commands.ingest_web_command import IngestWebCommand from src.application.dtos.commands.ingest_youtube_command import ( IngestYoutubeCommand, ) - from src.application.dtos.commands.ingest_web_command import IngestWebCommand from src.application.dtos.commands.process_audio_command import ( ProcessAudioCommand, ) - from src.application.dtos.commands.ingest_diarization_command import ( - IngestDiarizationCommand, - ) command_classes = { "IngestFileCommand": IngestFileCommand, diff --git a/src/infrastructure/services/voice_profile_service.py b/src/infrastructure/services/voice_profile_service.py index 08d7d973..ce673b57 100644 --- a/src/infrastructure/services/voice_profile_service.py +++ b/src/infrastructure/services/voice_profile_service.py @@ -1,8 +1,8 @@ import datetime -from contextlib import suppress import logging import os import uuid +from contextlib import suppress from typing import cast from urllib.parse import unquote @@ -27,8 +27,8 @@ def __init__(self, db: Session, hf_token: str): def _get_inference(self): if self._inference is None: - from pyannote.audio import Model, Inference import torch + from pyannote.audio import Inference, Model model = Model.from_pretrained( "pyannote/wespeaker-voxceleb-resnet34-LM", use_auth_token=self.hf_token diff --git a/src/infrastructure/services/whisperx_audio_diarizer.py b/src/infrastructure/services/whisperx_audio_diarizer.py index e94becd0..247edabe 100644 --- a/src/infrastructure/services/whisperx_audio_diarizer.py +++ b/src/infrastructure/services/whisperx_audio_diarizer.py @@ -5,9 +5,9 @@ import torch import whisperx -from src.domain.entities.diarization import Segment, DiarizationResult -from src.infrastructure.utils.audio_utils import load_whisperx_audio, get_best_device +from src.domain.entities.diarization import DiarizationResult, Segment from src.infrastructure.services.model_loader_service import model_loader +from src.infrastructure.utils.audio_utils import get_best_device, load_whisperx_audio logger = logging.getLogger(__name__) diff --git a/src/infrastructure/services/youtube_data_process_service.py b/src/infrastructure/services/youtube_data_process_service.py index c22e6549..922c3bea 100644 --- a/src/infrastructure/services/youtube_data_process_service.py +++ b/src/infrastructure/services/youtube_data_process_service.py @@ -1,6 +1,6 @@ import math from contextlib import suppress -from typing import Literal, List, Tuple, Dict, Optional +from typing import Dict, List, Literal, Optional, Tuple from langchain_core.documents import Document from youtube_transcript_api import FetchedTranscript diff --git a/src/presentation/api/dependencies.py b/src/presentation/api/dependencies.py index ba693dbd..2c2b9bee 100644 --- a/src/presentation/api/dependencies.py +++ b/src/presentation/api/dependencies.py @@ -6,12 +6,12 @@ from src.application.ingestion_context import IngestionContext from src.application.use_cases.auth_use_case import AuthUseCase from src.application.use_cases.content_source_use_case import ContentSourceUseCase -from src.application.use_cases.diarization_ingestion_use_case import ( - DiarizationIngestionUseCase, -) from src.application.use_cases.delete_diarization_use_case import ( DeleteDiarizationUseCase, ) +from src.application.use_cases.diarization_ingestion_use_case import ( + DiarizationIngestionUseCase, +) from src.application.use_cases.file_ingestion_use_case import FileIngestionUseCase from src.application.use_cases.generate_speaker_audio_access_url import ( GenerateSpeakerAudioAccessUrlUseCase, @@ -22,12 +22,12 @@ from src.application.use_cases.knowledge_subject_use_case import KnowledgeSubjectUseCase from src.application.use_cases.list_s3_audio_files import ListS3AudioFilesUseCase from src.application.use_cases.manage_voice_profiles import ( - RegisterNewVoiceProfileUseCase, - TrainVoiceProfileFromSpeakerSegmentUseCase, - ListRegisteredVoiceProfilesUseCase, + DeleteVoiceAudioFileUseCase, DeleteVoiceProfileUseCase, + ListRegisteredVoiceProfilesUseCase, ListVoiceAudioFilesUseCase, - DeleteVoiceAudioFileUseCase, + RegisterNewVoiceProfileUseCase, + TrainVoiceProfileFromSpeakerSegmentUseCase, ) from src.application.use_cases.retrieve_processed_audio_history import ( RetrieveProcessedAudioHistoryUseCase, @@ -50,6 +50,9 @@ from src.infrastructure.repositories.sql.content_source_repository import ( ContentSourceSQLRepository, ) +from src.infrastructure.repositories.sql.diarization_repository import ( + DiarizationRepository, +) from src.infrastructure.repositories.sql.ingestion_job_repository import ( IngestionJobSQLRepository, ) @@ -98,6 +101,14 @@ def get_job_repo() -> IngestionJobSQLRepository: return IngestionJobSQLRepository() +def get_diarization_repo(db: Session = Depends(get_db)) -> DiarizationRepository: + from src.infrastructure.repositories.sql.diarization_repository import ( + DiarizationRepository, + ) + + return DiarizationRepository(db) + + def get_subject_repo() -> KnowledgeSubjectSQLRepository: return KnowledgeSubjectSQLRepository() @@ -229,7 +240,10 @@ def get_vector_repository( except ImportError as e: from fastapi import HTTPException - error_msg = f"Vector driver for {settings.vector.store_type} is not installed: {e}. Please run 'pip install qdrant-client' (or the appropriate driver)." + error_msg = ( + f"Vector driver for {settings.vector.store_type} is not installed: {e}. " + f"Please run 'pip install qdrant-client' (or the appropriate driver)." + ) from src.config.logger import Logger Logger().error(error_msg, context={"store_type": settings.vector.store_type}) diff --git a/src/presentation/api/middleware/trace_middleware.py b/src/presentation/api/middleware/trace_middleware.py index a5594680..87c04a9b 100644 --- a/src/presentation/api/middleware/trace_middleware.py +++ b/src/presentation/api/middleware/trace_middleware.py @@ -1,9 +1,11 @@ import uuid + from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware + from src.infrastructure.loggers.std_logger import ( - set_global_context, clear_global_context, + set_global_context, ) diff --git a/src/presentation/api/routes/audio_diarization_and_recognition_router.py b/src/presentation/api/routes/audio_diarization_and_recognition_router.py index 2e0ed8f3..ca1b7f13 100644 --- a/src/presentation/api/routes/audio_diarization_and_recognition_router.py +++ b/src/presentation/api/routes/audio_diarization_and_recognition_router.py @@ -1,21 +1,21 @@ import logging import traceback -from typing import Annotated, Any, cast, Optional +from typing import Annotated, Any, Optional, cast from uuid import UUID -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from src.application.dtos.commands.process_audio_command import ProcessAudioCommand +from src.application.use_cases.delete_diarization_use_case import ( + DeleteDiarizationUseCase, +) from src.application.use_cases.generate_speaker_audio_access_url import ( GenerateSpeakerAudioAccessUrlUseCase, ) from src.application.use_cases.identify_speakers_in_processed_audio import ( IdentifySpeakersInProcessedAudioUseCase, ) -from src.application.use_cases.delete_diarization_use_case import ( - DeleteDiarizationUseCase, -) from src.application.use_cases.list_s3_audio_files import ListS3AudioFilesUseCase from src.application.use_cases.retrieve_processed_audio_history import ( RetrieveProcessedAudioHistoryUseCase, @@ -25,12 +25,12 @@ from src.domain.interfaces.services.i_task_queue import ITaskQueue from src.presentation.api.dependencies import ( get_db, - get_task_queue_service, + get_delete_diarization_use_case, get_generate_speaker_url_use_case, get_identify_speakers_use_case, get_list_s3_files_use_case, get_retrieve_history_use_case, - get_delete_diarization_use_case, + get_task_queue_service, ) from src.presentation.api.schemas.audio_processing_requests import ( AudioProcessingRequest, @@ -53,10 +53,11 @@ async def update_diarization_segments( Updates the segments of a diarization and marks it as completed. This is used when the user confirms the final transcript. """ + from typing import cast + from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) - from typing import cast try: repo = DiarizationRepository(db) @@ -72,13 +73,13 @@ async def update_diarization_segments( # Trigger ingestion for ContentSource try: - from src.infrastructure.repositories.sql.content_source_repository import ( - ContentSourceSQLRepository, - ) from src.application.dtos.commands.ingest_diarization_command import ( IngestDiarizationCommand, ) from src.application.workers import run_diarization_ingestion_worker + from src.infrastructure.repositories.sql.content_source_repository import ( + ContentSourceSQLRepository, + ) cs_repo = ContentSourceSQLRepository() target_source = cs_repo.get_by_diarization_id(diarization_id) diff --git a/src/presentation/api/routes/auth_router.py b/src/presentation/api/routes/auth_router.py index 7f5359a6..61525290 100644 --- a/src/presentation/api/routes/auth_router.py +++ b/src/presentation/api/routes/auth_router.py @@ -1,15 +1,16 @@ -from typing import Optional, Annotated +from typing import Annotated, Optional + from fastapi import APIRouter, Depends, HTTPException, Query from src.application.use_cases.auth_use_case import AuthUseCase +from src.config.settings import Settings +from src.domain.entities.user import User +from src.domain.exception.auth_exceptions import AuthDomainError from src.presentation.api.dependencies import ( get_auth_use_case, - get_settings, get_current_user, + get_settings, ) -from src.config.settings import Settings -from src.domain.exception.auth_exceptions import AuthDomainError -from src.domain.entities.user import User router = APIRouter() diff --git a/src/presentation/api/routes/chunk_router.py b/src/presentation/api/routes/chunk_router.py index b703d35b..bf91c225 100644 --- a/src/presentation/api/routes/chunk_router.py +++ b/src/presentation/api/routes/chunk_router.py @@ -1,7 +1,7 @@ from typing import Annotated, List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Query, Body +from fastapi import APIRouter, Body, Depends, HTTPException, Query from src.config.logger import Logger from src.infrastructure.services.chunk_index_service import ChunkIndexService diff --git a/src/presentation/api/routes/ingest_router.py b/src/presentation/api/routes/ingest_router.py index 9ec2b695..c4c19f95 100644 --- a/src/presentation/api/routes/ingest_router.py +++ b/src/presentation/api/routes/ingest_router.py @@ -4,48 +4,46 @@ from uuid import UUID import anyio -from fastapi import APIRouter, Body, Depends, HTTPException -from fastapi import UploadFile, File, Form +from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, UploadFile from src.application.dtos.commands.ingest_file_command import IngestFileCommand from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand from src.application.dtos.enums.youtube_data_type import YoutubeDataType -from src.application.use_cases.file_ingestion_use_case import FileIngestionUseCase -from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase -from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase -from src.infrastructure.services.content_source_service import ContentSourceService from src.application.use_cases.diarization_ingestion_use_case import ( DiarizationIngestionUseCase, ) +from src.application.use_cases.file_ingestion_use_case import FileIngestionUseCase +from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase +from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase from src.application.workers import ( + run_diarization_ingestion_worker, run_file_ingestion_worker, - run_youtube_ingestion_worker, - run_youtube_dispatcher_worker, run_web_ingestion_worker, - run_diarization_ingestion_worker, + run_youtube_dispatcher_worker, + run_youtube_ingestion_worker, ) - from src.config.logger import Logger +from src.domain.entities.user import User from src.domain.interfaces.services.i_task_queue import ITaskQueue +from src.infrastructure.services.content_source_service import ContentSourceService from src.presentation.api.dependencies import ( - get_ingest_youtube_use_case, - get_file_ingestion_use_case, - get_web_scraping_use_case, - get_diarization_ingestion_use_case, - get_task_queue_service, get_cs_service, get_current_user, + get_diarization_ingestion_use_case, + get_file_ingestion_use_case, + get_ingest_youtube_use_case, + get_task_queue_service, + get_web_scraping_use_case, ) -from src.domain.entities.user import User from src.presentation.api.schemas.ingest_schemas import ( - IngestResponse, - YoutubeIngestRequest, - FileUrlIngestRequest, - WebIngestRequest, - DiarizationIngestRequest, ChannelPreviewRequest, ChannelPreviewResponse, ChannelVideoItem, + DiarizationIngestRequest, + FileUrlIngestRequest, + IngestResponse, + WebIngestRequest, + YoutubeIngestRequest, ) logger = Logger() @@ -465,8 +463,8 @@ def preview_youtube_channel( raise HTTPException(status_code=400, detail="channel_url is required") try: - from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor from src.domain.entities.enums.source_type_enum_entity import SourceType + from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor extractor = YoutubeExtractor() videos, channel_name = extractor.extract_channel_videos( diff --git a/src/presentation/api/routes/notification_router.py b/src/presentation/api/routes/notification_router.py index 90ad7a2e..799e9384 100644 --- a/src/presentation/api/routes/notification_router.py +++ b/src/presentation/api/routes/notification_router.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import AsyncGenerator, Annotated +from typing import Annotated, AsyncGenerator from fastapi import APIRouter, Depends, Request from sse_starlette.sse import EventSourceResponse diff --git a/src/presentation/api/routes/settings_router.py b/src/presentation/api/routes/settings_router.py index e254fec7..71b7c53b 100644 --- a/src/presentation/api/routes/settings_router.py +++ b/src/presentation/api/routes/settings_router.py @@ -2,21 +2,21 @@ from typing import Annotated import sqlalchemy -from src.config.settings import Settings +from fastapi import APIRouter, Depends, HTTPException, Request -from fastapi import Depends, APIRouter, HTTPException, Request -from src.presentation.api.dependencies import get_vector_repository, get_settings +from src.config.settings import Settings from src.domain.interfaces.repository.retriver_repository import IVectorRepository +from src.infrastructure.repositories.sql.connector import Connector +from src.presentation.api.dependencies import get_settings, get_vector_repository from src.presentation.api.schemas.settings_schemas import ( - SettingsResponse, AppSettingsSchema, - VectorSettingsSchema, + HealthCheckResponse, ModelSettingsSchema, - SQLSettingsSchema, RedisSettingsSchema, - HealthCheckResponse, + SettingsResponse, + SQLSettingsSchema, + VectorSettingsSchema, ) -from src.infrastructure.repositories.sql.connector import Connector router = APIRouter() diff --git a/src/presentation/api/routes/source_router.py b/src/presentation/api/routes/source_router.py index a373ca77..59f1ebcc 100644 --- a/src/presentation/api/routes/source_router.py +++ b/src/presentation/api/routes/source_router.py @@ -1,23 +1,22 @@ from typing import Annotated, List -from fastapi import Depends, HTTPException, APIRouter +from fastapi import APIRouter, Depends, HTTPException +from src.application.use_cases.content_source_use_case import ContentSourceUseCase from src.config.logger import Logger +from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.services.content_source_service import ContentSourceService from src.infrastructure.services.model_loader_service import ModelLoaderService from src.presentation.api.dependencies import ( - get_cs_service, - get_model_loader, get_content_source_use_case, + get_cs_service, get_event_bus, + get_model_loader, ) -from src.domain.interfaces.services.i_event_bus import IEventBus -from src.domain.entities.enums.source_type_enum_entity import SourceType -from src.application.use_cases.content_source_use_case import ContentSourceUseCase from src.presentation.api.schemas.model_schemas import ModelInfoResponse from src.presentation.api.schemas.source_schemas import SourceResponse, SourceUpdate - logger = Logger() router = APIRouter() diff --git a/src/presentation/api/routes/subject_router.py b/src/presentation/api/routes/subject_router.py index 2bd77507..d01d03d2 100644 --- a/src/presentation/api/routes/subject_router.py +++ b/src/presentation/api/routes/subject_router.py @@ -3,17 +3,17 @@ from fastapi import APIRouter, Body, Depends, HTTPException +from src.application.use_cases.knowledge_subject_use_case import KnowledgeSubjectUseCase from src.config.logger import Logger +from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.services.knowledge_subject_service import ( KnowledgeSubjectService, ) -from src.application.use_cases.knowledge_subject_use_case import KnowledgeSubjectUseCase from src.presentation.api.dependencies import ( + get_event_bus, get_ks_service, get_ks_use_case, - get_event_bus, ) -from src.domain.interfaces.services.i_event_bus import IEventBus from src.presentation.api.schemas.subject_schemas import ( SubjectCreate, SubjectResponse, diff --git a/src/presentation/api/routes/voice_profile_management_router.py b/src/presentation/api/routes/voice_profile_management_router.py index dc00eda5..5496a46a 100644 --- a/src/presentation/api/routes/voice_profile_management_router.py +++ b/src/presentation/api/routes/voice_profile_management_router.py @@ -1,29 +1,36 @@ -from typing import Annotated import os import tempfile -import anyio +from typing import Annotated -from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form +import anyio +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile -from src.infrastructure.repositories.storage.storage import StorageService +from src.application.dtos.commands.train_voice_command import TrainVoiceCommand from src.application.use_cases.manage_voice_profiles import ( - RegisterNewVoiceProfileUseCase, - ListRegisteredVoiceProfilesUseCase, + DeleteVoiceAudioFileUseCase, DeleteVoiceProfileUseCase, - TrainVoiceProfileFromSpeakerSegmentUseCase, + ListRegisteredVoiceProfilesUseCase, ListVoiceAudioFilesUseCase, - DeleteVoiceAudioFileUseCase, + RegisterNewVoiceProfileUseCase, ) +from src.application.workers import run_voice_training_worker +from src.domain.entities.enums.diarization_status_enum import DiarizationStatus +from src.domain.interfaces.services.i_event_bus import IEventBus +from src.domain.interfaces.services.i_task_queue import ITaskQueue +from src.infrastructure.repositories.sql.diarization_repository import ( + DiarizationRepository, +) +from src.infrastructure.repositories.storage.storage import StorageService from src.presentation.api.dependencies import ( - get_register_voice_profile_use_case, - get_list_voice_profiles_use_case, - get_delete_voice_profile_use_case, - get_train_voice_from_speaker_use_case, - get_list_voice_audio_files_use_case, get_delete_voice_audio_file_use_case, + get_delete_voice_profile_use_case, + get_diarization_repo, get_event_bus, + get_list_voice_audio_files_use_case, + get_list_voice_profiles_use_case, + get_register_voice_profile_use_case, + get_task_queue_service, ) -from src.domain.interfaces.services.i_event_bus import IEventBus from src.presentation.api.schemas.voice_profile_requests import ( VoiceProfileRegistrationRequest, VoiceProfileTrainingFromSpeakerRequest, @@ -88,7 +95,9 @@ async def upload_and_register_new_voice_profile( @router.post( "/train-from-speaker", + status_code=202, responses={ + 202: {"description": "Accepted - Training started in background"}, 400: {"description": "Bad Request"}, 404: {"description": "Not Found"}, 500: {"description": "Internal Server Error"}, @@ -97,27 +106,55 @@ async def upload_and_register_new_voice_profile( async def train_voice_profile_from_existing_speaker_segment( request: VoiceProfileTrainingFromSpeakerRequest, event_bus: Annotated[IEventBus, Depends(get_event_bus)], - use_case: Annotated[ - TrainVoiceProfileFromSpeakerSegmentUseCase, - Depends(get_train_voice_from_speaker_use_case), - ], + task_queue: Annotated[ITaskQueue, Depends(get_task_queue_service)], + repo: Annotated[DiarizationRepository, Depends(get_diarization_repo)], ): try: - voice_id = use_case.execute( - diarization_id=request.diarization_id, - speaker_label=request.speaker_label, - name=request.name, + # 1. Verification + record = repo.get_by_id(request.diarization_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Diarization not found: {request.diarization_id}" + ) + + # 2. Update status to TRAINING immediately + repo.update_status( + request.diarization_id, + DiarizationStatus.TRAINING.value, + status_message=f"Preparando treinamento de voz: {request.name}", ) - # Notify + + # 3. Notify frontend event_bus.publish( "ingestion_status", - {"type": "voice", "action": "train", "name": request.name}, + { + "type": "diarization", + "id": request.diarization_id, + "status": DiarizationStatus.TRAINING.value, + "message": f"Iniciando treinamento de voz '{request.name}'...", + }, ) - return {"status": "success", "voice_id": voice_id, "name": request.name} - except ValueError as e: - raise HTTPException( - status_code=404 if "not found" in str(e) else 400, detail=str(e) + + # 4. Enqueue background task + cmd = TrainVoiceCommand( + diarization_id=request.diarization_id, + speaker_label=request.speaker_label, + name=request.name, + ) + + task_queue.enqueue( + run_voice_training_worker, + cmd, + task_title=f"Treino de Voz: {request.name}", ) + + return { + "status": "success", + "message": "Treinamento de voz iniciado em segundo plano", + "name": request.name, + } + except HTTPException: + raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/presentation/api/schemas/audio_processing_responses.py b/src/presentation/api/schemas/audio_processing_responses.py index 9286a614..7bd58b7b 100644 --- a/src/presentation/api/schemas/audio_processing_responses.py +++ b/src/presentation/api/schemas/audio_processing_responses.py @@ -1,5 +1,6 @@ +from typing import Dict, List, Optional + from pydantic import BaseModel -from typing import List, Optional, Dict class AudioSegmentSchema(BaseModel): diff --git a/src/presentation/api/schemas/chunk_schemas.py b/src/presentation/api/schemas/chunk_schemas.py index 3d5d9d05..44232c33 100644 --- a/src/presentation/api/schemas/chunk_schemas.py +++ b/src/presentation/api/schemas/chunk_schemas.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Optional from uuid import UUID + from pydantic import BaseModel, ConfigDict diff --git a/src/presentation/api/schemas/job_schemas.py b/src/presentation/api/schemas/job_schemas.py index 6e3ad575..40f94856 100644 --- a/src/presentation/api/schemas/job_schemas.py +++ b/src/presentation/api/schemas/job_schemas.py @@ -1,6 +1,7 @@ from datetime import datetime -from typing import Optional, List +from typing import List, Optional from uuid import UUID + from pydantic import BaseModel, ConfigDict diff --git a/src/presentation/api/schemas/search_schemas.py b/src/presentation/api/schemas/search_schemas.py index addf4d1a..3e61a1d1 100644 --- a/src/presentation/api/schemas/search_schemas.py +++ b/src/presentation/api/schemas/search_schemas.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from src.domain.entities.enums.search_mode_enum import SearchMode diff --git a/src/presentation/api/schemas/source_schemas.py b/src/presentation/api/schemas/source_schemas.py index 07a56e67..5d0421c3 100644 --- a/src/presentation/api/schemas/source_schemas.py +++ b/src/presentation/api/schemas/source_schemas.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Optional from uuid import UUID + from pydantic import BaseModel, ConfigDict diff --git a/tests/application/test_audio_diarization_workers.py b/tests/application/test_audio_diarization_workers.py index 8b52ddee..75883d8e 100644 --- a/tests/application/test_audio_diarization_workers.py +++ b/tests/application/test_audio_diarization_workers.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock, patch + import pytest -from src.application.workers import run_audio_diarization_dispatcher_worker + from src.application.dtos.commands.process_audio_command import ProcessAudioCommand +from src.application.workers import run_audio_diarization_dispatcher_worker @pytest.mark.AudioDiarizationWorker diff --git a/tests/application/test_workers.py b/tests/application/test_workers.py index ac3f2835..b85dee87 100644 --- a/tests/application/test_workers.py +++ b/tests/application/test_workers.py @@ -1,15 +1,17 @@ +import asyncio from unittest.mock import ANY, AsyncMock, MagicMock, patch + import pytest -import asyncio + +from src.application.dtos.commands.ingest_file_command import IngestFileCommand +from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand +from src.application.dtos.enums.youtube_data_type import YoutubeDataType from src.application.service_registry import registry from src.application.workers import ( run_file_ingestion_worker, - run_youtube_ingestion_worker, run_web_ingestion_worker, + run_youtube_ingestion_worker, ) -from src.application.dtos.commands.ingest_file_command import IngestFileCommand -from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand -from src.application.dtos.enums.youtube_data_type import YoutubeDataType @pytest.mark.Workers @@ -187,11 +189,12 @@ def side_effect(coro): mock_logger.error.assert_called_once() def test_run_diarization_ingestion_worker_success(self): - from src.application.workers import run_diarization_ingestion_worker + from uuid import uuid4 + from src.application.dtos.commands.ingest_diarization_command import ( IngestDiarizationCommand, ) - from uuid import uuid4 + from src.application.workers import run_diarization_ingestion_worker with ( patch( @@ -276,10 +279,10 @@ def test_audio_diarization_subprocess_success(self): mock_use_case.execute.assert_called_once() def test_run_audio_diarization_worker_success(self): - from src.application.workers import run_audio_diarization_worker from src.application.dtos.commands.process_audio_command import ( ProcessAudioCommand, ) + from src.application.workers import run_audio_diarization_worker with ( patch("multiprocessing.get_context") as mock_get_ctx, @@ -298,10 +301,10 @@ def test_run_audio_diarization_worker_success(self): mock_process.join.assert_called_once() def test_run_audio_diarization_worker_failure(self): - from src.application.workers import run_audio_diarization_worker from src.application.dtos.commands.process_audio_command import ( ProcessAudioCommand, ) + from src.application.workers import run_audio_diarization_worker with ( patch("multiprocessing.get_context") as mock_get_ctx, diff --git a/tests/application/use_cases/test_audio_recognition_use_cases.py b/tests/application/use_cases/test_audio_recognition_use_cases.py index 6b85fb3c..454a850f 100644 --- a/tests/application/use_cases/test_audio_recognition_use_cases.py +++ b/tests/application/use_cases/test_audio_recognition_use_cases.py @@ -1,6 +1,8 @@ -import pytest +# ruff: noqa: E402 from unittest.mock import MagicMock, patch +import pytest + # Module-level patch for boto3 to prevent botocore initialization during class/module loading patch("boto3.Session").start() patch("boto3.client").start() @@ -12,13 +14,15 @@ from src.application.use_cases.identify_speakers_in_processed_audio import ( # noqa: E402 IdentifySpeakersInProcessedAudioUseCase, ) -from src.application.use_cases.list_s3_audio_files import ListS3AudioFilesUseCase # noqa: E402 +from src.application.use_cases.list_s3_audio_files import ( + ListS3AudioFilesUseCase, # noqa: E402 +) from src.application.use_cases.manage_voice_profiles import ( # noqa: E402 - RegisterNewVoiceProfileUseCase, + DeleteVoiceAudioFileUseCase, DeleteVoiceProfileUseCase, ListRegisteredVoiceProfilesUseCase, ListVoiceAudioFilesUseCase, - DeleteVoiceAudioFileUseCase, + RegisterNewVoiceProfileUseCase, ) from src.infrastructure.repositories.sql.models.diarization_record import ( # noqa: E402 DiarizationRecord, diff --git a/tests/application/use_cases/test_auth_use_case.py b/tests/application/use_cases/test_auth_use_case.py index 29ed6281..5311e4a7 100644 --- a/tests/application/use_cases/test_auth_use_case.py +++ b/tests/application/use_cases/test_auth_use_case.py @@ -1,10 +1,12 @@ +from unittest.mock import AsyncMock, MagicMock + import pytest -from unittest.mock import MagicMock, AsyncMock + from src.application.use_cases.auth_use_case import AuthUseCase from src.domain.entities.user import User from src.domain.exception.auth_exceptions import ( - InvalidStateError, GoogleAuthError, + InvalidStateError, UserNotCreatedError, ) diff --git a/tests/application/use_cases/test_content_source_use_case.py b/tests/application/use_cases/test_content_source_use_case.py index 159d78e5..38ed4a39 100644 --- a/tests/application/use_cases/test_content_source_use_case.py +++ b/tests/application/use_cases/test_content_source_use_case.py @@ -1,6 +1,8 @@ -import pytest import uuid from unittest.mock import MagicMock + +import pytest + from src.application.use_cases.content_source_use_case import ContentSourceUseCase diff --git a/tests/application/use_cases/test_delete_diarization_use_case.py b/tests/application/use_cases/test_delete_diarization_use_case.py index 6e5ffbb7..125f7b37 100644 --- a/tests/application/use_cases/test_delete_diarization_use_case.py +++ b/tests/application/use_cases/test_delete_diarization_use_case.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import patch +import pytest + from src.application.use_cases.delete_diarization_use_case import ( DeleteDiarizationUseCase, ) diff --git a/tests/application/use_cases/test_diarization_ingestion_use_case.py b/tests/application/use_cases/test_diarization_ingestion_use_case.py index 10633c2d..3332ef90 100644 --- a/tests/application/use_cases/test_diarization_ingestion_use_case.py +++ b/tests/application/use_cases/test_diarization_ingestion_use_case.py @@ -1,20 +1,22 @@ -import pytest -from langchain_core.documents import Document -from src.domain.entities.enums.source_type_enum_entity import SourceType from unittest.mock import MagicMock from uuid import uuid4 -from src.application.use_cases.diarization_ingestion_use_case import ( - DiarizationIngestionUseCase, -) + +import pytest +from langchain_core.documents import Document + from src.application.dtos.commands.ingest_diarization_command import ( IngestDiarizationCommand, ) -from src.infrastructure.repositories.sql.models.diarization_record import ( - DiarizationRecord, +from src.application.use_cases.diarization_ingestion_use_case import ( + DiarizationIngestionUseCase, ) +from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) +from src.infrastructure.repositories.sql.models.diarization_record import ( + DiarizationRecord, +) @pytest.mark.DiarizationIngestion diff --git a/tests/application/use_cases/test_file_ingestion_use_case.py b/tests/application/use_cases/test_file_ingestion_use_case.py index 4fdeb053..c657de73 100644 --- a/tests/application/use_cases/test_file_ingestion_use_case.py +++ b/tests/application/use_cases/test_file_ingestion_use_case.py @@ -1,12 +1,14 @@ -import pytest -from langchain_core.documents import Document from unittest.mock import MagicMock from uuid import uuid4 -from src.application.use_cases.file_ingestion_use_case import FileIngestionUseCase + +import pytest +from langchain_core.documents import Document + from src.application.dtos.commands.ingest_file_command import IngestFileCommand -from src.domain.entities.enums.source_type_enum_entity import SourceType -from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus +from src.application.use_cases.file_ingestion_use_case import FileIngestionUseCase from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus +from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus +from src.domain.entities.enums.source_type_enum_entity import SourceType @pytest.mark.FileIngestionUseCase diff --git a/tests/application/use_cases/test_knowledge_subject_use_case.py b/tests/application/use_cases/test_knowledge_subject_use_case.py index 1633edaf..731d1e49 100644 --- a/tests/application/use_cases/test_knowledge_subject_use_case.py +++ b/tests/application/use_cases/test_knowledge_subject_use_case.py @@ -1,10 +1,12 @@ -import pytest import uuid +from datetime import datetime, timezone from unittest.mock import MagicMock + +import pytest + from src.application.use_cases.knowledge_subject_use_case import KnowledgeSubjectUseCase -from src.domain.entities.knowledge_subject_entity import KnowledgeSubjectEntity from src.domain.entities.content_source_entity import ContentSourceEntity -from datetime import datetime, timezone +from src.domain.entities.knowledge_subject_entity import KnowledgeSubjectEntity @pytest.mark.KnowledgeSubjectUseCase diff --git a/tests/application/use_cases/test_process_audio_diarization_pipeline.py b/tests/application/use_cases/test_process_audio_diarization_pipeline.py index d0a27cb4..26a7a7ec 100644 --- a/tests/application/use_cases/test_process_audio_diarization_pipeline.py +++ b/tests/application/use_cases/test_process_audio_diarization_pipeline.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import MagicMock, patch + +import pytest + from src.application.use_cases.process_audio_diarization_pipeline import ( ProcessAudioDiarizationPipelineUseCase, ) diff --git a/tests/application/use_cases/test_search_use_case.py b/tests/application/use_cases/test_search_use_case.py index da48fbdc..298fe756 100644 --- a/tests/application/use_cases/test_search_use_case.py +++ b/tests/application/use_cases/test_search_use_case.py @@ -1,7 +1,8 @@ import uuid -import pytest from types import SimpleNamespace +import pytest + from src.application.use_cases.search_use_case import SearchUseCase from src.domain.entities.enums.search_mode_enum import SearchMode diff --git a/tests/application/use_cases/test_web_scraping_use_case.py b/tests/application/use_cases/test_web_scraping_use_case.py index 961e3e11..1f6849fa 100644 --- a/tests/application/use_cases/test_web_scraping_use_case.py +++ b/tests/application/use_cases/test_web_scraping_use_case.py @@ -1,10 +1,12 @@ -import pytest import uuid from unittest.mock import AsyncMock, MagicMock -from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase + +import pytest +from langchain_core.documents import Document + from src.application.dtos.commands.ingest_web_command import IngestWebCommand +from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus -from langchain_core.documents import Document @pytest.fixture diff --git a/tests/application/use_cases/test_web_scraping_use_case_extended.py b/tests/application/use_cases/test_web_scraping_use_case_extended.py index aaa076ac..b7b1758a 100644 --- a/tests/application/use_cases/test_web_scraping_use_case_extended.py +++ b/tests/application/use_cases/test_web_scraping_use_case_extended.py @@ -1,11 +1,13 @@ -import pytest import uuid from unittest.mock import AsyncMock, MagicMock -from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase + +import pytest +from langchain_core.documents import Document + from src.application.dtos.commands.ingest_web_command import IngestWebCommand -from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus +from src.application.use_cases.web_scraping_use_case import WebScrapingUseCase from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus -from langchain_core.documents import Document +from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus @pytest.fixture diff --git a/tests/application/use_cases/test_youtube_ingestion_use_case.py b/tests/application/use_cases/test_youtube_ingestion_use_case.py index 9f26332d..8dbc1baf 100644 --- a/tests/application/use_cases/test_youtube_ingestion_use_case.py +++ b/tests/application/use_cases/test_youtube_ingestion_use_case.py @@ -1,7 +1,8 @@ -import pytest import uuid -from unittest.mock import MagicMock from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase diff --git a/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py b/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py index 4a0b0d01..5ce5db1d 100644 --- a/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py +++ b/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py @@ -1,9 +1,11 @@ -import pytest import uuid from unittest.mock import MagicMock, patch -from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase + +import pytest + from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand from src.application.dtos.enums.youtube_data_type import YoutubeDataType +from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase @pytest.mark.Dependencies @@ -366,10 +368,10 @@ def test_process_single_video_reprocess_cleanup(self, use_case, mock_services): assert mock_services["vector_service"].delete_by_video_id.called def test_known_exceptions_handling(self, use_case, mock_services): - from src.domain.exception.youtube_exceptions import YoutubeVideoPrivateException from src.domain.entities.enums.content_source_status_enum import ( ContentSourceStatus, ) + from src.domain.exception.youtube_exceptions import YoutubeVideoPrivateException video_id = "v1" # Ensure source exists so status update is called diff --git a/tests/application/use_cases/test_youtube_throttling.py b/tests/application/use_cases/test_youtube_throttling.py index e1a86b87..d06d351d 100644 --- a/tests/application/use_cases/test_youtube_throttling.py +++ b/tests/application/use_cases/test_youtube_throttling.py @@ -1,11 +1,11 @@ -import uuid import time -from unittest.mock import MagicMock +import uuid from types import SimpleNamespace +from unittest.mock import MagicMock from src.application.dtos.commands.ingest_youtube_command import IngestYoutubeCommand -from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase from src.application.dtos.enums.youtube_data_type import YoutubeDataType +from src.application.use_cases.youtube_ingestion_use_case import YoutubeIngestionUseCase class DummyDoc: @@ -40,10 +40,9 @@ def make_use_case_mocks(): def test_throttling_logic(monkeypatch): use_case = make_use_case_mocks() - from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor - # Mock settings from src.config.settings import settings + from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(settings.youtube, "throttle_batch_size", 2) monkeypatch.setattr(settings.youtube, "throttle_wait_seconds", 0.1) diff --git a/tests/config/test_logger.py b/tests/config/test_logger.py index 95a7cb46..51576042 100644 --- a/tests/config/test_logger.py +++ b/tests/config/test_logger.py @@ -1,4 +1,4 @@ -from src.config.logger import setup_logging, Logger +from src.config.logger import Logger, setup_logging def test_setup_logging(): diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py index f92cd5f2..4e055984 100644 --- a/tests/config/test_settings.py +++ b/tests/config/test_settings.py @@ -5,7 +5,7 @@ import pytest -from src.config.settings import Settings, App, SQLConfig, VectorConfig +from src.config.settings import App, Settings, SQLConfig, VectorConfig def test_allowed_log_levels_default(): diff --git a/tests/conftest.py b/tests/conftest.py index b6db2da1..859a65a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -from unittest.mock import MagicMock, patch import os +from unittest.mock import MagicMock, patch # Suppress NNPACK warnings (Unsupported hardware) os.environ["NNPACK_CPU_FAST_8x8_CONV"] = "0" @@ -40,8 +40,8 @@ def mock_infrastructure(): def mock_auth(): """Global mock for current user to avoid 401 Unauthorized in API tests.""" from main import app - from src.presentation.api.dependencies import get_current_user from src.domain.entities.user import User + from src.presentation.api.dependencies import get_current_user mock_user = User(id="admin", email="admin@whatyousaid.local", full_name="Admin") app.dependency_overrides[get_current_user] = lambda: mock_user diff --git a/tests/infrastructure/connectors/test_redis_connector.py b/tests/infrastructure/connectors/test_redis_connector.py index e1208aea..48d0102f 100644 --- a/tests/infrastructure/connectors/test_redis_connector.py +++ b/tests/infrastructure/connectors/test_redis_connector.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import patch + +import pytest + from src.infrastructure.connectors.redis_connector import RedisConnector diff --git a/tests/infrastructure/extractors/test_crawl4ai_extractor.py b/tests/infrastructure/extractors/test_crawl4ai_extractor.py index c9e52ebd..2ea7057d 100644 --- a/tests/infrastructure/extractors/test_crawl4ai_extractor.py +++ b/tests/infrastructure/extractors/test_crawl4ai_extractor.py @@ -1,8 +1,10 @@ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from src.infrastructure.extractors.crawl4ai_extractor import Crawl4AIExtractor from langchain_core.documents import Document +from src.infrastructure.extractors.crawl4ai_extractor import Crawl4AIExtractor + @pytest.mark.anyio class TestCrawl4AIExtractor: diff --git a/tests/infrastructure/extractors/test_docling_extractor.py b/tests/infrastructure/extractors/test_docling_extractor.py index e62d9e0a..426c80a9 100644 --- a/tests/infrastructure/extractors/test_docling_extractor.py +++ b/tests/infrastructure/extractors/test_docling_extractor.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock +import pytest + @pytest.mark.DoclingExtractor class TestDoclingExtractor: @@ -131,9 +132,10 @@ def test_get_file_type(self): assert extractor._get_file_type("no_ext") == "unknown" def test_extract_with_images(self, mock_converter, tmp_path): - from src.infrastructure.extractors.docling_extractor import DoclingExtractor from docling_core.types.doc import PictureItem + from src.infrastructure.extractors.docling_extractor import DoclingExtractor + test_file = tmp_path / "images.pdf" test_file.write_text("dummy") diff --git a/tests/infrastructure/extractors/test_plain_text_extractor.py b/tests/infrastructure/extractors/test_plain_text_extractor.py index 9fc21078..0dcdec99 100644 --- a/tests/infrastructure/extractors/test_plain_text_extractor.py +++ b/tests/infrastructure/extractors/test_plain_text_extractor.py @@ -1,5 +1,7 @@ +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock + from src.infrastructure.extractors.plain_text_extractor import PlainTextExtractor diff --git a/tests/infrastructure/extractors/test_youtube_extractor.py b/tests/infrastructure/extractors/test_youtube_extractor.py index 5a74bd44..cfb9b254 100644 --- a/tests/infrastructure/extractors/test_youtube_extractor.py +++ b/tests/infrastructure/extractors/test_youtube_extractor.py @@ -2,16 +2,16 @@ import pytest from youtube_transcript_api import ( - TranscriptsDisabled, NoTranscriptFound, + TranscriptsDisabled, ) from src.config.logger import Logger -from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor from src.domain.exception.youtube_exceptions import ( YoutubeTranscriptNotFoundException, YoutubeTranscriptsDisabledException, ) +from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor logger = Logger() diff --git a/tests/infrastructure/repositories/sql/test_additional_coverage.py b/tests/infrastructure/repositories/sql/test_additional_coverage.py index 774f7dce..4839a994 100644 --- a/tests/infrastructure/repositories/sql/test_additional_coverage.py +++ b/tests/infrastructure/repositories/sql/test_additional_coverage.py @@ -1,4 +1,4 @@ -from uuid import uuid4, UUID +from uuid import UUID, uuid4 import pytest diff --git a/tests/infrastructure/repositories/sql/test_chunk_index_repository.py b/tests/infrastructure/repositories/sql/test_chunk_index_repository.py index 8caa8cc0..4a3eddc0 100644 --- a/tests/infrastructure/repositories/sql/test_chunk_index_repository.py +++ b/tests/infrastructure/repositories/sql/test_chunk_index_repository.py @@ -1,6 +1,8 @@ -import pytest -from uuid import uuid4 from unittest.mock import patch +from uuid import uuid4 + +import pytest + from src.infrastructure.repositories.sql.chunk_index_repository import ( ChunkIndexSQLRepository, ) diff --git a/tests/infrastructure/repositories/sql/test_content_source_repository.py b/tests/infrastructure/repositories/sql/test_content_source_repository.py index 540c8c45..2912655c 100644 --- a/tests/infrastructure/repositories/sql/test_content_source_repository.py +++ b/tests/infrastructure/repositories/sql/test_content_source_repository.py @@ -1,6 +1,8 @@ -import pytest -from uuid import uuid4 from unittest.mock import patch +from uuid import uuid4 + +import pytest + from src.infrastructure.repositories.sql.content_source_repository import ( ContentSourceSQLRepository, ) diff --git a/tests/infrastructure/repositories/sql/test_diarization_repository.py b/tests/infrastructure/repositories/sql/test_diarization_repository.py index 51cdc05a..0ff0807c 100644 --- a/tests/infrastructure/repositories/sql/test_diarization_repository.py +++ b/tests/infrastructure/repositories/sql/test_diarization_repository.py @@ -1,8 +1,8 @@ +from src.domain.entities.diarization import DiarizationResult, Segment +from src.domain.entities.enums.diarization_status_enum import DiarizationStatus from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, ) -from src.domain.entities.diarization import DiarizationResult, Segment -from src.domain.entities.enums.diarization_status_enum import DiarizationStatus class TestDiarizationRepository: diff --git a/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py b/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py index c0f460ee..3bc9fa3c 100644 --- a/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py +++ b/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py @@ -1,12 +1,14 @@ -from uuid import uuid4, UUID +from uuid import UUID, uuid4 + import pytest -from src.infrastructure.repositories.sql.ingestion_job_repository import ( - IngestionJobSQLRepository, -) + +from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.sql.content_source_repository import ( ContentSourceSQLRepository, ) -from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.infrastructure.repositories.sql.ingestion_job_repository import ( + IngestionJobSQLRepository, +) @pytest.mark.IngestionJobRepository diff --git a/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py b/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py index e7079a5a..9523af64 100644 --- a/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py +++ b/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py @@ -1,13 +1,15 @@ -import pytest -from uuid import uuid4 from unittest.mock import patch -from src.infrastructure.repositories.sql.knowledge_subject_repository import ( - KnowledgeSubjectSQLRepository, -) +from uuid import uuid4 + +import pytest + +from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.sql.content_source_repository import ( ContentSourceSQLRepository, ) -from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.infrastructure.repositories.sql.knowledge_subject_repository import ( + KnowledgeSubjectSQLRepository, +) @pytest.mark.KnowledgeSubjectRepository diff --git a/tests/infrastructure/repositories/sql/test_repos_services.py b/tests/infrastructure/repositories/sql/test_repos_services.py index 29f986ed..465be75d 100644 --- a/tests/infrastructure/repositories/sql/test_repos_services.py +++ b/tests/infrastructure/repositories/sql/test_repos_services.py @@ -1,4 +1,4 @@ -from uuid import uuid4, UUID +from uuid import UUID, uuid4 import pytest diff --git a/tests/infrastructure/repositories/sql/test_user_repository.py b/tests/infrastructure/repositories/sql/test_user_repository.py index ae9179aa..4172d780 100644 --- a/tests/infrastructure/repositories/sql/test_user_repository.py +++ b/tests/infrastructure/repositories/sql/test_user_repository.py @@ -1,7 +1,9 @@ -import pytest from datetime import datetime, timezone -from src.infrastructure.repositories.sql.user_repository import UserSQLRepository + +import pytest + from src.domain.entities.user import User +from src.infrastructure.repositories.sql.user_repository import UserSQLRepository @pytest.mark.usefixtures("sqlite_memory") diff --git a/tests/infrastructure/repositories/test_storage.py b/tests/infrastructure/repositories/test_storage.py index 09c47edd..6f0948ef 100644 --- a/tests/infrastructure/repositories/test_storage.py +++ b/tests/infrastructure/repositories/test_storage.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import MagicMock, patch + +import pytest + from src.infrastructure.repositories.storage.storage import StorageService diff --git a/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py b/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py index 027e5074..af2d92e1 100644 --- a/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 import sys from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -10,11 +11,13 @@ sys.modules["chromadb"] = mock_chromadb sys.modules["langchain_chroma"] = mock_langchain_chroma +from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 from src.infrastructure.repositories.vector.chroma.chunk_repository import ( # noqa: E402 ChunkChromaRepository, ) -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel # noqa: E402 -from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 +from src.infrastructure.repositories.vector.models.chunk_model import ( + ChunkModel, # noqa: E402 +) @pytest.mark.ChunkChromaRepository diff --git a/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py b/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py index bcfc2183..8c66d678 100644 --- a/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py +++ b/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py @@ -1,8 +1,10 @@ -import pytest +# ruff: noqa: E402 import sys from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest + # Mock dependencies mock_faiss_lib = MagicMock() sys.modules["faiss"] = mock_faiss_lib @@ -14,11 +16,13 @@ mock_bm25_lib = MagicMock() sys.modules["rank_bm25"] = mock_bm25_lib +from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 from src.infrastructure.repositories.vector.faiss.chunk_repository import ( # noqa: E402 ChunkFAISSRepository, ) -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel # noqa: E402 -from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 +from src.infrastructure.repositories.vector.models.chunk_model import ( + ChunkModel, # noqa: E402 +) @pytest.mark.ChunkFAISSRepository diff --git a/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py b/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py index ec055452..60589918 100644 --- a/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py @@ -1,8 +1,10 @@ -import pytest +# ruff: noqa: E402 import sys from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest + # Mock dependencies mock_faiss_lib = MagicMock() sys.modules["faiss"] = mock_faiss_lib @@ -10,11 +12,13 @@ sys.modules["langchain_community.vectorstores"] = MagicMock() sys.modules["langchain_community.vectorstores.faiss"] = mock_langchain_faiss +from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 from src.infrastructure.repositories.vector.faiss.chunk_repository import ( # noqa: E402 ChunkFAISSRepository, ) -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel # noqa: E402 -from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 +from src.infrastructure.repositories.vector.models.chunk_model import ( + ChunkModel, # noqa: E402 +) @pytest.mark.ChunkFAISSRepository diff --git a/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py b/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py index 12c728f3..929cd4bb 100644 --- a/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py @@ -1,15 +1,16 @@ -import pytest -import pydantic_core +from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4 -from datetime import datetime, timezone + +import pydantic_core +import pytest from qdrant_client.http import models as rest +from src.domain.entities.enums.search_mode_enum import SearchMode +from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel from src.infrastructure.repositories.vector.qdrant.chunk_repository import ( ChunkQdrantRepository, ) -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel -from src.domain.entities.enums.search_mode_enum import SearchMode @pytest.mark.ChunkQdrantRepository diff --git a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py index ea7bf4d6..1668851f 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py @@ -1,12 +1,14 @@ -import pytest +from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4 -from datetime import datetime, timezone + +import pytest + +from src.domain.entities.enums.search_mode_enum import SearchMode +from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel from src.infrastructure.repositories.vector.weaviate.chunk_repository import ( ChunkWeaviateRepository, ) -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel -from src.domain.entities.enums.search_mode_enum import SearchMode @pytest.mark.ChunkWeaviateRepository diff --git a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py index 89c8d075..b3c3d8d8 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py @@ -1,8 +1,9 @@ -import pytest import sys +from types import SimpleNamespace from unittest.mock import MagicMock, patch from uuid import uuid4 -from types import SimpleNamespace + +import pytest # Mock weaviate and its complex nested structure mock_weaviate = MagicMock() @@ -15,10 +16,10 @@ sys.modules["weaviate.classes"] = MagicMock() sys.modules["weaviate.classes.query"] = MagicMock() +from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 from src.infrastructure.repositories.vector.weaviate.chunk_repository import ( # noqa: E402 ChunkWeaviateRepository, ) -from src.domain.entities.enums.search_mode_enum import SearchMode # noqa: E402 @pytest.mark.ChunkRepository diff --git a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py index f94918c9..616eac86 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py @@ -2,6 +2,7 @@ from uuid import uuid4 import pytest + from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel from src.infrastructure.repositories.vector.weaviate.chunk_repository import ( ChunkWeaviateRepository, diff --git a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client.py b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client.py index e8742340..ef707e8f 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client.py @@ -1,6 +1,8 @@ -import pytest import sys from unittest.mock import MagicMock + +import pytest + from src.infrastructure.repositories.vector.weaviate.weaviate_client import ( WeaviateClient, ) diff --git a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client_extended.py b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client_extended.py index 2c90e752..91c1d86f 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client_extended.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_client_extended.py @@ -1,6 +1,8 @@ -import pytest import sys from unittest.mock import MagicMock + +import pytest + from src.infrastructure.repositories.vector.weaviate.weaviate_client import ( WeaviateClient, ) diff --git a/tests/infrastructure/services/test_auth_service.py b/tests/infrastructure/services/test_auth_service.py index 7d42cbae..2b6c6a9f 100644 --- a/tests/infrastructure/services/test_auth_service.py +++ b/tests/infrastructure/services/test_auth_service.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import MagicMock, patch -from src.infrastructure.services.auth_service import AuthService + +import pytest + from src.domain.entities.user import User +from src.infrastructure.services.auth_service import AuthService @pytest.fixture diff --git a/tests/infrastructure/services/test_chunk_index_service.py b/tests/infrastructure/services/test_chunk_index_service.py index 6c18ed04..b6fc4f81 100644 --- a/tests/infrastructure/services/test_chunk_index_service.py +++ b/tests/infrastructure/services/test_chunk_index_service.py @@ -1,11 +1,13 @@ -import pytest +from datetime import datetime, timezone from unittest.mock import MagicMock from uuid import uuid4 -from datetime import datetime, timezone -from src.infrastructure.services.chunk_index_service import ChunkIndexService + +import pytest + from src.domain.entities.chunk_entity import ChunkEntity from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.sql.models.chunk_index import ChunkIndexModel +from src.infrastructure.services.chunk_index_service import ChunkIndexService @pytest.mark.ChunkIndexService diff --git a/tests/infrastructure/services/test_chunk_vector_service.py b/tests/infrastructure/services/test_chunk_vector_service.py index 08169f71..6fc8b012 100644 --- a/tests/infrastructure/services/test_chunk_vector_service.py +++ b/tests/infrastructure/services/test_chunk_vector_service.py @@ -1,11 +1,13 @@ -import pytest from unittest.mock import MagicMock from uuid import uuid4 -from src.infrastructure.services.chunk_vector_service import ChunkVectorService + +import pytest + from src.domain.entities.chunk_entity import ChunkEntity -from src.domain.entities.enums.source_type_enum_entity import SourceType from src.domain.entities.enums.search_mode_enum import SearchMode +from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel +from src.infrastructure.services.chunk_vector_service import ChunkVectorService @pytest.mark.ChunkVectorService diff --git a/tests/infrastructure/services/test_content_source_service.py b/tests/infrastructure/services/test_content_source_service.py index 47041c0c..f6d8c3e8 100644 --- a/tests/infrastructure/services/test_content_source_service.py +++ b/tests/infrastructure/services/test_content_source_service.py @@ -1,11 +1,13 @@ -import pytest -from unittest.mock import MagicMock -from uuid import uuid4 from datetime import datetime, timezone from types import SimpleNamespace -from src.infrastructure.services.content_source_service import ContentSourceService +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + from src.domain.entities.enums.content_source_status_enum import ContentSourceStatus from src.domain.entities.enums.source_type_enum_entity import SourceType +from src.infrastructure.services.content_source_service import ContentSourceService @pytest.mark.Dependencies diff --git a/tests/infrastructure/services/test_embedding_service.py b/tests/infrastructure/services/test_embedding_service.py index 9b2d180f..fcc9d79f 100644 --- a/tests/infrastructure/services/test_embedding_service.py +++ b/tests/infrastructure/services/test_embedding_service.py @@ -1,6 +1,7 @@ -from src.infrastructure.services.embedding_service import EmbeddingService import pytest +from src.infrastructure.services.embedding_service import EmbeddingService + class DummyModel: def encode(self, t): diff --git a/tests/infrastructure/services/test_ingestion_job_service.py b/tests/infrastructure/services/test_ingestion_job_service.py index 1b889d78..57bc88fe 100644 --- a/tests/infrastructure/services/test_ingestion_job_service.py +++ b/tests/infrastructure/services/test_ingestion_job_service.py @@ -1,10 +1,12 @@ -import pytest +from datetime import datetime, timezone from unittest.mock import MagicMock from uuid import uuid4 -from datetime import datetime, timezone -from src.infrastructure.services.ingestion_job_service import IngestionJobService + +import pytest + from src.domain.entities.enums.ingestion_job_status_enum import IngestionJobStatus from src.infrastructure.repositories.sql.models.ingestion_job import IngestionJobModel +from src.infrastructure.services.ingestion_job_service import IngestionJobService @pytest.mark.IngestionJobService diff --git a/tests/infrastructure/services/test_knowledge_subject_service.py b/tests/infrastructure/services/test_knowledge_subject_service.py index 580a5075..dba0f8c8 100644 --- a/tests/infrastructure/services/test_knowledge_subject_service.py +++ b/tests/infrastructure/services/test_knowledge_subject_service.py @@ -1,8 +1,10 @@ -import pytest +from datetime import datetime, timezone +from types import SimpleNamespace from unittest.mock import MagicMock from uuid import uuid4 -from types import SimpleNamespace -from datetime import datetime, timezone + +import pytest + from src.infrastructure.services.knowledge_subject_service import ( KnowledgeSubjectService, ) diff --git a/tests/infrastructure/services/test_model_loader_service.py b/tests/infrastructure/services/test_model_loader_service.py index 9ba16b7b..2b959aff 100644 --- a/tests/infrastructure/services/test_model_loader_service.py +++ b/tests/infrastructure/services/test_model_loader_service.py @@ -1,5 +1,8 @@ -import pytest from unittest.mock import MagicMock, patch + +# ruff: noqa: E402 +import pytest + from src.infrastructure.services.model_loader_service import ModelLoaderService diff --git a/tests/infrastructure/services/test_pyannote_voice_recognizer.py b/tests/infrastructure/services/test_pyannote_voice_recognizer.py index e8ea3c7e..d90c1f75 100644 --- a/tests/infrastructure/services/test_pyannote_voice_recognizer.py +++ b/tests/infrastructure/services/test_pyannote_voice_recognizer.py @@ -1,6 +1,8 @@ from unittest.mock import MagicMock, patch + import numpy as np import pytest + from src.infrastructure.services.pyannote_voice_recognizer import VoiceRecognizer diff --git a/tests/infrastructure/services/test_re_rank_service.py b/tests/infrastructure/services/test_re_rank_service.py index ffa2f511..358ae3dc 100644 --- a/tests/infrastructure/services/test_re_rank_service.py +++ b/tests/infrastructure/services/test_re_rank_service.py @@ -1,14 +1,18 @@ -import pytest +# ruff: noqa: E402 import sys from unittest.mock import MagicMock, patch from uuid import uuid4 +import pytest + # Mock flashrank at module level since it might not be installed in all environments mock_flashrank = MagicMock() sys.modules["flashrank"] = mock_flashrank +from src.infrastructure.repositories.vector.models.chunk_model import ( + ChunkModel, # noqa: E402 +) from src.infrastructure.services.re_rank_service import ReRankService # noqa: E402 -from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel # noqa: E402 @pytest.mark.ReRankService @@ -23,7 +27,7 @@ def test_init_success(self, monkeypatch): assert service._ranker is not None # Verify that it was called with the correct model and a cache_dir containing flashrank_cache - args, kwargs = mock_ranker_class.call_args + _, kwargs = mock_ranker_class.call_args assert kwargs["model_name"] == "test-model" assert "flashrank_cache" in kwargs["cache_dir"] diff --git a/tests/infrastructure/services/test_redis_event_bus.py b/tests/infrastructure/services/test_redis_event_bus.py index 2c43005a..6f314e33 100644 --- a/tests/infrastructure/services/test_redis_event_bus.py +++ b/tests/infrastructure/services/test_redis_event_bus.py @@ -1,6 +1,8 @@ import json -import pytest from unittest.mock import MagicMock, patch + +import pytest + from src.infrastructure.services.redis_event_bus import RedisEventBus diff --git a/tests/infrastructure/services/test_redis_task_queue_service.py b/tests/infrastructure/services/test_redis_task_queue_service.py index 3614d9a4..f5a1a6ee 100644 --- a/tests/infrastructure/services/test_redis_task_queue_service.py +++ b/tests/infrastructure/services/test_redis_task_queue_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest + from src.infrastructure.services.redis_task_queue_service import ( RedisTaskQueueService, register_task, diff --git a/tests/infrastructure/services/test_task_queue_service.py b/tests/infrastructure/services/test_task_queue_service.py index c7090d81..f976a224 100644 --- a/tests/infrastructure/services/test_task_queue_service.py +++ b/tests/infrastructure/services/test_task_queue_service.py @@ -1,5 +1,7 @@ import time + import pytest + from src.infrastructure.services.task_queue_service import TaskQueueService diff --git a/tests/infrastructure/services/test_voice_profile_service.py b/tests/infrastructure/services/test_voice_profile_service.py index 2c2a0264..b606f724 100644 --- a/tests/infrastructure/services/test_voice_profile_service.py +++ b/tests/infrastructure/services/test_voice_profile_service.py @@ -1,6 +1,8 @@ -import pytest -import numpy as np from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + from src.infrastructure.services.voice_profile_service import VoiceDB diff --git a/tests/infrastructure/services/test_whisperx_audio_diarizer.py b/tests/infrastructure/services/test_whisperx_audio_diarizer.py index 0f4bd886..26dad877 100644 --- a/tests/infrastructure/services/test_whisperx_audio_diarizer.py +++ b/tests/infrastructure/services/test_whisperx_audio_diarizer.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock, patch + import numpy as np import pytest diff --git a/tests/infrastructure/services/test_youtube_audio_downloader.py b/tests/infrastructure/services/test_youtube_audio_downloader.py index 9467075f..72d97a4a 100644 --- a/tests/infrastructure/services/test_youtube_audio_downloader.py +++ b/tests/infrastructure/services/test_youtube_audio_downloader.py @@ -1,5 +1,7 @@ -import pytest from unittest.mock import patch + +import pytest + from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor diff --git a/tests/infrastructure/services/test_youtube_vector_service.py b/tests/infrastructure/services/test_youtube_vector_service.py index 1d5a0db5..89a5ae04 100644 --- a/tests/infrastructure/services/test_youtube_vector_service.py +++ b/tests/infrastructure/services/test_youtube_vector_service.py @@ -1,6 +1,7 @@ from uuid import uuid4 import pytest + from src.domain.entities.chunk_entity import ChunkEntity from src.domain.entities.enums.source_type_enum_entity import SourceType from src.infrastructure.repositories.vector.models.chunk_model import ChunkModel diff --git a/tests/infrastructure/utils/test_audio_utils.py b/tests/infrastructure/utils/test_audio_utils.py index 63f2d036..df82454c 100644 --- a/tests/infrastructure/utils/test_audio_utils.py +++ b/tests/infrastructure/utils/test_audio_utils.py @@ -1,12 +1,14 @@ +from unittest.mock import patch + import numpy as np import pytest import torch -from unittest.mock import patch + from src.infrastructure.utils.audio_utils import ( - load_audio_tensor, - load_whisperx_audio, cosine_similarity, get_best_device, + load_audio_tensor, + load_whisperx_audio, ) diff --git a/tests/presentation/api/middleware/test_trace_middleware.py b/tests/presentation/api/middleware/test_trace_middleware.py index 100e161a..ae19c73b 100644 --- a/tests/presentation/api/middleware/test_trace_middleware.py +++ b/tests/presentation/api/middleware/test_trace_middleware.py @@ -1,6 +1,7 @@ import anyio import pytest from starlette.responses import Response + from src.presentation.api.middleware.trace_middleware import TraceMiddleware diff --git a/tests/presentation/api/routes/test_audio_diarization_router.py b/tests/presentation/api/routes/test_audio_diarization_router.py index 13881399..0b80bc40 100644 --- a/tests/presentation/api/routes/test_audio_diarization_router.py +++ b/tests/presentation/api/routes/test_audio_diarization_router.py @@ -9,11 +9,11 @@ ) from src.presentation.api.dependencies import ( get_db, - get_task_queue_service, - get_identify_speakers_use_case, - get_retrieve_history_use_case, get_generate_speaker_url_use_case, + get_identify_speakers_use_case, get_list_s3_files_use_case, + get_retrieve_history_use_case, + get_task_queue_service, ) client = TestClient(app) diff --git a/tests/presentation/api/routes/test_auth_router.py b/tests/presentation/api/routes/test_auth_router.py index 882d95f7..32bd5d9d 100644 --- a/tests/presentation/api/routes/test_auth_router.py +++ b/tests/presentation/api/routes/test_auth_router.py @@ -1,9 +1,11 @@ +from unittest.mock import AsyncMock, MagicMock + import pytest -from unittest.mock import MagicMock, AsyncMock from fastapi.testclient import TestClient + from main import app -from src.presentation.api.dependencies import get_auth_use_case, get_current_user from src.domain.entities.user import User +from src.presentation.api.dependencies import get_auth_use_case, get_current_user client = TestClient(app) diff --git a/tests/presentation/api/routes/test_chunk_router.py b/tests/presentation/api/routes/test_chunk_router.py index 7da70eb0..e33fab3c 100644 --- a/tests/presentation/api/routes/test_chunk_router.py +++ b/tests/presentation/api/routes/test_chunk_router.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import MagicMock from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import ( get_chunk_index_service, diff --git a/tests/presentation/api/routes/test_ingest_router.py b/tests/presentation/api/routes/test_ingest_router.py index 9cb53bcb..c9a9af72 100644 --- a/tests/presentation/api/routes/test_ingest_router.py +++ b/tests/presentation/api/routes/test_ingest_router.py @@ -1,11 +1,12 @@ -import pytest from unittest.mock import MagicMock +from uuid import UUID + +import pytest from fastapi.testclient import TestClient -from main import app -from src.presentation.api.dependencies import get_ingest_youtube_use_case +from main import app from src.application.dtos.results.ingest_youtube_result import IngestYoutubeResult -from uuid import UUID +from src.presentation.api.dependencies import get_ingest_youtube_use_case client = TestClient(app) diff --git a/tests/presentation/api/routes/test_ingest_router_extended.py b/tests/presentation/api/routes/test_ingest_router_extended.py index 3a0ea5cb..7b057a50 100644 --- a/tests/presentation/api/routes/test_ingest_router_extended.py +++ b/tests/presentation/api/routes/test_ingest_router_extended.py @@ -1,9 +1,11 @@ -import pytest from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_task_queue_service -from uuid import uuid4 client = TestClient(app) diff --git a/tests/presentation/api/routes/test_ingest_router_file.py b/tests/presentation/api/routes/test_ingest_router_file.py index 87d051f6..3edb031d 100644 --- a/tests/presentation/api/routes/test_ingest_router_file.py +++ b/tests/presentation/api/routes/test_ingest_router_file.py @@ -1,9 +1,11 @@ +from unittest.mock import MagicMock +from uuid import uuid4 + import pytest from fastapi.testclient import TestClient -from unittest.mock import MagicMock + from main import app from src.presentation.api.dependencies import get_file_ingestion_use_case -from uuid import uuid4 @pytest.fixture diff --git a/tests/presentation/api/routes/test_job_router.py b/tests/presentation/api/routes/test_job_router.py index ce7d08d7..61bb6bfa 100644 --- a/tests/presentation/api/routes/test_job_router.py +++ b/tests/presentation/api/routes/test_job_router.py @@ -1,11 +1,13 @@ -import pytest +from datetime import datetime, timezone +from types import SimpleNamespace from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_job_service -from uuid import uuid4 -from datetime import datetime, timezone -from types import SimpleNamespace @pytest.fixture diff --git a/tests/presentation/api/routes/test_notification_router.py b/tests/presentation/api/routes/test_notification_router.py index 4cbc1c9d..5eecd665 100644 --- a/tests/presentation/api/routes/test_notification_router.py +++ b/tests/presentation/api/routes/test_notification_router.py @@ -1,7 +1,9 @@ import json -import pytest from unittest.mock import MagicMock + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_event_bus diff --git a/tests/presentation/api/routes/test_search_router.py b/tests/presentation/api/routes/test_search_router.py index bf46c845..d6f11429 100644 --- a/tests/presentation/api/routes/test_search_router.py +++ b/tests/presentation/api/routes/test_search_router.py @@ -1,6 +1,8 @@ -import pytest from unittest.mock import MagicMock + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_search_chunks_use_case diff --git a/tests/presentation/api/routes/test_settings_router.py b/tests/presentation/api/routes/test_settings_router.py index 174ce4a4..ff679fed 100644 --- a/tests/presentation/api/routes/test_settings_router.py +++ b/tests/presentation/api/routes/test_settings_router.py @@ -1,6 +1,8 @@ -import pytest from unittest.mock import MagicMock, patch + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_settings, get_vector_repository diff --git a/tests/presentation/api/routes/test_source_router.py b/tests/presentation/api/routes/test_source_router.py index 9058aa61..9e838dae 100644 --- a/tests/presentation/api/routes/test_source_router.py +++ b/tests/presentation/api/routes/test_source_router.py @@ -1,11 +1,14 @@ -import pytest +# ruff: noqa: E402 from unittest.mock import MagicMock + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import ( + get_content_source_use_case, get_cs_service, get_model_loader, - get_content_source_use_case, ) client = TestClient(app) diff --git a/tests/presentation/api/routes/test_subject_router.py b/tests/presentation/api/routes/test_subject_router.py index c03cf7e5..737c6d5e 100644 --- a/tests/presentation/api/routes/test_subject_router.py +++ b/tests/presentation/api/routes/test_subject_router.py @@ -1,6 +1,8 @@ -import pytest from unittest.mock import MagicMock + +import pytest from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import get_ks_service, get_ks_use_case diff --git a/tests/presentation/api/routes/test_voice_profile_router.py b/tests/presentation/api/routes/test_voice_profile_router.py index aa3af782..3ad871e6 100644 --- a/tests/presentation/api/routes/test_voice_profile_router.py +++ b/tests/presentation/api/routes/test_voice_profile_router.py @@ -1,13 +1,17 @@ +from unittest.mock import MagicMock, mock_open, patch + import pytest -from unittest.mock import MagicMock, patch, mock_open from fastapi.testclient import TestClient + from main import app from src.presentation.api.dependencies import ( get_db, - get_register_voice_profile_use_case, - get_train_voice_from_speaker_use_case, - get_list_voice_profiles_use_case, get_delete_voice_profile_use_case, + get_diarization_repo, + get_event_bus, + get_list_voice_profiles_use_case, + get_register_voice_profile_use_case, + get_task_queue_service, ) client = TestClient(app) @@ -32,21 +36,35 @@ def test_register_voice_profile_success(self): app.dependency_overrides.clear() def test_train_from_speaker_success(self): + # 1. Mock dependencies app.dependency_overrides[get_db] = lambda: MagicMock() - mock_use_case = MagicMock() - app.dependency_overrides[get_train_voice_from_speaker_use_case] = lambda: ( - mock_use_case - ) + mock_task_queue = MagicMock() + app.dependency_overrides[get_task_queue_service] = lambda: mock_task_queue + + mock_repo = MagicMock() + app.dependency_overrides[get_diarization_repo] = lambda: mock_repo + mock_repo.get_by_id.return_value = MagicMock(id="d-1") - mock_use_case.execute.return_value = "v-456" + mock_event_bus = MagicMock() + app.dependency_overrides[get_event_bus] = lambda: mock_event_bus + + # 2. Execute request payload = { "diarization_id": "d-1", "speaker_label": "SPEAKER_00", "name": "Bob", } response = client.post("/rest/voices/train-from-speaker", json=payload) - assert response.status_code == 200 - assert response.json()["voice_id"] == "v-456" + + # 3. Assert status and body + assert response.status_code == 202 + assert "Treinamento de voz iniciado" in response.json()["message"] + assert response.json()["name"] == "Bob" + + # 4. Verify queue was called + assert mock_task_queue.enqueue.called + + app.dependency_overrides.clear() app.dependency_overrides.clear() diff --git a/tests/presentation/api/test_dependencies.py b/tests/presentation/api/test_dependencies.py index 4349674b..b92d9524 100644 --- a/tests/presentation/api/test_dependencies.py +++ b/tests/presentation/api/test_dependencies.py @@ -1,26 +1,28 @@ -import pytest from unittest.mock import MagicMock, patch + +import pytest + +from src.config.settings import Settings +from src.domain.entities.enums.vector_store_type_enum import VectorStoreType from src.presentation.api.dependencies import ( - get_settings, + get_chunk_index_service, get_chunk_repo, - get_source_repo, + get_chunk_vector_service, + get_cs_service, + get_embedding_service, + get_ingest_youtube_use_case, get_job_repo, - get_subject_repo, + get_job_service, + get_ks_service, get_model_loader, - get_embedding_service, - get_weaviate_client, + get_search_chunks_use_case, + get_settings, + get_source_repo, + get_subject_repo, get_vector_repository, - get_ks_service, - get_cs_service, - get_job_service, - get_chunk_vector_service, - get_chunk_index_service, + get_weaviate_client, get_youtube_vector_service, - get_search_chunks_use_case, - get_ingest_youtube_use_case, ) -from src.config.settings import Settings -from src.domain.entities.enums.vector_store_type_enum import VectorStoreType @pytest.mark.Dependencies diff --git a/tests/test_import_all_modules.py b/tests/test_import_all_modules.py index d667c7f6..b8827bf6 100644 --- a/tests/test_import_all_modules.py +++ b/tests/test_import_all_modules.py @@ -1,11 +1,12 @@ +# ruff: noqa: E402 """ Import all modules under the `src` package by package name to ensure coverage can see files that aren't otherwise executed by unit tests. This test will fail if any module fails to import so import problems are visible in CI. """ -import os import importlib +import os import traceback import src