diff --git a/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py b/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py index 68b7e182..b37aa2f1 100644 --- a/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py +++ b/alembic/versions/04e0f5f5f0af_add_status_message_and_error_message_to_.py @@ -23,9 +23,7 @@ def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("chunk_index", schema=None) as batch_op: - batch_op.alter_column( - "id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False) batch_op.alter_column( "content_source_id", existing_type=sa.NUMERIC(), @@ -48,9 +46,7 @@ def upgrade() -> None: with op.batch_alter_table("content_sources", schema=None) as batch_op: batch_op.add_column(sa.Column("status_message", sa.Text(), nullable=True)) batch_op.add_column(sa.Column("error_message", sa.Text(), nullable=True)) - batch_op.alter_column( - "id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False) batch_op.alter_column( "subject_id", existing_type=sa.NUMERIC(), @@ -66,14 +62,10 @@ def upgrade() -> None: existing_nullable=True, ) batch_op.drop_index(batch_op.f("ix_diarizations_title")) - batch_op.create_index( - batch_op.f("ix_diarizations_name"), ["name"], unique=False - ) + batch_op.create_index(batch_op.f("ix_diarizations_name"), ["name"], unique=False) with op.batch_alter_table("ingestion_jobs", schema=None) as batch_op: - batch_op.alter_column( - "id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False) batch_op.alter_column( "content_source_id", existing_type=sa.NUMERIC(), @@ -88,9 +80,7 @@ def upgrade() -> None: ) with op.batch_alter_table("knowledge_subjects", schema=None) as batch_op: - batch_op.alter_column( - "id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.NUMERIC(), type_=sa.UUID(), existing_nullable=False) # ### end Alembic commands ### @@ -99,9 +89,7 @@ def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("knowledge_subjects", schema=None) as batch_op: - batch_op.alter_column( - "id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False) with op.batch_alter_table("ingestion_jobs", schema=None) as batch_op: batch_op.alter_column( @@ -116,15 +104,11 @@ def downgrade() -> None: type_=sa.NUMERIC(), existing_nullable=True, ) - batch_op.alter_column( - "id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False) with op.batch_alter_table("diarizations", schema=None) as batch_op: batch_op.drop_index(batch_op.f("ix_diarizations_name")) - batch_op.create_index( - batch_op.f("ix_diarizations_title"), ["name"], unique=False - ) + batch_op.create_index(batch_op.f("ix_diarizations_title"), ["name"], unique=False) batch_op.alter_column( "subject_id", existing_type=sa.UUID(), @@ -139,9 +123,7 @@ def downgrade() -> None: type_=sa.NUMERIC(), existing_nullable=True, ) - batch_op.alter_column( - "id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False) batch_op.drop_column("error_message") batch_op.drop_column("status_message") @@ -164,8 +146,6 @@ def downgrade() -> None: type_=sa.NUMERIC(), existing_nullable=False, ) - batch_op.alter_column( - "id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False - ) + batch_op.alter_column("id", existing_type=sa.UUID(), type_=sa.NUMERIC(), existing_nullable=False) # ### end Alembic commands ### diff --git a/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py b/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py index 9e70cd48..5c264aa1 100644 --- a/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py +++ b/alembic/versions/0ce7f69147eb_update_unique_constraint_on_content_.py @@ -23,9 +23,7 @@ def upgrade() -> None: """Upgrade schema.""" # Using batch mode for SQLite compatibility with op.batch_alter_table("content_sources", schema=None) as batch_op: - batch_op.drop_constraint( - op.f("uq_content_source_external_source"), type_="unique" - ) + batch_op.drop_constraint(op.f("uq_content_source_external_source"), type_="unique") batch_op.create_unique_constraint( "uq_content_source_external_source_per_subject", ["external_source", "subject_id"], diff --git a/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py b/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py index 08373507..93448de1 100644 --- a/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py +++ b/alembic/versions/4e8d4e04a288_add_external_source_to_ingestion_jobs.py @@ -21,9 +21,7 @@ def upgrade() -> None: """Upgrade schema.""" - op.add_column( - "ingestion_jobs", sa.Column("external_source", sa.Text(), nullable=True) - ) + op.add_column("ingestion_jobs", sa.Column("external_source", sa.Text(), nullable=True)) def downgrade() -> None: diff --git a/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py b/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py index 53527a7e..fdad3527 100644 --- a/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py +++ b/alembic/versions/50420d500c2e_add_token_columns_to_content_source.py @@ -22,9 +22,7 @@ def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "content_sources", sa.Column("total_tokens", sa.Integer(), nullable=True) - ) + op.add_column("content_sources", sa.Column("total_tokens", sa.Integer(), nullable=True)) op.add_column( "content_sources", sa.Column("max_tokens_per_chunk", sa.Integer(), nullable=True), diff --git a/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py b/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py index d77ae0d3..9649a199 100644 --- a/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py +++ b/alembic/versions/5736075a22d0_add_vector_store_type_to_job_and_chunk_.py @@ -28,17 +28,13 @@ def upgrade() -> None: columns_chunk = [c["name"] for c in insp.get_columns("chunk_index")] if "vector_store_type" not in columns_chunk: with op.batch_alter_table("chunk_index", schema=None) as batch_op: - batch_op.add_column( - sa.Column("vector_store_type", sa.Text(), nullable=True) - ) + batch_op.add_column(sa.Column("vector_store_type", sa.Text(), nullable=True)) # ingestion_jobs columns_jobs = [c["name"] for c in insp.get_columns("ingestion_jobs")] if "vector_store_type" not in columns_jobs: with op.batch_alter_table("ingestion_jobs", schema=None) as batch_op: - batch_op.add_column( - sa.Column("vector_store_type", sa.Text(), nullable=True) - ) + batch_op.add_column(sa.Column("vector_store_type", sa.Text(), nullable=True)) def downgrade() -> None: diff --git a/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py b/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py index 3929a5f7..fa17cab7 100644 --- a/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py +++ b/alembic/versions/5ff7984a3bcc_optimize_sql_models_indexes_and_audit.py @@ -33,12 +33,8 @@ def upgrade() -> None: ) ) batch_op.create_index("ix_chunk_index_chunk_id", ["chunk_id"], unique=False) - batch_op.create_index( - "ix_chunk_index_content_source_id", ["content_source_id"], unique=False - ) - batch_op.create_index( - op.f("ix_chunk_index_created_at"), ["created_at"], unique=False - ) + batch_op.create_index("ix_chunk_index_content_source_id", ["content_source_id"], unique=False) + batch_op.create_index(op.f("ix_chunk_index_created_at"), ["created_at"], unique=False) batch_op.create_index("ix_chunk_index_job_id", ["job_id"], unique=False) with op.batch_alter_table("content_sources", schema=None) as batch_op: @@ -50,24 +46,16 @@ def upgrade() -> None: nullable=False, ) ) - batch_op.create_index( - op.f("ix_content_sources_created_at"), ["created_at"], unique=False - ) + batch_op.create_index(op.f("ix_content_sources_created_at"), ["created_at"], unique=False) batch_op.create_index( op.f("ix_content_sources_processing_status"), ["processing_status"], unique=False, ) - batch_op.create_index( - "ix_content_sources_source_type", ["source_type"], unique=False - ) + batch_op.create_index("ix_content_sources_source_type", ["source_type"], unique=False) batch_op.create_index("ix_content_sources_status", ["status"], unique=False) - batch_op.create_index( - "ix_content_sources_subject_id", ["subject_id"], unique=False - ) - batch_op.create_unique_constraint( - "uq_content_source_external_source", ["external_source"] - ) + batch_op.create_index("ix_content_sources_subject_id", ["subject_id"], unique=False) + batch_op.create_unique_constraint("uq_content_source_external_source", ["external_source"]) with op.batch_alter_table("ingestion_jobs", schema=None) as batch_op: batch_op.add_column( @@ -78,12 +66,8 @@ def upgrade() -> None: nullable=False, ) ) - batch_op.create_index( - "ix_ingestion_jobs_content_source_id", ["content_source_id"], unique=False - ) - batch_op.create_index( - op.f("ix_ingestion_jobs_started_at"), ["started_at"], unique=False - ) + batch_op.create_index("ix_ingestion_jobs_content_source_id", ["content_source_id"], unique=False) + batch_op.create_index(op.f("ix_ingestion_jobs_started_at"), ["started_at"], unique=False) batch_op.create_index("ix_ingestion_jobs_status", ["status"], unique=False) with op.batch_alter_table("knowledge_subjects", schema=None) as batch_op: @@ -105,9 +89,7 @@ def downgrade() -> None: with op.batch_alter_table("ingestion_jobs", schema=None) as batch_op: batch_op.drop_index("ix_ingestion_jobs_status") batch_op.drop_index(op.f("ix_ingestion_jobs_started_at")) - batch_op.create_index( - "ix_ingestion_jobs_content_source_id", ["content_source_id"], unique=False - ) + batch_op.create_index("ix_ingestion_jobs_content_source_id", ["content_source_id"], unique=False) batch_op.drop_column("updated_at") with op.batch_alter_table("content_sources", schema=None) as batch_op: diff --git a/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py b/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py index a9066038..0bad4b6e 100644 --- a/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py +++ b/alembic/versions/6e53bc32edfe_add_subject_id_to_diarization.py @@ -24,12 +24,8 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table("diarizations", schema=None) as batch_op: batch_op.add_column(sa.Column("subject_id", sa.UUID(), nullable=True)) - batch_op.create_index( - batch_op.f("ix_diarizations_subject_id"), ["subject_id"], unique=False - ) - batch_op.create_foreign_key( - "fk_diarization_subject", "knowledge_subjects", ["subject_id"], ["id"] - ) + batch_op.create_index(batch_op.f("ix_diarizations_subject_id"), ["subject_id"], unique=False) + batch_op.create_foreign_key("fk_diarization_subject", "knowledge_subjects", ["subject_id"], ["id"]) # ### end Alembic commands ### diff --git a/alembic/versions/72f69987a221_rename_diarization_title_to_name.py b/alembic/versions/72f69987a221_rename_diarization_title_to_name.py index f01fa161..2a618666 100644 --- a/alembic/versions/72f69987a221_rename_diarization_title_to_name.py +++ b/alembic/versions/72f69987a221_rename_diarization_title_to_name.py @@ -22,9 +22,7 @@ def upgrade() -> None: """Upgrade schema.""" with op.batch_alter_table("diarizations", schema=None) as batch_op: - batch_op.alter_column( - "title", new_column_name="name", existing_type=sa.String() - ) + batch_op.alter_column("title", new_column_name="name", existing_type=sa.String()) def downgrade() -> None: diff --git a/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py b/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py index f57e532a..8bac18db 100644 --- a/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py +++ b/alembic/versions/73f13c5ff10a_add_metadata_to_chunk_index.py @@ -39,9 +39,7 @@ def upgrade() -> None: # Check for FK fks = insp.get_foreign_keys("chunk_index") has_fk = any( - fk["referred_table"] == "knowledge_subjects" - and "subject_id" in fk["constrained_columns"] - for fk in fks + fk["referred_table"] == "knowledge_subjects" and "subject_id" in fk["constrained_columns"] for fk in fks ) if not has_fk: batch_op.create_foreign_key( diff --git a/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py b/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py index d38b6bc9..0de9831b 100644 --- a/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py +++ b/alembic/versions/946d88fe08b1_add_source_metadata_to_content_source.py @@ -22,9 +22,7 @@ def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.add_column( - "content_sources", sa.Column("source_metadata", sa.JSON(), nullable=True) - ) + op.add_column("content_sources", sa.Column("source_metadata", sa.JSON(), nullable=True)) # ### end Alembic commands ### diff --git a/alembic/versions/b2c3d4e5f6a7_add_status_to_voices.py b/alembic/versions/b2c3d4e5f6a7_add_status_to_voices.py new file mode 100644 index 00000000..7a7259cb --- /dev/null +++ b/alembic/versions/b2c3d4e5f6a7_add_status_to_voices.py @@ -0,0 +1,36 @@ +"""add status and status_message to voices + +Revision ID: b2c3d4e5f6a7 +Revises: 04e0f5f5f0af +Create Date: 2026-04-07 15:40:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b2c3d4e5f6a7" +down_revision: Union[str, Sequence[str], None] = "04e0f5f5f0af" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + with op.batch_alter_table("voices", schema=None) as batch_op: + batch_op.add_column(sa.Column("status", sa.String(), nullable=True)) + batch_op.add_column(sa.Column("status_message", sa.String(), nullable=True)) + + # Backfill: any existing voice row is assumed ready. + op.execute("UPDATE voices SET status = 'ready' WHERE status IS NULL") + + +def downgrade() -> None: + """Downgrade schema.""" + with op.batch_alter_table("voices", schema=None) as batch_op: + batch_op.drop_column("status_message") + batch_op.drop_column("status") diff --git a/alembic/versions/bd01964d9b26_created_tables.py b/alembic/versions/bd01964d9b26_created_tables.py index 8b2190b6..db4e6342 100644 --- a/alembic/versions/bd01964d9b26_created_tables.py +++ b/alembic/versions/bd01964d9b26_created_tables.py @@ -49,9 +49,7 @@ def upgrade() -> None: sa.Column("language", sa.Text(), nullable=True), sa.Column("embedding_model", sa.Text(), nullable=True), sa.Column("dimensions", sa.Integer(), nullable=True), - sa.Column( - "status", sa.Text(), server_default=sa.text("'active'"), nullable=False - ), + sa.Column("status", sa.Text(), server_default=sa.text("'active'"), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), @@ -103,9 +101,7 @@ def upgrade() -> None: sa.Column("chunk_id", sa.Text(), nullable=False), sa.Column("chars", sa.Integer(), server_default=sa.text("0"), nullable=False), sa.Column("language", sa.Text(), nullable=True), - sa.Column( - "version_number", sa.Integer(), server_default=sa.text("1"), nullable=False - ), + sa.Column("version_number", sa.Integer(), server_default=sa.text("1"), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), @@ -118,9 +114,7 @@ def upgrade() -> None: initially="IMMEDIATE", deferrable=True, ), - sa.ForeignKeyConstraint( - ["job_id"], ["ingestion_jobs.id"], initially="IMMEDIATE", deferrable=True - ), + sa.ForeignKeyConstraint(["job_id"], ["ingestion_jobs.id"], initially="IMMEDIATE", deferrable=True), sa.PrimaryKeyConstraint("id"), if_not_exists=True, ) diff --git a/alembic/versions/c16fab000f02_add_user_table.py b/alembic/versions/c16fab000f02_add_user_table.py index 46bb1530..eea5a75d 100644 --- a/alembic/versions/c16fab000f02_add_user_table.py +++ b/alembic/versions/c16fab000f02_add_user_table.py @@ -34,9 +34,7 @@ def upgrade() -> None: if_not_exists=True, ) with op.batch_alter_table("users", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_users_email"), ["email"], unique=True, if_not_exists=True - ) + batch_op.create_index(batch_op.f("ix_users_email"), ["email"], unique=True, if_not_exists=True) # ### end Alembic commands ### diff --git a/alembic/versions/c48798b08031_add_voice_samples_table.py b/alembic/versions/c48798b08031_add_voice_samples_table.py index 57192030..c47cab78 100644 --- a/alembic/versions/c48798b08031_add_voice_samples_table.py +++ b/alembic/versions/c48798b08031_add_voice_samples_table.py @@ -67,6 +67,4 @@ def downgrade() -> None: sa.PrimaryKeyConstraint("id"), ) with op.batch_alter_table("voice_samples", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_voice_samples_voice_id"), ["voice_id"], unique=False - ) + batch_op.create_index(batch_op.f("ix_voice_samples_voice_id"), ["voice_id"], unique=False) diff --git a/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py b/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py index 887bc5df..50c11189 100644 --- a/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py +++ b/alembic/versions/f120b614600a_add_diarizations_and_voices_tables.py @@ -56,9 +56,7 @@ def upgrade() -> None: if_not_exists=True, ) with op.batch_alter_table("voices", schema=None) as batch_op: - batch_op.create_index( - batch_op.f("ix_voices_name"), ["name"], unique=True, if_not_exists=True - ) + batch_op.create_index(batch_op.f("ix_voices_name"), ["name"], unique=True, if_not_exists=True) def downgrade() -> None: diff --git a/frontend/src/components/DiarizationView.tsx b/frontend/src/components/DiarizationView.tsx index 28ec2227..55a274b3 100644 --- a/frontend/src/components/DiarizationView.tsx +++ b/frontend/src/components/DiarizationView.tsx @@ -96,7 +96,13 @@ export function DiarizationView() { loadJobs(true).then(updatedJobs => { const updated = updatedJobs.find(j => j.id === activeJob.id); if (updated && updated.status !== activeJob.status) { - if (updated.status === 'awaiting_verification' || updated.status === 'completed' || updated.status === 'failed') { + // If user is already in identification/result, don't reset the + // local speakers state on transient status changes (e.g. voice + // training flipping the job to TRAINING → COMPLETED). Only + // open/reset the job when we were still waiting on the initial + // diarization run. + const inIdentification = step === 'identification' || step === 'result'; + if (!inIdentification && (updated.status === 'awaiting_verification' || updated.status === 'completed' || updated.status === 'failed')) { handleOpenJob(updated); } else { setActiveJob(updated); @@ -111,7 +117,18 @@ export function DiarizationView() { } } } - }, [lastEvent, loadJobs, activeJob, addToast, t]); + + // Voice training finished in background → clear per-speaker processing flag + if (lastEvent.type === 'voice' && lastEvent.action === 'train' && lastEvent.name) { + const finishedName = lastEvent.name as string; + setSpeakers(prev => prev.map(s => + s.trainingStatus === 'processing' && s.assigned === finishedName + ? { ...s, trainingStatus: undefined, confidence: Math.max(s.confidence, 95) } + : s + )); + refreshVoices(); + } + }, [lastEvent, loadJobs, activeJob, addToast, t, step, refreshVoices]); // -- HANDLERS -- @@ -497,9 +514,19 @@ export function DiarizationView() { speaker={trainingSpeaker} diarizationId={activeJob?.id || ''} onClose={() => setTrainingSpeaker(null)} - onTrained={() => { - refreshVoices(); - handleRecognizeSpeakers(); + onTrained={(name) => { + // Mark this speaker as "processing" and optimistically assign the + // new name so the UI reflects the in-flight training job. The + // button will switch to "Reinforce" once the backend publishes + // the voice/train completion event. + const targetId = trainingSpeaker?.id; + if (targetId) { + setSpeakers(prev => prev.map(s => + s.id === targetId + ? { ...s, assigned: name, trainingStatus: 'processing' } + : s + )); + } }} /> diff --git a/frontend/src/components/VoiceProfilesView.tsx b/frontend/src/components/VoiceProfilesView.tsx index 6d8af68d..ab734e0d 100644 --- a/frontend/src/components/VoiceProfilesView.tsx +++ b/frontend/src/components/VoiceProfilesView.tsx @@ -14,7 +14,10 @@ import { Plus, Search, Play, - Square + Square, + Loader2, + CheckCircle2, + XCircle } from 'lucide-react'; import { motion, AnimatePresence } from 'motion/react'; import { useAppContext } from '../store/AppContext'; @@ -29,14 +32,50 @@ interface AudioFile { path: string; } +type VoiceProfileStatus = 'ready' | 'processing' | 'failed'; + interface VoiceProfile { id: string; name: string; audios_path: string; samples_count: number; created_at: string; + status?: VoiceProfileStatus; + status_message?: string | null; } +const StatusBadge: React.FC<{ status?: VoiceProfileStatus; message?: string | null }> = ({ status, message }) => { + const { t } = useTranslation(); + const s = status || 'ready'; + const config: Record = { + processing: { + icon: , + cls: 'bg-amber-500/10 border-amber-500/30 text-amber-400', + label: t('voices.status.processing'), + }, + ready: { + icon: , + cls: 'bg-emerald-500/10 border-emerald-500/30 text-emerald-400', + label: t('voices.status.ready'), + }, + failed: { + icon: , + cls: 'bg-rose-500/10 border-rose-500/30 text-rose-400', + label: t('voices.status.failed'), + }, + }; + const c = config[s]; + return ( +
+ {c.icon} + {c.label} +
+ ); +}; + // --- Sub-components --- const VoiceCard = ({ @@ -72,10 +111,12 @@ const VoiceCard = ({ -

+

{profile.name}

+ +
@@ -287,7 +328,7 @@ const VoiceDetailsModal = ({ export function VoiceProfilesView() { const { t } = useTranslation(); - const { addToast, selectedSubjects } = useAppContext(); + const { addToast, selectedSubjects, lastEvent } = useAppContext(); const [voices, setVoices] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); @@ -313,6 +354,14 @@ export function VoiceProfilesView() { fetchVoices(); }, [fetchVoices]); + // Refetch whenever a voice-scoped SSE event arrives (train_started, + // train_progress, train, train_failed) so the status badge updates live. + useEffect(() => { + if (lastEvent && (lastEvent as any).type === 'voice') { + fetchVoices(); + } + }, [lastEvent, fetchVoices]); + const handleDeleteProfile = useCallback(async (name: string) => { try { await api.deleteVoiceProfile(name); diff --git a/frontend/src/components/diarization/SpeakerIdentificationPanel.tsx b/frontend/src/components/diarization/SpeakerIdentificationPanel.tsx index ef53cf39..2044d329 100644 --- a/frontend/src/components/diarization/SpeakerIdentificationPanel.tsx +++ b/frontend/src/components/diarization/SpeakerIdentificationPanel.tsx @@ -216,22 +216,36 @@ const SpeakerCard: React.FC<{ />
- + {(() => { + const isProcessing = speaker.trainingStatus === 'processing'; + const isUntrained = speaker.assigned.toUpperCase().startsWith('SPEAKER_') || speaker.assigned.toUpperCase() === 'UNKNOWN'; + let cls: string; + let label: string; + let icon: React.ReactNode; + if (isProcessing) { + cls = 'bg-amber-500/10 border-amber-500/20 text-amber-400 cursor-wait'; + label = t('diarization.identification.processing_voice'); + icon = ; + } else if (isUntrained) { + cls = 'bg-emerald-500/10 hover:bg-emerald-500 border-emerald-500/20 hover:border-emerald-500 text-emerald-400 hover:text-black shadow-emerald-500/10'; + label = t('diarization.identification.train_voice'); + icon = ; + } else { + cls = 'bg-blue-500/10 hover:bg-blue-500 border-blue-500/20 hover:border-blue-500 text-blue-400 hover:text-white shadow-blue-500/10'; + label = t('diarization.identification.reinforce_voice'); + icon = ; + } + return ( + + ); + })()} {speaker.confidence > 0 && (
diff --git a/frontend/src/components/diarization/VoiceTrainingModal.tsx b/frontend/src/components/diarization/VoiceTrainingModal.tsx index 6ffb0db8..b85e0775 100644 --- a/frontend/src/components/diarization/VoiceTrainingModal.tsx +++ b/frontend/src/components/diarization/VoiceTrainingModal.tsx @@ -11,7 +11,7 @@ interface VoiceTrainingModalProps { readonly speaker: Speaker | null; readonly diarizationId: string; readonly onClose: () => void; - readonly onTrained: () => void; + readonly onTrained: (name: string) => void; } export const VoiceTrainingModal: React.FC = ({ isOpen, speaker, diarizationId, onClose, onTrained }) => { @@ -75,7 +75,7 @@ export const VoiceTrainingModal: React.FC = ({ isOpen, : t('diarization.notifications.train_success', { name }); addToast(successMsg, 'success'); - onTrained(); + onTrained(name); onClose(); } catch (err: any) { console.error('Failed to train voice:', err); diff --git a/frontend/src/components/diarization/types.ts b/frontend/src/components/diarization/types.ts index 0eecc767..83f0c5cc 100644 --- a/frontend/src/components/diarization/types.ts +++ b/frontend/src/components/diarization/types.ts @@ -8,6 +8,7 @@ export interface Speaker { isPlaying: boolean; confidence: number; audioUrl?: string; + trainingStatus?: 'processing'; } export interface DiarizationJob { diff --git a/frontend/src/locales/en.json b/frontend/src/locales/en.json index d62fcdec..78787084 100644 --- a/frontend/src/locales/en.json +++ b/frontend/src/locales/en.json @@ -607,6 +607,7 @@ "new_name": "New Name...", "train_voice": "Train this voice", "reinforce_voice": "Reinforce this voice", + "processing_voice": "Training voice...", "match_pct": "{{count}}% match", "save_btn": "Save & Index" }, @@ -712,7 +713,10 @@ "delete_file_error": "Failed to delete audio sample" }, "status": { - "samples": "{{count}} audio files" + "samples": "{{count}} audio files", + "processing": "Training", + "ready": "Ready", + "failed": "Failed" }, "card": { "created_at": "Created" diff --git a/frontend/src/locales/pt-BR.json b/frontend/src/locales/pt-BR.json index e4274553..98317080 100644 --- a/frontend/src/locales/pt-BR.json +++ b/frontend/src/locales/pt-BR.json @@ -606,6 +606,7 @@ "new_name": "Novo Nome...", "train_voice": "Treinar esta voz", "reinforce_voice": "Reforçar esta voz", + "processing_voice": "Treinando voz...", "match_pct": "{{count}}% de match", "save_btn": "Salvar e Indexar" }, @@ -711,7 +712,10 @@ "delete_file_error": "Falha ao excluir amostra de áudio" }, "status": { - "samples": "{{count}} arquivos de áudio" + "samples": "{{count}} arquivos de áudio", + "processing": "Treinando", + "ready": "Pronto", + "failed": "Falhou" }, "card": { "created_at": "Criado em" diff --git a/main.py b/main.py index 6ae2b529..a212feed 100644 --- a/main.py +++ b/main.py @@ -73,9 +73,7 @@ async def lifespan(app: FastAPI): }, ) - app.state.model_loader = ModelLoaderService( - model_name=_settings.model_embedding.name - ) + app.state.model_loader = ModelLoaderService(model_name=_settings.model_embedding.name) logger.info("Embedding model pre-loaded successfully.") # Load Re-rank Model @@ -102,9 +100,7 @@ async def lifespan(app: FastAPI): register_task("run_youtube_ingestion_worker", run_youtube_ingestion_worker) register_task("run_web_ingestion_worker", run_web_ingestion_worker) register_task("run_audio_diarization_worker", run_audio_diarization_worker) - register_task( - "run_diarization_ingestion_worker", run_diarization_ingestion_worker - ) + register_task("run_diarization_ingestion_worker", run_diarization_ingestion_worker) register_task("run_youtube_dispatcher_worker", run_youtube_dispatcher_worker) register_task( "run_audio_diarization_dispatcher_worker", @@ -181,9 +177,7 @@ async def lifespan(app: FastAPI): tags=["Sources"], dependencies=secured_deps, ) -app.include_router( - job_router.router, prefix="/rest/jobs", tags=["Jobs"], dependencies=secured_deps -) +app.include_router(job_router.router, prefix="/rest/jobs", tags=["Jobs"], dependencies=secured_deps) app.include_router( settings_router.router, prefix="/rest/settings", diff --git a/src/application/use_cases/auth_use_case.py b/src/application/use_cases/auth_use_case.py index 0f391e79..c7fadee8 100644 --- a/src/application/use_cases/auth_use_case.py +++ b/src/application/use_cases/auth_use_case.py @@ -26,9 +26,7 @@ def get_login_url(self) -> Tuple[str, str]: url = self._auth_service.get_google_auth_url(state=state) return url, state - async def handle_google_callback( - self, code: str, received_state: str, expected_state: str - ) -> Dict[str, Any]: + async def handle_google_callback(self, code: str, received_state: str, expected_state: str) -> Dict[str, Any]: # 0. Validate state (only if expected_state was provided) if expected_state and received_state != expected_state: raise InvalidStateError("Invalid authentication state (CSRF Protection)") diff --git a/src/application/use_cases/delete_diarization_use_case.py b/src/application/use_cases/delete_diarization_use_case.py index eeba2592..ed0410e0 100644 --- a/src/application/use_cases/delete_diarization_use_case.py +++ b/src/application/use_cases/delete_diarization_use_case.py @@ -72,9 +72,7 @@ def execute(self, diarization_id: str) -> bool: else: os.remove(record.folder_path) except Exception as e: - logger.error( - "Failed to delete local folder %s: %s", record.folder_path, str(e) - ) + logger.error("Failed to delete local folder %s: %s", record.folder_path, str(e)) # 3. Delete from Database logger.info("Deleting database record: %s", diarization_id) diff --git a/src/application/use_cases/diarization_ingestion_use_case.py b/src/application/use_cases/diarization_ingestion_use_case.py index 41173854..7fe26127 100644 --- a/src/application/use_cases/diarization_ingestion_use_case.py +++ b/src/application/use_cases/diarization_ingestion_use_case.py @@ -93,9 +93,7 @@ def execute(self, cmd: IngestDiarizationCommand) -> Dict[str, Any]: ingestion = self.ingestion_service.get_by_id(cmd.ingestion_job_id) if ingestion is None: - ingestion = self._create_ingestion_job( - external_source, source_type, subject.id - ) + ingestion = self._create_ingestion_job(external_source, source_type, subject.id) self.ingestion_service.update_job( job_id=ingestion.id, @@ -105,16 +103,12 @@ def execute(self, cmd: IngestDiarizationCommand) -> Dict[str, Any]: total_steps=4, ) - full_text = self._format_transcript( - cast(list, record.segments), cast(dict, record.recognition_results) - ) + full_text = self._format_transcript(cast(list, record.segments), cast(dict, record.recognition_results)) if not full_text: raise ValueError("No segments found in diarization record") display_name = cmd.name or cast(str, record.name) or "Transcrição" - source = self._get_or_create_source( - source_type, external_source, subject.id, display_name, cmd, record - ) + source = self._get_or_create_source(source_type, external_source, subject.id, display_name, cmd, record) # Generate chunks and Embeddings self.ingestion_service.update_job( @@ -126,14 +120,10 @@ def execute(self, cmd: IngestDiarizationCommand) -> Dict[str, Any]: content_source_id=source.id, ) - split_docs = self._generate_split_docs( - full_text, display_name, external_source, source_type, cmd, record - ) + split_docs = self._generate_split_docs(full_text, display_name, external_source, source_type, cmd, record) # Persist Chunks - chunks = self._build_chunk_entities( - split_docs, source, subject, cmd, ingestion.id - ) + chunks = self._build_chunk_entities(split_docs, source, subject, cmd, ingestion.id) self.chunk_service.create_chunks(chunks) # Index @@ -215,9 +205,7 @@ def _resolve_source_info(self, record: Any) -> tuple[SourceType, str]: return source_type, external_source - def _create_ingestion_job( - self, external_source: str, source_type: SourceType, subject_id: UUID - ) -> Any: + def _create_ingestion_job(self, external_source: str, source_type: SourceType, subject_id: UUID) -> Any: return self.ingestion_service.create_job( content_source_id=None, status=IngestionJobStatus.STARTED, @@ -255,18 +243,14 @@ def _get_or_create_source( source_metadata=cast(dict, record.source_metadata), ) else: - self.cs_service.update_processing_status( - source.id, ContentSourceStatus.PROCESSING - ) + self.cs_service.update_processing_status(source.id, ContentSourceStatus.PROCESSING) # Update title if it has changed if cmd.name and source.title != cmd.name: self.cs_service.update_title(source.id, cmd.name) if cmd.reprocess: self.chunk_service.delete_by_content_source(source.id) - self.vector_service.delete( - filters={"content_source_id": str(source.id)} - ) + self.vector_service.delete(filters={"content_source_id": str(source.id)}) return source def _generate_split_docs( @@ -344,9 +328,7 @@ def _format_seconds(self, seconds: float) -> str: return f"{h:02d}:{m:02d}:{s:02d}" return f"{m:02d}:{s:02d}" - def _format_transcript( - self, segments: List[Dict[str, Any]], recognition: Optional[Dict[str, Any]] - ) -> str: + def _format_transcript(self, segments: List[Dict[str, Any]], recognition: Optional[Dict[str, Any]]) -> str: if not segments: return "" diff --git a/src/application/use_cases/file_ingestion_use_case.py b/src/application/use_cases/file_ingestion_use_case.py index 45dfa42f..fa49fc54 100644 --- a/src/application/use_cases/file_ingestion_use_case.py +++ b/src/application/use_cases/file_ingestion_use_case.py @@ -71,9 +71,7 @@ def execute(self, cmd: IngestFileCommand) -> Dict[str, Any]: external_source = cmd.external_source or cmd.file_name ingestion = self._get_or_create_job(cmd, source_type, external_source) - self._notify_status( - cmd, "processing", job_id=ingestion.id, step="extracting" - ) + self._notify_status(cmd, "processing", job_id=ingestion.id, step="extracting") # 1. Extraction source_path = cmd.file_url or cmd.file_path @@ -84,9 +82,7 @@ def execute(self, cmd: IngestFileCommand) -> Dict[str, Any]: source_type = self._refine_source_type(docs, source_type) # 2. Source Management - source = self._get_or_create_source( - subject, source_type, external_source, docs[0].metadata, cmd - ) + source = self._get_or_create_source(subject, source_type, external_source, docs[0].metadata, cmd) if cmd.reprocess: self._handle_reprocessing(source, ingestion) @@ -159,9 +155,7 @@ def _extract_docs(self, source_path: str, cmd: IngestFileCommand) -> List[Docume raise ValueError(f"No content extracted from {cmd.file_name}") return docs - def _refine_source_type( - self, docs: List[Document], current: SourceType - ) -> SourceType: + def _refine_source_type(self, docs: List[Document], current: SourceType) -> SourceType: docling_detected = docs[0].metadata.get("docling_source_type") extracted_ext = docs[0].metadata.get("source_type") @@ -172,19 +166,14 @@ def _refine_source_type( refined = SourceType(ext.lower()) if refined == SourceType.OTHER: continue - if ( - current in [SourceType.YOUTUBE, SourceType.WEB] - and refined == SourceType.TXT - ): + if current in [SourceType.YOUTUBE, SourceType.WEB] and refined == SourceType.TXT: continue return refined except ValueError: continue return current - def _get_or_create_job( - self, cmd: IngestFileCommand, source_type: SourceType, external_source: str - ) -> Any: + def _get_or_create_job(self, cmd: IngestFileCommand, source_type: SourceType, external_source: str) -> Any: if cmd.ingestion_job_id: job = self.ingestion_service.get_by_id(cmd.ingestion_job_id) if job: @@ -208,9 +197,7 @@ def _get_or_create_source( metadata: dict, cmd: IngestFileCommand, ) -> Any: - source = self.cs_service.get_by_source_info( - source_type, external_source, cmd.subject_id - ) + source = self.cs_service.get_by_source_info(source_type, external_source, cmd.subject_id) final_meta = {**metadata, **(cmd.source_metadata or {})} if not source: @@ -224,9 +211,7 @@ def _get_or_create_source( source_metadata=final_meta, ) - self.cs_service.update_processing_status( - source.id, ContentSourceStatus.PROCESSING - ) + self.cs_service.update_processing_status(source.id, ContentSourceStatus.PROCESSING) return source def _process_chunks( @@ -246,9 +231,7 @@ def _process_chunks( if tokenizer: splitter = TextSplitterService(tokenizer=tokenizer) - split_docs = splitter.split_text( - full_text, cmd.tokens_per_chunk, cmd.tokens_overlap, docs[0].metadata - ) + split_docs = splitter.split_text(full_text, cmd.tokens_per_chunk, cmd.tokens_overlap, docs[0].metadata) else: from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -256,9 +239,7 @@ def _process_chunks( chunk_size=cmd.tokens_per_chunk * 4, chunk_overlap=cmd.tokens_overlap * 4, ) - split_docs = ls.split_documents( - [Document(page_content=full_text, metadata=docs[0].metadata)] - ) + split_docs = ls.split_documents([Document(page_content=full_text, metadata=docs[0].metadata)]) chunks = self._build_chunk_entities(split_docs, source, subject, cmd, job_id) self.chunk_service.create_chunks(chunks) @@ -301,25 +282,13 @@ def _handle_error(self, e: Exception, ingestion: Any, source: Any): logger.error(e, context={"action": "file_ingestion_execute"}) if ingestion: msg = str(e).lower() - status = ( - IngestionJobStatus.CANCELLED - if ("404" in msg or "not found" in msg) - else IngestionJobStatus.FAILED - ) - self.ingestion_service.update_job( - ingestion.id, status, error_message=str(e) - ) + status = IngestionJobStatus.CANCELLED if ("404" in msg or "not found" in msg) else IngestionJobStatus.FAILED + self.ingestion_service.update_job(ingestion.id, status, error_message=str(e)) if source: - self.cs_service.update_processing_status( - source.id, ContentSourceStatus.FAILED - ) + self.cs_service.update_processing_status(source.id, ContentSourceStatus.FAILED) def _cleanup(self, cmd: IngestFileCommand): - if ( - cmd.delete_after_ingestion - and cmd.file_path - and os.path.exists(cmd.file_path) - ): + if cmd.delete_after_ingestion and cmd.file_path and os.path.exists(cmd.file_path): parent = os.path.dirname(cmd.file_path) if any(t in parent.lower() for t in ["tmp", "temp"]): shutil.rmtree(parent, ignore_errors=True) @@ -345,9 +314,7 @@ def _determine_source_type_refined(self, cmd: IngestFileCommand) -> SourceType: pass ext = cmd.file_name.split(".")[-1].lower() if "." in cmd.file_name else "" - if cmd.external_source and any( - d in cmd.external_source for d in ["youtube.com", "youtu.be"] - ): + if cmd.external_source and any(d in cmd.external_source for d in ["youtube.com", "youtu.be"]): return SourceType.YOUTUBE mapping = { @@ -383,9 +350,7 @@ def _build_chunk_entities( t_count = 0 if tokenizer: try: - t_count = len( - tokenizer.encode(doc.page_content, add_special_tokens=False) - ) + t_count = len(tokenizer.encode(doc.page_content, add_special_tokens=False)) except Exception: t_count = len(doc.page_content) // 4 else: diff --git a/src/application/use_cases/identify_speakers_in_processed_audio.py b/src/application/use_cases/identify_speakers_in_processed_audio.py index 251f8dfc..b5b9fef2 100644 --- a/src/application/use_cases/identify_speakers_in_processed_audio.py +++ b/src/application/use_cases/identify_speakers_in_processed_audio.py @@ -39,9 +39,7 @@ def execute(self, diarization_id: str) -> dict: if not s3_prefix: raise ValueError("No storage path found for this diarization.") - local_dir = os.path.join( - audio_cfg.temp_download_dir, f"recognize_{diarization_id}" - ) + local_dir = os.path.join(audio_cfg.temp_download_dir, f"recognize_{diarization_id}") os.makedirs(local_dir, exist_ok=True) try: @@ -64,15 +62,11 @@ def execute(self, diarization_id: str) -> dict: best_match, spk, ) - _, s3_path = voice_db.add( - name=best_match, audio_path=match.audio_path - ) + _, s3_path = voice_db.add(name=best_match, audio_path=match.audio_path) if s3_path: reinforced_paths[spk] = s3_path except Exception as e: - logger.error( - "Failed to reinforce voice profile '%s': %s", best_match, e - ) + logger.error("Failed to reinforce voice profile '%s': %s", best_match, e) recognition_data: dict[str, object] = { "mapping": mapping, diff --git a/src/application/use_cases/manage_voice_profiles.py b/src/application/use_cases/manage_voice_profiles.py index 08bde59b..1d65df68 100644 --- a/src/application/use_cases/manage_voice_profiles.py +++ b/src/application/use_cases/manage_voice_profiles.py @@ -5,7 +5,6 @@ from sqlalchemy.orm import Session from src.config.settings import settings -from src.domain.entities.enums.diarization_status_enum import DiarizationStatus from src.domain.interfaces.services.i_event_bus import IEventBus from src.infrastructure.repositories.sql.diarization_repository import ( DiarizationRepository, @@ -43,9 +42,7 @@ def execute(self) -> list[dict]: samples_count = 0 if r.audios_path: with suppress(Exception): - files = storage.list_files( - prefix=cast(str, r.audios_path), extension=".wav" - ) + files = storage.list_files(prefix=cast(str, r.audios_path), extension=".wav") samples_count = len(files) result.append( @@ -55,6 +52,8 @@ def execute(self) -> list[dict]: "audios_path": r.audios_path, "created_at": r.created_at.isoformat() if r.created_at else None, "samples_count": samples_count, + "status": r.status or "ready", + "status_message": r.status_message, } ) return result @@ -110,29 +109,25 @@ def execute( if not record.storage_path: raise ValueError("No storage path found for this diarization.") - # Ensure we update status if we have an event bus (background context) + # Voice training is orthogonal to the diarization lifecycle — do NOT + # mutate the diarization record's status here. Only emit voice-scoped + # events so the UI can track progress per-speaker. if self.event_bus: - self.repo.update_status( - diarization_id, - DiarizationStatus.TRAINING.value, - status_message=f"Treinando voz: {name}", - ) self.event_bus.publish( "ingestion_status", { - "type": "diarization", - "id": diarization_id, - "status": DiarizationStatus.TRAINING.value, - "message": f"Treinando perfil de voz '{name}'...", + "type": "voice", + "action": "train_progress", + "name": name, + "diarization_id": diarization_id, + "speaker_label": speaker_label, }, ) s3_key = f"{record.storage_path}/{speaker_label}.wav" audio_cfg = settings.audio - local_path = os.path.join( - audio_cfg.temp_download_dir, f"train_{diarization_id}_{speaker_label}.wav" - ) + local_path = os.path.join(audio_cfg.temp_download_dir, f"train_{diarization_id}_{speaker_label}.wav") os.makedirs(audio_cfg.temp_download_dir, exist_ok=True) try: @@ -143,43 +138,30 @@ def execute( voice_id, _ = voice_db.add(name=name, audio_path=local_path) if self.event_bus: - self.repo.update_status( - diarization_id, - DiarizationStatus.COMPLETED.value, - status_message=f"Voz '{name}' treinada com sucesso", - ) self.event_bus.publish( "ingestion_status", { - "type": "diarization", - "id": diarization_id, - "status": DiarizationStatus.COMPLETED.value, - "message": f"Perfil de voz '{name}' registrado!", + "type": "voice", + "action": "train", + "name": name, + "diarization_id": diarization_id, + "speaker_label": speaker_label, }, ) - # Also notify specifically about the voice - self.event_bus.publish( - "ingestion_status", - {"type": "voice", "action": "train", "name": name}, - ) return voice_id except Exception as e: if self.event_bus: - self.repo.update_status( - diarization_id, - DiarizationStatus.FAILED.value, - error_message=str(e), - status_message="Falha no treinamento de voz", - ) self.event_bus.publish( "ingestion_status", { - "type": "diarization", - "id": diarization_id, - "status": DiarizationStatus.FAILED.value, - "message": f"Erro no treinamento: {str(e)}", + "type": "voice", + "action": "train_failed", + "name": name, + "diarization_id": diarization_id, + "speaker_label": speaker_label, + "error": str(e), }, ) raise diff --git a/src/application/use_cases/process_audio_diarization_pipeline.py b/src/application/use_cases/process_audio_diarization_pipeline.py index 2cb00a9a..7c498797 100644 --- a/src/application/use_cases/process_audio_diarization_pipeline.py +++ b/src/application/use_cases/process_audio_diarization_pipeline.py @@ -44,13 +44,9 @@ def __init__( self.storage = StorageService() logger.info("Storage connection established, bucket=%s", self.storage.bucket) - def _notify( - self, diarization_id: str | None, status: str, message: str | None = None - ): + def _notify(self, diarization_id: str | None, status: str, message: str | None = None): if diarization_id: - self.repo.update_status( - diarization_id, status, status_message=message or "" - ) + self.repo.update_status(diarization_id, status, status_message=message or "") if self.event_bus and diarization_id: self.event_bus.publish( "ingestion_status", @@ -117,14 +113,10 @@ def execute( # 2. Prepare folders and workspace display_name = self._resolve_display_name(audio_raw_path, yt_metadata) recognition_folder = os.path.join(video_folder, "recognition") - audio_path = self._prepare_local_workspace( - audio_raw_path, video_folder, process_id - ) + audio_path = self._prepare_local_workspace(audio_raw_path, video_folder, process_id) # 3. Diarization - diarizer = AudioDiarizer( - hf_token=hf_token, model_size=model_size or "large-v2" - ) + diarizer = AudioDiarizer(hf_token=hf_token, model_size=model_size or "large-v2") diarization_result = self._run_diarization( diarizer, audio_path, @@ -154,9 +146,7 @@ def execute( recognition_data = {} if recognize_voices: - recognition_data = self._identify_voices( - recognition_folder, hf_token, diarization_id - ) + recognition_data = self._identify_voices(recognition_folder, hf_token, diarization_id) if recognition_data: db_record.recognition_results = cast(Any, recognition_data) self.db.commit() @@ -181,9 +171,7 @@ def execute( finally: self._cleanup_local_files(video_folder, audio_path) - def _resolve_display_name( - self, audio_raw_path: str, yt_metadata: Optional[Any] - ) -> str: + def _resolve_display_name(self, audio_raw_path: str, yt_metadata: Optional[Any]) -> str: original_title = Path(audio_raw_path).stem if yt_metadata and hasattr(yt_metadata, "title") and yt_metadata.title: return yt_metadata.title @@ -247,9 +235,7 @@ def _finalize_pipeline( ) from src.domain.entities.enums.source_type_enum_entity import SourceType - cs_source_type = ( - SourceType.YOUTUBE if source_type == "youtube" else SourceType.AUDIO - ) + cs_source_type = SourceType.YOUTUBE if source_type == "youtube" else SourceType.AUDIO subject_id = getattr(db_record, "subject_id", None) # Check if source already exists to avoid duplication @@ -276,9 +262,7 @@ def _finalize_pipeline( status_message="Aguardando revisão dos falantes", ) # Merge metadata - self.cs_service.update_metadata( - content_source_id=existing_source.id, metadata=source_metadata - ) + self.cs_service.update_metadata(content_source_id=existing_source.id, metadata=source_metadata) else: self.cs_service.create_source( subject_id=subject_id, @@ -318,14 +302,10 @@ def _resolve_audio_source( if not video_id: raise ValueError(f"Invalid YouTube source: {source}") - yt_extractor = YoutubeExtractor( - video_id=video_id, language=language or "pt" - ) + yt_extractor = YoutubeExtractor(video_id=video_id, language=language or "pt") # Use the full URL for downloading to be safe download_url = f"https://www.youtube.com/watch?v={video_id}" - audio_path = yt_extractor.download_audio( - download_url, output_dir=settings.audio.temp_download_dir - ) + audio_path = yt_extractor.download_audio(download_url, output_dir=settings.audio.temp_download_dir) if not audio_path: raise RuntimeError("YouTube download failed") @@ -339,23 +319,17 @@ def _resolve_audio_source( if source_type == "upload": s3_key = unquote(source.replace(f"s3://{self.storage.bucket}/", "")) - local_path = os.path.join( - settings.audio.temp_download_dir, f"{process_id}_{Path(s3_key).name}" - ) + local_path = os.path.join(settings.audio.temp_download_dir, f"{process_id}_{Path(s3_key).name}") os.makedirs(settings.audio.temp_download_dir, exist_ok=True) self.storage.download_file(s3_key, local_path) return local_path, source, None raise ValueError(f"Unsupported source type: {source_type}") - def _prepare_local_workspace( - self, audio_path: str, video_folder: str, process_id: str - ) -> str: + def _prepare_local_workspace(self, audio_path: str, video_folder: str, process_id: str) -> str: download_folder = os.path.join(video_folder, "download") os.makedirs(download_folder, exist_ok=True) - audio_dest = os.path.join( - download_folder, f"input_{process_id}{Path(audio_path).suffix}" - ) + audio_dest = os.path.join(download_folder, f"input_{process_id}{Path(audio_path).suffix}") os.replace(audio_path, audio_dest) return audio_dest @@ -383,12 +357,8 @@ def _run_diarization( max_speakers=max_s, ) - def _identify_voices( - self, recognition_folder: str, hf_token: str, d_id: str | None - ) -> dict: - self._notify( - d_id, DiarizationStatus.PROCESSING.value, DiarizationStep.RECOGNIZING.value - ) + def _identify_voices(self, recognition_folder: str, hf_token: str, d_id: str | None) -> dict: + self._notify(d_id, DiarizationStatus.PROCESSING.value, DiarizationStep.RECOGNIZING.value) voice_db = VoiceDB(db=self.db, hf_token=hf_token) if len(voice_db) == 0: return {} @@ -410,9 +380,7 @@ def _identify_voices( spk, best_score, ) - _, s3_path = voice_db.add( - name=best_match, audio_path=match.audio_path - ) + _, s3_path = voice_db.add(name=best_match, audio_path=match.audio_path) if s3_path: reinforced_paths[spk] = s3_path except Exception as e: @@ -444,16 +412,10 @@ def _identify_voices( }, } - def _update_record_metadata( - self, record: Any, storage_prefix: str, yt_metadata: Optional[Any] - ): + def _update_record_metadata(self, record: Any, storage_prefix: str, yt_metadata: Optional[Any]): record.storage_path = cast(Any, storage_prefix) if yt_metadata: - metadata_dict = ( - yt_metadata.model_dump() - if hasattr(yt_metadata, "model_dump") - else vars(yt_metadata) - ) + metadata_dict = yt_metadata.model_dump() if hasattr(yt_metadata, "model_dump") else vars(yt_metadata) record.source_metadata = cast(Any, metadata_dict) self.db.commit() diff --git a/src/application/use_cases/retrieve_processed_audio_history.py b/src/application/use_cases/retrieve_processed_audio_history.py index 7b2749bd..494d8f4a 100644 --- a/src/application/use_cases/retrieve_processed_audio_history.py +++ b/src/application/use_cases/retrieve_processed_audio_history.py @@ -10,9 +10,7 @@ class RetrieveProcessedAudioHistoryUseCase: def __init__(self, db: Session): self.repo = DiarizationRepository(db) - def execute( - self, limit: int = 10, offset: int = 0, subject_id: str | None = None - ) -> list[dict]: + def execute(self, limit: int = 10, offset: int = 0, subject_id: str | None = None) -> list[dict]: records = self.repo.get_all(limit=limit, offset=offset, subject_id=subject_id) return [ diff --git a/src/application/use_cases/search_use_case.py b/src/application/use_cases/search_use_case.py index 21f2893b..f15e002f 100644 --- a/src/application/use_cases/search_use_case.py +++ b/src/application/use_cases/search_use_case.py @@ -33,9 +33,7 @@ def execute( context={ "query": query, "top_k": top_k, - "subject_ids": [str(sid) for sid in subject_ids] - if subject_ids - else None, + "subject_ids": [str(sid) for sid in subject_ids] if subject_ids else None, "subject_name": subject_name, "search_mode": str(search_mode), "re_rank": re_rank, @@ -49,9 +47,7 @@ def execute( filters: Optional[Any] = None # Resolve subject_name to ID if provided if subject_name: - logger.debug( - "Resolving subject name", context={"subject_name": subject_name} - ) + logger.debug("Resolving subject name", context={"subject_name": subject_name}) if not self.ks_service: raise ValueError("ks_service is required to filter by subject_name") subject = self.ks_service.get_by_name(subject_name) @@ -100,9 +96,7 @@ def execute( if subject_cache[sid]: res.extra["subject_name"] = subject_cache[sid] - logger.info( - "Search completed", context={"query": query, "results_count": len(results)} - ) + logger.info("Search completed", context={"query": query, "results_count": len(results)}) return SearchChunksResult( query=query, diff --git a/src/application/use_cases/web_scraping_use_case.py b/src/application/use_cases/web_scraping_use_case.py index 01a40de8..9dd28a0e 100644 --- a/src/application/use_cases/web_scraping_use_case.py +++ b/src/application/use_cases/web_scraping_use_case.py @@ -84,11 +84,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: # 1. Create or retrieve Ingestion Job if cmd.ingestion_job_id: try: - jid = ( - UUID(cmd.ingestion_job_id) - if isinstance(cmd.ingestion_job_id, str) - else cmd.ingestion_job_id - ) + jid = UUID(cmd.ingestion_job_id) if isinstance(cmd.ingestion_job_id, str) else cmd.ingestion_job_id ingestion = self.ingestion_service.get_by_id(jid) except Exception as e: logger.warning( @@ -161,9 +157,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: ) else: # Update title and metadata if it exists - self.cs_service.update_processing_status( - source.id, ContentSourceStatus.PROCESSING - ) + self.cs_service.update_processing_status(source.id, ContentSourceStatus.PROCESSING) # --- REPROCESSING CLEANUP --- if source and source.id and getattr(cmd, "reprocess", False): @@ -206,8 +200,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: tokenizer = ( self.model_loader_service.model.tokenizer - if hasattr(self.model_loader_service, "model") - and hasattr(self.model_loader_service.model, "tokenizer") + if hasattr(self.model_loader_service, "model") and hasattr(self.model_loader_service.model, "tokenizer") else None ) @@ -233,9 +226,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: split_docs = langchain_splitter.split_documents([full_doc]) # 5. Build and Persist Chunks - chunks = self._build_chunk_entities( - split_docs, source, subject, cmd, ingestion.id - ) + chunks = self._build_chunk_entities(split_docs, source, subject, cmd, ingestion.id) self.chunk_service.create_chunks(chunks) # 6. Index in Vector Store @@ -259,9 +250,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: chunks_count=len(chunks), ) - total_tokens = sum( - c.tokens_count for c in chunks if c.tokens_count is not None - ) + total_tokens = sum(c.tokens_count for c in chunks if c.tokens_count is not None) dims = getattr(self.model_loader_service, "dimensions", 0) self.cs_service.finish_ingestion( @@ -320,9 +309,7 @@ async def execute(self, cmd: IngestWebCommand) -> Dict[str, Any]: def _resolve_subject(self, cmd: IngestWebCommand): if cmd.subject_id: subject = self.ks_service.get_subject_by_id( - UUID(cmd.subject_id) - if isinstance(cmd.subject_id, str) - else cmd.subject_id + UUID(cmd.subject_id) if isinstance(cmd.subject_id, str) else cmd.subject_id ) if not subject: raise ValueError(f"Subject not found: {cmd.subject_id}") @@ -345,18 +332,14 @@ def _build_chunk_entities( list_chunks: List[ChunkEntity] = [] tokenizer = None - if hasattr(self.model_loader_service, "model") and hasattr( - self.model_loader_service.model, "tokenizer" - ): + if hasattr(self.model_loader_service, "model") and hasattr(self.model_loader_service.model, "tokenizer"): tokenizer = self.model_loader_service.model.tokenizer for i, doc in enumerate(docs): tokens_count = None if tokenizer: try: - tokens = tokenizer.encode( - doc.page_content, add_special_tokens=False - ) + tokens = tokenizer.encode(doc.page_content, add_special_tokens=False) tokens_count = len(tokens) except Exception: tokens_count = len(doc.page_content) // 4 diff --git a/src/application/use_cases/youtube_ingestion_use_case.py b/src/application/use_cases/youtube_ingestion_use_case.py index 0d3c021b..2246a853 100644 --- a/src/application/use_cases/youtube_ingestion_use_case.py +++ b/src/application/use_cases/youtube_ingestion_use_case.py @@ -129,9 +129,7 @@ def _report_status( def _resolve_video_list(self, cmd: IngestYoutubeCommand) -> List[str]: """Resolves whether we are dealing with a playlist or a list of videos.""" if cmd.data_type == YoutubeDataType.PLAYLIST: - playlist_url = cmd.video_url or ( - cmd.video_urls[0] if cmd.video_urls else None - ) + playlist_url = cmd.video_url or (cmd.video_urls[0] if cmd.video_urls else None) if not playlist_url: raise ValueError("No video_url provided for playlist") @@ -143,9 +141,7 @@ def _resolve_video_list(self, cmd: IngestYoutubeCommand) -> List[str]: "No videos found in playlist", context={"playlist_url": playlist_url}, ) - raise ValueError( - f"No videos found in playlist: {playlist_url}. Verify if the URL is valid and public." - ) + raise ValueError(f"No videos found in playlist: {playlist_url}. Verify if the URL is valid and public.") return video_list video_list = [] @@ -211,10 +207,7 @@ def _execute_batch( batch_has_network_error = False with concurrent.futures.ThreadPoolExecutor(max_workers=len(batch)) as executor: - futures = [ - executor.submit(self._process_video_task_wrapper, url, subject, cmd) - for url in batch - ] + futures = [executor.submit(self._process_video_task_wrapper, url, subject, cmd) for url in batch] for future in concurrent.futures.as_completed(futures): single_result = future.result() result.video_results.append(single_result) @@ -224,20 +217,13 @@ def _execute_batch( if single_result.get("is_network_error"): batch_has_network_error = True - if ( - not single_result.get("skipped", False) - and "error" not in single_result - ): - result.created_chunks = ( - result.created_chunks or 0 - ) + single_result.get("created_chunks", 0) + if not single_result.get("skipped", False) and "error" not in single_result: + result.created_chunks = (result.created_chunks or 0) + single_result.get("created_chunks", 0) result.vector_ids.extend(single_result.get("vector_ids", [])) return {"ip_blocked": ip_blocked, "network_error": batch_has_network_error} - def _process_video_task_wrapper( - self, url: str, subject: Any, cmd: IngestYoutubeCommand - ) -> Dict[str, Any]: + def _process_video_task_wrapper(self, url: str, subject: Any, cmd: IngestYoutubeCommand) -> Dict[str, Any]: """Wrapper for processing a single video with error classification for batch results.""" vid_id = "unknown" try: @@ -282,9 +268,7 @@ def _apply_throttling(self, wait_time: float) -> None: ) time.sleep(total_wait) - def _finalize_parent_job( - self, ingestion: Any, result: IngestYoutubeResult, any_failed: bool - ) -> None: + def _finalize_parent_job(self, ingestion: Any, result: IngestYoutubeResult, any_failed: bool) -> None: """Updates the status of the main tracking job.""" if not any_failed: self._finish_job(ingestion, chunks_count=result.created_chunks) @@ -296,17 +280,10 @@ def _finalize_parent_job( return # 2. Real Errors vs Cancellations - real_errors = [ - r["error"] - for r in result.video_results - if "error" in r and not r.get("cancelled", False) - ] + real_errors = [r["error"] for r in result.video_results if "error" in r and not r.get("cancelled", False)] if real_errors: - error_summary = ( - f"Ingestion failed for {len(real_errors)} items: " - + "; ".join(real_errors)[:200] - ) + error_summary = f"Ingestion failed for {len(real_errors)} items: " + "; ".join(real_errors)[:200] self._fail_job(ingestion, error_summary) else: # Only known limitations/cancellations @@ -316,21 +293,15 @@ def _is_ip_blocked_in_results(self, result: IngestYoutubeResult) -> bool: """Checks if any video result failed due to an IP block.""" return any(r.get("is_ip_blocked") for r in result.video_results) - def _handle_ip_block_failure( - self, ingestion: Any, result: IngestYoutubeResult - ) -> None: + def _handle_ip_block_failure(self, ingestion: Any, result: IngestYoutubeResult) -> None: """Handles the termination of a job due to an IP block.""" block_item = next(r for r in result.video_results if r.get("is_ip_blocked")) error_summary = f"ABORTED: YouTube is blocking our requests (IP Ban/Block). {block_item.get('error', '')[:150]}" self._fail_job(ingestion, error_summary) - def _handle_partial_ingestion_status( - self, ingestion: Any, result: IngestYoutubeResult - ) -> None: + def _handle_partial_ingestion_status(self, ingestion: Any, result: IngestYoutubeResult) -> None: """Reports status for jobs that were partially successful (due to private videos, etc).""" - cancelled_msgs = [ - r["error"] for r in result.video_results if r.get("cancelled", False) - ] + cancelled_msgs = [r["error"] for r in result.video_results if r.get("cancelled", False)] summary = f"Partial ingestion: {len(cancelled_msgs)} items skipped (private/unplayable)." self._report_status( job_id=ingestion.id, @@ -339,20 +310,12 @@ def _handle_partial_ingestion_status( chunks_count=result.created_chunks, ) - def _finalize_parent_source( - self, source: Any, result: IngestYoutubeResult, cmd: IngestYoutubeCommand - ) -> None: + def _finalize_parent_source(self, source: Any, result: IngestYoutubeResult, cmd: IngestYoutubeCommand) -> None: """Updates the status of the parent source.""" - real_errors = [ - r["error"] - for r in result.video_results - if "error" in r and not r.get("cancelled", False) - ] + real_errors = [r["error"] for r in result.video_results if "error" in r and not r.get("cancelled", False)] if real_errors: - self._fail_ingestion( - source, error_message=real_errors[0] if real_errors else None - ) + self._fail_ingestion(source, error_message=real_errors[0] if real_errors else None) elif cmd.data_type != YoutubeDataType.PLAYLIST: # For single videos, if not already DONE/FAILED current_source = self.cs_service.get_by_id(source.id) @@ -383,11 +346,7 @@ def execute(self, cmd: IngestYoutubeCommand) -> IngestYoutubeResult: source = None if cmd.ingestion_job_id: with suppress(Exception): - jid = ( - UUID(cmd.ingestion_job_id) - if isinstance(cmd.ingestion_job_id, str) - else cmd.ingestion_job_id - ) + jid = UUID(cmd.ingestion_job_id) if isinstance(cmd.ingestion_job_id, str) else cmd.ingestion_job_id ingestion = self.ingestion_service.get_by_id(jid) if ingestion and ingestion.content_source_id: source = self.cs_service.get_by_id(ingestion.content_source_id) @@ -413,11 +372,7 @@ def execute(self, cmd: IngestYoutubeCommand) -> IngestYoutubeResult: # 3. For single video ingestion, if it fails, raise error for API is_batch = len(video_list) > 1 - if ( - any_failed - and cmd.data_type != YoutubeDataType.PLAYLIST - and not is_batch - ): + if any_failed and cmd.data_type != YoutubeDataType.PLAYLIST and not is_batch: failed_item = next(r for r in result.video_results if "error" in r) raise ValueError(failed_item["error"]) @@ -428,15 +383,11 @@ def execute(self, cmd: IngestYoutubeCommand) -> IngestYoutubeResult: logger.info( "YouTube ingestion completed", context={ - "job_ids": [ - r.get("job_id") for r in result.video_results if r.get("job_id") - ] + "job_ids": [r.get("job_id") for r in result.video_results if r.get("job_id")] or cmd.ingestion_job_id, "chunks": result.created_chunks, "total_videos": len(result.video_results), - "skipped": sum( - 1 for r in result.video_results if r.get("skipped", False) - ), + "skipped": sum(1 for r in result.video_results if r.get("skipped", False)), }, ) return result @@ -467,9 +418,7 @@ def execute(self, cmd: IngestYoutubeCommand) -> IngestYoutubeResult: try: self._fail_job(ingestion, error_msg) except Exception as ej: - logger.error( - ej, context={"action": "fail_job", "job_id": str(ingestion.id)} - ) + logger.error(ej, context={"action": "fail_job", "job_id": str(ingestion.id)}) raise @@ -480,20 +429,14 @@ def _ensure_ingestion_context( existing = self._check_existing_source(video_id, subject_id) # 1. Skip if already exists and done (not reprocessing) - if ( - existing - and existing.processing_status == "done" - and not getattr(cmd, "reprocess", False) - ): + if existing and existing.processing_status == "done" and not getattr(cmd, "reprocess", False): return self._handle_duplicate_ingestion(video_id, existing) # 2. Get or create Job job = self._resolve_or_create_job(cmd, existing, video_id, subject_id) return existing, job, False - def _handle_duplicate_ingestion( - self, video_id: str, existing: Any - ) -> tuple[Any, Any, bool]: + def _handle_duplicate_ingestion(self, video_id: str, existing: Any) -> tuple[Any, Any, bool]: """Handles skipping ingestion when a duplicate source is found and completed.""" logger.info( "Source already exists and is DONE, skipping ingestion (reprocess=False)", @@ -534,29 +477,21 @@ def _resolve_or_create_job( job = self.ingestion_service.get_by_id(job_uuid) if not job: - job = self._create_ingestion_job( - source=existing, external_source=video_id, subject_id=subject_id - ) + job = self._create_ingestion_job(source=existing, external_source=video_id, subject_id=subject_id) return job - def _handle_reprocessing_cleanup( - self, source: Any, job_id: UUID, video_id: str - ) -> None: + def _handle_reprocessing_cleanup(self, source: Any, job_id: UUID, video_id: str) -> None: """Cleans up previous ingestion data if reprocessing.""" logger.info( "REPROCESSING: Performing pre-ingestion cleanup", context={"source_id": str(source.id), "video_id": video_id}, ) try: - sql_del = self.chunk_service.delete_by_content_source( - content_source_id=source.id - ) + sql_del = self.chunk_service.delete_by_content_source(content_source_id=source.id) vec_del = self.vector_service.delete_by_video_id(video_id=video_id) # Mark previous jobs as REPROCESSED - self.ingestion_service.mark_previous_jobs_as_reprocessed( - content_source_id=source.id, current_job_id=job_id - ) + self.ingestion_service.mark_previous_jobs_as_reprocessed(content_source_id=source.id, current_job_id=job_id) logger.info( "Reprocessing cleanup finished", @@ -568,9 +503,7 @@ def _handle_reprocessing_cleanup( context={"source_id": str(source.id), "error": str(ce)}, ) - self.cs_service.update_processing_status( - content_source_id=source.id, status=ContentSourceStatus.PROCESSING - ) + self.cs_service.update_processing_status(content_source_id=source.id, status=ContentSourceStatus.PROCESSING) def _rollback_failed_ingestion( self, job_id: UUID, video_id: str, error_msg: str, source: Optional[Any] = None @@ -592,9 +525,7 @@ def _rollback_failed_ingestion( }, ) except Exception as er: - logger.error( - er, context={"action": "rollback_ingestion", "job_id": str(job_id)} - ) + logger.error(er, context={"action": "rollback_ingestion", "job_id": str(job_id)}) if source: try: @@ -615,14 +546,10 @@ def _process_single_video( self, video_url: str, video_id: str, subject: Any, cmd: IngestYoutubeCommand ) -> Dict[str, Any]: """Orchestrates the ingestion of a single YouTube video.""" - logger.info( - "Processing video", context={"video_id": video_id, "video_url": video_url} - ) + logger.info("Processing video", context={"video_id": video_id, "video_url": video_url}) # 1. Resolve source and job early - source, ingestion, skipped = self._ensure_ingestion_context( - video_id, subject.id, cmd - ) + source, ingestion, skipped = self._ensure_ingestion_context(video_id, subject.id, cmd) if skipped: return self._format_skipped_result(video_url, video_id, source, ingestion) @@ -638,12 +565,7 @@ def _process_single_video( # 3. Extract core metadata yt_extractor = YoutubeExtractor(video_id=video_id, language=cmd.language) metadata = yt_extractor.extract_metadata() - extracted_title = ( - metadata.full_title - or metadata.title - or cmd.title - or f"Video {video_id}" - ) + extracted_title = metadata.full_title or metadata.title or cmd.title or f"Video {video_id}" # 4. Ensure source exists and is linked if source is None: @@ -671,13 +593,9 @@ def _process_single_video( docs = self._extract_and_split(cmd, video_id, yt_extractor=yt_extractor) if not docs: - raise ValueError( - f"No transcript chunks generated for video {video_id}." - ) + raise ValueError(f"No transcript chunks generated for video {video_id}.") - self.cs_service.update_title( - content_source_id=source.id, title=extracted_title - ) + self.cs_service.update_title(content_source_id=source.id, title=extracted_title) self._mark_source_processing(source) # 6. Embed and Persist @@ -689,9 +607,7 @@ def _process_single_video( current_step=2, chunks_count=len(docs), ) - chunks = self._build_chunk_entities( - docs, source, subject, cmd, job_id=ingestion.id - ) + chunks = self._build_chunk_entities(docs, source, subject, cmd, job_id=ingestion.id) self._persist_chunks(chunks) # 7. Vector Indexing @@ -706,9 +622,7 @@ def _process_single_video( created_ids = self._index_chunks(chunks) # 8. Success Finalization - total_tokens = sum( - c.tokens_count for c in chunks if c.tokens_count is not None - ) + total_tokens = sum(c.tokens_count for c in chunks if c.tokens_count is not None) self._report_status( job_id=ingestion.id, status="completed", @@ -749,9 +663,7 @@ def _process_single_video( ) if ingestion: - self._report_status( - job_id=ingestion.id, status="failed", error=error_msg - ) + self._report_status(job_id=ingestion.id, status="failed", error=error_msg) if source: self.cs_service.update_processing_status( content_source_id=source.id, status=ContentSourceStatus.CANCELLED @@ -768,14 +680,10 @@ def _process_single_video( except Exception as e: error_msg = str(e) - logger.error( - e, context={"video_id": video_id, "action": "process_single_video"} - ) + logger.error(e, context={"video_id": video_id, "action": "process_single_video"}) if ingestion: - self._rollback_failed_ingestion( - ingestion.id, video_id, error_msg, source=source - ) + self._rollback_failed_ingestion(ingestion.id, video_id, error_msg, source=source) raise @@ -796,22 +704,16 @@ def _fail_job(self, ingestion, error_message: str) -> None: status=IngestionJobStatus.FAILED, error_message=error_message, ) - logger.info( - "Ingestion job updated to FAILED", context={"job_id": str(ingestion.id)} - ) + logger.info("Ingestion job updated to FAILED", context={"job_id": str(ingestion.id)}) def _resolve_subject(self, cmd: IngestYoutubeCommand): if getattr(cmd, "subject_id", None): try: subject_id_val = ( - cmd.subject_id - if isinstance(cmd.subject_id, uuid.UUID) - else uuid.UUID(str(cmd.subject_id)) + cmd.subject_id if isinstance(cmd.subject_id, uuid.UUID) else uuid.UUID(str(cmd.subject_id)) ) except Exception as e: - logger.error( - e, context={"subject_id": getattr(cmd, "subject_id", None)} - ) + logger.error(e, context={"subject_id": getattr(cmd, "subject_id", None)}) raise ValueError(f"Invalid subject_id provided: {e}") subject = self.ks_service.get_subject_by_id(subject_id_val) if subject is None: @@ -829,9 +731,7 @@ def _resolve_subject(self, cmd: IngestYoutubeCommand): "KnowledgeSubject not found by name", context={"subject_name": cmd.subject_name}, ) - raise ValueError( - f"KnowledgeSubject with name '{cmd.subject_name}' not found" - ) + raise ValueError(f"KnowledgeSubject with name '{cmd.subject_name}' not found") return subject logger.error( @@ -903,9 +803,7 @@ def _create_ingestion_job( return ingestion def _mark_source_processing(self, source) -> None: - self.cs_service.update_processing_status( - content_source_id=source.id, status=ContentSourceStatus.PROCESSING - ) + self.cs_service.update_processing_status(content_source_id=source.id, status=ContentSourceStatus.PROCESSING) logger.debug( "Content source marked as PROCESSING", context={"content_source_id": str(source.id)}, @@ -917,15 +815,11 @@ def _extract_and_split( video_id: str, yt_extractor: Optional[YoutubeExtractor] = None, ) -> List[Document]: - logger.debug( - "Starting extraction and transcript split", context={"video_id": video_id} - ) + logger.debug("Starting extraction and transcript split", context={"video_id": video_id}) if yt_extractor is None: yt_extractor = YoutubeExtractor(video_id=video_id, language=cmd.language) - ytts = YoutubeDataProcessService( - model_loader_service=self.model_loader_service, yt_extractor=yt_extractor - ) + ytts = YoutubeDataProcessService(model_loader_service=self.model_loader_service, yt_extractor=yt_extractor) effective_tokens = cmd.tokens_per_chunk docs: List[Document] = ytts.split_transcript( mode="tokens", @@ -940,12 +834,8 @@ def _extract_and_split( return docs def _update_ingestion_processing(self, ingestion) -> None: - self.ingestion_service.update_job( - job_id=ingestion.id, status=IngestionJobStatus.PROCESSING - ) - logger.debug( - "Ingestion job updated to PROCESSING", context={"job_id": str(ingestion.id)} - ) + self.ingestion_service.update_job(job_id=ingestion.id, status=IngestionJobStatus.PROCESSING) + logger.debug("Ingestion job updated to PROCESSING", context={"job_id": str(ingestion.id)}) self.event_bus.publish( "ingestion_status", { @@ -986,9 +876,7 @@ def _build_chunk_entities( def _persist_chunks(self, chunks: List[ChunkEntity]) -> None: self.chunk_service.create_chunks(chunks) - logger.debug( - "Persisted chunks to SQL repository", context={"num_chunks": len(chunks)} - ) + logger.debug("Persisted chunks to SQL repository", context={"num_chunks": len(chunks)}) def _index_chunks(self, chunks: List[ChunkEntity]) -> List[str]: created_ids = self.vector_service.index_documents(chunks) @@ -1034,9 +922,7 @@ def _finish_job(self, ingestion, chunks_count: Optional[int] = None) -> None: context={"job_id": str(ingestion.id), "chunks": chunks_count}, ) - def _format_skipped_result( - self, video_url: str, video_id: str, source: Any, ingestion: Any - ) -> Dict[str, Any]: + def _format_skipped_result(self, video_url: str, video_id: str, source: Any, ingestion: Any) -> Dict[str, Any]: """Formats the result for a skipped ingestion.""" return { "video_url": video_url, diff --git a/src/application/workers.py b/src/application/workers.py index 88900f3c..a9a7be28 100644 --- a/src/application/workers.py +++ b/src/application/workers.py @@ -74,9 +74,7 @@ def run_file_ingestion_worker(cmd: IngestFileCommand): use_case.execute(cmd) except Exception as e: - logger.error( - f"Worker Error: Failed to execute file ingestion: {e}", exc_info=True - ) + logger.error(f"Worker Error: Failed to execute file ingestion: {e}", exc_info=True) finally: clear_global_context() @@ -123,9 +121,7 @@ def run_youtube_ingestion_worker(cmd: IngestYoutubeCommand): use_case.execute(cmd) except Exception as e: - logger.error( - f"Worker Error: Failed to execute YouTube ingestion: {e}", exc_info=True - ) + logger.error(f"Worker Error: Failed to execute YouTube ingestion: {e}", exc_info=True) finally: clear_global_context() @@ -135,9 +131,7 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): Resolves the list of URLs and enqueues individual workers for each video. """ - set_global_context( - {"correlation_id": _get_correlation_id(cmd, "worker-youtube-dispatcher")} - ) + set_global_context({"correlation_id": _get_correlation_id(cmd, "worker-youtube-dispatcher")}) if isinstance(cmd, dict): cmd = IngestYoutubeCommand(**cmd) @@ -157,9 +151,7 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): video_list = [] if cmd.data_type == YoutubeDataType.PLAYLIST: # We need to resolve the playlist entries - playlist_url = cmd.video_url or ( - cmd.video_urls[0] if cmd.video_urls else None - ) + playlist_url = cmd.video_url or (cmd.video_urls[0] if cmd.video_urls else None) if not playlist_url: logger.warning("No URL provided for playlist dispatcher") return @@ -168,9 +160,7 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): video_list = extractor.extract_playlist_videos(playlist_url) elif cmd.data_type == YoutubeDataType.CHANNEL: # Resolve the channel entries - channel_url = cmd.video_url or ( - cmd.video_urls[0] if cmd.video_urls else None - ) + channel_url = cmd.video_url or (cmd.video_urls[0] if cmd.video_urls else None) if not channel_url: logger.warning("No URL provided for channel dispatcher") return @@ -182,14 +172,10 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): video_list = [v for v in cmd.video_urls if v] if not video_list: - logger.warning( - f"YouTube Dispatcher resolved 0 videos for type {cmd.data_type}." - ) + logger.warning(f"YouTube Dispatcher resolved 0 videos for type {cmd.data_type}.") return - logger.info( - f"YouTube Dispatcher resolved {len(video_list)} videos. Enqueueing individual tasks..." - ) + logger.info(f"YouTube Dispatcher resolved {len(video_list)} videos. Enqueueing individual tasks...") # 2. Enqueue each video as a separate task for url in video_list: @@ -212,14 +198,10 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): run_youtube_ingestion_worker, single_cmd, task_title=f"YouTube: {url}", - metadata={"parent_dispatcher_job": str(cmd.ingestion_job_id)} - if cmd.ingestion_job_id - else {}, + metadata={"parent_dispatcher_job": str(cmd.ingestion_job_id)} if cmd.ingestion_job_id else {}, ) - logger.info( - f"Successfully dispatched {len(video_list)} YouTube ingestion tasks." - ) + logger.info(f"Successfully dispatched {len(video_list)} YouTube ingestion tasks.") except Exception as e: logger.error(f"YouTube Dispatcher Worker Error: {e}", exc_info=True) @@ -229,9 +211,7 @@ def run_youtube_dispatcher_worker(cmd: IngestYoutubeCommand): def run_diarization_ingestion_worker(cmd: IngestDiarizationCommand): """Background worker function for direct diarization ingestion.""" - set_global_context( - {"correlation_id": _get_correlation_id(cmd, "worker-diarization")} - ) + set_global_context({"correlation_id": _get_correlation_id(cmd, "worker-diarization")}) app = _get_app() if not app: @@ -280,9 +260,7 @@ def run_diarization_ingestion_worker(cmd: IngestDiarizationCommand): finally: db.close() except Exception as e: - logger.error( - f"Worker Error: Failed to execute diarization ingestion: {e}", exc_info=True - ) + logger.error(f"Worker Error: Failed to execute diarization ingestion: {e}", exc_info=True) finally: clear_global_context() @@ -337,9 +315,7 @@ async def _run(): await use_case.execute(cmd) except Exception as e: - logging.getLogger(__name__).error( - f"Worker Error: Failed to execute Web Scraping: {e}", exc_info=True - ) + logging.getLogger(__name__).error(f"Worker Error: Failed to execute Web Scraping: {e}", exc_info=True) finally: clear_global_context() @@ -370,9 +346,7 @@ def _audio_diarization_subprocess(cmd_dict: dict): diarization_id = cmd_dict.get("diarization_id") try: - use_case = ProcessAudioDiarizationPipelineUseCase( - db, event_bus=event_bus, cs_service=cs_service - ) + use_case = ProcessAudioDiarizationPipelineUseCase(db, event_bus=event_bus, cs_service=cs_service) use_case.execute( source_type=cmd_dict["source_type"], source=cmd_dict["source"], @@ -439,12 +413,7 @@ def run_audio_diarization_dispatcher_worker(cmd: ProcessAudioCommand): repo = DiarizationRepository(db) # Detect YouTube Channel vs Playlist - is_channel = ( - "/channel/" in cmd.source - or "/c/" in cmd.source - or "/user/" in cmd.source - or "@" in cmd.source - ) + is_channel = "/channel/" in cmd.source or "/c/" in cmd.source or "/user/" in cmd.source or "@" in cmd.source logger.info( "Resolving %s URLs for diarization: %s", @@ -478,13 +447,9 @@ def run_audio_diarization_dispatcher_worker(cmd: ProcessAudioCommand): ) if ( - existing - and existing.status - != "failed" # DiarizationStatus is not imported here as Enum yet + existing and existing.status != "failed" # DiarizationStatus is not imported here as Enum yet ): - logger.info( - "Skipping duplicate video for diarization dispatcher: %s", url - ) + logger.info("Skipping duplicate video for diarization dispatcher: %s", url) continue # 1. Create a pending record @@ -549,9 +514,7 @@ def run_audio_diarization_worker(cmd: ProcessAudioCommand): process.join() if process.exitcode != 0: - logger.error( - "Audio diarization subprocess exited with code %d", process.exitcode - ) + logger.error("Audio diarization subprocess exited with code %d", process.exitcode) if cmd.diarization_id: from src.infrastructure.repositories.sql.connector import ( Session as DBSessionFactory, @@ -592,20 +555,19 @@ def run_audio_diarization_worker(cmd: ProcessAudioCommand): else: logger.info("Audio diarization subprocess completed successfully") except Exception as e: - logger.error( - f"Worker Error: Failed to execute audio diarization: {e}", exc_info=True - ) + logger.error(f"Worker Error: Failed to execute audio diarization: {e}", exc_info=True) finally: clear_global_context() def run_voice_training_worker(cmd: TrainVoiceCommand): """Background worker function for voice profile training from speaker segment.""" - set_global_context({"correlation_id": f"worker-voice-train-{cmd.name}"}) - + # Redis-serialized payloads arrive as dicts — convert BEFORE touching fields. if isinstance(cmd, dict): cmd = TrainVoiceCommand(**cmd) + set_global_context({"correlation_id": f"worker-voice-train-{cmd.name}"}) + app = _get_app() if not app: clear_global_context() @@ -621,9 +583,7 @@ def run_voice_training_worker(cmd: TrainVoiceCommand): ctx = resolve_ingestion_context(app) db = DBSession() try: - use_case = TrainVoiceProfileFromSpeakerSegmentUseCase( - db, event_bus=ctx.event_bus - ) + use_case = TrainVoiceProfileFromSpeakerSegmentUseCase(db, event_bus=ctx.event_bus) use_case.execute( diarization_id=cmd.diarization_id, speaker_label=cmd.speaker_label, @@ -632,8 +592,6 @@ def run_voice_training_worker(cmd: TrainVoiceCommand): finally: db.close() except Exception as e: - logger.error( - f"Worker Error: Failed to execute voice training: {e}", exc_info=True - ) + logger.error(f"Worker Error: Failed to execute voice training: {e}", exc_info=True) finally: clear_global_context() diff --git a/src/config/settings.py b/src/config/settings.py index bf6c2041..485e2ba1 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -8,15 +8,11 @@ from src.config.validators import docker_host_fallback, docker_host_fallback_optional from src.domain.entities.enums.vector_store_type_enum import VectorStoreType -warnings.filterwarnings( - "ignore", message='.*has conflict with protected namespace "model_".*' -) +warnings.filterwarnings("ignore", message='.*has conflict with protected namespace "model_".*') class SQLConfig(BaseModel): - type: Optional[str] = Field( - default=None, description="SQL database connection type" - ) + type: Optional[str] = Field(default=None, description="SQL database connection type") host: Optional[str] = Field(default=None, description="SQL database host") port: Optional[str] = Field(default=None, description="SQL database port") user: Optional[str] = Field(default=None, description="SQL database username") @@ -31,9 +27,7 @@ class SQLConfig(BaseModel): @field_validator("host", mode="after") @classmethod def _fallback_host(cls, v: Optional[str]) -> Optional[str]: - return docker_host_fallback_optional( - v, {"postgres", "mysql", "mariadb", "mssql", "db"} - ) + return docker_host_fallback_optional(v, {"postgres", "mysql", "mariadb", "mssql", "db"}) @property def url(self) -> str: @@ -68,16 +62,10 @@ class VectorConfig(BaseSettings): description="Path to store vector index files (for Chroma and FAISS)", ) - weaviate_host: str = Field( - default="localhost", description="WeaviateConfig host URL" - ) + weaviate_host: str = Field(default="localhost", description="WeaviateConfig host URL") weaviate_port: int = Field(default=8081, description="WeaviateConfig port") - weaviate_api_key: Optional[str] = Field( - default=None, description="WeaviateConfig API key for authentication" - ) - weaviate_grpc_port: int = Field( - default=50051, description="WeaviateConfig gRPC port for local connections" - ) + weaviate_api_key: Optional[str] = Field(default=None, description="WeaviateConfig API key for authentication") + weaviate_grpc_port: int = Field(default=50051, description="WeaviateConfig gRPC port for local connections") chroma_host: str = Field(default="localhost", description="ChromaDB host URL") chroma_port: int = Field(default=8000, description="ChromaDB port") @@ -106,8 +94,7 @@ def weaviate_url(self) -> str: class App(BaseSettings): env: str = Field( default="development", - description="Application environment (e.g., 'development', 'production', " - "'testing')", + description="Application environment (e.g., 'development', 'production', 'testing')", ) port: int = Field(default=5000, description="Application port") @@ -153,9 +140,7 @@ def allowed_log_levels(self) -> set: "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - return { - level_map[level] for level in self.list_log_levels if level in level_map - } + return {level_map[level] for level in self.list_log_levels if level in level_map} class ModelRerank(BaseSettings): @@ -193,33 +178,21 @@ def _fallback_host(cls, v: str) -> str: class YoutubeConfig(BaseSettings): # Throttling configurations - throttle_batch_size: int = Field( - default=2, description="Number of videos to process before waiting" - ) - throttle_wait_seconds: int = Field( - default=65, description="Seconds to wait between batches" - ) + throttle_batch_size: int = Field(default=2, description="Number of videos to process before waiting") + throttle_wait_seconds: int = Field(default=65, description="Seconds to wait between batches") # Proxy configurations - proxy_enabled: bool = Field( - default=False, description="Enable or disable proxy usage for YouTube" - ) + proxy_enabled: bool = Field(default=False, description="Enable or disable proxy usage for YouTube") proxy_url: Optional[str] = Field( default=None, description="Generic proxy URL (e.g. http://user:pass@host:port)", ) - webshare_username: Optional[str] = Field( - default=None, description="Webshare.io username for optimized proxy" - ) - webshare_password: Optional[str] = Field( - default=None, description="Webshare.io password for optimized proxy" - ) + webshare_username: Optional[str] = Field(default=None, description="Webshare.io username for optimized proxy") + webshare_password: Optional[str] = Field(default=None, description="Webshare.io password for optimized proxy") class StorageConfig(BaseSettings): - minio_url: str = Field( - default="http://localhost:9000", description="MinIO/S3 endpoint URL" - ) + minio_url: str = Field(default="http://localhost:9000", description="MinIO/S3 endpoint URL") minio_root_user: str = Field(default="root", description="MinIO access key") minio_root_password: str = Field(default="password", description="MinIO secret key") minio_bucket: str = Field(default="whatyousaid", description="MinIO bucket name") @@ -251,18 +224,10 @@ class AudioConfig(BaseSettings): class AuthConfig(BaseSettings): - hf_token: Optional[str] = Field( - default=None, description="HuggingFace token for pyannote/whisper models" - ) - enable_google: bool = Field( - default=False, description="Enable Google SSO authentication" - ) - google_client_id: Optional[str] = Field( - default=None, description="Google OAuth2 Client ID" - ) - google_client_secret: Optional[str] = Field( - default=None, description="Google OAuth2 Client Secret" - ) + hf_token: Optional[str] = Field(default=None, description="HuggingFace token for pyannote/whisper models") + enable_google: bool = Field(default=False, description="Enable Google SSO authentication") + google_client_id: Optional[str] = Field(default=None, description="Google OAuth2 Client ID") + google_client_secret: Optional[str] = Field(default=None, description="Google OAuth2 Client Secret") redirect_uri: str = Field( default="http://localhost:5000/rest/auth/google/callback", description="Google OAuth2 Redirect URI", @@ -286,12 +251,9 @@ def _warn_default_secret(cls, v: str) -> str: env = os.getenv("APP__ENV", "development") if env == "production": raise ValueError( - "JWT secret must be changed from default in production. " - "Set AUTH__JWT_SECRET environment variable." + "JWT secret must be changed from default in production. Set AUTH__JWT_SECRET environment variable." ) - logging.getLogger(__name__).warning( - "Using default JWT secret — change AUTH__JWT_SECRET for production." - ) + logging.getLogger(__name__).warning("Using default JWT secret — change AUTH__JWT_SECRET for production.") return v @@ -304,33 +266,15 @@ class Settings(BaseSettings): protected_namespaces=(), ) app: App = Field(default_factory=App, description="Application settings") - sql: SQLConfig = Field( - default_factory=SQLConfig, description="SQL database settings" - ) - vector: VectorConfig = Field( - default_factory=VectorConfig, description="Vector store settings" - ) - model_embedding: ModelEmbedding = Field( - default_factory=ModelEmbedding, description="Model embedding settings" - ) - model_rerank: ModelRerank = Field( - default_factory=ModelRerank, description="Model rerank settings" - ) - docling: DoclingConfig = Field( - default_factory=DoclingConfig, description="Docling settings" - ) - redis: RedisConfig = Field( - default_factory=RedisConfig, description="Redis settings" - ) - youtube: YoutubeConfig = Field( - default_factory=YoutubeConfig, description="YouTube ingestion settings" - ) - auth: AuthConfig = Field( - default_factory=AuthConfig, description="Authentication settings" - ) - storage: StorageConfig = Field( - default_factory=StorageConfig, description="MinIO/S3 storage settings" - ) + sql: SQLConfig = Field(default_factory=SQLConfig, description="SQL database settings") + vector: VectorConfig = Field(default_factory=VectorConfig, description="Vector store settings") + model_embedding: ModelEmbedding = Field(default_factory=ModelEmbedding, description="Model embedding settings") + model_rerank: ModelRerank = Field(default_factory=ModelRerank, description="Model rerank settings") + docling: DoclingConfig = Field(default_factory=DoclingConfig, description="Docling settings") + redis: RedisConfig = Field(default_factory=RedisConfig, description="Redis settings") + youtube: YoutubeConfig = Field(default_factory=YoutubeConfig, description="YouTube ingestion settings") + auth: AuthConfig = Field(default_factory=AuthConfig, description="Authentication settings") + storage: StorageConfig = Field(default_factory=StorageConfig, description="MinIO/S3 storage settings") audio: AudioConfig = Field( default_factory=AudioConfig, description="Audio diarization/recognition settings", diff --git a/src/config/validators.py b/src/config/validators.py index cdfecbe9..03e2ab00 100644 --- a/src/config/validators.py +++ b/src/config/validators.py @@ -5,18 +5,12 @@ def docker_host_fallback(host: str, docker_names: set[str]) -> str: """Fallback to localhost if docker service names are used on Windows/non-docker.""" - if ( - host in docker_names - and sys.platform == "win32" - and not os.path.exists("/.dockerenv") - ): + if host in docker_names and sys.platform == "win32" and not os.path.exists("/.dockerenv"): return "localhost" return host -def docker_host_fallback_optional( - host: Optional[str], docker_names: set[str] -) -> Optional[str]: +def docker_host_fallback_optional(host: Optional[str], docker_names: set[str]) -> Optional[str]: """Same as docker_host_fallback but accepts Optional[str].""" if host is None: return None diff --git a/src/domain/entities/chunk_entity.py b/src/domain/entities/chunk_entity.py index 91d6d5d5..f9e7b7f0 100644 --- a/src/domain/entities/chunk_entity.py +++ b/src/domain/entities/chunk_entity.py @@ -14,12 +14,8 @@ class ChunkEntity(BaseModel): domain models; persistence will ensure required identifiers when needed. """ - id: UUID = Field( - default_factory=lambda: uuid4(), description="Logical ID of the chunk" - ) - job_id: Optional[UUID] = Field( - default=None, description="ID of the processing job that created this chunk" - ) + id: UUID = Field(default_factory=lambda: uuid4(), description="Logical ID of the chunk") + job_id: Optional[UUID] = Field(default=None, description="ID of the processing job that created this chunk") content_source_id: Optional[UUID] = Field( default=None, description="ID of the original content source, e.g., video id or document id", @@ -28,9 +24,7 @@ class ChunkEntity(BaseModel): external_source: Optional[str] = Field(default=None) subject_id: Optional[UUID] = Field(default=None) - chunk_id: Optional[str] = Field( - default=None, description="External reference ID for the chunk" - ) + chunk_id: Optional[str] = Field(default=None, description="External reference ID for the chunk") index: Optional[int] = Field( default=None, description="Original sequence number of the chunk within the source", diff --git a/src/domain/entities/content_source_entity.py b/src/domain/entities/content_source_entity.py index 0f508c59..3296c182 100644 --- a/src/domain/entities/content_source_entity.py +++ b/src/domain/entities/content_source_entity.py @@ -8,12 +8,8 @@ class ContentSourceEntity(BaseModel): """Domain entity representing a content source (e.g., a YouTube video or document).""" - id: UUID = Field( - default_factory=lambda: uuid4(), description="Logical ID of the content source" - ) - subject_id: Optional[UUID] = Field( - default=None, description="Associated knowledge subject id" - ) + id: UUID = Field(default_factory=lambda: uuid4(), description="Logical ID of the content source") + subject_id: Optional[UUID] = Field(default=None, description="Associated knowledge subject id") source_type: str = Field(..., description="Type of source, e.g., 'youtube', 'pdf'") external_source: str = Field(..., description="External id or URL of the source") title: Optional[str] = Field(default=None) @@ -29,6 +25,4 @@ class ContentSourceEntity(BaseModel): chunks: int = Field(default=0) status_message: Optional[str] = Field(default=None) error_message: Optional[str] = Field(default=None) - source_metadata: Optional[dict] = Field( - default=None, description="Source-specific metadata in JSON format" - ) + source_metadata: Optional[dict] = Field(default=None, description="Source-specific metadata in JSON format") diff --git a/src/domain/entities/diarization.py b/src/domain/entities/diarization.py index cac78849..5890a47d 100644 --- a/src/domain/entities/diarization.py +++ b/src/domain/entities/diarization.py @@ -49,9 +49,7 @@ def duration(self) -> float: def speakers(self) -> list[str]: return sorted({s.speaker for s in self.segments}) - def export_speaker_audio( - self, output_dir: str, min_segment_duration: float = 0.5 - ) -> dict[str, str]: + def export_speaker_audio(self, output_dir: str, min_segment_duration: float = 0.5) -> dict[str, str]: """Export individual audio files per speaker from the diarized audio.""" os.makedirs(output_dir, exist_ok=True) exported: dict[str, str] = {} @@ -59,11 +57,7 @@ def export_speaker_audio( data, sr = sf.read(self.audio_path, dtype="float32") for speaker in self.speakers: # Filter segments by duration to avoid noise/stutters in the sample - speaker_segments = [ - s - for s in self.segments - if s.speaker == speaker and s.duration >= min_segment_duration - ] + speaker_segments = [s for s in self.segments if s.speaker == speaker and s.duration >= min_segment_duration] # Fallback: if all segments are short, take the longest ones anyway if not speaker_segments: diff --git a/src/domain/entities/knowledge_subject_entity.py b/src/domain/entities/knowledge_subject_entity.py index d6ea1cc0..4b77d55e 100644 --- a/src/domain/entities/knowledge_subject_entity.py +++ b/src/domain/entities/knowledge_subject_entity.py @@ -15,13 +15,9 @@ class KnowledgeSubjectEntity(BaseModel): default_factory=lambda: uuid4(), description="Logical ID of the knowledge subject", ) - external_ref: Optional[str] = Field( - default=None, description="Optional external reference ID" - ) + external_ref: Optional[str] = Field(default=None, description="Optional external reference ID") name: str = Field(..., description="Human-readable name of the subject") - description: Optional[str] = Field( - default=None, description="Optional longer description" - ) + description: Optional[str] = Field(default=None, description="Optional longer description") icon: Optional[str] = Field(default=None, description="Icon name for the frontend") source_count: int = Field(default=0, description="Number of content sources") created_at: datetime = Field( diff --git a/src/domain/exception/youtube_exceptions.py b/src/domain/exception/youtube_exceptions.py index df257d00..bef4077b 100644 --- a/src/domain/exception/youtube_exceptions.py +++ b/src/domain/exception/youtube_exceptions.py @@ -35,11 +35,7 @@ def __init__( available_languages: Optional[list[str]] = None, ): lang_str = f" in language '{language}'" if language else "" - avail_str = ( - f" Available languages: {', '.join(available_languages)}." - if available_languages - else "" - ) + avail_str = f" Available languages: {', '.join(available_languages)}." if available_languages else "" message = f"Transcript not found for video {video_id}{lang_str}.{avail_str}" super().__init__(message, video_id=video_id) self.available_languages = available_languages @@ -57,10 +53,7 @@ class YoutubeNetworkException(YoutubeException): """Raised when there is a network-related error (DNS, connection).""" def __init__(self, video_id: str, error_msg: str): - message = ( - f"Network error while accessing video {video_id}. " - f"Please check your connection. Details: {error_msg}" - ) + message = f"Network error while accessing video {video_id}. Please check your connection. Details: {error_msg}" super().__init__(message, video_id=video_id) diff --git a/src/domain/interfaces/repository/retriver_repository.py b/src/domain/interfaces/repository/retriver_repository.py index bb50fe29..518d0dfd 100644 --- a/src/domain/interfaces/repository/retriver_repository.py +++ b/src/domain/interfaces/repository/retriver_repository.py @@ -31,9 +31,7 @@ def delete(self, filters: Optional[Any]) -> int: raise NotImplementedError @abstractmethod - def list_chunks( - self, filters: Optional[Any], limit: int = 1000 - ) -> List[ChunkModel]: + def list_chunks(self, filters: Optional[Any], limit: int = 1000) -> List[ChunkModel]: """List chunks matching query and filters without vector search.""" raise NotImplementedError diff --git a/src/domain/mappers/chunk_index_mapper.py b/src/domain/mappers/chunk_index_mapper.py index 3857f78b..b3d013e6 100644 --- a/src/domain/mappers/chunk_index_mapper.py +++ b/src/domain/mappers/chunk_index_mapper.py @@ -52,9 +52,7 @@ def _extract_cs_metadata(model: ChunkIndexModel) -> dict: } -def _build_entity_kwargs( - model: ChunkIndexModel, cs_meta: dict, source_type: SourceType -) -> dict: +def _build_entity_kwargs(model: ChunkIndexModel, cs_meta: dict, source_type: SourceType) -> dict: """Construct keyword args for ChunkEntity from model and extracted metadata. Having this in a helper reduces the number of expressions inside the main @@ -63,9 +61,7 @@ def _build_entity_kwargs( return { "id": cast(UUID, getattr(model, "id")), "job_id": cast(Optional[UUID], getattr(model, "job_id", None)), - "content_source_id": cast( - Optional[UUID], getattr(model, "content_source_id", None) - ), + "content_source_id": cast(Optional[UUID], getattr(model, "content_source_id", None)), "source_type": source_type, "external_source": cast( Optional[str], @@ -106,9 +102,7 @@ def model_to_entity(model: Optional[ChunkIndexModel]) -> Optional[ChunkEntity]: if model is None: return None cs_meta = _extract_cs_metadata(model) - source_type_str = getattr(model, "source_type", None) or cs_meta.get( - "source_type_str" - ) + source_type_str = getattr(model, "source_type", None) or cs_meta.get("source_type_str") source_type = _resolve_source_type(source_type_str) kwargs = _build_entity_kwargs(model, cs_meta, source_type) return ChunkEntity(**kwargs) diff --git a/src/domain/mappers/chunk_mapper.py b/src/domain/mappers/chunk_mapper.py index e07f7e94..9b8fa7d6 100644 --- a/src/domain/mappers/chunk_mapper.py +++ b/src/domain/mappers/chunk_mapper.py @@ -38,10 +38,7 @@ def entity_to_model(entity: ChunkEntity) -> ChunkModel: # try to normalize enum names like 'YOUTUBE' to their values with suppress(Exception): data["source_type"] = SourceType[source].value - if ( - not isinstance(data.get("source_type"), str) - or data["source_type"] == source - ): + if not isinstance(data.get("source_type"), str) or data["source_type"] == source: with suppress(Exception): data["source_type"] = SourceType(source).value return ChunkModel(**data) @@ -69,9 +66,7 @@ def document_to_model(document: Document) -> ChunkModel: """ metadata = dict(getattr(document, "metadata", {}) or {}) # prefer page_content, fall back to content - content = getattr(document, "page_content", None) or getattr( - document, "content", None - ) + content = getattr(document, "page_content", None) or getattr(document, "content", None) data = metadata.copy() if content is not None: data["content"] = content diff --git a/src/domain/mappers/content_source_mapper.py b/src/domain/mappers/content_source_mapper.py index 40d173cb..3a67695b 100644 --- a/src/domain/mappers/content_source_mapper.py +++ b/src/domain/mappers/content_source_mapper.py @@ -25,21 +25,15 @@ def model_to_entity( created_at=cast(datetime, getattr(model, "created_at")), ingested_at=cast(Optional[datetime], getattr(model, "ingested_at", None)), processing_status=cast(str, getattr(model, "processing_status", "pending")), - embedding_model=cast( - Optional[str], getattr(model, "embedding_model", None) - ), + embedding_model=cast(Optional[str], getattr(model, "embedding_model", None)), dimensions=cast(Optional[int], getattr(model, "dimensions", None)), total_tokens=cast(Optional[int], getattr(model, "total_tokens", None)), - max_tokens_per_chunk=cast( - Optional[int], getattr(model, "max_tokens_per_chunk", None) - ), + max_tokens_per_chunk=cast(Optional[int], getattr(model, "max_tokens_per_chunk", None)), status=cast(str, getattr(model, "status", "active")), chunks=cast(int, getattr(model, "chunks", 0)), status_message=cast(Optional[str], getattr(model, "status_message", None)), error_message=cast(Optional[str], getattr(model, "error_message", None)), - source_metadata=cast( - Optional[dict], getattr(model, "source_metadata", None) - ), + source_metadata=cast(Optional[dict], getattr(model, "source_metadata", None)), ) @staticmethod diff --git a/src/domain/mappers/ingestion_job_mapper.py b/src/domain/mappers/ingestion_job_mapper.py index 3f638d73..ea602ff5 100644 --- a/src/domain/mappers/ingestion_job_mapper.py +++ b/src/domain/mappers/ingestion_job_mapper.py @@ -35,9 +35,7 @@ def model_to_entity( return IngestionJobEntity( id=cast(UUID, getattr(model, "id")), - content_source_id=cast( - Optional[UUID], getattr(model, "content_source_id", None) - ), + content_source_id=cast(Optional[UUID], getattr(model, "content_source_id", None)), started_at=cast(datetime, getattr(model, "started_at")), created_at=cast(datetime, getattr(model, "created_at")), finished_at=cast(Optional[datetime], getattr(model, "finished_at", None)), @@ -49,15 +47,9 @@ def model_to_entity( ingestion_type=cast(Optional[str], getattr(model, "ingestion_type", None)), source_title=source_title, chunks_count=cast(Optional[int], getattr(model, "chunks_count", None)), - embedding_model=cast( - Optional[str], getattr(model, "embedding_model", None) - ), - pipeline_version=cast( - Optional[str], getattr(model, "pipeline_version", None) - ), - external_source=cast( - Optional[str], getattr(model, "external_source", None) - ), + embedding_model=cast(Optional[str], getattr(model, "embedding_model", None)), + pipeline_version=cast(Optional[str], getattr(model, "pipeline_version", None)), + external_source=cast(Optional[str], getattr(model, "external_source", None)), subject_id=cast(Optional[UUID], subject_id), ) @@ -66,8 +58,6 @@ def model_list_to_entities( models: List[IngestionJobModel], ) -> List[IngestionJobEntity]: temp = [ - IngestionJobMapper.model_to_entity(o) - for o in models - if o is not None and isinstance(o, IngestionJobModel) + IngestionJobMapper.model_to_entity(o) for o in models if o is not None and isinstance(o, IngestionJobModel) ] return [r for r in temp if r is not None] diff --git a/src/domain/mappers/knowledge_subject_mapper.py b/src/domain/mappers/knowledge_subject_mapper.py index 46da7e1c..c853cdd4 100644 --- a/src/domain/mappers/knowledge_subject_mapper.py +++ b/src/domain/mappers/knowledge_subject_mapper.py @@ -31,7 +31,5 @@ def model_to_entity( def model_list_to_entities( models: List[KnowledgeSubjectModel], ) -> List[KnowledgeSubjectEntity]: - temp = [ - KnowledgeSubjectMapper.model_to_entity(m) for m in models if m is not None - ] + temp = [KnowledgeSubjectMapper.model_to_entity(m) for m in models if m is not None] return [r for r in temp if r is not None] diff --git a/src/infrastructure/extractors/crawl4ai_extractor.py b/src/infrastructure/extractors/crawl4ai_extractor.py index 12114c4b..eac872ca 100644 --- a/src/infrastructure/extractors/crawl4ai_extractor.py +++ b/src/infrastructure/extractors/crawl4ai_extractor.py @@ -68,9 +68,7 @@ async def extract(self, source: str, **kwargs: Any) -> List[Document]: # 4. Multi-depth crawl if requested depth = kwargs.get("depth", 1) if depth > 1: - sub_docs = await self._crawl_sub_pages( - crawler, result, source, run_config, kwargs - ) + sub_docs = await self._crawl_sub_pages(crawler, result, source, run_config, kwargs) documents.extend(sub_docs) return documents @@ -82,19 +80,12 @@ def _build_run_config(self, kwargs: Dict[str, Any]) -> CrawlerRunConfig: return CrawlerRunConfig( css_selector=kwargs.get("css_selector"), word_count_threshold=kwargs.get("word_count_threshold", 200), - cache_mode=CacheMode.BYPASS - if kwargs.get("bypass_cache", True) - else CacheMode.ENABLED, + cache_mode=CacheMode.BYPASS if kwargs.get("bypass_cache", True) else CacheMode.ENABLED, ) - async def _handle_crawl_failure( - self, result: Any, source: str, kwargs: Dict[str, Any] - ) -> List[Document]: + async def _handle_crawl_failure(self, result: Any, source: str, kwargs: Dict[str, Any]) -> List[Document]: # FALLBACK: If blocked by anti-bot or structural error (common for PDFs), try direct download - if ( - "Blocked by anti-bot" in result.error_message - or "Structural" in result.error_message - ): + if "Blocked by anti-bot" in result.error_message or "Structural" in result.error_message: logger.warning( "Crawl4AI blocked or failed structurally, trying direct download fallback", context={"url": source, "error": result.error_message}, @@ -115,9 +106,7 @@ def _clean_markdown(self, text: str, exclude_links: bool) -> str: return text return re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) - def _build_main_documents( - self, result: Any, source: str, exclude_links: bool - ) -> List[Document]: + def _build_main_documents(self, result: Any, source: str, exclude_links: bool) -> List[Document]: main_markdown = self._clean_markdown(result.markdown, exclude_links) metadata = { "source": source, @@ -148,11 +137,7 @@ async def _crawl_sub_pages( ) internal_links = result.links.get("internal", []) sub_urls = list( - { - link["href"] - for link in internal_links - if link.get("href") and link["href"].startswith("http") - } + {link["href"] for link in internal_links if link.get("href") and link["href"].startswith("http")} ) if not sub_urls: @@ -164,18 +149,14 @@ async def _crawl_sub_pages( if not sub_urls: return [] - sub_results = await crawler.arun_many( - urls=sub_urls, config=run_config, concurrency_count=concurrency_count - ) + sub_results = await crawler.arun_many(urls=sub_urls, config=run_config, concurrency_count=concurrency_count) documents = [] for sub_res in sub_results: if sub_res.success: sub_metadata = { "source": sub_res.url, - "title": sub_res.metadata.get("title", "") - if sub_res.metadata - else "", + "title": sub_res.metadata.get("title", "") if sub_res.metadata else "", "source_type": "web", "scraper": "crawl4ai", "status_code": sub_res.status_code, @@ -184,9 +165,7 @@ async def _crawl_sub_pages( } documents.append( Document( - page_content=self._clean_markdown( - sub_res.markdown, exclude_links - ), + page_content=self._clean_markdown(sub_res.markdown, exclude_links), metadata=sub_metadata, ) ) @@ -214,9 +193,7 @@ async def _extract_pdf_directly(self, url: str, **kwargs: Any) -> List[Document] # Basic Content-Type check to confirm it's actually a PDF content_type = response.headers.get("Content-Type", "").lower() - if "application/pdf" not in content_type and not url.lower().endswith( - ".pdf" - ): + if "application/pdf" not in content_type and not url.lower().endswith(".pdf"): logger.warning( "URL might not be a PDF", context={"url": url, "content_type": content_type}, @@ -225,9 +202,7 @@ async def _extract_pdf_directly(self, url: str, **kwargs: Any) -> List[Document] # Fix: Use anyio.to_thread.run_sync for synchronous tempfile operations if needed, # or just create a path and use anyio for I/O. def create_temp_file(): - with tempfile.NamedTemporaryFile( - delete=False, suffix=".pdf" - ) as tmp: + with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: return tmp.name tmp_path = await anyio.to_thread.run_sync(create_temp_file) diff --git a/src/infrastructure/extractors/docling_extractor.py b/src/infrastructure/extractors/docling_extractor.py index 4132656f..5a53d552 100644 --- a/src/infrastructure/extractors/docling_extractor.py +++ b/src/infrastructure/extractors/docling_extractor.py @@ -25,9 +25,7 @@ def _get_pipeline_options(self, do_ocr: bool = False) -> PdfPipelineOptions: pipeline_options.do_ocr = do_ocr if settings.app.device == "cpu": - pipeline_options.accelerator_options.num_threads = ( - settings.docling.cpu_num_threads - ) + pipeline_options.accelerator_options.num_threads = settings.docling.cpu_num_threads return pipeline_options def __init__(self): @@ -86,9 +84,7 @@ def extract(self, file_path: str, do_ocr: bool = False) -> List[Document]: return [doc] except Exception as e: - logger.error( - f"Error extracting with Docling: {e}", context={"file_path": file_path} - ) + logger.error(f"Error extracting with Docling: {e}", context={"file_path": file_path}) raise ValueError(f"Failed to extract content from {file_path}: {str(e)}") def _convert_document(self, file_path: str, do_ocr: bool) -> Any: @@ -142,32 +138,20 @@ def _get_extension(self, origin_filename: str, file_path: str) -> str: def _get_document_stats(self, document: Any) -> dict: return { "num_pages": len(document.pages) if hasattr(document, "pages") else 0, - "num_pictures": len(document.pictures) - if hasattr(document, "pictures") - else 0, + "num_pictures": len(document.pictures) if hasattr(document, "pictures") else 0, "num_tables": len(document.tables) if hasattr(document, "tables") else 0, "num_groups": len(document.groups) if hasattr(document, "groups") else 0, "texts_count": len(document.texts) if hasattr(document, "texts") else 0, - "key_value_items_count": len(document.key_value_items) - if hasattr(document, "key_value_items") - else 0, - "form_items_count": len(document.form_items) - if hasattr(document, "form_items") - else 0, - "field_items_count": len(document.field_items) - if hasattr(document, "field_items") - else 0, + "key_value_items_count": len(document.key_value_items) if hasattr(document, "key_value_items") else 0, + "form_items_count": len(document.form_items) if hasattr(document, "form_items") else 0, + "field_items_count": len(document.field_items) if hasattr(document, "field_items") else 0, } def _enrich_with_doc_meta(self, result: Any, metadata: dict) -> None: doc_meta = None if hasattr(result.document, "meta"): doc_meta = result.document.meta - elif ( - hasattr(result, "input") - and hasattr(result.input, "document") - and hasattr(result.input.document, "meta") - ): + elif hasattr(result, "input") and hasattr(result.input, "document") and hasattr(result.input.document, "meta"): doc_meta = result.input.document.meta if doc_meta: diff --git a/src/infrastructure/extractors/models/youtube_metadata_dto.py b/src/infrastructure/extractors/models/youtube_metadata_dto.py index a8499d55..9741eb5b 100644 --- a/src/infrastructure/extractors/models/youtube_metadata_dto.py +++ b/src/infrastructure/extractors/models/youtube_metadata_dto.py @@ -7,21 +7,13 @@ class YoutubeMetadataDTO(BaseModel): model_config = ConfigDict(populate_by_name=True, extra="ignore") video_id: Optional[str] = Field(default=None, description="ID do vídeo do YouTube") - original_url: Optional[str] = Field( - default=None, description="URL original do vídeo" - ) + original_url: Optional[str] = Field(default=None, description="URL original do vídeo") title: Optional[str] = Field(default=None, description="Título do vídeo") - full_title: Optional[str] = Field( - default=None, description="Título completo do vídeo", alias="fulltitle" - ) + full_title: Optional[str] = Field(default=None, description="Título completo do vídeo", alias="fulltitle") description: Optional[str] = Field(default=None, description="Descrição do vídeo") duration: Optional[int] = Field(default=None, description="Duração em segundos") - duration_string: Optional[str] = Field( - default=None, description="Duração formatada" - ) - categories: Optional[List[str]] = Field( - default=None, description="Categorias do vídeo" - ) + duration_string: Optional[str] = Field(default=None, description="Duração formatada") + categories: Optional[List[str]] = Field(default=None, description="Categorias do vídeo") tags: Optional[List[str]] = Field(default=None, description="Tags do vídeo") channel: Optional[str] = Field(default=None, description="Nome do canal") channel_id: Optional[str] = Field(default=None, description="ID do canal") diff --git a/src/infrastructure/extractors/plain_text_extractor.py b/src/infrastructure/extractors/plain_text_extractor.py index 251c149f..b210b710 100644 --- a/src/infrastructure/extractors/plain_text_extractor.py +++ b/src/infrastructure/extractors/plain_text_extractor.py @@ -42,9 +42,7 @@ def extract(self, file_path_or_url: str) -> List[Document]: def _extract_from_url(self, url: str) -> List[Document]: logger.info("Extracting plain text from URL", context={"url": url}) try: - with httpx.Client( - follow_redirects=True, headers=self.headers, timeout=self.timeout - ) as client: + with httpx.Client(follow_redirects=True, headers=self.headers, timeout=self.timeout) as client: response = client.get(url) response.raise_for_status() content = response.text @@ -63,9 +61,7 @@ def _extract_from_url(self, url: str) -> List[Document]: raise ValueError(f"Failed to download content from {url}: {str(e)}") def _extract_from_local(self, file_path: str) -> List[Document]: - logger.info( - "Extracting plain text from local file", context={"file_path": file_path} - ) + logger.info("Extracting plain text from local file", context={"file_path": file_path}) if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") diff --git a/src/infrastructure/extractors/youtube_extractor.py b/src/infrastructure/extractors/youtube_extractor.py index 529f3f60..87e68e86 100644 --- a/src/infrastructure/extractors/youtube_extractor.py +++ b/src/infrastructure/extractors/youtube_extractor.py @@ -37,9 +37,7 @@ class YoutubeExtractor(IYoutubeExtractor): def __init__(self, video_id: str | None = None, language: str = "pt"): self.video_id = video_id - self.video_url = ( - f"https://www.youtube.com/watch?v={video_id}" if video_id else None - ) + self.video_url = f"https://www.youtube.com/watch?v={video_id}" if video_id else None self.language = language self._ffmpeg_path = imageio_ffmpeg.get_ffmpeg_exe() @@ -69,9 +67,7 @@ def _get_common_ydl_opts(self, quiet: bool = True) -> dict: "retry_sleep_functions": {"http": lambda n: 5 * 2**n}, # Use multiple clients to avoid "Sign in to confirm you're not a bot" # mediaconnect is often more resilient for servers - "extractor_args": { - "youtube": {"player_client": ["mediaconnect", "web", "mweb", "android"]} - }, + "extractor_args": {"youtube": {"player_client": ["mediaconnect", "web", "mweb", "android"]}}, } # Use cookies if provided by user in data/cookies.txt @@ -89,7 +85,9 @@ def _get_common_ydl_opts(self, quiet: bool = True) -> dict: and settings.youtube.webshare_password and settings.youtube.webshare_password.strip() ): - proxy = f"http://{settings.youtube.webshare_username}:{settings.youtube.webshare_password}@p.webshare.io:80" + proxy = ( + f"http://{settings.youtube.webshare_username}:{settings.youtube.webshare_password}@p.webshare.io:80" + ) if proxy: opts["proxy"] = proxy @@ -106,17 +104,12 @@ def _run_with_retry(self, action, max_retries: int = 3, initial_delay: int = 5): last_exception = e # Check for specific fatal errors that shouldn't be retried err_msg = str(e) - if any( - x in err_msg for x in ["Private video", "not available", "Sign in"] - ): + if any(x in err_msg for x in ["Private video", "not available", "Sign in"]): logger.error(f"Fatal YouTube error: {e}") raise # Check for IP blockings - if ( - "blocking requests from your IP" in err_msg - or "IP is blocked" in err_msg - ): + if "blocking requests from your IP" in err_msg or "IP is blocked" in err_msg: logger.error("YouTube IP Block detected in extractor retry loop") raise YoutubeIPBlockedException(self.video_id or "unknown", err_msg) @@ -186,9 +179,7 @@ def _extract(): ) return YoutubeMetadataDTO(video_id=self.video_id or "unknown") - def download_audio( - self, url: str, output_dir: str = "./temp_audio", quality: str = "192" - ) -> str | None: + def download_audio(self, url: str, output_dir: str = "./temp_audio", quality: str = "192") -> str | None: """Downloads and extracts audio from a YouTube video with resilience.""" os.makedirs(output_dir, exist_ok=True) @@ -216,9 +207,7 @@ def _download(): try: return self._run_with_retry(_download) except Exception as e: - logger.error( - f"Download failed after ALL retries: {e}", context={"url": url} - ) + logger.error(f"Download failed after ALL retries: {e}", context={"url": url}) return None def extract_playlist_videos(self, playlist_url: str) -> list[str]: @@ -241,9 +230,7 @@ def extract_playlist_videos(self, playlist_url: str) -> list[str]: context={"playlist_url": playlist_url, "error": str(e)}, ) - logger.info( - "Starting playlist extraction", context={"playlist_url": playlist_url} - ) + logger.info("Starting playlist extraction", context={"playlist_url": playlist_url}) def _extract(): ydl_opts = self._get_common_ydl_opts() @@ -291,9 +278,7 @@ def _extract(): return [], chan videos = self._parse_channel_entries(channel_info["entries"]) - channel_name = ( - channel_info.get("channel") or channel_info.get("uploader") or "" - ) + channel_name = channel_info.get("channel") or channel_info.get("uploader") or "" return videos, channel_name try: @@ -324,9 +309,7 @@ def _fetch_transcript_with_fallback( try: # First attempt: Try preferred languages in order - transcript = api.fetch( - video_id=self.video_id, languages=preferred_languages - ) + transcript = api.fetch(video_id=self.video_id, languages=preferred_languages) logger.debug( "Transcript fetched successfully (preferred).", context={ @@ -358,15 +341,10 @@ def _handle_transcript_error(self, error: Exception): if "This video is private" in error_msg: raise YoutubeVideoPrivateException(self.video_id or "unknown") if "unplayable" in error_msg.lower(): - raise YoutubeVideoUnplayableException( - self.video_id or "unknown", reason=error_msg - ) + raise YoutubeVideoUnplayableException(self.video_id or "unknown", reason=error_msg) # Hard Stop on IP Block - if ( - "blocking requests from your IP" in error_msg - or "IP is blocked" in error_msg - ): + if "blocking requests from your IP" in error_msg or "IP is blocked" in error_msg: logger.error("YouTube IP Block detected during transcript fetch") raise YoutubeIPBlockedException(self.video_id or "unknown", error_msg) @@ -377,17 +355,12 @@ def _handle_transcript_error(self, error: Exception): ) raise YoutubeTranscriptsDisabledException(self.video_id or "unknown") - if ( - isinstance(error, NoTranscriptFound) - or "No transcript available" in error_msg - ): + if isinstance(error, NoTranscriptFound) or "No transcript available" in error_msg: logger.error( "No transcript available for video in ANY language", context={"video_id": self.video_id, "error": error_msg}, ) - raise YoutubeTranscriptNotFoundException( - self.video_id or "unknown", self.language - ) + raise YoutubeTranscriptNotFoundException(self.video_id or "unknown", self.language) msg = f"Unexpected error while fetching transcript for video {self.video_id}: {error_msg}" logger.error( diff --git a/src/infrastructure/loggers/std_logger.py b/src/infrastructure/loggers/std_logger.py index 54975ca7..8e208444 100644 --- a/src/infrastructure/loggers/std_logger.py +++ b/src/infrastructure/loggers/std_logger.py @@ -11,9 +11,7 @@ from src.domain.interfaces.logger.logger import ILogger # Global context for logging via contextvars -_global_log_context: ContextVar[Dict[str, Any]] = ContextVar( - "global_log_context", default={} -) +_global_log_context: ContextVar[Dict[str, Any]] = ContextVar("global_log_context", default={}) def set_global_context(context: Dict[str, Any]) -> None: @@ -130,9 +128,7 @@ def get_log_record(level: str, message: str): asctime = datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] filename = os.path.basename(frame_best.filename) # Caminho relativo ao diretório do projeto - project_root = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../../..") - ) + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) filepath_abs = os.path.abspath(frame_best.filename) filepath_rel = os.path.relpath(filepath_abs, project_root).replace("\\", "/") lineno = frame_best.lineno @@ -232,33 +228,23 @@ def _log( if log_method: log_method(formatted_message) - def info( - self, message: str, context: dict[str, Any] | None = None, *args, **kwargs - ) -> None: + def info(self, message: str, context: dict[str, Any] | None = None, *args, **kwargs) -> None: """Log at INFO level.""" self._log("INFO", message, context, *args, **kwargs) - def debug( - self, message: str, context: dict[str, Any] | None = None, *args, **kwargs - ) -> None: + def debug(self, message: str, context: dict[str, Any] | None = None, *args, **kwargs) -> None: """Log at DEBUG level.""" self._log("DEBUG", message, context, *args, **kwargs) - def warning( - self, message: str, context: dict[str, Any] | None = None, *args, **kwargs - ) -> None: + def warning(self, message: str, context: dict[str, Any] | None = None, *args, **kwargs) -> None: """Log at WARNING level.""" self._log("WARNING", message, context, *args, **kwargs) - def error( - self, error: Any, context: dict[str, Any] | None = None, *args, **kwargs - ) -> None: + def error(self, error: Any, context: dict[str, Any] | None = None, *args, **kwargs) -> None: """Log at ERROR level with optional exception/context support.""" self._log("ERROR", str(error), context, *args, **kwargs) - def critical( - self, message: str, context: dict[str, Any] | None = None, *args, **kwargs - ) -> None: + def critical(self, message: str, context: dict[str, Any] | None = None, *args, **kwargs) -> None: """Log at CRITICAL level.""" self._log("CRITICAL", message, context, *args, **kwargs) diff --git a/src/infrastructure/repositories/sql/chunk_index_repository.py b/src/infrastructure/repositories/sql/chunk_index_repository.py index 864ca304..ea4f4e2e 100644 --- a/src/infrastructure/repositories/sql/chunk_index_repository.py +++ b/src/infrastructure/repositories/sql/chunk_index_repository.py @@ -50,9 +50,7 @@ def create_chunks(self, chunks: List[dict]) -> List[UUID]: # Update ContentSource count (increment by the number of chunks added) if orm_objs: # Collect content_source_ids to update - source_ids = { - o.content_source_id for o in orm_objs if o.content_source_id - } + source_ids = {o.content_source_id for o in orm_objs if o.content_source_id} for sid in source_ids: count = sum(1 for o in orm_objs if o.content_source_id == sid) session.query(ContentSourceModel).filter_by(id=sid).update( @@ -60,16 +58,12 @@ def create_chunks(self, chunks: List[dict]) -> List[UUID]: ) session.commit() - logger.debug( - "Created chunk index rows", context={"count": len(orm_objs)} - ) + logger.debug("Created chunk index rows", context={"count": len(orm_objs)}) return [cast(UUID, o.id) for o in orm_objs] except Exception as e: session.rollback() - logger.error( - "Error creating chunk index rows", context={"error": str(e)} - ) + logger.error("Error creating chunk index rows", context={"error": str(e)}) raise def list_by_content_source( @@ -101,9 +95,7 @@ def list_chunks( ) -> List[ChunkIndexModel]: source_id = ensure_uuid(source_id) with Connector() as session: - query = session.query(ChunkIndexModel).options( - joinedload(ChunkIndexModel.content_source) - ) + query = session.query(ChunkIndexModel).options(joinedload(ChunkIndexModel.content_source)) if source_id: query = query.filter_by(content_source_id=source_id).order_by( @@ -120,11 +112,7 @@ def list_chunks( def count_by_content_source(self, content_source_id: Any) -> int: content_source_id = ensure_uuid(content_source_id) with Connector() as session: - return ( - session.query(ChunkIndexModel) - .filter_by(content_source_id=content_source_id) - .count() - ) + return session.query(ChunkIndexModel).filter_by(content_source_id=content_source_id).count() def delete_by_content_source(self, content_source_id: Any) -> int: content_source_id = ensure_uuid(content_source_id) @@ -137,17 +125,13 @@ def delete_by_content_source(self, content_source_id: Any) -> int: ) # Update ContentSource count to 0 - session.query(ContentSourceModel).filter_by( - id=content_source_id - ).update({"chunks": 0}) + session.query(ContentSourceModel).filter_by(id=content_source_id).update({"chunks": 0}) session.commit() return int(deleted) except Exception as e: session.rollback() - logger.error( - "Error deleting chunk index rows", context={"error": str(e)} - ) + logger.error("Error deleting chunk index rows", context={"error": str(e)}) raise def delete_by_job_id(self, job_id: Any) -> int: @@ -156,24 +140,14 @@ def delete_by_job_id(self, job_id: Any) -> int: with Connector() as session: try: # 1. Find the content_source_id before deleting - chunks = ( - session.query(ChunkIndexModel.content_source_id) - .filter_by(job_id=job_id) - .all() - ) + chunks = session.query(ChunkIndexModel.content_source_id).filter_by(job_id=job_id).all() if not chunks: return 0 - source_ids = { - c.content_source_id for c in chunks if c.content_source_id - } + source_ids = {c.content_source_id for c in chunks if c.content_source_id} # 2. Delete the chunks - deleted = ( - session.query(ChunkIndexModel) - .filter_by(job_id=job_id) - .delete(synchronize_session=False) - ) + deleted = session.query(ChunkIndexModel).filter_by(job_id=job_id).delete(synchronize_session=False) # 3. Update ContentSource counts (decrement by the number of chunks deleted) if deleted > 0: @@ -197,9 +171,7 @@ def delete_by_job_id(self, job_id: Any) -> int: ) raise - def search( - self, query: Optional[str], top_k: int = 10, filters: Optional[Any] = None - ) -> List[ChunkIndexModel]: + def search(self, query: Optional[str], top_k: int = 10, filters: Optional[Any] = None) -> List[ChunkIndexModel]: with Connector() as session: q = ( session.query(ChunkIndexModel) @@ -239,9 +211,9 @@ def delete_chunk(self, chunk_id: Any) -> bool: session.delete(chunk) # 3. Decrement count in ContentSource - session.query(ContentSourceModel).filter_by( - id=content_source_id - ).update({"chunks": ContentSourceModel.chunks - 1}) + session.query(ContentSourceModel).filter_by(id=content_source_id).update( + {"chunks": ContentSourceModel.chunks - 1} + ) session.commit() return True diff --git a/src/infrastructure/repositories/sql/content_source_repository.py b/src/infrastructure/repositories/sql/content_source_repository.py index 726e432b..6db68951 100644 --- a/src/infrastructure/repositories/sql/content_source_repository.py +++ b/src/infrastructure/repositories/sql/content_source_repository.py @@ -77,9 +77,7 @@ def create( return cast(UUID, cs.id) except Exception as e: - logger.error( - "Error creating ContentSource", context={**extra, "error": str(e)} - ) + logger.error("Error creating ContentSource", context={**extra, "error": str(e)}) session.rollback() raise @@ -101,9 +99,7 @@ def get_by_id(self, cs_id: Any) -> Optional[ContentSourceModel]: ) raise - def get_by_diarization_id( - self, diarization_id: str - ) -> Optional[ContentSourceModel]: + def get_by_diarization_id(self, diarization_id: str) -> Optional[ContentSourceModel]: """Find a ContentSource that has this diarization_id in its source_metadata JSON.""" with Connector() as session: try: @@ -121,13 +117,9 @@ def get_by_diarization_id( result = ( session.query(ContentSourceModel) .filter( - cast( - ContentSourceModel.source_metadata["diarization_id"], String - ) - == f'"{diarization_id}"' + cast(ContentSourceModel.source_metadata["diarization_id"], String) == f'"{diarization_id}"' if session.bind.dialect.name == "sqlite" - else ContentSourceModel.source_metadata["diarization_id"].astext - == diarization_id + else ContentSourceModel.source_metadata["diarization_id"].astext == diarization_id ) .first() ) @@ -161,9 +153,7 @@ def get_by_source_info( result = query.order_by(ContentSourceModel.created_at.desc()).all() - logger.debug( - "Fetch successful", context={**extra, "count": len(result)} - ) + logger.debug("Fetch successful", context={**extra, "count": len(result)}) return result except Exception as e: logger.error( @@ -204,16 +194,12 @@ def list_by_subject( ) raise - def list( - self, limit: Optional[int] = None, offset: Optional[int] = None - ) -> List[ContentSourceModel]: + def list(self, limit: Optional[int] = None, offset: Optional[int] = None) -> List[ContentSourceModel]: with Connector() as session: try: extra = {"limit": limit, "offset": offset} logger.debug("Listing all ContentSources", context=extra) - query = session.query(ContentSourceModel).order_by( - ContentSourceModel.created_at.desc() - ) + query = session.query(ContentSourceModel).order_by(ContentSourceModel.created_at.desc()) if offset is not None: query = query.offset(offset) @@ -236,11 +222,7 @@ def count_by_subject(self, subject_id: Any) -> int: try: extra = {"subject_id": subject_id} logger.debug("Counting ContentSources by subject ID", context=extra) - result = ( - session.query(ContentSourceModel) - .filter_by(subject_id=subject_id) - .count() - ) + result = session.query(ContentSourceModel).filter_by(subject_id=subject_id).count() logger.debug("Count successful", context={**extra, "count": result}) return result except Exception as e: @@ -260,9 +242,7 @@ def update_status( with Connector() as session: try: extra = {"content_source_id": content_source_id, "status": status} - logger.debug( - "Updating processing status for ContentSource", context=extra - ) + logger.debug("Updating processing status for ContentSource", context=extra) cs = session.get(ContentSourceModel, content_source_id) if cs is None: logger.warning("ContentSource not found for update", context=extra) @@ -289,9 +269,7 @@ def update_title(self, content_source_id: UUID, title: str) -> None: logger.debug("Updating title for ContentSource", context=extra) cs = session.get(ContentSourceModel, content_source_id) if cs is None: - logger.warning( - "ContentSource not found for title update", context=extra - ) + logger.warning("ContentSource not found for title update", context=extra) return cs.title = title session.commit() @@ -329,9 +307,7 @@ def finish_ingestion( logger.debug("Finishing ingestion for ContentSource", context=extra) cs = session.get(ContentSourceModel, content_source_id) if cs is None: - logger.warning( - "ContentSource not found for finishing ingestion", context=extra - ) + logger.warning("ContentSource not found for finishing ingestion", context=extra) return # Explicitly update processing_status to 'done' @@ -360,9 +336,7 @@ def finish_ingestion( session.rollback() raise - def list_external_sources_by_subject( - self, subject_id: Any, source_type: str - ) -> List[str]: + def list_external_sources_by_subject(self, subject_id: Any, source_type: str) -> List[str]: """Return all external_source values for a given subject and source_type. Optimized query that only fetches the external_source column. @@ -394,9 +368,7 @@ def update_metadata(self, content_source_id: UUID, metadata: dict) -> None: logger.debug("Updating metadata for ContentSource", context=extra) cs = session.get(ContentSourceModel, content_source_id) if cs is None: - logger.warning( - "ContentSource not found for metadata update", context=extra - ) + logger.warning("ContentSource not found for metadata update", context=extra) return # Merge existing metadata with new metadata current = dict(cs.source_metadata or {}) @@ -420,9 +392,7 @@ def delete(self, content_source_id: UUID) -> bool: logger.debug("Deleting ContentSource", context=extra) cs = session.get(ContentSourceModel, content_source_id) if cs is None: - logger.warning( - "ContentSource not found for deletion", context=extra - ) + logger.warning("ContentSource not found for deletion", context=extra) return False session.delete(cs) session.commit() diff --git a/src/infrastructure/repositories/sql/diarization_repository.py b/src/infrastructure/repositories/sql/diarization_repository.py index fc4b2837..26baf739 100644 --- a/src/infrastructure/repositories/sql/diarization_repository.py +++ b/src/infrastructure/repositories/sql/diarization_repository.py @@ -30,9 +30,7 @@ def create_pending( language=language, status=DiarizationStatus.PENDING.value, model_size=model_size, - subject_id=UUID(subject_id) - if subject_id and isinstance(subject_id, str) - else subject_id, + subject_id=UUID(subject_id) if subject_id and isinstance(subject_id, str) else subject_id, ) self.db.add(record) self.db.commit() @@ -129,37 +127,22 @@ def get_all( parsed_id = UUID(subject_id) if isinstance(subject_id, str) else subject_id query = query.filter(DiarizationRecord.subject_id == parsed_id) - result = ( - query.order_by(DiarizationRecord.created_at.desc()) - .offset(offset) - .limit(limit) - .all() - ) + result = query.order_by(DiarizationRecord.created_at.desc()).offset(offset).limit(limit).all() return cast(List[DiarizationRecord], cast(object, result)) def get_by_id(self, diarization_id: str) -> Optional[DiarizationRecord]: - result = ( - self.db.query(DiarizationRecord) - .filter(DiarizationRecord.id == diarization_id) - .first() - ) + result = self.db.query(DiarizationRecord).filter(DiarizationRecord.id == diarization_id).first() return cast(Optional[DiarizationRecord], result) def delete(self, diarization_id: str) -> bool: - record = ( - self.db.query(DiarizationRecord) - .filter(DiarizationRecord.id == diarization_id) - .first() - ) + record = self.db.query(DiarizationRecord).filter(DiarizationRecord.id == diarization_id).first() if not record: return False self.db.delete(record) self.db.commit() return True - def update_recognition_results( - self, diarization_id: str, recognition_results: dict - ) -> Optional[DiarizationRecord]: + def update_recognition_results(self, diarization_id: str, recognition_results: dict) -> Optional[DiarizationRecord]: record = self.get_by_id(diarization_id) if not record: return None @@ -168,9 +151,7 @@ def update_recognition_results( self.db.refresh(record) return record - def reset_for_reprocessing( - self, diarization_id: str - ) -> Optional[DiarizationRecord]: + def reset_for_reprocessing(self, diarization_id: str) -> Optional[DiarizationRecord]: """Resets the record for a new diarization run.""" record = self.get_by_id(diarization_id) if not record: diff --git a/src/infrastructure/repositories/sql/ingestion_job_repository.py b/src/infrastructure/repositories/sql/ingestion_job_repository.py index 4e0dcc18..4c216103 100644 --- a/src/infrastructure/repositories/sql/ingestion_job_repository.py +++ b/src/infrastructure/repositories/sql/ingestion_job_repository.py @@ -59,15 +59,11 @@ def create_job( session.add(job) session.commit() session.refresh(job) - logger.debug( - "Ingestion job created successfully", context={"job_id": job.id} - ) + logger.debug("Ingestion job created successfully", context={"job_id": job.id}) return cast(UUID, job.id) except Exception as e: - logger.error( - "Error creating ingestion job", context={**extra, "error": str(e)} - ) + logger.error("Error creating ingestion job", context={**extra, "error": str(e)}) session.rollback() raise @@ -136,9 +132,7 @@ def update_job( session.commit() logger.debug("Ingestion job updated successfully", context=extra) except Exception as e: - logger.error( - "Error updating ingestion job", context={**extra, "error": str(e)} - ) + logger.error("Error updating ingestion job", context={**extra, "error": str(e)}) session.rollback() raise @@ -175,9 +169,7 @@ def link_job_to_source( job.ingestion_type = ingestion_type session.commit() else: - logger.warning( - "Job not found for linking", context={"job_id": job_id} - ) + logger.warning("Job not found for linking", context={"job_id": job_id}) except Exception as e: logger.error( "Error linking job to source", @@ -228,9 +220,7 @@ def get_by_id(self, job_id: Any) -> Optional[IngestionJobModel]: ) raise - def list_recent_jobs( - self, limit: int = 50, offset: int = 0 - ) -> List[IngestionJobModel]: + def list_recent_jobs(self, limit: int = 50, offset: int = 0) -> List[IngestionJobModel]: """Backward compatible list_recent_jobs with offset support.""" return self.list_jobs(limit=limit, offset=offset) @@ -252,27 +242,19 @@ def list_jobs( "search": search, }, ) - query = session.query(IngestionJobModel).options( - joinedload(IngestionJobModel.content_source) - ) + query = session.query(IngestionJobModel).options(joinedload(IngestionJobModel.content_source)) if status: if status == "processing": - query = query.filter( - IngestionJobModel.status.in_(["processing", "started"]) - ) + query = query.filter(IngestionJobModel.status.in_(["processing", "started"])) elif status == "completed": - query = query.filter( - IngestionJobModel.status.in_(["done", "finished"]) - ) + query = query.filter(IngestionJobModel.status.in_(["done", "finished"])) elif status == "failed": # Exclude Duplicates from Failed query = query.filter( IngestionJobModel.status.in_(["failed", "error"]), (IngestionJobModel.error_message.is_(None)) - | ( - ~IngestionJobModel.error_message.ilike(DUPLICATE_FILTER) - ), + | (~IngestionJobModel.error_message.ilike(DUPLICATE_FILTER)), ) elif status == "cancelled": # Include Duplicates in Cancelled @@ -291,33 +273,22 @@ def list_jobs( | (IngestionJobModel.external_source.ilike(search_term)) ) - result = ( - query.order_by(IngestionJobModel.created_at.desc()) - .limit(limit) - .offset(offset) - .all() - ) + result = query.order_by(IngestionJobModel.created_at.desc()).limit(limit).offset(offset).all() return result except Exception as e: logger.error("Error listing ingestion jobs", context={"error": str(e)}) raise - def count_jobs( - self, status: Optional[str] = None, search: Optional[str] = None - ) -> int: + def count_jobs(self, status: Optional[str] = None, search: Optional[str] = None) -> int: with Connector() as session: try: query = session.query(IngestionJobModel) if status: if status == "processing": - query = query.filter( - IngestionJobModel.status.in_(["processing", "started"]) - ) + query = query.filter(IngestionJobModel.status.in_(["processing", "started"])) elif status == "completed": - query = query.filter( - IngestionJobModel.status.in_(["done", "finished"]) - ) + query = query.filter(IngestionJobModel.status.in_(["done", "finished"])) elif status == "failed": # Exclude Duplicates from Failed query = query.filter( @@ -361,29 +332,19 @@ def get_status_counts(self, search: Optional[str] = None) -> dict: ) total = base_query.count() - processing = base_query.filter( - IngestionJobModel.status.in_(["processing", "started"]) - ).count() - completed = base_query.filter( - IngestionJobModel.status.in_(["done", "finished"]) - ).count() + processing = base_query.filter(IngestionJobModel.status.in_(["processing", "started"])).count() + completed = base_query.filter(IngestionJobModel.status.in_(["done", "finished"])).count() # Treat "Duplicate" errors as CANCELLED - duplicate_filter = IngestionJobModel.error_message.ilike( - DUPLICATE_FILTER - ) - not_duplicate_filter = (IngestionJobModel.error_message.is_(None)) | ( - ~duplicate_filter - ) + duplicate_filter = IngestionJobModel.error_message.ilike(DUPLICATE_FILTER) + not_duplicate_filter = (IngestionJobModel.error_message.is_(None)) | (~duplicate_filter) failed = base_query.filter( IngestionJobModel.status.in_(["failed", "error"]), not_duplicate_filter, ).count() - cancelled = base_query.filter( - (IngestionJobModel.status == "cancelled") | duplicate_filter - ).count() + cancelled = base_query.filter((IngestionJobModel.status == "cancelled") | duplicate_filter).count() return { "total": total, @@ -396,9 +357,7 @@ def get_status_counts(self, search: Optional[str] = None) -> dict: logger.error("Error getting status counts", context={"error": str(e)}) raise - def list_recent_jobs_by_subject( - self, subject_id: Any, limit: int = 50, offset: int = 0 - ) -> List[IngestionJobModel]: + def list_recent_jobs_by_subject(self, subject_id: Any, limit: int = 50, offset: int = 0) -> List[IngestionJobModel]: subject_id = ensure_uuid(subject_id) from src.infrastructure.repositories.sql.models.content_source import ( ContentSourceModel, @@ -436,14 +395,8 @@ def list_by_content_source(self, content_source_id: Any) -> List[IngestionJobMod with Connector() as session: try: extra = {"content_source_id": content_source_id} - logger.debug( - "Listing ingestion jobs by content source ID", context=extra - ) - result = ( - session.query(IngestionJobModel) - .filter_by(content_source_id=content_source_id) - .all() - ) + logger.debug("Listing ingestion jobs by content source ID", context=extra) + result = session.query(IngestionJobModel).filter_by(content_source_id=content_source_id).all() logger.debug("List successful", context={**extra, "count": len(result)}) return result except Exception as e: @@ -453,9 +406,7 @@ def list_by_content_source(self, content_source_id: Any) -> List[IngestionJobMod ) raise - def mark_previous_jobs_as_reprocessed( - self, content_source_id: Any, current_job_id: Any - ) -> int: + def mark_previous_jobs_as_reprocessed(self, content_source_id: Any, current_job_id: Any) -> int: content_source_id = ensure_uuid(content_source_id) current_job_id = ensure_uuid(current_job_id) """Mark all previous jobs for a content source as REPROCESSED.""" diff --git a/src/infrastructure/repositories/sql/knowledge_subject_repository.py b/src/infrastructure/repositories/sql/knowledge_subject_repository.py index 49579505..5d38a616 100644 --- a/src/infrastructure/repositories/sql/knowledge_subject_repository.py +++ b/src/infrastructure/repositories/sql/knowledge_subject_repository.py @@ -43,9 +43,7 @@ def create_subject( session.add(ks) session.commit() session.refresh(ks) - logger.debug( - "KnowledgeSubject created successfully", context={"id": ks.id} - ) + logger.debug("KnowledgeSubject created successfully", context={"id": ks.id}) return cast(UUID, ks.id) except Exception as e: @@ -122,9 +120,7 @@ def list(self, limit: int = 100) -> List[KnowledgeSubjectModel]: .limit(limit) .all() ) - logger.debug( - "List successful", context={"limit": limit, "count": len(result)} - ) + logger.debug("List successful", context={"limit": limit, "count": len(result)}) return result except Exception as e: logger.error( @@ -155,9 +151,7 @@ def update( ) ks = session.get(KnowledgeSubjectModel, id) if ks is None: - logger.warning( - "KnowledgeSubject not found for update", context={"id": id} - ) + logger.warning("KnowledgeSubject not found for update", context={"id": id}) return if name is not None: ks.name = name @@ -168,9 +162,7 @@ def update( if icon is not None: ks.icon = icon session.commit() - logger.debug( - "KnowledgeSubject updated successfully", context={"id": id} - ) + logger.debug("KnowledgeSubject updated successfully", context={"id": id}) except Exception as e: logger.error( "Error updating KnowledgeSubject", @@ -210,18 +202,14 @@ def delete(self, id: UUID) -> int: def get_by_name(self, name) -> Optional[KnowledgeSubjectModel]: with Connector() as session: try: - logger.debug( - "Fetching KnowledgeSubject by name", context={"name": name} - ) + logger.debug("Fetching KnowledgeSubject by name", context={"name": name}) result = ( session.query(KnowledgeSubjectModel) .options(selectinload(KnowledgeSubjectModel.content_sources)) .filter_by(name=name) .first() ) - logger.debug( - "Fetch successful", context={"name": name, "result": result} - ) + logger.debug("Fetch successful", context={"name": name, "result": result}) return result except Exception as e: logger.error( diff --git a/src/infrastructure/repositories/sql/models/chunk_index.py b/src/infrastructure/repositories/sql/models/chunk_index.py index 734bda10..6ad883f1 100644 --- a/src/infrastructure/repositories/sql/models/chunk_index.py +++ b/src/infrastructure/repositories/sql/models/chunk_index.py @@ -51,9 +51,7 @@ class ChunkIndexModel(Base): extra = Column(JSON, nullable=True) version_number = Column(Integer, nullable=False, server_default=text("1")) vector_store_type = Column(Text, nullable=True) - created_at = Column( - DateTime(timezone=True), server_default=func.now(), nullable=False, index=True - ) + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) updated_at = Column( DateTime(timezone=True), server_default=func.now(), diff --git a/src/infrastructure/repositories/sql/models/content_source.py b/src/infrastructure/repositories/sql/models/content_source.py index 3a2cf4db..7686f055 100644 --- a/src/infrastructure/repositories/sql/models/content_source.py +++ b/src/infrastructure/repositories/sql/models/content_source.py @@ -37,9 +37,7 @@ class ContentSourceModel(Base): language = Column(Text, nullable=True) embedding_model = Column(Text, nullable=True) status = Column(Text, nullable=False, server_default=text("'active'")) - created_at = Column( - DateTime(timezone=True), server_default=func.now(), nullable=False, index=True - ) + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) updated_at = Column( DateTime(timezone=True), server_default=func.now(), diff --git a/src/infrastructure/repositories/sql/models/diarization_record.py b/src/infrastructure/repositories/sql/models/diarization_record.py index a3ed626d..1b637fae 100644 --- a/src/infrastructure/repositories/sql/models/diarization_record.py +++ b/src/infrastructure/repositories/sql/models/diarization_record.py @@ -20,9 +20,7 @@ class DiarizationRecord(Base): id = Column(String, primary_key=True, default=_generate_uuid) name = Column(String, index=True) - subject_id = Column( - UUID, ForeignKey("knowledge_subjects.id"), nullable=True, index=True - ) + subject_id = Column(UUID, ForeignKey("knowledge_subjects.id"), nullable=True, index=True) source_type = Column(String) external_source = Column(String) language = Column(String) diff --git a/src/infrastructure/repositories/sql/models/ingestion_job.py b/src/infrastructure/repositories/sql/models/ingestion_job.py index 08c3e5eb..ae7dc01e 100644 --- a/src/infrastructure/repositories/sql/models/ingestion_job.py +++ b/src/infrastructure/repositories/sql/models/ingestion_job.py @@ -19,9 +19,7 @@ class IngestionJobModel(Base): ForeignKey("content_sources.id", deferrable=True, initially="IMMEDIATE"), nullable=True, ) - started_at = Column( - DateTime(timezone=True), server_default=func.now(), nullable=False, index=True - ) + started_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False, index=True) created_at = synonym("started_at") updated_at = Column( DateTime(timezone=True), @@ -50,9 +48,7 @@ class IngestionJobModel(Base): ) content_source = relationship("ContentSourceModel", back_populates="ingestion_jobs") - chunks = relationship( - "ChunkIndexModel", back_populates="job", cascade="all, delete-orphan" - ) + chunks = relationship("ChunkIndexModel", back_populates="job", cascade="all, delete-orphan") __table_args__ = ( Index("ix_ingestion_jobs_content_source_id", "content_source_id"), diff --git a/src/infrastructure/repositories/sql/models/knowledge_subject.py b/src/infrastructure/repositories/sql/models/knowledge_subject.py index 80fea806..595ddcb9 100644 --- a/src/infrastructure/repositories/sql/models/knowledge_subject.py +++ b/src/infrastructure/repositories/sql/models/knowledge_subject.py @@ -18,9 +18,7 @@ class KnowledgeSubjectModel(Base): name = Column(Text, nullable=False) description = Column(Text, nullable=True) icon = Column(Text, nullable=True) - created_at = Column( - DateTime(timezone=True), server_default=func.now(), nullable=False - ) + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column( DateTime(timezone=True), server_default=func.now(), @@ -29,6 +27,4 @@ class KnowledgeSubjectModel(Base): ) # Relationships (string-based names avoid import-order issues) - content_sources = relationship( - "ContentSourceModel", back_populates="subject", cascade="all, delete-orphan" - ) + content_sources = relationship("ContentSourceModel", back_populates="subject", cascade="all, delete-orphan") diff --git a/src/infrastructure/repositories/sql/models/user.py b/src/infrastructure/repositories/sql/models/user.py index c5831ce4..fa6c9a0e 100644 --- a/src/infrastructure/repositories/sql/models/user.py +++ b/src/infrastructure/repositories/sql/models/user.py @@ -11,17 +11,11 @@ class User(Base): __tablename__ = "users" id: Mapped[str] = mapped_column(String(36), primary_key=True) - email: Mapped[str] = mapped_column( - String(255), unique=True, index=True, nullable=False - ) + email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False) full_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) picture_url: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) - ) - last_login: Mapped[Optional[datetime]] = mapped_column( - DateTime(timezone=True), nullable=True - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) + last_login: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True) def __repr__(self) -> str: return f"" diff --git a/src/infrastructure/repositories/sql/models/voice_record.py b/src/infrastructure/repositories/sql/models/voice_record.py index 9f83d105..8b0d476b 100644 --- a/src/infrastructure/repositories/sql/models/voice_record.py +++ b/src/infrastructure/repositories/sql/models/voice_record.py @@ -20,7 +20,9 @@ class VoiceRecord(Base): id = Column(String, primary_key=True, default=_generate_uuid) name = Column(String, unique=True, index=True) embedding = Column(JSON) - audios_path = Column( - String - ) # S3 directory prefix where audio samples are stored (e.g. "voices/{id}/") + audios_path = Column(String) # S3 directory prefix where audio samples are stored (e.g. "voices/{id}/") + # Lifecycle status: "processing" (training/reinforcing in background), + # "ready" (usable), "failed" (training error). Nullable for legacy rows. + status = Column(String, nullable=True, default="ready") + status_message = Column(String, nullable=True) created_at = Column(DateTime, default=lambda: datetime.datetime.now(datetime.UTC)) diff --git a/src/infrastructure/repositories/sql/utils/utils.py b/src/infrastructure/repositories/sql/utils/utils.py index 800c575f..470f773a 100644 --- a/src/infrastructure/repositories/sql/utils/utils.py +++ b/src/infrastructure/repositories/sql/utils/utils.py @@ -6,9 +6,7 @@ logger = Logger() -def ensure_uuid( - val: Any, error_msg: str = "Invalid UUID string provided" -) -> Optional[UUID]: +def ensure_uuid(val: Any, error_msg: str = "Invalid UUID string provided") -> Optional[UUID]: """Ensures that the provided value is a UUID object or attempt to convert it if it's a string. Args: diff --git a/src/infrastructure/repositories/storage/storage.py b/src/infrastructure/repositories/storage/storage.py index c4d598da..eeefc980 100644 --- a/src/infrastructure/repositories/storage/storage.py +++ b/src/infrastructure/repositories/storage/storage.py @@ -16,9 +16,7 @@ def __init__(self): if not endpoint.startswith("http"): endpoint = f"http://{endpoint}" - logger.info( - "Connecting to MinIO at %s (bucket=%s)", endpoint, storage_cfg.minio_bucket - ) + logger.info("Connecting to MinIO at %s (bucket=%s)", endpoint, storage_cfg.minio_bucket) self.s3 = boto3.client( "s3", endpoint_url=endpoint, @@ -86,12 +84,8 @@ def delete_directory(self, s3_prefix: str): for page in paginator.paginate(Bucket=self.bucket, Prefix=s3_prefix): if "Contents" in page: delete_list = [{"Key": obj["Key"]} for obj in page["Contents"]] - self.s3.delete_objects( - Bucket=self.bucket, Delete={"Objects": delete_list} - ) - logger.info( - "Deleted %d objects with prefix %s", len(delete_list), s3_prefix - ) + self.s3.delete_objects(Bucket=self.bucket, Delete={"Objects": delete_list}) + logger.info("Deleted %d objects with prefix %s", len(delete_list), s3_prefix) def list_files(self, prefix: str = "", extension: str | None = None) -> list[dict]: paginator = self.s3.get_paginator("list_objects_v2") diff --git a/src/infrastructure/repositories/vector/chroma/chunk_repository.py b/src/infrastructure/repositories/vector/chroma/chunk_repository.py index 03bdd94d..a676454e 100644 --- a/src/infrastructure/repositories/vector/chroma/chunk_repository.py +++ b/src/infrastructure/repositories/vector/chroma/chunk_repository.py @@ -51,9 +51,7 @@ def __init__( self._chroma_client = None def create_documents(self, documents: List[ChunkModel]) -> List[str]: - logger.debug( - "Creating documents in ChromaDB", context={"num_documents": len(documents)} - ) + logger.debug("Creating documents in ChromaDB", context={"num_documents": len(documents)}) try: if not self._vector_store: @@ -172,16 +170,12 @@ def retriever( ) return [] - def _semantic_search( - self, query: str, top_kn: int, chroma_filter: Optional[Dict] - ) -> List[ChunkModel]: + def _semantic_search(self, query: str, top_kn: int, chroma_filter: Optional[Dict]) -> List[ChunkModel]: """Standard Chroma vector similarity search.""" if not self._vector_store: return [] - docs_with_scores = self._vector_store.similarity_search_with_score( - query, k=top_kn, filter=chroma_filter - ) + docs_with_scores = self._vector_store.similarity_search_with_score(query, k=top_kn, filter=chroma_filter) mapper = ChunkMapper() models: List[ChunkModel] = [] @@ -201,9 +195,7 @@ def _get_all_docs(self, chroma_filter: Optional[Dict]): collection = self._chroma_client.get_collection(self._collection_name) # Get all documents matching the filter (or all if no filter) - results = collection.get( - where=chroma_filter, include=["documents", "metadatas"] - ) + results = collection.get(where=chroma_filter, include=["documents", "metadatas"]) from langchain_core.documents import Document @@ -221,9 +213,7 @@ def _get_all_docs(self, chroma_filter: Optional[Dict]): ) return docs - def _bm25_search( - self, query: str, top_kn: int, chroma_filter: Optional[Dict] - ) -> List[ChunkModel]: + def _bm25_search(self, query: str, top_kn: int, chroma_filter: Optional[Dict]) -> List[ChunkModel]: """BM25 keyword search over Chroma docs using rank_bm25.""" try: from rank_bm25 import BM25Okapi @@ -264,9 +254,7 @@ def _bm25_search( return models - def _hybrid_search( - self, query: str, top_kn: int, chroma_filter: Optional[Dict] - ) -> List[ChunkModel]: + def _hybrid_search(self, query: str, top_kn: int, chroma_filter: Optional[Dict]) -> List[ChunkModel]: """Custom Hybrid search using Reciprocal Rank Fusion (RRF).""" fetch_k = max(top_kn * 3, 20) @@ -295,9 +283,7 @@ def _hybrid_search( rrf_scores[doc_id] = score # Sort by RRF score - sorted_ids = sorted( - rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True - ) + sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True) final_results = [] for doc_id in sorted_ids[:top_kn]: @@ -323,9 +309,7 @@ def delete(self, filters: Optional[Any]) -> int: try: chroma_filter = self._build_chroma_filter(filters) if not chroma_filter: - logger.warning( - "Delete called without filters in Chroma, skipping for safety." - ) + logger.warning("Delete called without filters in Chroma, skipping for safety.") return 0 collection = self._chroma_client.get_collection(self._collection_name) @@ -346,9 +330,7 @@ def delete(self, filters: Optional[Any]) -> int: ) return 0 - def list_chunks( - self, filters: Optional[Any], limit: int = 1000 - ) -> List[ChunkModel]: + def list_chunks(self, filters: Optional[Any], limit: int = 1000) -> List[ChunkModel]: if not self._chroma_client: return [] @@ -356,9 +338,7 @@ def list_chunks( chroma_filter = self._build_chroma_filter(filters) collection = self._chroma_client.get_collection(self._collection_name) - results = collection.get( - where=chroma_filter, limit=limit, include=["documents", "metadatas"] - ) + results = collection.get(where=chroma_filter, limit=limit, include=["documents", "metadatas"]) mapper = ChunkMapper() models = [] diff --git a/src/infrastructure/repositories/vector/faiss/chunk_repository.py b/src/infrastructure/repositories/vector/faiss/chunk_repository.py index 7383f6e7..e2979a8b 100644 --- a/src/infrastructure/repositories/vector/faiss/chunk_repository.py +++ b/src/infrastructure/repositories/vector/faiss/chunk_repository.py @@ -32,9 +32,7 @@ def _load_or_create(self): from langchain_community.vectorstores import FAISS if os.path.exists(os.path.join(self._index_path, f"{self._index_name}.faiss")): - logger.debug( - "Loading existing FAISS index", context={"path": self._index_path} - ) + logger.debug("Loading existing FAISS index", context={"path": self._index_path}) try: self._vector_store = FAISS.load_local( folder_path=self._index_path, @@ -43,32 +41,22 @@ def _load_or_create(self): allow_dangerous_deserialization=True, ) except Exception as e: - logger.error( - e, context={"action": "load_faiss_index", "path": self._index_path} - ) + logger.error(e, context={"action": "load_faiss_index", "path": self._index_path}) self._vector_store = None else: - logger.debug( - "FAISS index not found, it will be created upon first document addition" - ) + logger.debug("FAISS index not found, it will be created upon first document addition") def _save(self): """Save the FAISS index to disk.""" if self._vector_store: os.makedirs(self._index_path, exist_ok=True) - self._vector_store.save_local( - folder_path=self._index_path, index_name=self._index_name - ) + self._vector_store.save_local(folder_path=self._index_path, index_name=self._index_name) def create_documents(self, documents: List[ChunkModel]) -> List[str]: - logger.debug( - "Creating documents in FAISS", context={"num_documents": len(documents)} - ) + logger.debug("Creating documents in FAISS", context={"num_documents": len(documents)}) try: - texts: List[str] = [ - doc.content for doc in documents if doc.content is not None - ] + texts: List[str] = [doc.content for doc in documents if doc.content is not None] if not texts: return [] @@ -78,9 +66,7 @@ def create_documents(self, documents: List[ChunkModel]) -> List[str]: metadatas = [] for doc in valid_docs: - meta = doc.model_dump( - exclude={"content", "score"} - ) # No longer exclude ID + meta = doc.model_dump(exclude={"content", "score"}) # No longer exclude ID # Convert UUIDs and datetimes to string for better compatibility for key, value in meta.items(): @@ -168,21 +154,15 @@ def filter_func(metadata: dict) -> bool: else: return self._semantic_search(query, top_kn, filter_callable) except Exception as e: - logger.error( - "Error retrieving from FAISS", context={"query": query, "error": str(e)} - ) + logger.error("Error retrieving from FAISS", context={"query": query, "error": str(e)}) raise - def _semantic_search( - self, query: str, top_kn: int, filter_callable: Optional[Any] - ) -> List[ChunkModel]: + def _semantic_search(self, query: str, top_kn: int, filter_callable: Optional[Any]) -> List[ChunkModel]: """Standard FAISS vector similarity search.""" if not self._vector_store: return [] - docs_with_scores = self._vector_store.similarity_search_with_score( - query, k=top_kn, filter=filter_callable - ) + docs_with_scores = self._vector_store.similarity_search_with_score(query, k=top_kn, filter=filter_callable) mapper = ChunkMapper() models: List[ChunkModel] = [] for doc, score in docs_with_scores: @@ -202,9 +182,7 @@ def _get_all_docs(self, filter_callable: Optional[Any]): all_docs = [d for d in all_docs if filter_callable(d.metadata)] return all_docs - def _bm25_search( - self, query: str, top_kn: int, filter_callable: Optional[Any] - ) -> List[ChunkModel]: + def _bm25_search(self, query: str, top_kn: int, filter_callable: Optional[Any]) -> List[ChunkModel]: """BM25 keyword search over the in-memory FAISS docstore using rank_bm25.""" try: from rank_bm25 import BM25Okapi @@ -213,24 +191,18 @@ def _bm25_search( "rank_bm25 is not installed", context={"hint": "Run: pip install rank-bm25"}, ) - raise ImportError( - "rank-bm25 package required for BM25 search. Install with: pip install rank-bm25" - ) + raise ImportError("rank-bm25 package required for BM25 search. Install with: pip install rank-bm25") all_docs = self._get_all_docs(filter_callable) if not all_docs: return [] - tokenized_corpus = [ - (doc.page_content or "").lower().split() for doc in all_docs - ] + tokenized_corpus = [(doc.page_content or "").lower().split() for doc in all_docs] bm25 = BM25Okapi(tokenized_corpus) scores = bm25.get_scores(query.lower().split()) # Get top_kn indices sorted by descending score - ranked_indices = sorted( - range(len(scores)), key=lambda i: scores[i], reverse=True - )[:top_kn] + ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_kn] mapper = ChunkMapper() models: List[ChunkModel] = [] @@ -242,9 +214,7 @@ def _bm25_search( models.append(model) return models - def _hybrid_search( - self, query: str, top_kn: int, filter_callable: Optional[Any] - ) -> List[ChunkModel]: + def _hybrid_search(self, query: str, top_kn: int, filter_callable: Optional[Any]) -> List[ChunkModel]: """Hybrid search: merge BM25 + semantic results using Reciprocal Rank Fusion (RRF).""" # Fetch more candidates per method to improve fusion quality fetch_k = max(top_kn * 3, 20) @@ -280,9 +250,7 @@ def _doc_key(model: ChunkModel) -> str: try: content_str = (model.content or "").strip() - content_hash = hashlib.md5( - content_str.encode("utf-8"), usedforsecurity=False - ).hexdigest() + content_hash = hashlib.md5(content_str.encode("utf-8"), usedforsecurity=False).hexdigest() return content_hash except TypeError: # Fallback for older python versions if needed, though 3.12+ supports it @@ -301,9 +269,7 @@ def _doc_key(model: ChunkModel) -> str: # Sort by RRF score descending and take top_kn # Use a stable sort key to avoid TypeError when comparing different key types on tie-breaks - ranked_keys = sorted( - scores.keys(), key=lambda k: (scores[k], str(k)), reverse=True - )[:top_kn] + ranked_keys = sorted(scores.keys(), key=lambda k: (scores[k], str(k)), reverse=True)[:top_kn] results = [] for key in ranked_keys: @@ -311,9 +277,7 @@ def _doc_key(model: ChunkModel) -> str: model.score = scores[key] results.append(model) - logger.info( - "Hybrid search fusion completed", context={"total_results": len(results)} - ) + logger.info("Hybrid search fusion completed", context={"total_results": len(results)}) return results def delete(self, filters: Optional[Any]) -> int: @@ -323,9 +287,7 @@ def delete(self, filters: Optional[Any]) -> int: try: if not filters: - logger.warning( - "Delete called without filters in FAISS, skipping for safety." - ) + logger.warning("Delete called without filters in FAISS, skipping for safety.") return 0 # If it's a simple ID filter @@ -366,12 +328,8 @@ def delete(self, filters: Optional[Any]) -> int: ) raise - def list_chunks( - self, filters: Optional[Any], limit: int = 1000 - ) -> List[ChunkModel]: - logger.debug( - "Listing chunks from FAISS", context={"filters": filters, "limit": limit} - ) + def list_chunks(self, filters: Optional[Any], limit: int = 1000) -> List[ChunkModel]: + logger.debug("Listing chunks from FAISS", context={"filters": filters, "limit": limit}) if not self._vector_store: return [] diff --git a/src/infrastructure/repositories/vector/models/chunk_model.py b/src/infrastructure/repositories/vector/models/chunk_model.py index 25dd8d59..9ac75ee2 100644 --- a/src/infrastructure/repositories/vector/models/chunk_model.py +++ b/src/infrastructure/repositories/vector/models/chunk_model.py @@ -12,32 +12,20 @@ class ChunkModel(BaseModel): ) job_id: UUID = Field(description="ID of the processing job that created this chunk") - content_source_id: UUID = Field( - description="ID of the original content source, e.g., video ID, document ID, etc." - ) + content_source_id: UUID = Field(description="ID of the original content source, e.g., video ID, document ID, etc.") source_type: str = Field(description="e.g., YOUTUBE, PDF, WEB_PAGE, etc.") - external_source: Optional[str] = Field( - default=None, description="URL, file path, id, etc." - ) - subject_id: Optional[UUID] = Field( - default=None, description="Optional subject or category for the chunk" - ) + external_source: Optional[str] = Field(default=None, description="URL, file path, id, etc.") + subject_id: Optional[UUID] = Field(default=None, description="Optional subject or category for the chunk") index: Optional[int] = Field( default=None, description="Original sequence number of the chunk within the source", ) - content: Optional[str] = Field( - default=None, description="Text content of the chunk" - ) - tokens_count: Optional[int] = Field( - default=None, description="Number of tokens in the content" - ) + content: Optional[str] = Field(default=None, description="Text content of the chunk") + tokens_count: Optional[int] = Field(default=None, description="Number of tokens in the content") extra: Dict[str, Any] = Field(default_factory=dict) - language: Optional[str] = Field( - default=None, description="Language of the content, e.g., 'en', 'pt', etc." - ) + language: Optional[str] = Field(default=None, description="Language of the content, e.g., 'en', 'pt', etc.") embedding_model: Optional[str] = Field( default=None, description="Name of the embedding models used to generate the vector", diff --git a/src/infrastructure/repositories/vector/qdrant/chunk_repository.py b/src/infrastructure/repositories/vector/qdrant/chunk_repository.py index 5aa0b874..c22af663 100644 --- a/src/infrastructure/repositories/vector/qdrant/chunk_repository.py +++ b/src/infrastructure/repositories/vector/qdrant/chunk_repository.py @@ -39,9 +39,7 @@ def _ensure_collection_exists(self): context={"collection": self._collection_name}, ) # Get vector size from embedding service - vector_size = ( - self._embedding_service.model_loader_service.dimensions - ) + vector_size = self._embedding_service.model_loader_service.dimensions client.create_collection( collection_name=self._collection_name, @@ -70,9 +68,7 @@ def _ensure_collection_exists(self): ) def create_documents(self, documents: List[ChunkModel]) -> List[str]: - logger.debug( - "Creating documents in Qdrant", context={"num_documents": len(documents)} - ) + logger.debug("Creating documents in Qdrant", context={"num_documents": len(documents)}) try: self._ensure_collection_exists() @@ -158,9 +154,7 @@ def retriever( ) raise - def _semantic_search( - self, query: str, top_kn: int, filters: Optional[rest.Filter] - ) -> List[ChunkModel]: + def _semantic_search(self, query: str, top_kn: int, filters: Optional[rest.Filter]) -> List[ChunkModel]: query_vector = self._embedding_service.embed_query(query) with self._connector as client: @@ -174,9 +168,7 @@ def _semantic_search( return self._transform_hits(search_result) - def _bm25_search( - self, query: str, top_kn: int, filters: Optional[rest.Filter] - ) -> List[ChunkModel]: + def _bm25_search(self, query: str, top_kn: int, filters: Optional[rest.Filter]) -> List[ChunkModel]: """Qdrant full-text search as a proxy for BM25.""" # Create a text match filter for the content text_filter = rest.Filter( @@ -233,9 +225,7 @@ def _bm25_search( return self._transform_hits(search_result) - def _hybrid_search( - self, query: str, top_kn: int, filters: Optional[rest.Filter] - ) -> List[ChunkModel]: + def _hybrid_search(self, query: str, top_kn: int, filters: Optional[rest.Filter]) -> List[ChunkModel]: """Hybrid search combining semantic and text match.""" # Simple implementation: get top results from both and merge semantic_results = self._semantic_search(query, top_kn * 2, filters) @@ -262,9 +252,7 @@ def _reciprocal_rank_fusion( fused_scores[chunk_id] = fused_scores.get(chunk_id, 0.0) + score # Sort by fused score - sorted_ids = sorted( - fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True - ) + sorted_ids = sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True) final_results = [] for chunk_id in sorted_ids[:top_n]: @@ -296,9 +284,7 @@ def _convert_filters(self, filters: Optional[Any]) -> Optional[rest.Filter]: for k, v in filters.items(): if k == "id": if isinstance(v, list): - must_conditions.append( - rest.HasIdCondition(has_id=[str(id_val) for id_val in v]) - ) + must_conditions.append(rest.HasIdCondition(has_id=[str(id_val) for id_val in v])) else: must_conditions.append(rest.HasIdCondition(has_id=[str(v)])) else: @@ -311,9 +297,7 @@ def _convert_filters(self, filters: Optional[Any]) -> Optional[rest.Filter]: ) ) else: - must_conditions.append( - rest.FieldCondition(key=k, match=rest.MatchValue(value=v)) - ) + must_conditions.append(rest.FieldCondition(key=k, match=rest.MatchValue(value=v))) if must_conditions: return rest.Filter(must=must_conditions) @@ -331,9 +315,7 @@ def _transform_hits(self, hits: List[Any]) -> List[ChunkModel]: # Convert string dates back to datetime objects if "created_at" in payload and isinstance(payload["created_at"], str): try: - payload["created_at"] = datetime.fromisoformat( - payload["created_at"] - ) + payload["created_at"] = datetime.fromisoformat(payload["created_at"]) except ValueError: pass @@ -353,9 +335,7 @@ def _transform_hits(self, hits: List[Any]) -> List[ChunkModel]: def delete(self, filters: Optional[Any]) -> int: qdrant_filters = self._convert_filters(filters) if not qdrant_filters: - logger.warning( - "Delete called without filters in Qdrant, skipping for safety." - ) + logger.warning("Delete called without filters in Qdrant, skipping for safety.") return 0 self._ensure_collection_exists() @@ -374,9 +354,7 @@ def delete(self, filters: Optional[Any]) -> int: logger.error("Error deleting from Qdrant", context={"error": str(e)}) raise - def list_chunks( - self, filters: Optional[Any], limit: int = 1000 - ) -> List[ChunkModel]: + def list_chunks(self, filters: Optional[Any], limit: int = 1000) -> List[ChunkModel]: qdrant_filters = self._convert_filters(filters) try: diff --git a/src/infrastructure/repositories/vector/qdrant/connector.py b/src/infrastructure/repositories/vector/qdrant/connector.py index 5cefa5a4..863286d9 100644 --- a/src/infrastructure/repositories/vector/qdrant/connector.py +++ b/src/infrastructure/repositories/vector/qdrant/connector.py @@ -5,9 +5,7 @@ from src.config.logger import Logger -warnings.filterwarnings( - "ignore", message="Api key is used with an insecure connection." -) +warnings.filterwarnings("ignore", message="Api key is used with an insecure connection.") logger = Logger() diff --git a/src/infrastructure/repositories/vector/weaviate/chunk_repository.py b/src/infrastructure/repositories/vector/weaviate/chunk_repository.py index fee3cd75..c1ae941f 100644 --- a/src/infrastructure/repositories/vector/weaviate/chunk_repository.py +++ b/src/infrastructure/repositories/vector/weaviate/chunk_repository.py @@ -43,9 +43,7 @@ def __init__( ) def create_documents(self, documents: List[ChunkModel]) -> List[str]: - logger.debug( - "Creating documents in Weaviate", context={"num_documents": len(documents)} - ) + logger.debug("Creating documents in Weaviate", context={"num_documents": len(documents)}) try: texts = [doc.content for doc in documents] @@ -87,17 +85,13 @@ def create_documents(self, documents: List[ChunkModel]) -> List[str]: raise ValueError("All 'ids' must be strings.") with self.vector_store as vector_store: - created_ids = vector_store.add_texts( - texts=texts, metadatas=meta_datas, ids=ids - ) + created_ids = vector_store.add_texts(texts=texts, metadatas=meta_datas, ids=ids) logger.debug( "Created documents in Weaviate", context={ "num_documents": len(documents), - "created_ids_count": len(created_ids) - if created_ids is not None - else 0, + "created_ids_count": len(created_ids) if created_ids is not None else 0, }, ) @@ -136,9 +130,7 @@ def retriever( weaviate_filters_list.append(Filter.by_id().equal(v)) else: if isinstance(v, list): - weaviate_filters_list.append( - Filter.by_property(k).contains_any(v) - ) + weaviate_filters_list.append(Filter.by_property(k).contains_any(v)) else: weaviate_filters_list.append(Filter.by_property(k).equal(v)) @@ -167,19 +159,13 @@ def retriever( else: return self._semantic_search(query, top_kn, weaviate_filters) except Exception as e: - logger.error( - "Error retrieving documents", context={"query": query, "error": str(e)} - ) + logger.error("Error retrieving documents", context={"query": query, "error": str(e)}) raise - def _semantic_search( - self, query: str, top_kn: int, weaviate_filters: Optional[Any] - ) -> List[ChunkModel]: + def _semantic_search(self, query: str, top_kn: int, weaviate_filters: Optional[Any]) -> List[ChunkModel]: """Standard semantic (vector) search via LangChain WeaviateVectorStore.""" with self.vector_store as vector_store: - docs_with_scores = vector_store.similarity_search_with_score( - query, k=top_kn, filters=weaviate_filters - ) + docs_with_scores = vector_store.similarity_search_with_score(query, k=top_kn, filters=weaviate_filters) mapper = ChunkMapper() all_models: List[ChunkModel] = [] @@ -231,9 +217,7 @@ def _weaviate_objects_to_models(self, response_objects: list) -> List[ChunkModel chunks.append(chunk_model) return chunks - def _bm25_search( - self, query: str, top_kn: int, weaviate_filters: Optional[Any] - ) -> List[ChunkModel]: + def _bm25_search(self, query: str, top_kn: int, weaviate_filters: Optional[Any]) -> List[ChunkModel]: """Native Weaviate BM25 keyword search.""" with self._weaviate_client as client: collection = client.collections.get(self._collection_name) @@ -253,9 +237,7 @@ def _bm25_search( ) return models - def _hybrid_search( - self, query: str, top_kn: int, weaviate_filters: Optional[Any] - ) -> List[ChunkModel]: + def _hybrid_search(self, query: str, top_kn: int, weaviate_filters: Optional[Any]) -> List[ChunkModel]: """Native Weaviate Hybrid search (vector + BM25, alpha=0.5).""" # Generate query vector since Weaviate collection has no automatic vectorizer query_vector = self._embedding_service.embed_query(query) @@ -295,9 +277,7 @@ def delete(self, filters: Optional[Any]) -> int: weaviate_filters_list.append(Filter.by_id().equal(v)) else: if isinstance(v, list): - weaviate_filters_list.append( - Filter.by_property(k).contains_any(v) - ) + weaviate_filters_list.append(Filter.by_property(k).contains_any(v)) else: weaviate_filters_list.append(Filter.by_property(k).equal(v)) @@ -311,9 +291,7 @@ def delete(self, filters: Optional[Any]) -> int: # Weaviate v4 requires a valid Filter object for delete_many # To delete all, we can use a filter that matches everything (not recommended for production without care) # For now, let's just log and return 0 if no filter is provided to avoid accidental mass deletion - logger.warning( - "Delete called without filters in Weaviate, skipping for safety." - ) + logger.warning("Delete called without filters in Weaviate, skipping for safety.") return 0 logger.debug("Deleting documents", context={"filters": weaviate_filters}) @@ -336,9 +314,7 @@ def delete(self, filters: Optional[Any]) -> int: ) raise - def list_chunks( - self, filters: Optional[Any], limit: int = 1000 - ) -> List[ChunkModel]: + def list_chunks(self, filters: Optional[Any], limit: int = 1000) -> List[ChunkModel]: logger.debug("Listing chunks", context={"filters": filters, "limit": limit}) try: @@ -365,13 +341,9 @@ def list_chunks( weaviate_filters_list.append(Filter.by_id().equal(v)) else: if isinstance(v, list): - weaviate_filters_list.append( - Filter.by_property(k).contains_any(v) - ) + weaviate_filters_list.append(Filter.by_property(k).contains_any(v)) else: - weaviate_filters_list.append( - Filter.by_property(k).equal(v) - ) + weaviate_filters_list.append(Filter.by_property(k).equal(v)) if weaviate_filters_list: if len(weaviate_filters_list) == 1: @@ -379,16 +351,12 @@ def list_chunks( else: weaviate_filters = Filter.all_of(weaviate_filters_list) - response = collection.query.fetch_objects( - filters=weaviate_filters, limit=limit, include_vector=True - ) + response = collection.query.fetch_objects(filters=weaviate_filters, limit=limit, include_vector=True) chunks = [] for obj in response.objects: if not hasattr(obj, "uuid"): - logger.warning( - "Object missing 'uuid' attribute", context={"object": obj} - ) + logger.warning("Object missing 'uuid' attribute", context={"object": obj}) continue properties = obj.properties @@ -401,9 +369,7 @@ def list_chunks( chunks.append(chunk_model) # Sort by index if present - chunks.sort( - key=lambda x: x.index if x.index is not None else float("inf") - ) + chunks.sort(key=lambda x: x.index if x.index is not None else float("inf")) logger.debug( "Listed chunks", @@ -412,9 +378,7 @@ def list_chunks( return chunks except Exception as e: - logger.error( - "Error listing chunks", context={"filters": filters, "error": str(e)} - ) + logger.error("Error listing chunks", context={"filters": filters, "error": str(e)}) raise def is_ready(self) -> bool: diff --git a/src/infrastructure/repositories/vector/weaviate/weaviate_client.py b/src/infrastructure/repositories/vector/weaviate/weaviate_client.py index 7757c837..0471ac99 100644 --- a/src/infrastructure/repositories/vector/weaviate/weaviate_client.py +++ b/src/infrastructure/repositories/vector/weaviate/weaviate_client.py @@ -20,9 +20,7 @@ def _create_client(self): if self._weaviate_config.weaviate_api_key: client = weaviate.connect_to_weaviate_cloud( cluster_url=self._weaviate_config.weaviate_url, - auth_credentials=Auth.api_key( - self._weaviate_config.weaviate_api_key - ), + auth_credentials=Auth.api_key(self._weaviate_config.weaviate_api_key), ) else: client = weaviate.connect_to_local( @@ -45,9 +43,7 @@ def _create_client(self): return client except Exception as e: - logger.error( - "Error creating WeaviateConfig connection", context={"error": str(e)} - ) + logger.error("Error creating WeaviateConfig connection", context={"error": str(e)}) raise def __enter__(self): @@ -61,15 +57,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): try: self._client.close() except Exception as e: - logger.error( - "Error closing WeaviateConfig connection", context={"error": str(e)} - ) + logger.error("Error closing WeaviateConfig connection", context={"error": str(e)}) finally: self._client = None - def create_collection_if_not_exists( - self, collection_name: str, dimensions: Optional[int] = None - ): + def create_collection_if_not_exists(self, collection_name: str, dimensions: Optional[int] = None): """Creates the collection with explicit property types if it doesn't exist. This prevents Weaviate auto-schema from misidentifying types (e.g. tokens_count as text). @@ -104,19 +96,13 @@ def create_collection_if_not_exists( wvc.Property(name="version_number", data_type=wvc.DataType.INT), # Text fields wvc.Property(name="source_type", data_type=wvc.DataType.TEXT), - wvc.Property( - name="external_source", data_type=wvc.DataType.TEXT - ), + wvc.Property(name="external_source", data_type=wvc.DataType.TEXT), wvc.Property(name="language", data_type=wvc.DataType.TEXT), wvc.Property(name="content", data_type=wvc.DataType.TEXT), - wvc.Property( - name="embedding_model", data_type=wvc.DataType.TEXT - ), + wvc.Property(name="embedding_model", data_type=wvc.DataType.TEXT), # ID fields (stored as TEXT for simplicity or UUID if supported by the client version) wvc.Property(name="job_id", data_type=wvc.DataType.TEXT), - wvc.Property( - name="content_source_id", data_type=wvc.DataType.TEXT - ), + wvc.Property(name="content_source_id", data_type=wvc.DataType.TEXT), wvc.Property(name="subject_id", data_type=wvc.DataType.TEXT), # Extra metadata as text (JSON string) wvc.Property(name="extra_json", data_type=wvc.DataType.TEXT), diff --git a/src/infrastructure/services/auth_service.py b/src/infrastructure/services/auth_service.py index d022f56f..6548ddb9 100644 --- a/src/infrastructure/services/auth_service.py +++ b/src/infrastructure/services/auth_service.py @@ -74,9 +74,7 @@ def create_access_token(self, user: UserEntity) -> str: def verify_token(self, token: str) -> Optional[Dict[str, Any]]: """Verifies a local JWT and returns the payload.""" try: - payload = jwt.decode( - token, self.jwt_secret, algorithms=[self.jwt_algorithm] - ) + payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) return payload except JWTError: return None diff --git a/src/infrastructure/services/chunk_index_service.py b/src/infrastructure/services/chunk_index_service.py index d1d24bc3..b2cd1117 100644 --- a/src/infrastructure/services/chunk_index_service.py +++ b/src/infrastructure/services/chunk_index_service.py @@ -13,9 +13,7 @@ class ChunkIndexService: """Service that works with the chunk_index SQL table and returns domain ChunkEntity items where appropriate.""" - def __init__( - self, repository: ChunkIndexSQLRepository, logger: Optional[Logger] = None - ) -> None: + def __init__(self, repository: ChunkIndexSQLRepository, logger: Optional[Logger] = None) -> None: self._repo = repository self._logger = logger or Logger() self._mapper = ChunkIndexMapper() @@ -34,16 +32,12 @@ def create_chunks(self, entities: List[ChunkEntity]) -> List[UUID]: "chars": len(e.content) if e.content is not None else 0, "tokens_count": e.tokens_count, "language": e.language, - "source_type": e.source_type.value - if isinstance(e.source_type, SourceType) - else e.source_type, + "source_type": e.source_type.value if isinstance(e.source_type, SourceType) else e.source_type, "subject_id": e.subject_id, "external_source": e.external_source, "extra": e.extra, "version_number": e.version_number, - "vector_store_type": e.extra.get("vector_store_type") - if hasattr(e, "extra") - else None, + "vector_store_type": e.extra.get("vector_store_type") if hasattr(e, "extra") else None, } ) return self._repo.create_chunks(rows) @@ -54,9 +48,7 @@ def list_by_content_source( limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[ChunkEntity]: - models = self._repo.list_by_content_source( - content_source_id=content_source_id, limit=limit, offset=offset - ) + models = self._repo.list_by_content_source(content_source_id=content_source_id, limit=limit, offset=offset) temp = [self._mapper.model_to_entity(m) for m in models] return [e for e in temp if e is not None] @@ -71,9 +63,7 @@ def delete_by_job_id(self, job_id: UUID) -> int: """Delete from SQL by job_id. Vector store sync should happen at Use Case level.""" return self._repo.delete_by_job_id(job_id=job_id) - def search( - self, query: Optional[str], top_k: int = 10, filters: Optional[Any] = None - ) -> List[ChunkEntity]: + def search(self, query: Optional[str], top_k: int = 10, filters: Optional[Any] = None) -> List[ChunkEntity]: models = self._repo.search(query=query, top_k=top_k, filters=filters) temp = [self._mapper.model_to_entity(m) for m in models] return [e for e in temp if e is not None] @@ -92,9 +82,7 @@ def list_chunks( source_id: Optional[UUID] = None, search_query: Optional[str] = None, ) -> List[ChunkEntity]: - models = self._repo.list_chunks( - limit=limit, offset=offset, source_id=source_id, search_query=search_query - ) + models = self._repo.list_chunks(limit=limit, offset=offset, source_id=source_id, search_query=search_query) temp = [self._mapper.model_to_entity(m) for m in models] return [e for e in temp if e is not None] diff --git a/src/infrastructure/services/chunk_vector_service.py b/src/infrastructure/services/chunk_vector_service.py index 01f301e3..1335f63b 100644 --- a/src/infrastructure/services/chunk_vector_service.py +++ b/src/infrastructure/services/chunk_vector_service.py @@ -84,9 +84,7 @@ def list_by_source(self, filters: Optional[Any] = None) -> List[ChunkEntity]: def delete(self, filters: Optional[Any]) -> int: """Delete documents from the vector store based on provided filters.""" - logger.debug( - "Deleting documents from vector store", context={"filters": str(filters)} - ) + logger.debug("Deleting documents from vector store", context={"filters": str(filters)}) return self._repository.delete(filters=filters) def delete_by_id(self, chunk_id: UUID) -> int: diff --git a/src/infrastructure/services/content_source_service.py b/src/infrastructure/services/content_source_service.py index 5fb47d16..fb6d0a1c 100644 --- a/src/infrastructure/services/content_source_service.py +++ b/src/infrastructure/services/content_source_service.py @@ -17,9 +17,7 @@ class ContentSourceService: Receives a ContentSourceSQLRepository and returns domain entities as outputs. """ - def __init__( - self, repository: ContentSourceSQLRepository, logger: Optional[Logger] = None - ) -> None: + def __init__(self, repository: ContentSourceSQLRepository, logger: Optional[Logger] = None) -> None: self._repo = repository self._logger = logger or Logger() @@ -39,13 +37,9 @@ def create_source( source_metadata: Optional[dict] = None, ) -> ContentSourceEntity: """Create a content source and return a domain entity.""" - self._logger.debug( - "Creating content source", context={"external_source": external_source} - ) + self._logger.debug("Creating content source", context={"external_source": external_source}) - effective_processing_status = processing_status or ( - status.value if status is not None else "pending" - ) + effective_processing_status = processing_status or (status.value if status is not None else "pending") created_id = self._repo.create( subject_id=subject_id, @@ -88,17 +82,13 @@ def get_by_source_info( subject_id=subject_id, ) - return ( - ContentSourceMapper.model_to_entity(list_models[0]) if list_models else None - ) + return ContentSourceMapper.model_to_entity(list_models[0]) if list_models else None def get_by_id(self, id: UUID) -> Optional[ContentSourceEntity]: model = self._repo.get_by_id(id) return ContentSourceMapper.model_to_entity(model) - def get_by_diarization_id( - self, diarization_id: str - ) -> Optional[ContentSourceEntity]: + def get_by_diarization_id(self, diarization_id: str) -> Optional[ContentSourceEntity]: """Get a content source by its diarization_id in metadata.""" model = self._repo.get_by_diarization_id(diarization_id) return ContentSourceMapper.model_to_entity(model) @@ -112,9 +102,7 @@ def list_by_subject( models = self._repo.list_by_subject(subject_id, limit=limit, offset=offset) return ContentSourceMapper.model_list_to_entities(models) - def list_all( - self, limit: Optional[int] = None, offset: Optional[int] = None - ) -> List[ContentSourceEntity]: + def list_all(self, limit: Optional[int] = None, offset: Optional[int] = None) -> List[ContentSourceEntity]: models = self._repo.list(limit=limit, offset=offset) return ContentSourceMapper.model_list_to_entities(models) @@ -168,20 +156,14 @@ def update_title(self, content_source_id: UUID, title: str) -> None: def update_metadata(self, content_source_id: UUID, metadata: dict) -> None: """Update the metadata of a content source.""" - self._repo.update_metadata( - content_source_id=content_source_id, metadata=metadata - ) + self._repo.update_metadata(content_source_id=content_source_id, metadata=metadata) - def get_existing_external_sources( - self, subject_id: UUID, source_type: SourceType - ) -> set[str]: + def get_existing_external_sources(self, subject_id: UUID, source_type: SourceType) -> set[str]: """Return a set of all external_source values for a subject and source_type. Used for bulk deduplication when ingesting channels or large batches. """ - raw = self._repo.list_external_sources_by_subject( - subject_id=subject_id, source_type=source_type.value - ) + raw = self._repo.list_external_sources_by_subject(subject_id=subject_id, source_type=source_type.value) return set(raw) def delete_source(self, content_source_id: UUID) -> bool: diff --git a/src/infrastructure/services/ingestion_job_service.py b/src/infrastructure/services/ingestion_job_service.py index 769cfc6f..50682493 100644 --- a/src/infrastructure/services/ingestion_job_service.py +++ b/src/infrastructure/services/ingestion_job_service.py @@ -13,9 +13,7 @@ class IngestionJobService: """Service layer for ingestion jobs.""" - def __init__( - self, repository: IngestionJobSQLRepository, logger: Optional[Logger] = None - ) -> None: + def __init__(self, repository: IngestionJobSQLRepository, logger: Optional[Logger] = None) -> None: self._repo = repository self._logger = logger or Logger() @@ -97,15 +95,11 @@ def get_by_id(self, job_id: UUID) -> Optional[IngestionJobEntity]: return None return IngestionJobMapper.model_to_entity(model) - def list_by_content_source( - self, content_source_id: UUID - ) -> List[IngestionJobEntity]: + def list_by_content_source(self, content_source_id: UUID) -> List[IngestionJobEntity]: models = self._repo.list_by_content_source(content_source_id) return IngestionJobMapper.model_list_to_entities(models) - def list_recent_jobs( - self, limit: int = 50, offset: int = 0 - ) -> List[IngestionJobEntity]: + def list_recent_jobs(self, limit: int = 50, offset: int = 0) -> List[IngestionJobEntity]: """List recent ingestion jobs, ordered by creation date.""" models = self._repo.list_recent_jobs(limit=limit, offset=offset) return IngestionJobMapper.model_list_to_entities(models) @@ -118,9 +112,7 @@ def list_jobs( search: Optional[str] = None, ) -> dict: """List jobs with pagination and filters. Returns {'jobs': [...], 'total': int, 'stats': {...}}""" - models = self._repo.list_jobs( - limit=limit, offset=offset, status=status, search=search - ) + models = self._repo.list_jobs(limit=limit, offset=offset, status=status, search=search) total = self._repo.count_jobs(status=status, search=search) stats = self._repo.get_status_counts(search=search) @@ -134,15 +126,9 @@ def list_recent_jobs_by_subject( self, subject_id: UUID, limit: int = 50, offset: int = 0 ) -> List[IngestionJobEntity]: """List recent ingestion jobs for a specific subject.""" - models = self._repo.list_recent_jobs_by_subject( - subject_id, limit=limit, offset=offset - ) + models = self._repo.list_recent_jobs_by_subject(subject_id, limit=limit, offset=offset) return IngestionJobMapper.model_list_to_entities(models) - def mark_previous_jobs_as_reprocessed( - self, content_source_id: UUID, current_job_id: UUID - ) -> int: + def mark_previous_jobs_as_reprocessed(self, content_source_id: UUID, current_job_id: UUID) -> int: """Mark previous jobs as reprocessed.""" - return self._repo.mark_previous_jobs_as_reprocessed( - content_source_id, current_job_id - ) + return self._repo.mark_previous_jobs_as_reprocessed(content_source_id, current_job_id) diff --git a/src/infrastructure/services/knowledge_subject_service.py b/src/infrastructure/services/knowledge_subject_service.py index 16c374a5..be568c42 100644 --- a/src/infrastructure/services/knowledge_subject_service.py +++ b/src/infrastructure/services/knowledge_subject_service.py @@ -19,9 +19,7 @@ class KnowledgeSubjectService: All outputs that represent subject data are returned as KnowledgeSubjectEntity instances. """ - def __init__( - self, repository: KnowledgeSubjectSQLRepository, logger: Optional[Logger] = None - ) -> None: + def __init__(self, repository: KnowledgeSubjectSQLRepository, logger: Optional[Logger] = None) -> None: self._repo = repository self._logger = logger or Logger() @@ -37,9 +35,7 @@ def create_subject( "Creating knowledge subject", context={"name": name, "external_ref": external_ref, "icon": icon}, ) - created_id = self._repo.create_subject( - name=name, external_ref=external_ref, description=description, icon=icon - ) + created_id = self._repo.create_subject(name=name, external_ref=external_ref, description=description, icon=icon) model = self._repo.get_by_id(created_id) entity = KnowledgeSubjectMapper.model_to_entity(model) if entity is None: @@ -57,9 +53,7 @@ def get_subject_by_id(self, id: UUID) -> Optional[KnowledgeSubjectEntity]: model = self._repo.get_by_id(id) return KnowledgeSubjectMapper.model_to_entity(model) - def get_subject_by_external_ref( - self, external_ref: str - ) -> Optional[KnowledgeSubjectEntity]: + def get_subject_by_external_ref(self, external_ref: str) -> Optional[KnowledgeSubjectEntity]: """Fetch a subject by an external reference string and return as an Entity.""" model = self._repo.get_by_external_ref(external_ref) return KnowledgeSubjectMapper.model_to_entity(model) @@ -75,9 +69,7 @@ def get_or_create_by_external_ref( If name is not provided when creating, external_ref is used as the name. Returns a Domain Entity representing the subject. """ - self._logger.debug( - "get_or_create_by_external_ref", context={"external_ref": external_ref} - ) + self._logger.debug("get_or_create_by_external_ref", context={"external_ref": external_ref}) existing = self._repo.get_by_external_ref(external_ref) if existing is not None: entity = KnowledgeSubjectMapper.model_to_entity(existing) diff --git a/src/infrastructure/services/model_loader_service.py b/src/infrastructure/services/model_loader_service.py index 4b1e35c1..f870721a 100644 --- a/src/infrastructure/services/model_loader_service.py +++ b/src/infrastructure/services/model_loader_service.py @@ -82,9 +82,7 @@ def get_align_model(self, language_code: str, device: str): with self._lock: if key not in ModelLoaderService._models: logger.info("Loading Alignment model: %s", language_code) - ModelLoaderService._models[key] = whisperx.load_align_model( - language_code=language_code, device=device - ) + ModelLoaderService._models[key] = whisperx.load_align_model(language_code=language_code, device=device) return ModelLoaderService._models[key] def get_diarization_pipeline(self, hf_token: str, device: str): @@ -106,14 +104,10 @@ def get_voice_inference(self, hf_token: str, device: str): with self._lock: if key not in ModelLoaderService._models: logger.info("Loading Pyannote Voice Identification model") - model = Model.from_pretrained( - "pyannote/wespeaker-voxceleb-resnet34-LM", use_auth_token=hf_token - ) + model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", use_auth_token=hf_token) if model is None: raise RuntimeError("Failed to load Pyannote Model") - ModelLoaderService._models[key] = Inference( - model, window="whole", device=torch.device(device) - ) + ModelLoaderService._models[key] = Inference(model, window="whole", device=torch.device(device)) return ModelLoaderService._models[key] def clear_cache(self): diff --git a/src/infrastructure/services/pyannote_voice_recognizer.py b/src/infrastructure/services/pyannote_voice_recognizer.py index b2e8d458..59de8222 100644 --- a/src/infrastructure/services/pyannote_voice_recognizer.py +++ b/src/infrastructure/services/pyannote_voice_recognizer.py @@ -21,9 +21,7 @@ def __init__(self, voice_db, hf_token: str, threshold: float = 0.8): self._device = get_best_device() def _get_inference(self): - return model_loader.get_voice_inference( - hf_token=self.hf_token, device=self._device - ) + return model_loader.get_voice_inference(hf_token=self.hf_token, device=self._device) def _compare(self, embedding: np.ndarray) -> list[tuple[str, float, str]]: scores = [] diff --git a/src/infrastructure/services/re_rank_service.py b/src/infrastructure/services/re_rank_service.py index 429948e1..e6aef1aa 100644 --- a/src/infrastructure/services/re_rank_service.py +++ b/src/infrastructure/services/re_rank_service.py @@ -34,10 +34,7 @@ def rerank(self, query: str, documents: List[ChunkModel]) -> List[ChunkModel]: ) # FlashRank expects a list of dicts with 'id' and 'text' - passages = [ - {"id": str(doc.id), "text": doc.content, "meta": {"model": doc}} - for doc in documents - ] + passages = [{"id": str(doc.id), "text": doc.content, "meta": {"model": doc}} for doc in documents] rerank_request = RerankRequest(query=query, passages=passages) results = self._ranker.rerank(rerank_request) diff --git a/src/infrastructure/services/redis_task_queue_service.py b/src/infrastructure/services/redis_task_queue_service.py index d4a8a04d..662f47a2 100644 --- a/src/infrastructure/services/redis_task_queue_service.py +++ b/src/infrastructure/services/redis_task_queue_service.py @@ -52,16 +52,12 @@ def __init__(self, queue_name: str = "wys_task_queue", num_workers: int = 1): def start(self): if self._workers: - logger.warning( - "RedisTaskQueueService already started.", context={"where": "start"} - ) + logger.warning("RedisTaskQueueService already started.", context={"where": "start"}) return self._should_stop = False for i in range(self._num_workers): - t = threading.Thread( - target=self._worker_loop, name=f"RedisTaskWorker-{i}", daemon=True - ) + t = threading.Thread(target=self._worker_loop, name=f"RedisTaskWorker-{i}", daemon=True) t.start() self._workers.append(t) logger.info( @@ -123,9 +119,7 @@ def peek_queue(self, limit: int = 50) -> list[dict]: try: # Fetch raw JSON payloads from the list # Redis list is LPUSH (front is index 0) - raw_tasks = cast( - list[bytes], self._redis.lrange(self._queue_name, 0, limit - 1) - ) + raw_tasks = cast(list[bytes], self._redis.lrange(self._queue_name, 0, limit - 1)) tasks = [] for payload in raw_tasks: try: diff --git a/src/infrastructure/services/task_queue_service.py b/src/infrastructure/services/task_queue_service.py index 4468407b..4296550e 100644 --- a/src/infrastructure/services/task_queue_service.py +++ b/src/infrastructure/services/task_queue_service.py @@ -24,16 +24,12 @@ def __init__(self, num_workers: int = 1): def start(self): """Starts the background worker threads.""" if self._workers: - logger.warning( - "TaskQueueService already started.", context={"where": "start"} - ) + logger.warning("TaskQueueService already started.", context={"where": "start"}) return self._should_stop = False for i in range(self._num_workers): - t = threading.Thread( - target=self._worker_loop, name="TaskQueueWorker-" + str(i), daemon=True - ) + t = threading.Thread(target=self._worker_loop, name="TaskQueueWorker-" + str(i), daemon=True) t.start() self._workers.append(t) logger.info( diff --git a/src/infrastructure/services/text_splitter_service.py b/src/infrastructure/services/text_splitter_service.py index d2d9f9d4..1f493dd7 100644 --- a/src/infrastructure/services/text_splitter_service.py +++ b/src/infrastructure/services/text_splitter_service.py @@ -58,20 +58,14 @@ def split_text( # 2. Decode back to text try: - chunk_text = self.tokenizer.decode( - chunk_ids, skip_special_tokens=True - ) + chunk_text = self.tokenizer.decode(chunk_ids, skip_special_tokens=True) except Exception: chunk_text = self.tokenizer.decode(chunk_ids) chunk_metadata = (metadata or {}).copy() - chunk_metadata.update( - {"tokens_count": len(chunk_ids), "chunk_index": chunk_index} - ) + chunk_metadata.update({"tokens_count": len(chunk_ids), "chunk_index": chunk_index}) - documents.append( - Document(page_content=chunk_text, metadata=chunk_metadata) - ) + documents.append(Document(page_content=chunk_text, metadata=chunk_metadata)) i += step chunk_index += 1 @@ -102,8 +96,6 @@ def split_text( "is_fallback": True, } ) - documents.append( - Document(page_content=content, metadata=chunk_metadata) - ) + documents.append(Document(page_content=content, metadata=chunk_metadata)) return documents diff --git a/src/infrastructure/services/voice_profile_service.py b/src/infrastructure/services/voice_profile_service.py index ce673b57..ab192383 100644 --- a/src/infrastructure/services/voice_profile_service.py +++ b/src/infrastructure/services/voice_profile_service.py @@ -3,7 +3,7 @@ import os import uuid from contextlib import suppress -from typing import cast +from typing import Any, cast from urllib.parse import unquote import numpy as np @@ -30,9 +30,7 @@ def _get_inference(self): import torch from pyannote.audio import Inference, Model - model = Model.from_pretrained( - "pyannote/wespeaker-voxceleb-resnet34-LM", use_auth_token=self.hf_token - ) + model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", use_auth_token=self.hf_token) device = torch.device(self._device) self._inference = Inference(model, window="whole", device=device) return self._inference @@ -53,16 +51,12 @@ def add(self, name: str, audio_path: str) -> tuple[str, str]: local_temp_file = None is_s3 = ( - audio_path.startswith("s3://") - or audio_path.startswith("processed/") - or audio_path.startswith("voices/") + audio_path.startswith("s3://") or audio_path.startswith("processed/") or audio_path.startswith("voices/") ) if is_s3 or not os.path.exists(audio_path): s3_key = unquote(audio_path.replace(f"s3://{self.storage.bucket}/", "")) - local_temp_file = os.path.join( - temp_download_dir, f"tmp_voice_{uuid.uuid4()}.wav" - ) + local_temp_file = os.path.join(temp_download_dir, f"tmp_voice_{uuid.uuid4()}.wav") os.makedirs(temp_download_dir, exist_ok=True) try: @@ -71,27 +65,47 @@ def add(self, name: str, audio_path: str) -> tuple[str, str]: audio_path = local_temp_file except Exception as e: if is_s3: - raise ValueError( - f"Failed to download from S3 (key: {s3_key}): {str(e)}" - ) + raise ValueError(f"Failed to download from S3 (key: {s3_key}): {str(e)}") else: - raise FileNotFoundError( - f"Local file not found and S3 download skipped: {audio_path}" - ) + raise FileNotFoundError(f"Local file not found and S3 download skipped: {audio_path}") if not os.path.exists(audio_path): raise FileNotFoundError(f"Final audio path does not exist: {audio_path}") + record: VoiceRecord | None = None try: - existing = ( - self.db.query(VoiceRecord).filter(VoiceRecord.name == name).first() - ) + existing = self.db.query(VoiceRecord).filter(VoiceRecord.name == name).first() + is_reinforcement = existing is not None + + # Create or mark the target row as "processing" and commit BEFORE the + # slow embedding extraction, so the API/UI can observe the in-flight + # state while training/reinforcement runs. + if is_reinforcement and existing: + record = existing + record.status = cast(Any, "processing") + record.status_message = cast(Any, f"Reforçando voz '{name}'...") + else: + voice_id = str(uuid.uuid4()) + record = VoiceRecord( + id=voice_id, + name=name, + embedding=[], + audios_path=f"voices/{voice_id}/", + status="processing", + status_message=f"Treinando voz '{name}'...", + created_at=datetime.datetime.now(datetime.UTC), + ) + self.db.add(record) + self.db.commit() + + if not record: + raise ValueError("Failed to create or retrieve voice record.") logger.info("Extracting embedding for voice: %s", name) new_embedding = np.array(self._extract_embedding(audio_path)) - if existing: - old_emb = np.array(existing.embedding) + if is_reinforcement: + old_emb = np.array(record.embedding) similarity = cosine_similarity(old_emb, new_embedding) if similarity >= 0.95: @@ -101,7 +115,10 @@ def add(self, name: str, audio_path: str) -> tuple[str, str]: name, similarity, ) - return str(existing.id), "" + record.status = cast(Any, "ready") + record.status_message = cast(Any, None) + self.db.commit() + return str(record.id), "" logger.info( "Reinforcing existing voice profile: %s (similarity=%.4f)", @@ -109,34 +126,33 @@ def add(self, name: str, audio_path: str) -> tuple[str, str]: similarity, ) combined_emb = (old_emb + new_embedding) / 2.0 - existing.embedding = combined_emb.tolist() + record.embedding = combined_emb.tolist() sample_id = str(uuid.uuid4()) - target_s3_key = f"{existing.audios_path}sample_{sample_id}.wav" + target_s3_key = f"{record.audios_path}sample_{sample_id}.wav" + self.storage.upload_file(audio_path, target_s3_key) + else: + record.embedding = new_embedding.tolist() + target_s3_key = f"{record.audios_path}reference_{record.id}.wav" + logger.info("Uploading reference audio to: %s", target_s3_key) self.storage.upload_file(audio_path, target_s3_key) - self.db.commit() - return str(existing.id), target_s3_key - - # Create new voice - voice_id = str(uuid.uuid4()) - audios_path = f"voices/{voice_id}/" - - target_s3_key = f"{audios_path}reference_{voice_id}.wav" - logger.info("Uploading reference audio to: %s", target_s3_key) - self.storage.upload_file(audio_path, target_s3_key) - - new_voice = VoiceRecord( - id=voice_id, - name=name, - embedding=new_embedding.tolist(), - audios_path=audios_path, - created_at=datetime.datetime.now(datetime.UTC), - ) - self.db.add(new_voice) + record.status = cast(Any, "ready") + record.status_message = cast(Any, None) self.db.commit() - return voice_id, target_s3_key + return str(record.id), target_s3_key + except Exception as e: + if record is not None: + with suppress(Exception): + self.db.rollback() + # Re-fetch to work on an attached instance after rollback + fresh = self.db.query(VoiceRecord).filter(VoiceRecord.id == record.id).first() + if fresh is not None: + fresh.status = cast(Any, "failed") + fresh.status_message = cast(Any, str(e)[:500]) + self.db.commit() + raise finally: if local_temp_file and os.path.exists(local_temp_file): os.remove(local_temp_file) @@ -161,25 +177,23 @@ def list_audio_files(self, voice_id: str) -> list[dict]: voice = self.db.query(VoiceRecord).filter(VoiceRecord.id == voice_id).first() if not voice or not voice.audios_path: return [] - return self.storage.list_files( - prefix=cast(str, voice.audios_path), extension=".wav" - ) + return self.storage.list_files(prefix=cast(str, voice.audios_path), extension=".wav") def delete_audio_file(self, s3_key: str) -> None: """Delete a specific audio file from S3.""" self.storage.delete_file(s3_key) def list_voices(self) -> dict[str, str]: - voices = self.db.query(VoiceRecord).all() + voices = self.db.query(VoiceRecord).filter(VoiceRecord.status == "ready").all() return {cast(str, v.name): cast(str, v.audios_path) for v in voices} @property def voices(self) -> dict: - records = self.db.query(VoiceRecord).all() - return { - cast(str, r.name): {"embedding": r.embedding, "id": cast(str, r.id)} - for r in records - } + # Only expose voices that finished training — placeholder rows in + # "processing" state have empty embeddings and would break similarity + # comparisons. + records = self.db.query(VoiceRecord).filter(VoiceRecord.status == "ready").all() + return {cast(str, r.name): {"embedding": r.embedding, "id": cast(str, r.id)} for r in records} def __len__(self) -> int: - return self.db.query(VoiceRecord).count() + return self.db.query(VoiceRecord).filter(VoiceRecord.status == "ready").count() diff --git a/src/infrastructure/services/whisperx_audio_diarizer.py b/src/infrastructure/services/whisperx_audio_diarizer.py index 247edabe..a013df6b 100644 --- a/src/infrastructure/services/whisperx_audio_diarizer.py +++ b/src/infrastructure/services/whisperx_audio_diarizer.py @@ -76,9 +76,7 @@ def _align(self, result: dict, audio: np.ndarray, language: str | None) -> dict: logger.info("[2/3] Word alignment starting") lang = language or result.get("language", "en") try: - model_a, metadata = model_loader.get_align_model( - language_code=lang, device=self._device - ) + model_a, metadata = model_loader.get_align_model(language_code=lang, device=self._device) result = whisperx.align( result["segments"], model_a, @@ -152,9 +150,7 @@ def run( result_trans = self._transcribe(audio, language) result_aligned = self._align(result_trans, audio, language) - segments, lang = self._diarize( - audio, result_aligned, num_speakers, min_speakers, max_speakers - ) + segments, lang = self._diarize(audio, result_aligned, num_speakers, min_speakers, max_speakers) return DiarizationResult( segments=segments, diff --git a/src/infrastructure/services/youtube_data_process_service.py b/src/infrastructure/services/youtube_data_process_service.py index 922c3bea..3f74b38f 100644 --- a/src/infrastructure/services/youtube_data_process_service.py +++ b/src/infrastructure/services/youtube_data_process_service.py @@ -23,9 +23,7 @@ class YoutubeDataProcessService: Se não houver tokenizer, faz fallback para tiktoken. """ - def __init__( - self, model_loader_service: IModelLoaderService, yt_extractor: YoutubeExtractor - ): + def __init__(self, model_loader_service: IModelLoaderService, yt_extractor: YoutubeExtractor): self.model_loader_service: IModelLoaderService = model_loader_service self.yt_extractor = yt_extractor @@ -58,14 +56,10 @@ def split_transcript( return documents if mode == "time" or not tokens_per_chunk: - return self._split_by_time( - transcript, time_window_size, time_overlap, context - ) + return self._split_by_time(transcript, time_window_size, time_overlap, context) if mode == "tokens": - return self._split_by_tokens( - transcript, tokens_per_chunk, tokens_overlap, context - ) + return self._split_by_tokens(transcript, tokens_per_chunk, tokens_overlap, context) logger.error("Unknown splitting mode.", context={**context, "mode": mode}) raise ValueError(f"Unknown splitting mode: {mode}") @@ -93,11 +87,7 @@ def _split_by_time( for i in range(windows): start = i * step end = start + window_size - window_text = [ - self._get_text(snippet) - for snippet in transcript - if start <= self._get_start(snippet) < end - ] + window_text = [self._get_text(snippet) for snippet in transcript if start <= self._get_start(snippet) < end] if window_text: doc_context = { @@ -108,11 +98,7 @@ def _split_by_time( "window_text_length": len(window_text), } logger.debug("Creating document for time window", context=doc_context) - documents.append( - self._create_document( - window_text, start, end, self._get_video_id(transcript) - ) - ) + documents.append(self._create_document(window_text, start, end, self._get_video_id(transcript))) logger.debug( "Transcript split into windows", @@ -129,9 +115,7 @@ def _split_by_tokens( ) -> List[Document]: step = tokens_per_chunk - token_overlap if step <= 0: - logger.error( - "token_overlap must be smaller than tokens_per_chunk", context=context - ) + logger.error("token_overlap must be smaller than tokens_per_chunk", context=context) raise ValueError("token_overlap must be smaller than tokens_per_chunk") tokenizer = getattr(self.model_loader_service.model, "tokenizer", None) @@ -141,12 +125,8 @@ def _split_by_tokens( "No tokenizer available in the models. Please configure a tokenizer in model_loader_service.models." ) - token_ids, token_meta = self._tokenize_transcript( - transcript, tokenizer, context - ) - documents = self._create_token_chunks( - token_ids, token_meta, tokens_per_chunk, step, transcript, context - ) + token_ids, token_meta = self._tokenize_transcript(transcript, tokenizer, context) + documents = self._create_token_chunks(token_ids, token_meta, tokens_per_chunk, step, transcript, context) logger.debug( "Transcript split into token windows", context={**context, "token_windows_created": len(documents)}, @@ -176,9 +156,7 @@ def _encode(txt: str): end_time = start_time + (duration or 0.0) if not text: - logger.debug( - "Skipping empty snippet", context={**context, "snippet_index": idx} - ) + logger.debug("Skipping empty snippet", context={**context, "snippet_index": idx}) continue ids = _encode(text) @@ -188,12 +166,8 @@ def _encode(txt: str): ) for t_id in ids: token_ids.append(t_id) - token_meta.append( - {"start": start_time, "end": end_time, "snippet_index": idx} - ) - logger.debug( - "Tokenization complete", context={**context, "total_tokens": len(token_ids)} - ) + token_meta.append({"start": start_time, "end": end_time, "snippet_index": idx}) + logger.debug("Tokenization complete", context={**context, "total_tokens": len(token_ids)}) return token_ids, token_meta def _create_token_chunks( @@ -211,9 +185,7 @@ def _create_token_chunks( def _decode(_ids: list): try: - return self.model_loader_service.model.tokenizer.decode( - _ids, skip_special_tokens=True - ) + return self.model_loader_service.model.tokenizer.decode(_ids, skip_special_tokens=True) except TypeError: return self.model_loader_service.model.tokenizer.decode(_ids) except AttributeError: diff --git a/src/infrastructure/services/youtube_vector_service.py b/src/infrastructure/services/youtube_vector_service.py index ebb5919a..9331b35c 100644 --- a/src/infrastructure/services/youtube_vector_service.py +++ b/src/infrastructure/services/youtube_vector_service.py @@ -25,24 +25,18 @@ def index_documents(self, documents: List[ChunkEntity]) -> List[str]: return result - def search( - self, query: str, top_k: int = 5, filters: Optional[Any] = None - ) -> List[ChunkEntity]: + def search(self, query: str, top_k: int = 5, filters: Optional[Any] = None) -> List[ChunkEntity]: if not query: raise ValueError("Query must be provided for search") - models: List[ChunkModel] = self._repository.retriever( - query=query, top_kn=top_k, filters=filters - ) + models: List[ChunkModel] = self._repository.retriever(query=query, top_kn=top_k, filters=filters) mapper = ChunkMapper() entities: List[ChunkEntity] = [mapper.model_to_entity(doc) for doc in models] return entities - def search_by_video_id( - self, video_id: str, filters: Optional[Any] = None - ) -> List[ChunkEntity]: + def search_by_video_id(self, video_id: str, filters: Optional[Any] = None) -> List[ChunkEntity]: if not video_id: raise ValueError("video_id must be provided") @@ -55,9 +49,7 @@ def search_by_video_id( # we just pass it through, but we prefer dicts now. combined_filters = filters - models: List[ChunkModel] = self._repository.list_chunks( - filters=combined_filters - ) + models: List[ChunkModel] = self._repository.list_chunks(filters=combined_filters) mapper = ChunkMapper() entities: List[ChunkEntity] = [mapper.model_to_entity(doc) for doc in models] diff --git a/src/infrastructure/utils/audio_utils.py b/src/infrastructure/utils/audio_utils.py index 10f202bb..1070ae1d 100644 --- a/src/infrastructure/utils/audio_utils.py +++ b/src/infrastructure/utils/audio_utils.py @@ -21,9 +21,7 @@ def load_audio_tensor(audio_path: str) -> dict: if path.suffix.lower() not in {".wav", ".flac", ".ogg", ".opus"}: if not PYDUB_AVAILABLE: - raise ImportError( - f"To load '{path.suffix}' install pydub: pip install pydub" - ) + raise ImportError(f"To load '{path.suffix}' install pydub: pip install pydub") tmp_wav = str(path.with_suffix(".tmp_utils.wav")) AudioSegment.from_file(audio_path).export(tmp_wav, format="wav") read_path = tmp_wav diff --git a/src/presentation/api/routes/audio_diarization_and_recognition_router.py b/src/presentation/api/routes/audio_diarization_and_recognition_router.py index ca1b7f13..52e5eaf5 100644 --- a/src/presentation/api/routes/audio_diarization_and_recognition_router.py +++ b/src/presentation/api/routes/audio_diarization_and_recognition_router.py @@ -98,9 +98,7 @@ async def update_diarization_segments( language=cast(str, target_source.language or "pt"), source_type=cast(Optional[str], target_source.source_type), external_source=cast(Optional[str], target_source.external_source), - source_metadata=cast( - Optional[dict[str, Any]], target_source.source_metadata - ), + source_metadata=cast(Optional[dict[str, Any]], target_source.source_metadata), ) task_queue.enqueue( @@ -166,9 +164,7 @@ async def start_audio_processing_pipeline( min_speakers=request.min_speakers, max_speakers=request.max_speakers, model_size=request.model_size or "large-v2", - recognize_voices=request.recognize_voices - if request.recognize_voices is not None - else True, + recognize_voices=request.recognize_voices if request.recognize_voices is not None else True, subject_id=request.subject_id, ) task_queue.enqueue( @@ -232,9 +228,7 @@ async def start_audio_processing_pipeline( min_speakers=request.min_speakers, max_speakers=request.max_speakers, model_size=request.model_size or "large-v2", - recognize_voices=request.recognize_voices - if request.recognize_voices is not None - else True, + recognize_voices=request.recognize_voices if request.recognize_voices is not None else True, subject_id=request.subject_id, ) @@ -268,18 +262,14 @@ async def start_audio_processing_pipeline( ) async def identify_speakers_in_existing_diarization( diarization_id: str, - use_case: Annotated[ - IdentifySpeakersInProcessedAudioUseCase, Depends(get_identify_speakers_use_case) - ], + use_case: Annotated[IdentifySpeakersInProcessedAudioUseCase, Depends(get_identify_speakers_use_case)], ): logger.info("Speaker recognition request for diarization_id=%s", diarization_id) try: return use_case.execute(diarization_id) except ValueError as e: logger.warning("Recognition failed (ValueError): %s", str(e)) - raise HTTPException( - status_code=404 if "not found" in str(e) else 400, detail=str(e) - ) + raise HTTPException(status_code=404 if "not found" in str(e) else 400, detail=str(e)) except Exception as e: logger.error("Recognition failed: %s\n%s", str(e), traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @@ -319,26 +309,20 @@ async def list_available_s3_files_for_recognition( async def generate_signed_url_for_speaker_audio( diarization_id: str, speaker_label: str, - use_case: Annotated[ - GenerateSpeakerAudioAccessUrlUseCase, Depends(get_generate_speaker_url_use_case) - ], + use_case: Annotated[GenerateSpeakerAudioAccessUrlUseCase, Depends(get_generate_speaker_url_use_case)], ): try: return use_case.execute(diarization_id, speaker_label) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - logger.error( - "Presigned URL generation failed: %s\n%s", str(e), traceback.format_exc() - ) + logger.error("Presigned URL generation failed: %s\n%s", str(e), traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @router.get("") async def retrieve_all_processed_audio_history( - use_case: Annotated[ - RetrieveProcessedAudioHistoryUseCase, Depends(get_retrieve_history_use_case) - ], + use_case: Annotated[RetrieveProcessedAudioHistoryUseCase, Depends(get_retrieve_history_use_case)], limit: int = 10, offset: int = 0, subject_id: str | None = None, @@ -354,9 +338,7 @@ async def retrieve_all_processed_audio_history( logger.info("Audio history returned %d records", len(result)) return result except Exception as e: - logger.error( - "Failed to retrieve audio history: %s\n%s", str(e), traceback.format_exc() - ) + logger.error("Failed to retrieve audio history: %s\n%s", str(e), traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @@ -368,9 +350,7 @@ async def retrieve_all_processed_audio_history( ) async def delete_diarization_record( diarization_id: str, - use_case: Annotated[ - DeleteDiarizationUseCase, Depends(get_delete_diarization_use_case) - ], + use_case: Annotated[DeleteDiarizationUseCase, Depends(get_delete_diarization_use_case)], ): deleted = use_case.execute(diarization_id) if not deleted: diff --git a/src/presentation/api/routes/chunk_router.py b/src/presentation/api/routes/chunk_router.py index bf91c225..17297a56 100644 --- a/src/presentation/api/routes/chunk_router.py +++ b/src/presentation/api/routes/chunk_router.py @@ -30,9 +30,7 @@ def get_chunks( ): """Retrieve text chunks with optional filtering by source or search query""" try: - chunks = chunk_service.list_chunks( - limit=limit, offset=offset, source_id=source_id, search_query=q - ) + chunks = chunk_service.list_chunks(limit=limit, offset=offset, source_id=source_id, search_query=q) return chunks except Exception as e: logger.error(e, context={"action": "list_chunks"}) diff --git a/src/presentation/api/routes/ingest_router.py b/src/presentation/api/routes/ingest_router.py index c4c19f95..fbda4c14 100644 --- a/src/presentation/api/routes/ingest_router.py +++ b/src/presentation/api/routes/ingest_router.py @@ -90,8 +90,7 @@ def ingest_youtube( is_bulk = request.video_urls and len(request.video_urls) > 0 is_playlist = request.data_type == "playlist" is_channel = request.data_type == "channel" or ( - request.video_url - and any(x in request.video_url for x in ["/channel/", "/c/", "/user/", "@"]) + request.video_url and any(x in request.video_url for x in ["/channel/", "/c/", "/user/", "@"]) ) if is_channel: @@ -101,19 +100,13 @@ def ingest_youtube( logger.info( "Running ingestion in background via queue", context={ - "reason": "reprocess" - if request.reprocess - else ("bulk/playlist" if not is_channel else "channel") + "reason": "reprocess" if request.reprocess else ("bulk/playlist" if not is_channel else "channel") }, ) # Determine which worker to use: dispatcher for playlists/channels/bulk, ingestion for single is_bulk_processing = is_bulk or is_playlist or is_channel - worker = ( - run_youtube_dispatcher_worker - if is_bulk_processing - else run_youtube_ingestion_worker - ) + worker = run_youtube_dispatcher_worker if is_bulk_processing else run_youtube_ingestion_worker if request.reprocess or is_bulk_processing: # Determine the reason for background processing @@ -124,17 +117,13 @@ def ingest_youtube( else: reason = "bulk/playlist" - logger.info( - "Running ingestion in background via queue", context={"reason": reason} - ) + logger.info("Running ingestion in background via queue", context={"reason": reason}) task_queue.enqueue( worker, cmd, task_title=request.title or request.video_url or "YouTube Ingestion", - metadata={"job_id": str(request.ingestion_job_id)} - if request.ingestion_job_id - else {}, + metadata={"job_id": str(request.ingestion_job_id)} if request.ingestion_job_id else {}, ) return IngestResponse( skipped=False, @@ -155,9 +144,7 @@ def ingest_youtube( except HTTPException: raise except ValueError as ve: - logger.warning( - "Validation error in youtube ingestion", context={"error": str(ve)} - ) + logger.warning("Validation error in youtube ingestion", context={"error": str(ve)}) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, context={"action": "youtube_ingestion"}) @@ -361,13 +348,9 @@ async def ingest_web( task_title=request.title or request.url, metadata={"url": request.url}, ) - logger.info( - "Web ingestion task enqueued successfully", context={"url": request.url} - ) + logger.info("Web ingestion task enqueued successfully", context={"url": request.url}) except Exception as e: - logger.error( - f"Failed to enqueue web ingestion task: {e}", context={"url": request.url} - ) + logger.error(f"Failed to enqueue web ingestion task: {e}", context={"url": request.url}) raise HTTPException(status_code=500, detail=f"Failed to enqueue task: {str(e)}") return { @@ -387,9 +370,7 @@ async def ingest_web( ) async def ingest_diarization( request: Annotated[DiarizationIngestRequest, Body()], - use_case: Annotated[ - DiarizationIngestionUseCase, Depends(get_diarization_ingestion_use_case) - ], + use_case: Annotated[DiarizationIngestionUseCase, Depends(get_diarization_ingestion_use_case)], task_queue: Annotated[ITaskQueue, Depends(get_task_queue_service)], ): """ @@ -467,9 +448,7 @@ def preview_youtube_channel( from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor extractor = YoutubeExtractor() - videos, channel_name = extractor.extract_channel_videos( - request.channel_url.strip() - ) + videos, channel_name = extractor.extract_channel_videos(request.channel_url.strip()) if not videos: return ChannelPreviewResponse( diff --git a/src/presentation/api/routes/job_router.py b/src/presentation/api/routes/job_router.py index a215db8b..4996511f 100644 --- a/src/presentation/api/routes/job_router.py +++ b/src/presentation/api/routes/job_router.py @@ -45,9 +45,7 @@ def get_jobs( """Retrieve ingestion jobs with pagination and filtering""" try: offset = (page - 1) * page_size - result = job_service.list_jobs( - limit=page_size, offset=offset, status=status, search=search - ) + result = job_service.list_jobs(limit=page_size, offset=offset, status=status, search=search) return PaginatedJobsResponse( items=result["jobs"], total=result["total"], @@ -56,9 +54,7 @@ def get_jobs( stats=result["stats"], ) except Exception as e: - logger.error( - e, context={"action": "list_jobs", "page": page, "page_size": page_size} - ) + logger.error(e, context={"action": "list_jobs", "page": page, "page_size": page_size}) raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR) diff --git a/src/presentation/api/routes/notification_router.py b/src/presentation/api/routes/notification_router.py index 799e9384..110f8b85 100644 --- a/src/presentation/api/routes/notification_router.py +++ b/src/presentation/api/routes/notification_router.py @@ -33,9 +33,7 @@ async def event_generator() -> AsyncGenerator[dict, None]: try: while not await request.is_disconnected(): - message = await loop.run_in_executor( - None, lambda: pubsub.get_message(timeout=1.0) - ) + message = await loop.run_in_executor(None, lambda: pubsub.get_message(timeout=1.0)) if message and message["type"] == "message": data = json.loads(message["data"]) yield {"event": "message", "data": json.dumps(data)} diff --git a/src/presentation/api/routes/settings_router.py b/src/presentation/api/routes/settings_router.py index 71b7c53b..79bcebd6 100644 --- a/src/presentation/api/routes/settings_router.py +++ b/src/presentation/api/routes/settings_router.py @@ -25,9 +25,7 @@ def get_current_settings(settings: Annotated[Settings, Depends(get_settings)]): """Return current application settings (sanitized)""" return SettingsResponse( - app=AppSettingsSchema( - env=settings.app.env, log_levels=", ".join(settings.app.list_log_levels) - ), + app=AppSettingsSchema(env=settings.app.env, log_levels=", ".join(settings.app.list_log_levels)), vector=VectorSettingsSchema( store_type=settings.vector.store_type.value, weaviate_host=settings.vector.weaviate_host, @@ -69,9 +67,7 @@ def check_component_health( try: if component == "api": latency = int((time.time() - start_time) * 1000) - return HealthCheckResponse( - status="success", latency_ms=latency, message="API is responding" - ) + return HealthCheckResponse(status="success", latency_ms=latency, message="API is responding") elif component == "sql": with Connector() as session: @@ -110,9 +106,7 @@ def check_component_health( message="Vector store is ready", ) else: - return HealthCheckResponse( - status="error", message="Vector store is not ready" - ) + return HealthCheckResponse(status="error", message="Vector store is not ready") elif component == "model": # For now, we just assume it's okay if it loaded correctly at startup @@ -125,9 +119,7 @@ def check_component_health( ) else: - raise HTTPException( - status_code=400, detail=f"Unknown component: {component}" - ) + raise HTTPException(status_code=400, detail=f"Unknown component: {component}") except Exception as e: return HealthCheckResponse(status="error", message=str(e)) diff --git a/src/presentation/api/routes/subject_router.py b/src/presentation/api/routes/subject_router.py index d01d03d2..07a05193 100644 --- a/src/presentation/api/routes/subject_router.py +++ b/src/presentation/api/routes/subject_router.py @@ -112,9 +112,7 @@ def update_subject( }, ) except Exception as e: - logger.error( - e, context={"action": "update_subject", "subject_id": str(subject_id)} - ) + logger.error(e, context={"action": "update_subject", "subject_id": str(subject_id)}) raise HTTPException(status_code=500, detail=INTERNAL_ERROR) @@ -145,7 +143,5 @@ def delete_subject( except HTTPException: raise except Exception as e: - logger.error( - e, context={"action": "delete_subject", "subject_id": str(subject_id)} - ) + logger.error(e, context={"action": "delete_subject", "subject_id": str(subject_id)}) raise HTTPException(status_code=500, detail=INTERNAL_ERROR) diff --git a/src/presentation/api/routes/voice_profile_management_router.py b/src/presentation/api/routes/voice_profile_management_router.py index 5496a46a..d2b845bc 100644 --- a/src/presentation/api/routes/voice_profile_management_router.py +++ b/src/presentation/api/routes/voice_profile_management_router.py @@ -14,7 +14,6 @@ RegisterNewVoiceProfileUseCase, ) from src.application.workers import run_voice_training_worker -from src.domain.entities.enums.diarization_status_enum import DiarizationStatus from src.domain.interfaces.services.i_event_bus import IEventBus from src.domain.interfaces.services.i_task_queue import ITaskQueue from src.infrastructure.repositories.sql.diarization_repository import ( @@ -43,9 +42,7 @@ async def register_new_voice_profile( request: VoiceProfileRegistrationRequest, event_bus: Annotated[IEventBus, Depends(get_event_bus)], - use_case: Annotated[ - RegisterNewVoiceProfileUseCase, Depends(get_register_voice_profile_use_case) - ], + use_case: Annotated[RegisterNewVoiceProfileUseCase, Depends(get_register_voice_profile_use_case)], ): try: voice_id = use_case.execute(request.name, request.audio_path) @@ -62,17 +59,17 @@ async def register_new_voice_profile( @router.post("/upload", responses={400: {"description": "Bad Request"}}) async def upload_and_register_new_voice_profile( event_bus: Annotated[IEventBus, Depends(get_event_bus)], - use_case: Annotated[ - RegisterNewVoiceProfileUseCase, Depends(get_register_voice_profile_use_case) - ], + use_case: Annotated[RegisterNewVoiceProfileUseCase, Depends(get_register_voice_profile_use_case)], name: str = Form(...), - file: UploadFile = File(...), + file: UploadFile | None = File(None), ): - if not file.filename: + if not file or not file.filename or not file.filename.strip(): raise HTTPException(status_code=400, detail="No filename provided") temp_dir = tempfile.mkdtemp() - temp_path = os.path.join(temp_dir, file.filename) + # Sanitize the filename to avoid path traversal or restricted character issues + filename = os.path.basename(file.filename) + temp_path = os.path.join(temp_dir, filename) try: async with await anyio.open_file(temp_path, "wb") as buffer: while content := await file.read(1024 * 1024): # 1MB chunks @@ -80,9 +77,7 @@ async def upload_and_register_new_voice_profile( voice_id = use_case.execute(name, temp_path) # Notify - event_bus.publish( - "ingestion_status", {"type": "voice", "action": "register", "name": name} - ) + event_bus.publish("ingestion_status", {"type": "voice", "action": "register", "name": name}) return {"status": "success", "voice_id": voice_id, "name": name} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @@ -113,29 +108,23 @@ async def train_voice_profile_from_existing_speaker_segment( # 1. Verification record = repo.get_by_id(request.diarization_id) if not record: - raise HTTPException( - status_code=404, detail=f"Diarization not found: {request.diarization_id}" - ) - - # 2. Update status to TRAINING immediately - repo.update_status( - request.diarization_id, - DiarizationStatus.TRAINING.value, - status_message=f"Preparando treinamento de voz: {request.name}", - ) + raise HTTPException(status_code=404, detail=f"Diarization not found: {request.diarization_id}") - # 3. Notify frontend + # 2. Notify frontend (voice-scoped — do NOT touch the diarization record's + # status, since voice training is orthogonal to the diarization lifecycle + # and mutating the record's status corrupts its real state on reload) event_bus.publish( "ingestion_status", { - "type": "diarization", - "id": request.diarization_id, - "status": DiarizationStatus.TRAINING.value, - "message": f"Iniciando treinamento de voz '{request.name}'...", + "type": "voice", + "action": "train_started", + "name": request.name, + "diarization_id": request.diarization_id, + "speaker_label": request.speaker_label, }, ) - # 4. Enqueue background task + # 3. Enqueue background task cmd = TrainVoiceCommand( diarization_id=request.diarization_id, speaker_label=request.speaker_label, @@ -161,9 +150,7 @@ async def train_voice_profile_from_existing_speaker_segment( @router.get("") async def list_all_registered_voice_profiles( - use_case: Annotated[ - ListRegisteredVoiceProfilesUseCase, Depends(get_list_voice_profiles_use_case) - ], + use_case: Annotated[ListRegisteredVoiceProfilesUseCase, Depends(get_list_voice_profiles_use_case)], ): return use_case.execute() @@ -172,16 +159,12 @@ async def list_all_registered_voice_profiles( async def delete_existing_voice_profile( name: str, event_bus: Annotated[IEventBus, Depends(get_event_bus)], - use_case: Annotated[ - DeleteVoiceProfileUseCase, Depends(get_delete_voice_profile_use_case) - ], + use_case: Annotated[DeleteVoiceProfileUseCase, Depends(get_delete_voice_profile_use_case)], ): try: use_case.execute(name) # Notify - event_bus.publish( - "ingestion_status", {"type": "voice", "action": "delete", "name": name} - ) + event_bus.publish("ingestion_status", {"type": "voice", "action": "delete", "name": name}) return { "status": "success", "message": f"Voice profile '{name}' successfully removed", @@ -193,9 +176,7 @@ async def delete_existing_voice_profile( @router.get("/{voice_id}/audios") async def list_voice_audio_files( voice_id: str, - use_case: Annotated[ - ListVoiceAudioFilesUseCase, Depends(get_list_voice_audio_files_use_case) - ], + use_case: Annotated[ListVoiceAudioFilesUseCase, Depends(get_list_voice_audio_files_use_case)], ): return use_case.execute(voice_id) @@ -204,16 +185,12 @@ async def list_voice_audio_files( async def delete_voice_audio_file( s3_key: str, event_bus: Annotated[IEventBus, Depends(get_event_bus)], - use_case: Annotated[ - DeleteVoiceAudioFileUseCase, Depends(get_delete_voice_audio_file_use_case) - ], + use_case: Annotated[DeleteVoiceAudioFileUseCase, Depends(get_delete_voice_audio_file_use_case)], ): try: use_case.execute(s3_key) # Notify (voice updated) - event_bus.publish( - "ingestion_status", {"type": "voice", "action": "audio_deleted"} - ) + event_bus.publish("ingestion_status", {"type": "voice", "action": "audio_deleted"}) return {"status": "success", "message": "Audio file deleted"} except Exception as e: raise HTTPException(status_code=404, detail=str(e)) diff --git a/tests/application/test_audio_diarization_workers.py b/tests/application/test_audio_diarization_workers.py index 75883d8e..1509237c 100644 --- a/tests/application/test_audio_diarization_workers.py +++ b/tests/application/test_audio_diarization_workers.py @@ -18,9 +18,7 @@ def mock_app(self): def mock_db_session(self): return MagicMock() - def test_run_audio_diarization_dispatcher_worker_deduplication( - self, mock_app, mock_db_session - ): + def test_run_audio_diarization_dispatcher_worker_deduplication(self, mock_app, mock_db_session): # 1. Setup command cmd = ProcessAudioCommand( source_type="youtube", @@ -40,16 +38,12 @@ def test_run_audio_diarization_dispatcher_worker_deduplication( with ( patch("src.application.workers.registry.get", return_value=mock_app), - patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeExtractor" - ) as mock_extractor_cls, + patch("src.infrastructure.extractors.youtube_extractor.YoutubeExtractor") as mock_extractor_cls, patch( "src.infrastructure.repositories.sql.connector.Session", return_value=mock_db_session, ), - patch( - "src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository" - ) as mock_repo_cls, + patch("src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository") as mock_repo_cls, ): mock_extractor = mock_extractor_cls.return_value mock_extractor.extract_playlist_videos.return_value = video_urls @@ -83,9 +77,7 @@ def test_run_audio_diarization_dispatcher_worker_deduplication( single_cmd = args[1] assert single_cmd.source == "https://youtube.com/watch?v=v2" - def test_run_audio_diarization_dispatcher_worker_retry_failed( - self, mock_app, mock_db_session - ): + def test_run_audio_diarization_dispatcher_worker_retry_failed(self, mock_app, mock_db_session): # 1. Setup command cmd = ProcessAudioCommand( source_type="youtube", @@ -98,22 +90,16 @@ def test_run_audio_diarization_dispatcher_worker_retry_failed( with ( patch("src.application.workers.registry.get", return_value=mock_app), - patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeExtractor" - ) as mock_extractor_cls, + patch("src.infrastructure.extractors.youtube_extractor.YoutubeExtractor") as mock_extractor_cls, patch( "src.infrastructure.repositories.sql.connector.Session", return_value=mock_db_session, ), - patch( - "src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository" - ) as mock_repo_cls, + patch("src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository") as mock_repo_cls, ): mock_extractor = mock_extractor_cls.return_value # Mock playlist extraction to return one video - mock_extractor.extract_playlist_videos.return_value = [ - "https://youtube.com/watch?v=v1" - ] + mock_extractor.extract_playlist_videos.return_value = ["https://youtube.com/watch?v=v1"] mock_repo = mock_repo_cls.return_value mock_repo.get_by_external_source.return_value = failed_record diff --git a/tests/application/test_workers.py b/tests/application/test_workers.py index b85dee87..25b4ce5d 100644 --- a/tests/application/test_workers.py +++ b/tests/application/test_workers.py @@ -27,61 +27,41 @@ def setup_registry(self): def test_run_file_ingestion_worker_success(self): with ( - patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx, + patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx, patch("src.presentation.api.dependencies.resolve_vector_repository"), patch("src.presentation.api.dependencies.resolve_rerank_service"), - patch( - "src.infrastructure.services.chunk_vector_service.ChunkVectorService" - ), - patch( - "src.application.use_cases.file_ingestion_use_case.FileIngestionUseCase" - ) as mock_use_case_cls, + patch("src.infrastructure.services.chunk_vector_service.ChunkVectorService"), + patch("src.application.use_cases.file_ingestion_use_case.FileIngestionUseCase") as mock_use_case_cls, ): mock_use_case = MagicMock() mock_use_case_cls.return_value = mock_use_case mock_ctx.return_value = MagicMock() - cmd = IngestFileCommand( - file_path="test.pdf", file_name="test.pdf", subject_name="test" - ) + cmd = IngestFileCommand(file_path="test.pdf", file_name="test.pdf", subject_name="test") run_file_ingestion_worker(cmd) mock_use_case.execute.assert_called_once_with(cmd) def test_run_file_ingestion_worker_no_app(self): registry._services = {} # Remove app - cmd = IngestFileCommand( - file_path="test.pdf", file_name="test.pdf", subject_name="test" - ) + cmd = IngestFileCommand(file_path="test.pdf", file_name="test.pdf", subject_name="test") with patch("src.application.workers.logger") as mock_logger: run_file_ingestion_worker(cmd) mock_logger.error.assert_called_once() def test_run_file_ingestion_worker_exception(self): - with patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx: + with patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx: mock_ctx.side_effect = Exception("Test error") - cmd = IngestFileCommand( - file_path="test.pdf", file_name="test.pdf", subject_name="test" - ) + cmd = IngestFileCommand(file_path="test.pdf", file_name="test.pdf", subject_name="test") with patch("src.application.workers.logger") as mock_logger: run_file_ingestion_worker(cmd) mock_logger.error.assert_called_once() def test_run_youtube_ingestion_worker_success(self): with ( - patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx, + patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx, patch("src.presentation.api.dependencies.resolve_vector_repository"), - patch( - "src.infrastructure.services.youtube_vector_service.YouTubeVectorService" - ), - patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeIngestionUseCase" - ) as mock_use_case_cls, + patch("src.infrastructure.services.youtube_vector_service.YouTubeVectorService"), + patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeIngestionUseCase") as mock_use_case_cls, ): mock_use_case = MagicMock() mock_use_case_cls.return_value = mock_use_case @@ -106,9 +86,7 @@ def test_run_youtube_ingestion_worker_no_app(self): # Should return silently def test_run_youtube_ingestion_worker_exception(self): - with patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx: + with patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx: mock_ctx.side_effect = Exception("Test error") cmd = IngestYoutubeCommand( video_url="https://youtube.com/watch?v=123", @@ -121,18 +99,12 @@ def test_run_youtube_ingestion_worker_exception(self): def test_run_web_ingestion_worker_success(self): with ( - patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx, + patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx, patch("src.presentation.api.dependencies.resolve_vector_repository"), patch("src.presentation.api.dependencies.resolve_rerank_service"), - patch( - "src.infrastructure.services.chunk_vector_service.ChunkVectorService" - ), + patch("src.infrastructure.services.chunk_vector_service.ChunkVectorService"), patch("src.presentation.api.dependencies.get_web_extractor"), - patch( - "src.application.use_cases.web_scraping_use_case.WebScrapingUseCase" - ) as mock_use_case_cls, + patch("src.application.use_cases.web_scraping_use_case.WebScrapingUseCase") as mock_use_case_cls, patch("asyncio.run") as mock_asyncio_run, ): mock_use_case = MagicMock() @@ -164,9 +136,7 @@ def test_run_web_ingestion_worker_no_app(self): def test_run_web_ingestion_worker_exception(self): with ( - patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx, + patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx, patch("asyncio.run") as mock_asyncio_run, ): mock_ctx.side_effect = Exception("Test error") @@ -197,20 +167,12 @@ def test_run_diarization_ingestion_worker_success(self): from src.application.workers import run_diarization_ingestion_worker with ( - patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx, + patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx, patch("src.presentation.api.dependencies.resolve_vector_repository"), patch("src.presentation.api.dependencies.resolve_rerank_service"), - patch( - "src.infrastructure.services.chunk_vector_service.ChunkVectorService" - ), - patch( - "src.infrastructure.repositories.sql.connector.Session" - ) as mock_session_cls, - patch( - "src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository" - ), + patch("src.infrastructure.services.chunk_vector_service.ChunkVectorService"), + patch("src.infrastructure.repositories.sql.connector.Session") as mock_session_cls, + patch("src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository"), patch( "src.application.use_cases.diarization_ingestion_use_case.DiarizationIngestionUseCase" ) as mock_use_case_cls, @@ -238,9 +200,7 @@ def test_run_diarization_ingestion_worker_no_app(self): def test_run_diarization_ingestion_worker_exception(self): from src.application.workers import run_diarization_ingestion_worker - with patch( - "src.presentation.api.dependencies.resolve_ingestion_context" - ) as mock_ctx: + with patch("src.presentation.api.dependencies.resolve_ingestion_context") as mock_ctx: mock_ctx.side_effect = Exception("Test error") cmd = MagicMock() with patch("src.application.workers.logger") as mock_logger: @@ -251,9 +211,7 @@ def test_audio_diarization_subprocess_success(self): from src.application.workers import _audio_diarization_subprocess with ( - patch( - "src.infrastructure.repositories.sql.connector.Session" - ) as mock_session_cls, + patch("src.infrastructure.repositories.sql.connector.Session") as mock_session_cls, patch("src.infrastructure.services.redis_event_bus.RedisEventBus"), patch( "src.application.use_cases.process_audio_diarization_pipeline.ProcessAudioDiarizationPipelineUseCase" @@ -308,12 +266,8 @@ def test_run_audio_diarization_worker_failure(self): with ( patch("multiprocessing.get_context") as mock_get_ctx, - patch( - "src.infrastructure.repositories.sql.connector.Session" - ) as mock_session_factory, - patch( - "src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository" - ) as mock_repo_cls, + patch("src.infrastructure.repositories.sql.connector.Session") as mock_session_factory, + patch("src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository") as mock_repo_cls, patch("src.infrastructure.services.redis_event_bus.RedisEventBus"), ): mock_ctx = MagicMock() @@ -327,11 +281,7 @@ def test_run_audio_diarization_worker_failure(self): mock_repo = MagicMock() mock_repo_cls.return_value = mock_repo - cmd = ProcessAudioCommand( - source_type="youtube", source="url", diarization_id="test-id" - ) + cmd = ProcessAudioCommand(source_type="youtube", source="url", diarization_id="test-id") run_audio_diarization_worker(cmd) - mock_repo.update_status.assert_called_with( - "test-id", "failed", error_message=ANY, status_message=ANY - ) + mock_repo.update_status.assert_called_with("test-id", "failed", error_message=ANY, status_message=ANY) diff --git a/tests/application/use_cases/test_audio_recognition_use_cases.py b/tests/application/use_cases/test_audio_recognition_use_cases.py index 454a850f..764352ae 100644 --- a/tests/application/use_cases/test_audio_recognition_use_cases.py +++ b/tests/application/use_cases/test_audio_recognition_use_cases.py @@ -35,12 +35,8 @@ class TestAudioRecognitionUseCases: def mock_infra_and_fs(self): # Stub StorageService and FS logic globally for this class with ( - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.StorageService" - ), - patch( - "src.application.use_cases.generate_speaker_audio_access_url.StorageService" - ), + patch("src.application.use_cases.identify_speakers_in_processed_audio.StorageService"), + patch("src.application.use_cases.generate_speaker_audio_access_url.StorageService"), patch("src.application.use_cases.list_s3_audio_files.StorageService"), patch("src.application.use_cases.manage_voice_profiles.StorageService"), patch("src.infrastructure.services.voice_profile_service.StorageService"), @@ -70,9 +66,7 @@ def test_generate_speaker_url(self, sqlite_memory): sqlite_memory.commit() use_case = GenerateSpeakerAudioAccessUrlUseCase(sqlite_memory) - with patch.object( - use_case.storage, "get_presigned_url", return_value="http://p" - ): + with patch.object(use_case.storage, "get_presigned_url", return_value="http://p"): result = use_case.execute("1", "S0") assert result["url"] == "http://p" @@ -96,16 +90,12 @@ def test_list_s3_files(self, sqlite_memory): sqlite_memory.commit() use_case = ListS3AudioFilesUseCase(sqlite_memory) - with patch.object( - use_case.storage, "list_files", return_value=[{"key": "f1.wav"}] - ): + with patch.object(use_case.storage, "list_files", return_value=[{"key": "f1.wav"}]): files = use_case.execute("1") assert len(files) == 1 def test_register_voice_profile(self, sqlite_memory): - with patch( - "src.application.use_cases.manage_voice_profiles.VoiceDB" - ) as mock_vdb_cls: + with patch("src.application.use_cases.manage_voice_profiles.VoiceDB") as mock_vdb_cls: mock_vdb = mock_vdb_cls.return_value mock_vdb.add.return_value = ("v-123", "voices/v-123/sample.wav") @@ -120,19 +110,13 @@ def test_identify_speakers(self, mock_rm, sqlite_memory): sqlite_memory.commit() with ( - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB" - ) as mock_db_cls, - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceRecognizer" - ) as mock_rec_cls, + patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB") as mock_db_cls, + patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceRecognizer") as mock_rec_cls, ): mock_db = mock_db_cls.return_value mock_db.__len__.return_value = 1 mock_rec = mock_rec_cls.return_value - mock_rec.identify_dir.return_value = MagicMock( - mapping={"S0": "N"}, id_mapping={"S0": "id"}, results={} - ) + mock_rec.identify_dir.return_value = MagicMock(mapping={"S0": "N"}, id_mapping={"S0": "id"}, results={}) use_case = IdentifySpeakersInProcessedAudioUseCase(sqlite_memory) res = use_case.execute("1") @@ -148,9 +132,7 @@ def test_identify_speakers_empty_db(self, sqlite_memory): sqlite_memory.add(record) sqlite_memory.commit() - with patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB" - ) as mock_db_cls: + with patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB") as mock_db_cls: mock_db = mock_db_cls.return_value mock_db.__len__.return_value = 0 use_case = IdentifySpeakersInProcessedAudioUseCase(sqlite_memory) @@ -162,9 +144,7 @@ def test_identify_speakers_no_storage_path(self, sqlite_memory): sqlite_memory.add(record) sqlite_memory.commit() - with patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB" - ) as mock_db_cls: + with patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB") as mock_db_cls: mock_db = mock_db_cls.return_value mock_db.__len__.return_value = 1 use_case = IdentifySpeakersInProcessedAudioUseCase(sqlite_memory) @@ -179,31 +159,21 @@ def test_identify_speakers_cleanup_error(self, mock_rm, sqlite_memory): mock_rm.side_effect = Exception("Cleanup failed") with ( - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB" - ) as mock_db_cls, - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.VoiceRecognizer" - ) as mock_rec_cls, - patch( - "src.application.use_cases.identify_speakers_in_processed_audio.logger" - ) as mock_logger, + patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceDB") as mock_db_cls, + patch("src.application.use_cases.identify_speakers_in_processed_audio.VoiceRecognizer") as mock_rec_cls, + patch("src.application.use_cases.identify_speakers_in_processed_audio.logger") as mock_logger, ): mock_db = mock_db_cls.return_value mock_db.__len__.return_value = 1 mock_rec = mock_rec_cls.return_value - mock_rec.identify_dir.return_value = MagicMock( - mapping={"S0": "N"}, id_mapping={"S0": "id"}, results={} - ) + mock_rec.identify_dir.return_value = MagicMock(mapping={"S0": "N"}, id_mapping={"S0": "id"}, results={}) use_case = IdentifySpeakersInProcessedAudioUseCase(sqlite_memory) use_case.execute("5") mock_logger.warning.assert_called() def test_delete_voice_profile(self, sqlite_memory): - with patch( - "src.application.use_cases.manage_voice_profiles.VoiceDB" - ) as mock_vdb_cls: + with patch("src.application.use_cases.manage_voice_profiles.VoiceDB") as mock_vdb_cls: mock_vdb = mock_vdb_cls.return_value use_case = DeleteVoiceProfileUseCase(sqlite_memory) use_case.execute(name="N") @@ -223,9 +193,7 @@ def test_register_voice_profile_no_name(self, sqlite_memory): use_case.execute(name="", audio_path="a") def test_list_voice_audio_files(self, sqlite_memory): - with patch( - "src.application.use_cases.manage_voice_profiles.VoiceDB" - ) as mock_vdb_cls: + with patch("src.application.use_cases.manage_voice_profiles.VoiceDB") as mock_vdb_cls: mock_vdb = mock_vdb_cls.return_value mock_vdb.list_audio_files.return_value = [{"key": "test.wav"}] @@ -235,9 +203,7 @@ def test_list_voice_audio_files(self, sqlite_memory): assert res[0]["key"] == "test.wav" def test_delete_voice_audio_file(self, sqlite_memory): - with patch( - "src.application.use_cases.manage_voice_profiles.VoiceDB" - ) as mock_vdb_cls: + with patch("src.application.use_cases.manage_voice_profiles.VoiceDB") as mock_vdb_cls: mock_vdb = mock_vdb_cls.return_value use_case = DeleteVoiceAudioFileUseCase(sqlite_memory) use_case.execute(s3_key="path/test.wav") diff --git a/tests/application/use_cases/test_auth_use_case.py b/tests/application/use_cases/test_auth_use_case.py index 5311e4a7..ec873422 100644 --- a/tests/application/use_cases/test_auth_use_case.py +++ b/tests/application/use_cases/test_auth_use_case.py @@ -36,13 +36,9 @@ def test_get_login_url(self, use_case, mock_service): assert len(state) > 0 @pytest.mark.asyncio - async def test_handle_google_callback_new_user( - self, use_case, mock_repo, mock_service - ): + async def test_handle_google_callback_new_user(self, use_case, mock_repo, mock_service): # 1. Mock token exchange - mock_service.exchange_code_for_token = AsyncMock( - return_value={"access_token": "abc"} - ) + mock_service.exchange_code_for_token = AsyncMock(return_value={"access_token": "abc"}) # 2. Mock user info mock_service.get_google_user_info = AsyncMock( @@ -68,20 +64,14 @@ async def test_handle_google_callback_new_user( mock_repo.create.assert_called_once() @pytest.mark.asyncio - async def test_handle_google_callback_invalid_state( - self, use_case, mock_repo, mock_service - ): + async def test_handle_google_callback_invalid_state(self, use_case, mock_repo, mock_service): with pytest.raises(InvalidStateError, match="Invalid authentication state"): await use_case.handle_google_callback("test_code", "received", "expected") @pytest.mark.asyncio - async def test_handle_google_callback_existing_user( - self, use_case, mock_repo, mock_service - ): + async def test_handle_google_callback_existing_user(self, use_case, mock_repo, mock_service): # 1. Mock token exchange - mock_service.exchange_code_for_token = AsyncMock( - return_value={"access_token": "abc"} - ) + mock_service.exchange_code_for_token = AsyncMock(return_value={"access_token": "abc"}) # 2. Mock user info mock_service.get_google_user_info = AsyncMock( @@ -89,9 +79,7 @@ async def test_handle_google_callback_existing_user( ) # 3. Mock repository (found -> update login) - existing_user = User( - id="u2", email="existing@example.com", full_name="Existing User" - ) + existing_user = User(id="u2", email="existing@example.com", full_name="Existing User") mock_repo.get_by_email.return_value = existing_user mock_repo.update_last_login.return_value = existing_user @@ -121,23 +109,15 @@ async def test_handle_google_callback_no_access_token(self, use_case, mock_servi @pytest.mark.asyncio async def test_handle_google_callback_missing_info(self, use_case, mock_service): - mock_service.exchange_code_for_token = AsyncMock( - return_value={"access_token": "abc"} - ) + mock_service.exchange_code_for_token = AsyncMock(return_value={"access_token": "abc"}) mock_service.get_google_user_info = AsyncMock(return_value={"email": "e@e.c"}) with pytest.raises(GoogleAuthError, match="Google user info missing"): await use_case.handle_google_callback("code", "s", "s") @pytest.mark.asyncio - async def test_handle_google_callback_user_not_created( - self, use_case, mock_repo, mock_service - ): - mock_service.exchange_code_for_token = AsyncMock( - return_value={"access_token": "abc"} - ) - mock_service.get_google_user_info = AsyncMock( - return_value={"email": "e@e.c", "name": "N"} - ) + async def test_handle_google_callback_user_not_created(self, use_case, mock_repo, mock_service): + mock_service.exchange_code_for_token = AsyncMock(return_value={"access_token": "abc"}) + mock_service.get_google_user_info = AsyncMock(return_value={"email": "e@e.c", "name": "N"}) mock_repo.get_by_email.return_value = None mock_repo.create.return_value = None with pytest.raises(UserNotCreatedError, match="Failed to create or retrieve"): diff --git a/tests/application/use_cases/test_content_source_use_case.py b/tests/application/use_cases/test_content_source_use_case.py index 38ed4a39..8529e08e 100644 --- a/tests/application/use_cases/test_content_source_use_case.py +++ b/tests/application/use_cases/test_content_source_use_case.py @@ -31,9 +31,7 @@ def test_delete_source_success(use_case, mock_services): assert success is True mock_services["cs_service"].get_by_id.assert_called_once_with(source_id) - mock_services["chunk_service"].delete_by_content_source.assert_called_once_with( - source_id - ) + mock_services["chunk_service"].delete_by_content_source.assert_called_once_with(source_id) mock_services["vector_repo"].delete.assert_called_once() mock_services["cs_service"].delete_source.assert_called_once_with(source_id) @@ -52,9 +50,7 @@ def test_delete_source_not_found(use_case, mock_services): def test_delete_source_exception(use_case, mock_services): source_id = uuid.uuid4() mock_services["cs_service"].get_by_id.return_value = MagicMock(id=source_id) - mock_services["chunk_service"].delete_by_content_source.side_effect = Exception( - "DB error" - ) + mock_services["chunk_service"].delete_by_content_source.side_effect = Exception("DB error") with pytest.raises(Exception, match="DB error"): use_case.delete(source_id) diff --git a/tests/application/use_cases/test_delete_diarization_use_case.py b/tests/application/use_cases/test_delete_diarization_use_case.py index 125f7b37..486ec0bd 100644 --- a/tests/application/use_cases/test_delete_diarization_use_case.py +++ b/tests/application/use_cases/test_delete_diarization_use_case.py @@ -14,18 +14,14 @@ class TestDeleteDiarizationUseCase: @pytest.fixture def mock_storage(self): - with patch( - "src.application.use_cases.delete_diarization_use_case.StorageService" - ) as mock: + with patch("src.application.use_cases.delete_diarization_use_case.StorageService") as mock: instance = mock.return_value instance.bucket = "test-bucket" yield instance @pytest.fixture def mock_cs_service(self): - return patch( - "src.infrastructure.services.content_source_service.ContentSourceService" - ).start() + return patch("src.infrastructure.services.content_source_service.ContentSourceService").start() def test_execute_success(self, sqlite_memory, mock_storage, mock_cs_service): # Setup record @@ -39,9 +35,7 @@ def test_execute_success(self, sqlite_memory, mock_storage, mock_cs_service): sqlite_memory.add(record) sqlite_memory.commit() - use_case = DeleteDiarizationUseCase( - sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage - ) + use_case = DeleteDiarizationUseCase(sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage) with ( patch("os.path.exists", return_value=True), @@ -55,15 +49,11 @@ def test_execute_success(self, sqlite_memory, mock_storage, mock_cs_service): mock_rmtree.assert_called_once_with("/tmp/local/folder") # Verify DB deletion - deleted_record = ( - sqlite_memory.query(DiarizationRecord).filter_by(id="test-id").first() - ) + deleted_record = sqlite_memory.query(DiarizationRecord).filter_by(id="test-id").first() assert deleted_record is None def test_execute_not_found(self, sqlite_memory, mock_storage, mock_cs_service): - use_case = DeleteDiarizationUseCase( - sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage - ) + use_case = DeleteDiarizationUseCase(sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage) result = use_case.execute("non-existent") assert result is False @@ -78,24 +68,16 @@ def test_execute_no_paths(self, sqlite_memory, mock_storage, mock_cs_service): sqlite_memory.add(record) sqlite_memory.commit() - use_case = DeleteDiarizationUseCase( - sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage - ) + use_case = DeleteDiarizationUseCase(sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage) result = use_case.execute("test-id-no-paths") assert result is True mock_storage.delete_directory.assert_not_called() - deleted_record = ( - sqlite_memory.query(DiarizationRecord) - .filter_by(id="test-id-no-paths") - .first() - ) + deleted_record = sqlite_memory.query(DiarizationRecord).filter_by(id="test-id-no-paths").first() assert deleted_record is None - def test_execute_s3_error_continues( - self, sqlite_memory, mock_storage, mock_cs_service - ): + def test_execute_s3_error_continues(self, sqlite_memory, mock_storage, mock_cs_service): record = DiarizationRecord( id="test-id-s3-error", name="Test", @@ -107,25 +89,17 @@ def test_execute_s3_error_continues( mock_storage.delete_directory.side_effect = Exception("S3 Delete Failed") - use_case = DeleteDiarizationUseCase( - sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage - ) + use_case = DeleteDiarizationUseCase(sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage) result = use_case.execute("test-id-s3-error") # Should still return true because DB deletion succeeds assert result is True mock_storage.delete_directory.assert_called_once() - deleted_record = ( - sqlite_memory.query(DiarizationRecord) - .filter_by(id="test-id-s3-error") - .first() - ) + deleted_record = sqlite_memory.query(DiarizationRecord).filter_by(id="test-id-s3-error").first() assert deleted_record is None - def test_execute_local_file_instead_of_dir( - self, sqlite_memory, mock_storage, mock_cs_service - ): + def test_execute_local_file_instead_of_dir(self, sqlite_memory, mock_storage, mock_cs_service): record = DiarizationRecord( id="test-id-file", name="Test", @@ -135,9 +109,7 @@ def test_execute_local_file_instead_of_dir( sqlite_memory.add(record) sqlite_memory.commit() - use_case = DeleteDiarizationUseCase( - sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage - ) + use_case = DeleteDiarizationUseCase(sqlite_memory, cs_service=mock_cs_service, storage_service=mock_storage) with ( patch("os.path.exists", return_value=True), @@ -149,9 +121,5 @@ def test_execute_local_file_instead_of_dir( assert result is True mock_remove.assert_called_once_with("/tmp/local/file.txt") - deleted_record = ( - sqlite_memory.query(DiarizationRecord) - .filter_by(id="test-id-file") - .first() - ) + deleted_record = sqlite_memory.query(DiarizationRecord).filter_by(id="test-id-file").first() assert deleted_record is None diff --git a/tests/application/use_cases/test_diarization_ingestion_use_case.py b/tests/application/use_cases/test_diarization_ingestion_use_case.py index 3332ef90..9fa2dba7 100644 --- a/tests/application/use_cases/test_diarization_ingestion_use_case.py +++ b/tests/application/use_cases/test_diarization_ingestion_use_case.py @@ -63,9 +63,7 @@ def test_execute_success(self, use_case_deps, sqlite_memory): db.commit() subject_id = uuid4() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=subject_id - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=subject_id) job_mock = MagicMock(id=uuid4()) use_case_deps["ingestion_service"].create_job.return_value = job_mock @@ -75,9 +73,7 @@ def test_execute_success(self, use_case_deps, sqlite_memory): use_case_deps["cs_service"].create_source.return_value = MagicMock(id=source_id) use_case_deps["vector_service"].index_documents.return_value = ["vec1"] - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=subject_id - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=subject_id) result = use_case.execute(cmd) @@ -108,9 +104,7 @@ def test_execute_subject_not_found(self, use_case_deps, sqlite_memory): db.commit() use_case_deps["ks_service"].get_subject_by_id.return_value = None - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=uuid4() - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=uuid4()) with pytest.raises(ValueError, match="Subject not found"): use_case.execute(cmd) @@ -132,16 +126,10 @@ def test_execute_reprocess(self, use_case_deps, sqlite_memory): db.commit() subject_id = uuid4() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=subject_id - ) - use_case_deps["cs_service"].get_by_source_info.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=subject_id) + use_case_deps["cs_service"].get_by_source_info.return_value = MagicMock(id=uuid4()) - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=subject_id, reprocess=True - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=subject_id, reprocess=True) use_case.execute(cmd) assert use_case_deps["chunk_service"].delete_by_content_source.called @@ -163,13 +151,9 @@ def test_execute_error_handling(self, use_case_deps, sqlite_memory): db.add(record) db.commit() - use_case_deps["ks_service"].get_subject_by_id.side_effect = Exception( - "Service error" - ) + use_case_deps["ks_service"].get_subject_by_id.side_effect = Exception("Service error") - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=uuid4() - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=uuid4()) with pytest.raises(Exception, match="Service error"): use_case.execute(cmd) @@ -190,28 +174,20 @@ def test_execute_failure_after_job_creation(self, use_case_deps, sqlite_memory): db.add(record) db.commit() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) job_mock = MagicMock(id=uuid4()) use_case_deps["ingestion_service"].create_job.return_value = job_mock # Fail at _get_or_create_source - use_case_deps["cs_service"].get_by_source_info.side_effect = Exception( - "Late error" - ) + use_case_deps["cs_service"].get_by_source_info.side_effect = Exception("Late error") - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=uuid4() - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=uuid4()) with pytest.raises(Exception, match="Late error"): use_case.execute(cmd) use_case_deps["ingestion_service"].update_job.assert_any_call( job_id=job_mock.id, - status=pytest.importorskip( - "src.domain.entities.enums.ingestion_job_status_enum" - ).IngestionJobStatus.FAILED, + status=pytest.importorskip("src.domain.entities.enums.ingestion_job_status_enum").IngestionJobStatus.FAILED, error_message="Late error", ) @@ -223,12 +199,7 @@ def test_resolve_source_info_upload(self, use_case_deps): record.source_metadata = None st, es = use_case._resolve_source_info(record) - assert ( - st - == pytest.importorskip( - "src.domain.entities.enums.source_type_enum_entity" - ).SourceType.AUDIO - ) + assert st == pytest.importorskip("src.domain.entities.enums.source_type_enum_entity").SourceType.AUDIO assert es == "s3://path" def test_format_transcript_long_audio(self, use_case_deps): @@ -250,13 +221,9 @@ def test_execute_empty_transcript_error(self, use_case_deps, sqlite_memory): ) db.add(record) db.commit() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=uuid4() - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=uuid4()) with pytest.raises(ValueError, match="No segments found"): use_case.execute(cmd) @@ -264,9 +231,7 @@ def test_resolve_source_info_branches(self, use_case_deps): use_case = DiarizationIngestionUseCase(**use_case_deps) # Test 1: Invalid source_type -> OTHER - record = MagicMock( - source_type="garbage", external_source="url", source_metadata=None - ) + record = MagicMock(source_type="garbage", external_source="url", source_metadata=None) st, es = use_case._resolve_source_info(record) assert st == SourceType.OTHER @@ -285,9 +250,7 @@ def test_generate_split_docs_no_tokenizer(self, use_case_deps, monkeypatch): record = MagicMock(source_metadata={"meta": "data"}) mock_splitter = MagicMock() - mock_splitter.split_documents.return_value = [ - Document(page_content="c", metadata={}) - ] + mock_splitter.split_documents.return_value = [Document(page_content="c", metadata={})] monkeypatch.setattr( "langchain_text_splitters.RecursiveCharacterTextSplitter", lambda **kwargs: mock_splitter, @@ -321,14 +284,10 @@ def test_execute_with_existing_job_id(self, use_case_deps, sqlite_memory): job_id = uuid4() mock_job = MagicMock(id=job_id) use_case_deps["ingestion_service"].get_by_id.return_value = mock_job - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) use_case_deps["cs_service"].create_source.return_value = MagicMock(id=uuid4()) - cmd = IngestDiarizationCommand( - diarization_id=diarization_id, subject_id=uuid4(), ingestion_job_id=job_id - ) + cmd = IngestDiarizationCommand(diarization_id=diarization_id, subject_id=uuid4(), ingestion_job_id=job_id) result = use_case.execute(cmd) assert result["job_id"] == job_id use_case_deps["ingestion_service"].get_by_id.assert_called_with(job_id) diff --git a/tests/application/use_cases/test_file_ingestion_use_case.py b/tests/application/use_cases/test_file_ingestion_use_case.py index c657de73..7acefa61 100644 --- a/tests/application/use_cases/test_file_ingestion_use_case.py +++ b/tests/application/use_cases/test_file_ingestion_use_case.py @@ -50,25 +50,17 @@ def test_execute_success(self, use_case_deps, mock_extractor): job_id = uuid4() source_id = uuid4() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=subject_id - ) - use_case_deps["ingestion_service"].create_job.return_value = MagicMock( - id=job_id - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=subject_id) + use_case_deps["ingestion_service"].create_job.return_value = MagicMock(id=job_id) use_case_deps["cs_service"].get_by_source_info.return_value = None use_case_deps["cs_service"].create_source.return_value = MagicMock( id=source_id, source_type=SourceType.DOCX, external_source="test.docx" ) - mock_extractor.extract.return_value = [ - MagicMock(page_content="content", metadata={"source": "test.docx"}) - ] + mock_extractor.extract.return_value = [MagicMock(page_content="content", metadata={"source": "test.docx"})] use_case_deps["vector_service"].index_documents.return_value = ["vec1"] - cmd = IngestFileCommand( - file_path="/tmp/test.docx", file_name="test.docx", subject_id=subject_id - ) + cmd = IngestFileCommand(file_path="/tmp/test.docx", file_name="test.docx", subject_id=subject_id) result = use_case.execute(cmd) @@ -83,27 +75,21 @@ def test_execute_subject_not_found(self, use_case_deps): use_case = FileIngestionUseCase(**use_case_deps) use_case_deps["ks_service"].get_subject_by_id.return_value = None - cmd = IngestFileCommand( - file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4() - ) + cmd = IngestFileCommand(file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4()) with pytest.raises(ValueError, match="Subject not found"): use_case.execute(cmd) def test_execute_extraction_failure(self, use_case_deps, mock_extractor): use_case = FileIngestionUseCase(**use_case_deps) - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) job_mock = MagicMock(id=uuid4()) use_case_deps["ingestion_service"].create_job.return_value = job_mock use_case_deps["ingestion_service"].get_by_id.return_value = job_mock mock_extractor.extract.side_effect = Exception("Extraction failed") - cmd = IngestFileCommand( - file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4() - ) + cmd = IngestFileCommand(file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4()) with pytest.raises(Exception, match="Extraction failed"): use_case.execute(cmd) @@ -117,9 +103,7 @@ def test_determine_source_type_fallbacks(self, use_case_deps): use_case = FileIngestionUseCase(**use_case_deps) def check(fname): - return use_case._determine_source_type_refined( - IngestFileCommand(file_name=fname) - ) + return use_case._determine_source_type_refined(IngestFileCommand(file_name=fname)) assert check("test.docx") == SourceType.DOCX assert check("test.doc") == SourceType.DOCX @@ -136,22 +120,16 @@ def check(fname): def test_execute_rollback_on_failure(self, use_case_deps, mock_extractor): use_case = FileIngestionUseCase(**use_case_deps) - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) job_mock = MagicMock(id=uuid4()) - source_mock = MagicMock( - id=uuid4(), source_type=SourceType.DOCX, external_source="f" - ) + source_mock = MagicMock(id=uuid4(), source_type=SourceType.DOCX, external_source="f") use_case_deps["ingestion_service"].create_job.return_value = job_mock use_case_deps["cs_service"].get_by_source_info.return_value = None use_case_deps["cs_service"].create_source.return_value = source_mock # Fail at vector indexing mock_extractor.extract.return_value = [MagicMock(page_content="c", metadata={})] - use_case_deps["vector_service"].index_documents.side_effect = Exception( - "Vector fail" - ) + use_case_deps["vector_service"].index_documents.side_effect = Exception("Vector fail") cmd = IngestFileCommand(file_path="f", file_name="f", subject_id=uuid4()) with pytest.raises(Exception, match="Vector fail"): @@ -161,39 +139,27 @@ def test_execute_rollback_on_failure(self, use_case_deps, mock_extractor): source_mock.id, ContentSourceStatus.FAILED ) - def test_execute_source_type_refinement( - self, use_case_deps, mock_extractor, monkeypatch - ): + def test_execute_source_type_refinement(self, use_case_deps, mock_extractor, monkeypatch): use_case = FileIngestionUseCase(**use_case_deps) # Force it to be OTHER first so refinement triggers - monkeypatch.setattr( - use_case, "_determine_source_type_refined", lambda cmd: SourceType.OTHER - ) + monkeypatch.setattr(use_case, "_determine_source_type_refined", lambda cmd: SourceType.OTHER) subject_id = uuid4() job_id = uuid4() source_id = uuid4() - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=subject_id - ) - use_case_deps["ingestion_service"].create_job.return_value = MagicMock( - id=job_id - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=subject_id) + use_case_deps["ingestion_service"].create_job.return_value = MagicMock(id=job_id) use_case_deps["cs_service"].get_by_source_info.return_value = None use_case_deps["cs_service"].create_source.return_value = MagicMock( id=source_id, source_type=SourceType.PDF, external_source="test.docx" ) # Mock Docling to detect PDF despite .docx extension in filename - mock_extractor.extract.return_value = [ - MagicMock(page_content="content", metadata={"source_type": "pdf"}) - ] + mock_extractor.extract.return_value = [MagicMock(page_content="content", metadata={"source_type": "pdf"})] use_case_deps["vector_service"].index_documents.return_value = ["vec1"] - cmd = IngestFileCommand( - file_path="/tmp/test.docx", file_name="test.docx", subject_id=subject_id - ) + cmd = IngestFileCommand(file_path="/tmp/test.docx", file_name="test.docx", subject_id=subject_id) use_case.execute(cmd) @@ -201,27 +167,19 @@ def test_execute_source_type_refinement( _, kwargs = use_case_deps["cs_service"].create_source.call_args assert kwargs["source_type"] == SourceType.PDF - def test_execute_fallback_splitter( - self, use_case_deps, mock_extractor, monkeypatch - ): + def test_execute_fallback_splitter(self, use_case_deps, mock_extractor, monkeypatch): # Remove model from deps to trigger fallback splitter del use_case_deps["model_loader_service"].model use_case = FileIngestionUseCase(**use_case_deps) - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) - use_case_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) + use_case_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid4()) use_case_deps["cs_service"].get_by_source_info.return_value = None use_case_deps["cs_service"].create_source.return_value = MagicMock( id=uuid4(), source_type=SourceType.DOCX, external_source="test.docx" ) - mock_extractor.extract.return_value = [ - MagicMock(page_content="content line 1\ncontent line 2", metadata={}) - ] + mock_extractor.extract.return_value = [MagicMock(page_content="content line 1\ncontent line 2", metadata={})] # Mock RecursiveCharacterTextSplitter from its original package mock_splitter = MagicMock() @@ -234,9 +192,7 @@ def test_execute_fallback_splitter( lambda **kwargs: mock_splitter, ) - cmd = IngestFileCommand( - file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4() - ) + cmd = IngestFileCommand(file_path="/tmp/test.docx", file_name="test.docx", subject_id=uuid4()) result = use_case.execute(cmd) assert result["created_chunks"] == 2 @@ -247,9 +203,7 @@ def test_build_chunk_entities_tokenizer_exceptions(self, use_case_deps): # First exception: General failure triggers fallback len//4 tokenizer.encode.side_effect = Exception("Fatal") docs = [MagicMock(page_content="test", metadata={})] - source = MagicMock( - id=uuid4(), source_type=SourceType.DOCX, external_source="test.docx" - ) + source = MagicMock(id=uuid4(), source_type=SourceType.DOCX, external_source="test.docx") subject = MagicMock(id=uuid4()) chunks = use_case._build_chunk_entities( @@ -263,27 +217,19 @@ def test_build_chunk_entities_tokenizer_exceptions(self, use_case_deps): def test_execute_no_source_path(self, use_case_deps): use_case = FileIngestionUseCase(**use_case_deps) - use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid4() - ) - cmd = IngestFileCommand( - file_path=None, file_url=None, file_name="f", subject_id=uuid4() - ) + use_case_deps["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid4()) + cmd = IngestFileCommand(file_path=None, file_url=None, file_name="f", subject_id=uuid4()) with pytest.raises(ValueError, match="Neither file_path nor file_url provided"): use_case.execute(cmd) - def test_extract_docs_plain_text_fallback( - self, use_case_deps, mock_extractor, monkeypatch - ): + def test_extract_docs_plain_text_fallback(self, use_case_deps, mock_extractor, monkeypatch): use_case = FileIngestionUseCase(**use_case_deps) mock_extractor.extract.side_effect = Exception("format not allowed") mock_plain = MagicMock() mock_plain.extract.return_value = [Document(page_content="plain", metadata={})] monkeypatch.setattr(use_case, "plain_text_extractor", mock_plain) - docs = use_case._extract_docs( - "test.xyz", IngestFileCommand(file_name="test.xyz") - ) + docs = use_case._extract_docs("test.xyz", IngestFileCommand(file_name="test.xyz")) assert docs[0].page_content == "plain" def test_extract_docs_no_content_error(self, use_case_deps, mock_extractor): @@ -301,14 +247,10 @@ def test_refine_source_type_branches(self, use_case_deps): # Test Case 2: Current is YOUTUBE and refined is TXT (skips) docs = [Document(page_content="c", metadata={"docling_source_type": "txt"})] - assert ( - use_case._refine_source_type(docs, SourceType.YOUTUBE) == SourceType.YOUTUBE - ) + assert use_case._refine_source_type(docs, SourceType.YOUTUBE) == SourceType.YOUTUBE # Test Case 3: ValueError in SourceType - docs = [ - Document(page_content="c", metadata={"docling_source_type": "invalid_type"}) - ] + docs = [Document(page_content="c", metadata={"docling_source_type": "invalid_type"})] assert use_case._refine_source_type(docs, SourceType.DOCX) == SourceType.DOCX def test_get_or_create_job_reuse(self, use_case_deps): @@ -323,9 +265,7 @@ def test_get_or_create_job_reuse(self, use_case_deps): def test_resolve_subject_missing_id_and_name(self, use_case_deps): use_case = FileIngestionUseCase(**use_case_deps) - cmd = IngestFileCommand( - file_path="f", file_name="f", subject_id=None, subject_name=None - ) + cmd = IngestFileCommand(file_path="f", file_name="f", subject_id=None, subject_name=None) with pytest.raises(ValueError, match="Subject missing"): use_case._resolve_subject(cmd) @@ -341,9 +281,7 @@ def test_determine_source_type_refined_v_youtube(self, use_case_deps): assert use_case._determine_source_type_refined(cmd) == SourceType.TXT # YouTube from external_source - cmd = IngestFileCommand( - file_name="f", external_source="https://youtube.com/watch?v=123" - ) + cmd = IngestFileCommand(file_name="f", external_source="https://youtube.com/watch?v=123") assert use_case._determine_source_type_refined(cmd) == SourceType.YOUTUBE def test_cleanup_temp_dir(self, use_case_deps, monkeypatch): @@ -355,16 +293,12 @@ def test_cleanup_temp_dir(self, use_case_deps, monkeypatch): monkeypatch.setattr("os.remove", mock_remove) # Temp dir - cmd = IngestFileCommand( - file_path="/tmp/dir/file.txt", file_name="f", delete_after_ingestion=True - ) + cmd = IngestFileCommand(file_path="/tmp/dir/file.txt", file_name="f", delete_after_ingestion=True) use_case._cleanup(cmd) mock_rmtree.assert_called_once() # Single file mock_rmtree.reset_mock() - cmd = IngestFileCommand( - file_path="/data/file.txt", file_name="f", delete_after_ingestion=True - ) + cmd = IngestFileCommand(file_path="/data/file.txt", file_name="f", delete_after_ingestion=True) use_case._cleanup(cmd) mock_remove.assert_called_once() diff --git a/tests/application/use_cases/test_knowledge_subject_use_case.py b/tests/application/use_cases/test_knowledge_subject_use_case.py index 731d1e49..b90170aa 100644 --- a/tests/application/use_cases/test_knowledge_subject_use_case.py +++ b/tests/application/use_cases/test_knowledge_subject_use_case.py @@ -31,9 +31,7 @@ def use_case(self, mock_ks_service, mock_cs_use_case, mock_vector_repo): vector_repo=mock_vector_repo, ) - def test_delete_knowledge_success( - self, use_case, mock_ks_service, mock_cs_use_case, mock_vector_repo - ): + def test_delete_knowledge_success(self, use_case, mock_ks_service, mock_cs_use_case, mock_vector_repo): subject_id = uuid.uuid4() # 1. Subject exists @@ -65,16 +63,12 @@ def test_delete_knowledge_success( # Assertions assert result is True mock_ks_service.get_subject_by_id.assert_called_once_with(subject_id) - mock_vector_repo.delete.assert_called_once_with( - filters={"subject_id": str(subject_id)} - ) + mock_vector_repo.delete.assert_called_once_with(filters={"subject_id": str(subject_id)}) mock_cs_use_case.cs_service.list_by_subject.assert_called_once_with(subject_id) assert mock_cs_use_case.delete.call_count == 2 mock_ks_service.delete_subject.assert_called_once_with(subject_id) - def test_delete_knowledge_not_found( - self, use_case, mock_ks_service, mock_vector_repo - ): + def test_delete_knowledge_not_found(self, use_case, mock_ks_service, mock_vector_repo): subject_id = uuid.uuid4() mock_ks_service.get_subject_by_id.return_value = None @@ -84,9 +78,7 @@ def test_delete_knowledge_not_found( mock_ks_service.get_subject_by_id.assert_called_once_with(subject_id) mock_vector_repo.delete.assert_not_called() - def test_delete_knowledge_exception( - self, use_case, mock_ks_service, mock_vector_repo - ): + def test_delete_knowledge_exception(self, use_case, mock_ks_service, mock_vector_repo): subject_id = uuid.uuid4() mock_ks_service.get_subject_by_id.return_value = MagicMock() mock_vector_repo.delete.side_effect = Exception("Vector error") diff --git a/tests/application/use_cases/test_process_audio_diarization_pipeline.py b/tests/application/use_cases/test_process_audio_diarization_pipeline.py index 26a7a7ec..4680c501 100644 --- a/tests/application/use_cases/test_process_audio_diarization_pipeline.py +++ b/tests/application/use_cases/test_process_audio_diarization_pipeline.py @@ -10,12 +10,8 @@ @pytest.mark.ProcessAudioDiarizationPipeline class TestProcessAudioDiarizationPipeline: - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.StorageService" - ) - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor" - ) + @patch("src.application.use_cases.process_audio_diarization_pipeline.StorageService") + @patch("src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor") @patch("src.application.use_cases.process_audio_diarization_pipeline.AudioDiarizer") @patch("src.application.use_cases.process_audio_diarization_pipeline.VoiceDB") @patch("os.makedirs") @@ -36,9 +32,7 @@ def test_execute_youtube_success( # Setup mocks mock_extractor = mock_extractor_cls.return_value mock_extractor.download_audio.return_value = "/tmp/audio.mp3" - mock_extractor.extract_metadata.return_value = MagicMock( - title=None, full_title=None - ) + mock_extractor.extract_metadata.return_value = MagicMock(title=None, full_title=None) mock_diarizer = mock_diarizer_cls.return_value mock_diarization_result = MagicMock(spec=DiarizationResult) @@ -68,12 +62,8 @@ def test_execute_youtube_success( assert mock_extractor.download_audio.called assert mock_diarizer.run.called - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.StorageService" - ) - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor" - ) + @patch("src.application.use_cases.process_audio_diarization_pipeline.StorageService") + @patch("src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor") @patch("src.application.use_cases.process_audio_diarization_pipeline.AudioDiarizer") @patch("src.application.use_cases.process_audio_diarization_pipeline.VoiceDB") @patch("os.makedirs") @@ -116,12 +106,8 @@ def test_execute_upload_success( assert result["storage_path"] == "processed/uuid-456/recognition" assert mock_storage.download_file.called - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.StorageService" - ) - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor" - ) + @patch("src.application.use_cases.process_audio_diarization_pipeline.StorageService") + @patch("src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor") @patch("src.application.use_cases.process_audio_diarization_pipeline.AudioDiarizer") @patch("src.application.use_cases.process_audio_diarization_pipeline.VoiceDB") @patch("os.makedirs") @@ -160,9 +146,7 @@ def test_execute_with_diarization_id( "src.infrastructure.repositories.sql.diarization_repository.DiarizationRepository.update_status" ) as mock_update_status, ): - use_case = ProcessAudioDiarizationPipelineUseCase( - sqlite_memory, event_bus=mock_event_bus - ) + use_case = ProcessAudioDiarizationPipelineUseCase(sqlite_memory, event_bus=mock_event_bus) use_case.execute( source_type="youtube", source="https://youtube.com/watch?v=test", @@ -173,12 +157,8 @@ def test_execute_with_diarization_id( assert mock_event_bus.publish.called assert mock_update_status.called - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.StorageService" - ) - @patch( - "src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor" - ) + @patch("src.application.use_cases.process_audio_diarization_pipeline.StorageService") + @patch("src.application.use_cases.process_audio_diarization_pipeline.YoutubeExtractor") @patch("src.application.use_cases.process_audio_diarization_pipeline.AudioDiarizer") @patch("src.application.use_cases.process_audio_diarization_pipeline.VoiceDB") @patch("os.makedirs") diff --git a/tests/application/use_cases/test_search_use_case.py b/tests/application/use_cases/test_search_use_case.py index 298fe756..2f0d3ca4 100644 --- a/tests/application/use_cases/test_search_use_case.py +++ b/tests/application/use_cases/test_search_use_case.py @@ -29,11 +29,7 @@ def retrieve( self.last_search_mode = search_mode self.last_re_rank = re_rank # return dummy chunks as simple objects (avoid ChunkEntity validation in tests) - return [ - SimpleNamespace( - id=uuid.uuid4(), content="a", subject_id=uuid.uuid4(), extra={} - ) - ] + return [SimpleNamespace(id=uuid.uuid4(), content="a", subject_id=uuid.uuid4(), extra={})] class DummyKS: @@ -112,18 +108,14 @@ def test_search_passes_re_rank_to_service(): def test_search_both_id_and_name_raises(): vec = DummyVectorService() uc = SearchUseCase(vector_service=vec, ks_service=None) - with pytest.raises( - ValueError, match="Provide only one of subject_ids or subject_name" - ): + with pytest.raises(ValueError, match="Provide only one of subject_ids or subject_name"): uc.execute(query="q", subject_ids=[uuid.uuid4()], subject_name="Alice") def test_search_name_without_ks_service_raises(): vec = DummyVectorService() uc = SearchUseCase(vector_service=vec, ks_service=None) - with pytest.raises( - ValueError, match="ks_service is required to filter by subject_name" - ): + with pytest.raises(ValueError, match="ks_service is required to filter by subject_name"): uc.execute(query="q", subject_name="Alice") diff --git a/tests/application/use_cases/test_web_scraping_use_case.py b/tests/application/use_cases/test_web_scraping_use_case.py index 1f6849fa..bf34f769 100644 --- a/tests/application/use_cases/test_web_scraping_use_case.py +++ b/tests/application/use_cases/test_web_scraping_use_case.py @@ -29,14 +29,10 @@ async def test_web_scraping_use_case_execute_success(mock_dependencies): # Setup use_case = WebScrapingUseCase(**mock_dependencies, vector_store_type="weaviate") - cmd = IngestWebCommand( - url="https://example.com", subject_id=str(uuid.uuid4()), language="en" - ) + cmd = IngestWebCommand(url="https://example.com", subject_id=str(uuid.uuid4()), language="en") # Mock behavior - mock_dependencies["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_dependencies["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) mock_dependencies["extractor"].extract.return_value = [ Document(page_content="Scraped content", metadata={"title": "Test"}) ] @@ -72,9 +68,7 @@ async def test_web_scraping_use_case_extraction_failure(mock_dependencies): cmd = IngestWebCommand(url="https://fail.com", subject_id=str(uuid.uuid4())) - mock_dependencies["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_dependencies["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) mock_dependencies["extractor"].extract.side_effect = Exception("Scraping error") job_mock = MagicMock(id=uuid.uuid4()) diff --git a/tests/application/use_cases/test_web_scraping_use_case_extended.py b/tests/application/use_cases/test_web_scraping_use_case_extended.py index b7b1758a..2776ad36 100644 --- a/tests/application/use_cases/test_web_scraping_use_case_extended.py +++ b/tests/application/use_cases/test_web_scraping_use_case_extended.py @@ -36,12 +36,8 @@ async def test_resolve_subject_by_id_success(self, mock_deps): mock_subject = MagicMock(id=subject_id) mock_deps["ks_service"].get_subject_by_id.return_value = mock_subject - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) mock_deps["cs_service"].get_by_source_info.return_value = None mock_deps["cs_service"].create_source.return_value = MagicMock( id=uuid.uuid4(), external_source="http://test.com" @@ -58,9 +54,7 @@ async def test_execute_extractor_exception_logged_and_raised(self, mock_deps): mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) mock_deps["extractor"].extract.side_effect = Exception("Scraping error") - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) with pytest.raises(Exception, match="Scraping error"): await use_case.execute(cmd) @@ -73,12 +67,8 @@ async def test_resolve_subject_by_name_success(self, mock_deps): mock_subject = MagicMock(id=uuid.uuid4()) mock_deps["ks_service"].get_by_name.return_value = mock_subject - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) mock_deps["cs_service"].get_by_source_info.return_value = None mock_deps["cs_service"].create_source.return_value = MagicMock( id=uuid.uuid4(), external_source="http://test.com" @@ -111,9 +101,7 @@ async def test_resolve_subject_none_provided_raises(self, mock_deps): use_case = WebScrapingUseCase(**mock_deps, vector_store_type="qdrant") cmd = IngestWebCommand(url="http://test.com") - with pytest.raises( - ValueError, match="Either subject_id or subject_name must be provided" - ): + with pytest.raises(ValueError, match="Either subject_id or subject_name must be provided"): await use_case.execute(cmd) @pytest.mark.asyncio @@ -122,9 +110,7 @@ async def test_execute_empty_docs_raises(self, mock_deps): cmd = IngestWebCommand(url="http://test.com", subject_name="S") mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) mock_deps["extractor"].extract.return_value = [] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) with pytest.raises(ValueError, match="No content extracted"): await use_case.execute(cmd) @@ -134,16 +120,12 @@ async def test_execute_with_existing_job_id(self, mock_deps): mock_deps["model_loader_service"].model_name = "test-model" job_id = uuid.uuid4() use_case = WebScrapingUseCase(**mock_deps, vector_store_type="qdrant") - cmd = IngestWebCommand( - url="http://test.com", subject_name="S", ingestion_job_id=str(job_id) - ) + cmd = IngestWebCommand(url="http://test.com", subject_name="S", ingestion_job_id=str(job_id)) mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) mock_job = MagicMock(id=job_id) mock_deps["ingestion_service"].get_by_id.return_value = mock_job - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] mock_deps["cs_service"].get_by_source_info.return_value = None mock_deps["cs_service"].create_source.return_value = MagicMock( id=uuid.uuid4(), external_source="http://test.com" @@ -155,26 +137,18 @@ async def test_execute_with_existing_job_id(self, mock_deps): mock_deps["ingestion_service"].create_job.assert_not_called() @pytest.mark.asyncio - async def test_execute_with_existing_job_id_invalid_uuid_falls_back( - self, mock_deps - ): + async def test_execute_with_existing_job_id_invalid_uuid_falls_back(self, mock_deps): mock_deps["model_loader_service"].model_name = "test-model" use_case = WebScrapingUseCase(**mock_deps, vector_store_type="qdrant") - cmd = IngestWebCommand( - url="http://test.com", subject_name="S", ingestion_job_id="invalid-uuid" - ) + cmd = IngestWebCommand(url="http://test.com", subject_name="S", ingestion_job_id="invalid-uuid") mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] mock_deps["cs_service"].get_by_source_info.return_value = None mock_deps["cs_service"].create_source.return_value = MagicMock( id=uuid.uuid4(), external_source="http://test.com" ) - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) await use_case.execute(cmd) @@ -187,12 +161,8 @@ async def test_execute_reprocess_cleanup(self, mock_deps): cmd = IngestWebCommand(url="http://test.com", subject_name="S", reprocess=True) mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) source_id = uuid.uuid4() mock_source = MagicMock(id=source_id, external_source="http://test.com") @@ -200,12 +170,8 @@ async def test_execute_reprocess_cleanup(self, mock_deps): await use_case.execute(cmd) - mock_deps["chunk_service"].delete_by_content_source.assert_called_once_with( - source_id - ) - mock_deps["vector_service"].delete.assert_called_once_with( - filters={"content_source_id": str(source_id)} - ) + mock_deps["chunk_service"].delete_by_content_source.assert_called_once_with(source_id) + mock_deps["vector_service"].delete.assert_called_once_with(filters={"content_source_id": str(source_id)}) @pytest.mark.asyncio async def test_execute_reprocess_cleanup_error_ignored(self, mock_deps): @@ -214,19 +180,13 @@ async def test_execute_reprocess_cleanup_error_ignored(self, mock_deps): cmd = IngestWebCommand(url="http://test.com", subject_name="S", reprocess=True) mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) source_id = uuid.uuid4() mock_source = MagicMock(id=source_id, external_source="http://test.com") mock_deps["cs_service"].get_by_source_info.return_value = mock_source - mock_deps["chunk_service"].delete_by_content_source.side_effect = Exception( - "Cleanup error" - ) + mock_deps["chunk_service"].delete_by_content_source.side_effect = Exception("Cleanup error") # Should not raise exception await use_case.execute(cmd) @@ -247,20 +207,14 @@ async def test_build_chunk_entities_with_tokenizer_success(self, mock_deps): subject = MagicMock(id=uuid.uuid4()) job_id = uuid.uuid4() - chunks = use_case._build_chunk_entities( - docs, source, subject, IngestWebCommand(url="url"), job_id - ) + chunks = use_case._build_chunk_entities(docs, source, subject, IngestWebCommand(url="url"), job_id) assert len(chunks) == 1 assert chunks[0].tokens_count == 3 - mock_model.tokenizer.encode.assert_called_once_with( - "Hello world", add_special_tokens=False - ) + mock_model.tokenizer.encode.assert_called_once_with("Hello world", add_special_tokens=False) @pytest.mark.asyncio - async def test_build_chunk_entities_with_tokenizer_exception_fallback( - self, mock_deps - ): + async def test_build_chunk_entities_with_tokenizer_exception_fallback(self, mock_deps): mock_model = MagicMock() mock_model.tokenizer.encode.side_effect = Exception("Tokenizer error") mock_deps["model_loader_service"].model = mock_model @@ -273,9 +227,7 @@ async def test_build_chunk_entities_with_tokenizer_exception_fallback( subject = MagicMock(id=uuid.uuid4()) job_id = uuid.uuid4() - chunks = use_case._build_chunk_entities( - docs, source, subject, IngestWebCommand(url="url"), job_id - ) + chunks = use_case._build_chunk_entities(docs, source, subject, IngestWebCommand(url="url"), job_id) assert len(chunks) == 1 # Fallback is len(page_content) // 4 @@ -294,9 +246,7 @@ async def test_execute_langchain_splitter_fallback(self, mock_deps): mock_deps["extractor"].extract.return_value = [ Document(page_content="Very long content " * 100, metadata={"title": "T"}) ] - mock_deps["ingestion_service"].create_job.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=uuid.uuid4()) mock_deps["cs_service"].get_by_source_info.return_value = None mock_deps["cs_service"].create_source.return_value = MagicMock( id=uuid.uuid4(), external_source="http://test.com" @@ -314,9 +264,7 @@ async def test_execute_failure_marks_source_failed(self, mock_deps): cmd = IngestWebCommand(url="http://test.com", subject_name="S") mock_deps["ks_service"].get_by_name.return_value = MagicMock(id=uuid.uuid4()) - mock_deps["extractor"].extract.return_value = [ - Document(page_content="content", metadata={"title": "T"}) - ] + mock_deps["extractor"].extract.return_value = [Document(page_content="content", metadata={"title": "T"})] job_id = uuid.uuid4() mock_deps["ingestion_service"].create_job.return_value = MagicMock(id=job_id) @@ -326,9 +274,7 @@ async def test_execute_failure_marks_source_failed(self, mock_deps): mock_deps["cs_service"].get_by_source_info.return_value = mock_source # Fail during chunk creation - mock_deps["chunk_service"].create_chunks.side_effect = Exception( - "Execute error" - ) + mock_deps["chunk_service"].create_chunks.side_effect = Exception("Execute error") with pytest.raises(Exception, match="Execute error"): await use_case.execute(cmd) diff --git a/tests/application/use_cases/test_youtube_ingestion_use_case.py b/tests/application/use_cases/test_youtube_ingestion_use_case.py index 8dbc1baf..5dd55545 100644 --- a/tests/application/use_cases/test_youtube_ingestion_use_case.py +++ b/tests/application/use_cases/test_youtube_ingestion_use_case.py @@ -66,9 +66,7 @@ def create_source( **kwargs, ): # ensure source_type is a value the use case can handle - val = ( - source_type.value if hasattr(source_type, "value") else str(source_type) - ) + val = source_type.value if hasattr(source_type, "value") else str(source_type) src = SimpleNamespace( id=uuid.uuid4(), source_type=val, @@ -85,9 +83,7 @@ def update_processing_status(self, content_source_id, status): # noop for tests return None - def finish_ingestion( - self, content_source_id, embedding_model, dimensions, chunks, **kwargs - ): + def finish_ingestion(self, content_source_id, embedding_model, dimensions, chunks, **kwargs): # noop for tests return None @@ -96,9 +92,7 @@ def finish_ingestion( def make_ingestion_service(): class IS: - def create_job( - self, content_source_id, status, embedding_model, pipeline_version, **kwargs - ): + def create_job(self, content_source_id, status, embedding_model, pipeline_version, **kwargs): return SimpleNamespace(id=uuid.uuid4(), content_source_id=content_source_id) def update_job(self, job_id, status, error_message=None, **kwargs): @@ -163,9 +157,7 @@ def test_ingest_single_url_processes_chunks(monkeypatch): vec_svc = make_vector_service() event_bus = MagicMock() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, embedding, chunk_svc, vec_svc, "weaviate", event_bus - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, embedding, chunk_svc, vec_svc, "weaviate", event_bus) docs = [ DummyDoc("chunk1", {"start": 0, "end": 10}), @@ -174,12 +166,8 @@ def test_ingest_single_url_processes_chunks(monkeypatch): from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: "dQw4w9WgXcQ") - monkeypatch.setattr( - YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata() - ) - monkeypatch.setattr( - use_case, "_extract_and_split", lambda cmd, video_id, yt_extractor=None: docs - ) + monkeypatch.setattr(YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata()) + monkeypatch.setattr(use_case, "_extract_and_split", lambda cmd, video_id, yt_extractor=None: docs) cmd = IngestYoutubeCommand( video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ", @@ -252,14 +240,10 @@ def mock_extract_id(url): from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", mock_extract_id) - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")]) from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor - monkeypatch.setattr( - YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata() - ) + monkeypatch.setattr(YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata()) cmd = IngestYoutubeCommand( video_urls=[ @@ -314,12 +298,8 @@ def test_ingest_playlist(monkeypatch): from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: url) - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")] - ) - monkeypatch.setattr( - YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata() - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")]) + monkeypatch.setattr(YoutubeExtractor, "extract_metadata", lambda *args, **kwargs: MockMetadata()) cmd = IngestYoutubeCommand( video_url="https://youtube.com/playlist?list=PL123", @@ -338,16 +318,12 @@ def test_ingest_playlist_empty_raises(monkeypatch): cs = make_cs_service() isvc = make_ingestion_service() model_loader = make_model_loader() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, None, None, None, "weaviate", MagicMock()) from src.application.dtos.enums.youtube_data_type import YoutubeDataType from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor - monkeypatch.setattr( - YoutubeExtractor, "extract_playlist_videos", lambda *args, **kwargs: [] - ) + monkeypatch.setattr(YoutubeExtractor, "extract_playlist_videos", lambda *args, **kwargs: []) cmd = IngestYoutubeCommand( video_url="https://youtube.com/playlist?list=empty", @@ -378,14 +354,10 @@ def test_resolve_subject_errors(monkeypatch): monkeypatch.setattr(ks, "get_by_name", lambda name: None) monkeypatch.setattr(ks, "get_subject_by_id", lambda id: None) - use_case = YoutubeIngestionUseCase( - ks, None, None, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, None, None, None, None, None, None, "weaviate", MagicMock()) cmd_name = IngestYoutubeCommand(video_url="v", subject_name="unknown") - with pytest.raises( - ValueError, match="KnowledgeSubject with name 'unknown' not found" - ): + with pytest.raises(ValueError, match="KnowledgeSubject with name 'unknown' not found"): use_case.execute(cmd_name) cmd_id = IngestYoutubeCommand(video_url="v", subject_id=uuid.uuid4()) @@ -393,9 +365,7 @@ def test_resolve_subject_errors(monkeypatch): use_case.execute(cmd_id) cmd_none = IngestYoutubeCommand(video_url="v") - with pytest.raises( - ValueError, match="Either subject_id or subject_name must be provided" - ): + with pytest.raises(ValueError, match="Either subject_id or subject_name must be provided"): use_case.execute(cmd_none) @@ -408,9 +378,7 @@ def test_ingest_fails_to_create_job(monkeypatch): model_loader = make_model_loader() chunk_svc = make_chunk_service() vec_svc = make_vector_service() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock()) from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: "vid") @@ -427,9 +395,7 @@ def test_ingest_fails_no_transcript(monkeypatch): model_loader = make_model_loader() chunk_svc = make_chunk_service() vec_svc = make_vector_service() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock()) from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor @@ -453,26 +419,18 @@ def test_ingest_with_pre_created_job(monkeypatch): job_id = uuid.uuid4() # Mock IS to return a pre-created job - monkeypatch.setattr( - isvc, "get_by_id", lambda id: SimpleNamespace(id=id, content_source_id=None) - ) + monkeypatch.setattr(isvc, "get_by_id", lambda id: SimpleNamespace(id=id, content_source_id=None)) model_loader = make_model_loader() chunk_svc = make_chunk_service() vec_svc = make_vector_service() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock()) from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: "vid") - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")]) - cmd = IngestYoutubeCommand( - video_url="vid", subject_name="s", ingestion_job_id=str(job_id) - ) + cmd = IngestYoutubeCommand(video_url="vid", subject_name="s", ingestion_job_id=str(job_id)) result = use_case.execute(cmd) assert len(result.video_results) == 1 assert result.video_results[0].get("skipped") is not True @@ -482,13 +440,9 @@ def test_resolve_subject_by_id(monkeypatch): ks = make_ks_service() subject_id = uuid.uuid4() subject = SimpleNamespace(id=subject_id, name="subject_by_id") - monkeypatch.setattr( - ks, "get_subject_by_id", lambda id: subject if id == subject_id else None - ) + monkeypatch.setattr(ks, "get_subject_by_id", lambda id: subject if id == subject_id else None) - use_case = YoutubeIngestionUseCase( - ks, None, None, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, None, None, None, None, None, None, "weaviate", MagicMock()) # Test valid UUID object cmd = IngestYoutubeCommand(video_url="v", subject_id=subject_id) @@ -515,9 +469,7 @@ def test_execute_exception_recovery(monkeypatch): lambda id: SimpleNamespace(id=id, content_source_id=source_id), ) - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, None, None, None, None, "weaviate", MagicMock()) # Force error in _resolve_subject def mock_error(*args, **kwargs): @@ -547,16 +499,12 @@ def test_process_single_video_with_existing_but_not_done_source(monkeypatch): model_loader = make_model_loader() chunk_svc = make_chunk_service() vec_svc = make_vector_service() - use_case = YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, cs, isvc, model_loader, None, chunk_svc, vec_svc, "weaviate", MagicMock()) from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: "vid") - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("content")]) cmd = IngestYoutubeCommand(video_url="vid", subject_name="s") result = use_case.execute(cmd) @@ -591,9 +539,7 @@ def test_execute_job_recovery_success(monkeypatch): "get_by_id", lambda id: SimpleNamespace(id=id, content_source_id=source_id), ) - monkeypatch.setattr( - cs, "get_by_id", lambda id: SimpleNamespace(id=id, processing_status="pending") - ) + monkeypatch.setattr(cs, "get_by_id", lambda id: SimpleNamespace(id=id, processing_status="pending")) use_case = YoutubeIngestionUseCase( ks, @@ -609,9 +555,7 @@ def test_execute_job_recovery_success(monkeypatch): from src.infrastructure.extractors.youtube_extractor import YoutubeExtractor monkeypatch.setattr(YoutubeExtractor, "get_video_id", lambda url: "vid") - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")]) cmd = IngestYoutubeCommand(video_url="v", subject_name="s", ingestion_job_id=job_id) result = use_case.execute(cmd) @@ -620,9 +564,7 @@ def test_execute_job_recovery_success(monkeypatch): def test_execute_playlist_no_url(): ks = make_ks_service() - use_case = YoutubeIngestionUseCase( - ks, None, None, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, None, None, None, None, None, None, "weaviate", MagicMock()) from src.application.dtos.enums.youtube_data_type import YoutubeDataType cmd = IngestYoutubeCommand( @@ -637,9 +579,7 @@ def test_execute_playlist_no_url(): def test_execute_no_urls(): ks = make_ks_service() - use_case = YoutubeIngestionUseCase( - ks, None, None, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(ks, None, None, None, None, None, None, "weaviate", MagicMock()) cmd = IngestYoutubeCommand(video_url=None, video_urls=[], subject_name="s") with pytest.raises(ValueError, match="No video_url.*s.* provided"): use_case.execute(cmd) @@ -684,20 +624,14 @@ def test_process_single_video_reprocess(monkeypatch): "weaviate", MagicMock(), ) - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")]) from unittest.mock import patch - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ) as mock_yt: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor") as mock_yt: mock_yt.return_value.extract_metadata.return_value = MockMetadata() cmd = IngestYoutubeCommand(video_url="vid", subject_name="s", reprocess=True) - result = use_case._process_single_video( - "url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd - ) + result = use_case._process_single_video("url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd) assert deleted["sql"] is True assert deleted["vec"] is True @@ -735,9 +669,7 @@ def test_process_single_video_rollback_on_fail(monkeypatch): "weaviate", MagicMock(), ) - monkeypatch.setattr( - use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")] - ) + monkeypatch.setattr(use_case, "_extract_and_split", lambda *args, **kwargs: [DummyDoc("c")]) # Fail at index_chunks monkeypatch.setattr( @@ -748,25 +680,19 @@ def test_process_single_video_rollback_on_fail(monkeypatch): from unittest.mock import patch - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ) as mock_yt: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor") as mock_yt: mock_yt.return_value.extract_metadata.return_value = MockMetadata() cmd = IngestYoutubeCommand(video_url="vid", subject_name="s") with pytest.raises(Exception, match="Vec index fail"): - use_case._process_single_video( - "url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd - ) + use_case._process_single_video("url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd) assert rolled_back["sql"] is True assert rolled_back["vec"] is True def test_resolve_subject_invalid_uuid(monkeypatch): - use_case = YoutubeIngestionUseCase( - None, None, None, None, None, None, None, "weaviate", MagicMock() - ) + use_case = YoutubeIngestionUseCase(None, None, None, None, None, None, None, "weaviate", MagicMock()) cmd = IngestYoutubeCommand(video_url="v", subject_id="not-a-uuid") with pytest.raises(ValueError, match="Invalid subject_id provided"): use_case._resolve_subject(cmd) @@ -791,6 +717,4 @@ def test_process_single_video_fails_to_create_job(monkeypatch): ) cmd = IngestYoutubeCommand(video_url="vid", subject_name="s") with pytest.raises(ValueError, match="Failed to create or retrieve ingestion job"): - use_case._process_single_video( - "url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd - ) + use_case._process_single_video("url", "vid", SimpleNamespace(id=uuid.uuid4()), cmd) diff --git a/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py b/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py index 5ce5db1d..58df8ddf 100644 --- a/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py +++ b/tests/application/use_cases/test_youtube_ingestion_use_case_edge_cases.py @@ -36,9 +36,7 @@ def test_execute_job_recovery_fail(self, use_case, mock_services): ) # This should trigger the except block in execute for job recovery # But we need it to continue, so we don't mock it to fail completely - mock_services["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_services["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) with patch.object( use_case, "_process_single_video", @@ -73,36 +71,24 @@ def test_execute_any_failed_with_ingestion(self, use_case, mock_services): subject_id=str(uuid.uuid4()), ingestion_job_id=str(uuid.uuid4()), ) - mock_services["ingestion_service"].get_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) - mock_services["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_services["ingestion_service"].get_by_id.return_value = MagicMock(id=uuid.uuid4()) + mock_services["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) with ( patch( "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor.get_video_id", return_value="12345678901", ), - patch.object( - use_case, "_process_single_video", return_value={"error": "some error"} - ), + patch.object(use_case, "_process_single_video", return_value={"error": "some error"}), ): with pytest.raises(ValueError, match="some error"): use_case.execute(cmd) mock_services["ingestion_service"].update_job.assert_called() - def test_process_single_video_duplicate_fail_job_creation( - self, use_case, mock_services - ): + def test_process_single_video_duplicate_fail_job_creation(self, use_case, mock_services): video_id = "12345678901" - mock_services["cs_service"].get_by_source_info.return_value = MagicMock( - processing_status="done" - ) - mock_services["ingestion_service"].create_job.side_effect = Exception( - "Failed to create job" - ) + mock_services["cs_service"].get_by_source_info.return_value = MagicMock(processing_status="done") + mock_services["ingestion_service"].create_job.side_effect = Exception("Failed to create job") cmd = IngestYoutubeCommand(video_url="...", subject_id=str(uuid.uuid4())) subject = MagicMock() @@ -113,9 +99,7 @@ def test_process_single_video_duplicate_fail_job_creation( def test_process_single_video_job_reuse_not_found(self, use_case, mock_services): video_id = "12345678901" mock_services["cs_service"].get_by_source_info.return_value = None - mock_services[ - "ingestion_service" - ].get_by_id.return_value = None # Job not found + mock_services["ingestion_service"].get_by_id.return_value = None # Job not found cmd = IngestYoutubeCommand( video_url="...", @@ -125,12 +109,8 @@ def test_process_single_video_job_reuse_not_found(self, use_case, mock_services) subject = MagicMock() with patch.object(use_case, "_create_ingestion_job") as mock_create: - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ): - with pytest.raises( - Exception - ): # it will fail later but we check _create_ingestion_job call + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor"): + with pytest.raises(Exception): # it will fail later but we check _create_ingestion_job call use_case._process_single_video("url", video_id, subject, cmd) mock_create.assert_called() @@ -142,31 +122,23 @@ def test_resolve_subject_invalid_uuid(self, use_case): def test_resolve_subject_not_found_by_name(self, use_case, mock_services): cmd = IngestYoutubeCommand(video_url="...", subject_name="NonExistent") mock_services["ks_service"].get_by_name.return_value = None - with pytest.raises( - ValueError, match="KnowledgeSubject with name 'NonExistent' not found" - ): + with pytest.raises(ValueError, match="KnowledgeSubject with name 'NonExistent' not found"): use_case._resolve_subject(cmd) def test_resolve_subject_no_identifier(self, use_case): cmd = IngestYoutubeCommand(video_url="...") cmd.subject_id = None cmd.subject_name = None - with pytest.raises( - ValueError, match="Either subject_id or subject_name must be provided" - ): + with pytest.raises(ValueError, match="Either subject_id or subject_name must be provided"): use_case._resolve_subject(cmd) def test_fail_ingestion_and_job_exception_handling(self, use_case, mock_services): # Trigger the catch-all Exception in execute cmd = IngestYoutubeCommand(video_url="...", subject_id=str(uuid.uuid4())) - mock_services["ks_service"].get_subject_by_id.side_effect = Exception( - "General Failure" - ) + mock_services["ks_service"].get_subject_by_id.side_effect = Exception("General Failure") # Mocking to reach the error handlers at the end of execute - with patch.object( - use_case, "_resolve_subject", side_effect=Exception("Failure") - ): + with patch.object(use_case, "_resolve_subject", side_effect=Exception("Failure")): with pytest.raises(Exception): use_case.execute(cmd) @@ -177,14 +149,10 @@ def test_process_single_video_no_transcript_chunks(self, use_case, mock_services cmd = IngestYoutubeCommand(video_url="...", subject_id=str(uuid.uuid4())) with ( - patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ), + patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor"), patch.object(use_case, "_extract_and_split", return_value=[]), ): - with pytest.raises( - ValueError, match="No transcript chunks generated for video" - ): + with pytest.raises(ValueError, match="No transcript chunks generated for video"): use_case._process_single_video("url", video_id, subject, cmd) def test_process_single_video_fail_handlers_internal(self, use_case, mock_services): @@ -195,20 +163,12 @@ def test_process_single_video_fail_handlers_internal(self, use_case, mock_servic # Original error should bubble up if fail handlers also fail # But here we just want to see it reaches those lines. - mock_services["cs_service"].update_processing_status.side_effect = Exception( - "Failed status update" - ) - mock_services["ingestion_service"].update_job.side_effect = Exception( - "Failed job update" - ) + mock_services["cs_service"].update_processing_status.side_effect = Exception("Failed status update") + mock_services["ingestion_service"].update_job.side_effect = Exception("Failed job update") with ( - patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ), - patch.object( - use_case, "_extract_and_split", side_effect=Exception("Main Error") - ), + patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor"), + patch.object(use_case, "_extract_and_split", side_effect=Exception("Main Error")), ): with pytest.raises(Exception) as excinfo: use_case._process_single_video("url", video_id, subject, cmd) @@ -228,9 +188,7 @@ def test_build_chunk_entities_coverage(self, use_case, mock_services): source.external_source = "vid1" subject = MagicMock() subject.id = uuid.uuid4() - cmd = IngestYoutubeCommand( - video_url="...", subject_id=str(subject.id), language="en" - ) + cmd = IngestYoutubeCommand(video_url="...", subject_id=str(subject.id), language="en") job_id = uuid.uuid4() mock_services["model_loader_service"].model_name = "test-model" @@ -255,15 +213,11 @@ def test_execute_batch_failure_summary(self, use_case, mock_services): subject_id=str(uuid.uuid4()), ingestion_job_id=str(job_id), ) - mock_services["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_services["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) mock_services["ingestion_service"].get_by_id.return_value = MagicMock(id=job_id) # Mocking extraction to return valid IDs - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor.get_video_id" - ) as mock_ext: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor.get_video_id") as mock_ext: mock_ext.side_effect = ["v1", "v2"] with patch.object(use_case, "_process_single_video") as mock_process: mock_process.side_effect = [ @@ -275,16 +229,13 @@ def test_execute_batch_failure_summary(self, use_case, mock_services): # Verify update_job was called with summary assert mock_services["ingestion_service"].update_job.called found_failure = False - for call in mock_services[ - "ingestion_service" - ].update_job.call_args_list: + for call in mock_services["ingestion_service"].update_job.call_args_list: status = call.kwargs.get("status") # Handle both enum and string if necessary status_val = status.value if hasattr(status, "value") else status if ( status_val == IngestionJobStatus.FAILED.value - and "Ingestion failed for 1 items" - in call.kwargs.get("error_message", "") + and "Ingestion failed for 1 items" in call.kwargs.get("error_message", "") ): found_failure = True break @@ -304,9 +255,7 @@ def test_execute_batch_only_cancelled(self, use_case, mock_services): subject_id=str(uuid.uuid4()), ingestion_job_id=str(job_id), ) - mock_services["ks_service"].get_subject_by_id.return_value = MagicMock( - id=uuid.uuid4() - ) + mock_services["ks_service"].get_subject_by_id.return_value = MagicMock(id=uuid.uuid4()) mock_services["ingestion_service"].get_by_id.return_value = MagicMock(id=job_id) with patch( @@ -321,9 +270,7 @@ def test_execute_batch_only_cancelled(self, use_case, mock_services): use_case.execute(cmd) # Verify update_job called with FINISHED status but partial message - update_calls = mock_services[ - "ingestion_service" - ].update_job.call_args_list + update_calls = mock_services["ingestion_service"].update_job.call_args_list finished_call = None for c in update_calls: status = c.kwargs.get("status") @@ -347,22 +294,14 @@ def test_process_single_video_reprocess_cleanup(self, use_case, mock_services): ) subject = MagicMock(id=uuid.uuid4()) - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ) as mock_ext_cls: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor") as mock_ext_cls: mock_ext = mock_ext_cls.return_value - mock_ext.extract_metadata.return_value = MagicMock( - full_title="Title", title="Title" - ) + mock_ext.extract_metadata.return_value = MagicMock(full_title="Title", title="Title") # Make it fail after cleanup to check if cleanup was called - with patch.object( - use_case, "_extract_and_split", side_effect=Exception("Stop here") - ): + with patch.object(use_case, "_extract_and_split", side_effect=Exception("Stop here")): with pytest.raises(Exception, match="Stop here"): - use_case._process_single_video( - "https://www.youtube.com/watch?v=v1", video_id, subject, cmd - ) + use_case._process_single_video("https://www.youtube.com/watch?v=v1", video_id, subject, cmd) assert mock_services["chunk_service"].delete_by_content_source.called assert mock_services["vector_service"].delete_by_video_id.called @@ -378,22 +317,14 @@ def test_known_exceptions_handling(self, use_case, mock_services): source = MagicMock(id=uuid.uuid4()) mock_services["cs_service"].get_by_source_info.return_value = source - cmd = IngestYoutubeCommand( - video_url="https://www.youtube.com/watch?v=v1", subject_id=str(uuid.uuid4()) - ) + cmd = IngestYoutubeCommand(video_url="https://www.youtube.com/watch?v=v1", subject_id=str(uuid.uuid4())) subject = MagicMock(id=uuid.uuid4()) - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ) as mock_ext_cls: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor") as mock_ext_cls: mock_ext = mock_ext_cls.return_value - mock_ext.extract_metadata.side_effect = YoutubeVideoPrivateException( - "Private video" - ) + mock_ext.extract_metadata.side_effect = YoutubeVideoPrivateException("Private video") - result = use_case._process_single_video( - "https://www.youtube.com/watch?v=v1", video_id, subject, cmd - ) + result = use_case._process_single_video("https://www.youtube.com/watch?v=v1", video_id, subject, cmd) assert result["cancelled"] is True # Verify source marked as CANCELLED using the Enum @@ -417,15 +348,11 @@ def test_rollback_on_generic_error(self, use_case, mock_services): # Embedding model must be a string mock_services["model_loader_service"].model_name = "test-model" - cmd = IngestYoutubeCommand( - video_url="https://www.youtube.com/watch?v=v1", subject_id=str(uuid.uuid4()) - ) + cmd = IngestYoutubeCommand(video_url="https://www.youtube.com/watch?v=v1", subject_id=str(uuid.uuid4())) subject = MagicMock(id=uuid.uuid4()) subject.id = uuid.uuid4() - with patch( - "src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor" - ) as mock_ext_cls: + with patch("src.application.use_cases.youtube_ingestion_use_case.YoutubeExtractor") as mock_ext_cls: mock_ext = mock_ext_cls.return_value mock_ext.extract_metadata.return_value = MagicMock(full_title="Title") @@ -436,13 +363,9 @@ def test_rollback_on_generic_error(self, use_case, mock_services): with patch.object(use_case, "_extract_and_split", return_value=[doc]): # Fail at indexing - with patch.object( - use_case, "_index_chunks", side_effect=Exception("Indexing failed") - ): + with patch.object(use_case, "_index_chunks", side_effect=Exception("Indexing failed")): with pytest.raises(Exception, match="Indexing failed"): - use_case._process_single_video( - "https://www.youtube.com/watch?v=v1", video_id, subject, cmd - ) + use_case._process_single_video("https://www.youtube.com/watch?v=v1", video_id, subject, cmd) # Check rollback calls assert mock_services["chunk_service"].delete_by_job_id.called diff --git a/tests/application/use_cases/test_youtube_throttling.py b/tests/application/use_cases/test_youtube_throttling.py index d06d351d..37f4d80d 100644 --- a/tests/application/use_cases/test_youtube_throttling.py +++ b/tests/application/use_cases/test_youtube_throttling.py @@ -33,9 +33,7 @@ def make_use_case_mocks(): vec_svc = MagicMock() event_bus = MagicMock() - return YoutubeIngestionUseCase( - ks, cs, isvc, model_loader, embedding, chunk_svc, vec_svc, "weaviate", event_bus - ) + return YoutubeIngestionUseCase(ks, cs, isvc, model_loader, embedding, chunk_svc, vec_svc, "weaviate", event_bus) def test_throttling_logic(monkeypatch): diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py index 4e055984..7b79be42 100644 --- a/tests/config/test_settings.py +++ b/tests/config/test_settings.py @@ -9,9 +9,7 @@ def test_allowed_log_levels_default(): - s = Settings( - app=App(list_log_levels=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]) - ) + s = Settings(app=App(list_log_levels=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])) expected = { logging.DEBUG, logging.INFO, @@ -59,9 +57,7 @@ def test_sql_url_postgres(): def test_sql_url_mysql(): test_pw = os.environ.get("TEST_SQL_PASSWORD", "p") - cfg = SQLConfig( - type="mysql", user="u", password=test_pw, host="h", port="3306", database="db" - ) + cfg = SQLConfig(type="mysql", user="u", password=test_pw, host="h", port="3306", database="db") assert cfg.url == f"mysql+pymysql://u:{test_pw}@h:3306/db" @@ -72,17 +68,13 @@ def test_weaviate_url_custom(): def test_sql_url_mariadb(): test_pw = os.environ.get("TEST_SQL_PASSWORD", "p") - cfg = SQLConfig( - type="mariadb", user="u", password=test_pw, host="h", port="3306", database="db" - ) + cfg = SQLConfig(type="mariadb", user="u", password=test_pw, host="h", port="3306", database="db") assert cfg.url == f"mariadb+pymysql://u:{test_pw}@h:3306/db" def test_sql_url_mssql(): test_pw = os.environ.get("TEST_SQL_PASSWORD", "p") - cfg = SQLConfig( - type="mssql", user="u", password=test_pw, host="h", port="1433", database="db" - ) + cfg = SQLConfig(type="mssql", user="u", password=test_pw, host="h", port="1433", database="db") assert cfg.url == f"mssql+pytds://u:{test_pw}@h:1433/db" diff --git a/tests/infrastructure/extractors/test_crawl4ai_extractor.py b/tests/infrastructure/extractors/test_crawl4ai_extractor.py index 2ea7057d..31b9fd1a 100644 --- a/tests/infrastructure/extractors/test_crawl4ai_extractor.py +++ b/tests/infrastructure/extractors/test_crawl4ai_extractor.py @@ -18,9 +18,7 @@ async def test_extract_success(self): mock_result.status_code = 200 # Mock AsyncWebCrawler - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -48,9 +46,7 @@ async def test_extract_not_exclude_links(self): mock_result.markdown = "[link](http://link.com)" mock_result.metadata = {} - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -67,9 +63,7 @@ async def test_extract_failed_result(self): mock_result.success = False mock_result.error_message = "Crawl error" - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -88,9 +82,7 @@ async def test_extract_with_selector(self): mock_result.markdown = "Filtered Content" mock_result.metadata = {} - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -124,9 +116,7 @@ async def test_extract_multi_depth_success(self): mock_sub_result.metadata = {"title": "Sub"} mock_sub_result.status_code = 200 - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -151,9 +141,7 @@ async def test_extract_multi_depth_no_links(self): mock_result.markdown = "Main Page" mock_result.links = {"internal": []} - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.return_value = mock_result @@ -168,9 +156,7 @@ async def test_extract_multi_depth_no_links(self): async def test_error_handling(self): url = "https://invalid-url" - with patch( - "src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler" - ) as mock_crawler_class: + with patch("src.infrastructure.extractors.crawl4ai_extractor.AsyncWebCrawler") as mock_crawler_class: mock_crawler = AsyncMock() mock_crawler.__aenter__.return_value = mock_crawler mock_crawler.arun.side_effect = Exception("Crawl failed") diff --git a/tests/infrastructure/extractors/test_docling_extractor.py b/tests/infrastructure/extractors/test_docling_extractor.py index 426c80a9..4b5bf1ec 100644 --- a/tests/infrastructure/extractors/test_docling_extractor.py +++ b/tests/infrastructure/extractors/test_docling_extractor.py @@ -116,9 +116,7 @@ def test_is_noisy_chunk(self): extractor = DoclingExtractor() # TOC like content - toc_content = ( - "Chapter 1 . . . . . . . . . . . 1\nChapter 2 . . . . . . . . . . . 2" - ) + toc_content = "Chapter 1 . . . . . . . . . . . 1\nChapter 2 . . . . . . . . . . . 2" assert extractor._is_noisy_chunk(toc_content) is True normal_content = "This is a normal paragraph of text with no noise." diff --git a/tests/infrastructure/extractors/test_youtube_extractor.py b/tests/infrastructure/extractors/test_youtube_extractor.py index cfb9b254..7f313ea9 100644 --- a/tests/infrastructure/extractors/test_youtube_extractor.py +++ b/tests/infrastructure/extractors/test_youtube_extractor.py @@ -27,9 +27,7 @@ def test_extract_transcript_success(self): video_id = "dummy_id" dummy_transcript = DummyTranscript() - with patch( - "src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi" - ) as mock_api_cls: + with patch("src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi") as mock_api_cls: mock_api = mock_api_cls.return_value mock_api.fetch.return_value = dummy_transcript with patch.object(logger, "info"), patch.object(logger, "debug"): @@ -42,26 +40,18 @@ def test_extract_transcript_success(self): if lang not in expected_languages: expected_languages.append(lang) - mock_api.fetch.assert_called_once_with( - video_id=video_id, languages=expected_languages - ) + mock_api.fetch.assert_called_once_with(video_id=video_id, languages=expected_languages) def test_extract_transcript_no_transcript_found(self): video_id = "dummy_id" transcript_data = "" # Should be a string requested_language_codes = ["pt"] message = "No transcript found" - with patch( - "src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi" - ) as mock_api_cls: + with patch("src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi") as mock_api_cls: mock_api = mock_api_cls.return_value # Both primary and fallback should fail to trigger the final exception - mock_api.fetch.side_effect = NoTranscriptFound( - transcript_data, requested_language_codes, message - ) - mock_api.list.side_effect = NoTranscriptFound( - transcript_data, requested_language_codes, message - ) + mock_api.fetch.side_effect = NoTranscriptFound(transcript_data, requested_language_codes, message) + mock_api.list.side_effect = NoTranscriptFound(transcript_data, requested_language_codes, message) with patch.object(logger, "info"), patch.object(logger, "error"): extractor = YoutubeExtractor(video_id) with pytest.raises(YoutubeTranscriptNotFoundException): @@ -69,9 +59,7 @@ def test_extract_transcript_no_transcript_found(self): def test_extract_transcript_transcripts_disabled(self): video_id = "dummy_id" - with patch( - "src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi" - ) as mock_api_cls: + with patch("src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi") as mock_api_cls: mock_api = mock_api_cls.return_value mock_api.fetch.side_effect = TranscriptsDisabled("Transcripts disabled") with patch.object(logger, "info"), patch.object(logger, "warning"): @@ -81,9 +69,7 @@ def test_extract_transcript_transcripts_disabled(self): def test_extract_transcript_generic_error(self): video_id = "dummy_id" - with patch( - "src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi" - ) as mock_api_cls: + with patch("src.infrastructure.extractors.youtube_extractor.YouTubeTranscriptApi") as mock_api_cls: mock_api = mock_api_cls.return_value mock_api.fetch.side_effect = Exception("Generic error") with patch.object(logger, "info"), patch.object(logger, "error"): @@ -113,9 +99,7 @@ def test_extract_metadata_success(self): "uploader_url": "https://youtube.com/uploader_dummy", } # Patch YoutubeDL in the correct module - with patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeDL" - ) as mock_ytdlp: + with patch("src.infrastructure.extractors.youtube_extractor.YoutubeDL") as mock_ytdlp: mock_instance = mock_ytdlp.return_value.__enter__.return_value mock_instance.extract_info.return_value = dummy_info with patch.object(logger, "info"): @@ -165,9 +149,7 @@ def test_extract_playlist_videos_success(self): None, ] } - with patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeDL" - ) as mock_ytdlp: + with patch("src.infrastructure.extractors.youtube_extractor.YoutubeDL") as mock_ytdlp: mock_instance = mock_ytdlp.return_value.__enter__.return_value mock_instance.extract_info.return_value = dummy_playlist_info @@ -183,9 +165,7 @@ def test_extract_playlist_videos_success(self): def test_extract_playlist_videos_empty_entries(self): playlist_url = "https://www.youtube.com/playlist?list=PL123" dummy_playlist_info = {"entries": []} - with patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeDL" - ) as mock_ytdlp: + with patch("src.infrastructure.extractors.youtube_extractor.YoutubeDL") as mock_ytdlp: mock_instance = mock_ytdlp.return_value.__enter__.return_value mock_instance.extract_info.return_value = dummy_playlist_info @@ -196,9 +176,7 @@ def test_extract_playlist_videos_empty_entries(self): @pytest.mark.PlaylistExtraction def test_extract_playlist_videos_error(self): playlist_url = "https://www.youtube.com/playlist?list=PL123" - with patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeDL" - ) as mock_ytdlp: + with patch("src.infrastructure.extractors.youtube_extractor.YoutubeDL") as mock_ytdlp: mock_instance = mock_ytdlp.return_value.__enter__.return_value mock_instance.extract_info.side_effect = Exception("Playlist Error") @@ -210,9 +188,7 @@ def test_extract_playlist_videos_error(self): def test_extract_playlist_videos_normalization(self): playlist_url = "https://www.youtube.com/watch?v=video1&list=PL123" dummy_playlist_info = {"entries": [{"id": "video1"}]} - with patch( - "src.infrastructure.extractors.youtube_extractor.YoutubeDL" - ) as mock_ytdlp: + with patch("src.infrastructure.extractors.youtube_extractor.YoutubeDL") as mock_ytdlp: mock_instance = mock_ytdlp.return_value.__enter__.return_value mock_instance.extract_info.return_value = dummy_playlist_info diff --git a/tests/infrastructure/loggers/test_std_logger.py b/tests/infrastructure/loggers/test_std_logger.py index 4d8e32fd..5e635d4c 100644 --- a/tests/infrastructure/loggers/test_std_logger.py +++ b/tests/infrastructure/loggers/test_std_logger.py @@ -18,9 +18,7 @@ def patch_settings(monkeypatch): @pytest.fixture def logger(): - log_format = ( - "{asctime} [{levelname}] {filename}:{lineno} {class}.{funcName} - {message}" - ) + log_format = "{asctime} [{levelname}] {filename}:{lineno} {class}.{funcName} - {message}" return StdLogger(log_format, name="test") diff --git a/tests/infrastructure/repositories/sql/test_additional_coverage.py b/tests/infrastructure/repositories/sql/test_additional_coverage.py index 4839a994..7e7e73c6 100644 --- a/tests/infrastructure/repositories/sql/test_additional_coverage.py +++ b/tests/infrastructure/repositories/sql/test_additional_coverage.py @@ -34,9 +34,7 @@ def test_more_sql_paths(): chunk_repo = ChunkIndexSQLRepository() # create a subject and exercise get_by_external_ref, list, get_by_name - ks_id = ks_repo.create_subject( - name="KS_extra", external_ref="external-extra", description="d" - ) + ks_id = ks_repo.create_subject(name="KS_extra", external_ref="external-extra", description="d") found_by_ext = ks_repo.get_by_external_ref("external-extra") assert found_by_ext is not None @@ -65,9 +63,7 @@ def test_more_sql_paths(): title="t", language="en", ) - by_source = cs_repo.get_by_source_info( - source_type=SourceType.YOUTUBE.value, external_source="v_extra" - ) + by_source = cs_repo.get_by_source_info(source_type=SourceType.YOUTUBE.value, external_source="v_extra") assert isinstance(by_source, list) and len(by_source) >= 1 by_subject = cs_repo.list_by_subject(ks_id2) @@ -80,16 +76,12 @@ def test_more_sql_paths(): cs_repo.finish_ingestion(UUID(int=0), embedding_model="m", dimensions=1, chunks=1) # create ingestion job and list_by_content_source - job_id = job_repo.create_job( - content_source_id=cs_id, status=IngestionJobStatus.STARTED.value - ) + job_id = job_repo.create_job(content_source_id=cs_id, status=IngestionJobStatus.STARTED.value) jobs = job_repo.list_by_content_source(cs_id) assert isinstance(jobs, list) # update non-existent job to hit not-found branch - job_repo.update_job( - UUID(int=0), status=IngestionJobStatus.FAILED.value, error_message="err" - ) + job_repo.update_job(UUID(int=0), status=IngestionJobStatus.FAILED.value, error_message="err") # chunk_index: create and search with filters chunk_id = uuid4() diff --git a/tests/infrastructure/repositories/sql/test_chunk_index_repository.py b/tests/infrastructure/repositories/sql/test_chunk_index_repository.py index 4a3eddc0..ff3ab520 100644 --- a/tests/infrastructure/repositories/sql/test_chunk_index_repository.py +++ b/tests/infrastructure/repositories/sql/test_chunk_index_repository.py @@ -219,9 +219,7 @@ def test_get_by_id_error(self, sqlite_memory): def test_delete_by_content_source_error(self, sqlite_memory): repo = ChunkIndexSQLRepository() - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Commit Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Commit Error")): with pytest.raises(Exception, match="Commit Error"): repo.delete_by_content_source(uuid4()) @@ -253,9 +251,7 @@ def test_delete_chunk_error(self, sqlite_memory): } ] ) - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Delete Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Delete Error")): with pytest.raises(Exception, match="Delete Error"): repo.delete_chunk(cid) @@ -275,8 +271,6 @@ def test_update_chunk_error(self, sqlite_memory): } ] ) - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error")): with pytest.raises(Exception, match="Update Error"): repo.update_chunk(cid, "new") diff --git a/tests/infrastructure/repositories/sql/test_content_source_repository.py b/tests/infrastructure/repositories/sql/test_content_source_repository.py index 2912655c..1724c545 100644 --- a/tests/infrastructure/repositories/sql/test_content_source_repository.py +++ b/tests/infrastructure/repositories/sql/test_content_source_repository.py @@ -151,27 +151,21 @@ def test_count_by_subject_error(self, sqlite_memory): def test_update_status_error(self, sqlite_memory): repo = ContentSourceSQLRepository() cid = repo.create(uuid4(), "pdf", "1.pdf") - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error")): with pytest.raises(Exception, match="Update Error"): repo.update_status(cid, "status") def test_update_title_error(self, sqlite_memory): repo = ContentSourceSQLRepository() cid = repo.create(uuid4(), "pdf", "1.pdf") - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Update Error")): with pytest.raises(Exception, match="Update Error"): repo.update_title(cid, "title") def test_finish_ingestion_error(self, sqlite_memory): repo = ContentSourceSQLRepository() cid = repo.create(uuid4(), "pdf", "1.pdf") - with patch( - "sqlalchemy.orm.Session.commit", side_effect=Exception("Finish Error") - ): + with patch("sqlalchemy.orm.Session.commit", side_effect=Exception("Finish Error")): with pytest.raises(Exception, match="Finish Error"): repo.finish_ingestion(cid, "m", 1, 1) diff --git a/tests/infrastructure/repositories/sql/test_diarization_repository.py b/tests/infrastructure/repositories/sql/test_diarization_repository.py index 0ff0807c..1e5bfa91 100644 --- a/tests/infrastructure/repositories/sql/test_diarization_repository.py +++ b/tests/infrastructure/repositories/sql/test_diarization_repository.py @@ -16,9 +16,7 @@ def test_create_pending(self, sqlite_memory): def test_save_new_and_update(self, sqlite_memory): repo = DiarizationRepository(sqlite_memory) - result = DiarizationResult( - segments=[Segment(start=0, end=1, text="t", speaker="S1")], language="en" - ) + result = DiarizationResult(segments=[Segment(start=0, end=1, text="t", speaker="S1")], language="en") # Save new record = repo.save(result, "T1", "upload", "f1", "/folder") @@ -27,9 +25,7 @@ def test_save_new_and_update(self, sqlite_memory): # Update existing result2 = DiarizationResult(segments=[], language="fr") - updated = repo.save( - result2, "T2", "upload", "f1", "/folder2", diarization_id=record.id - ) + updated = repo.save(result2, "T2", "upload", "f1", "/folder2", diarization_id=record.id) assert updated.id == record.id assert updated.name == "T2" assert updated.language == "fr" diff --git a/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py b/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py index 3bc9fa3c..f7dcbb33 100644 --- a/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py +++ b/tests/infrastructure/repositories/sql/test_ingestion_job_repository.py @@ -42,9 +42,7 @@ def test_create_job_error(self): # Mock session to raise an error during add from unittest.mock import patch - with patch( - "src.infrastructure.repositories.sql.ingestion_job_repository.Connector" - ) as mock_connector: + with patch("src.infrastructure.repositories.sql.ingestion_job_repository.Connector") as mock_connector: mock_session = mock_connector.return_value.__enter__.return_value mock_session.add.side_effect = Exception("DB Error") with pytest.raises(Exception): @@ -161,20 +159,14 @@ def test_list_jobs_filtering_and_search(self): self.repo.create_job(None, status="finished", source_title="gamma") self.repo.create_job(None, status="failed", source_title="delta") jid_dup = self.repo.create_job(None, status="failed", source_title="epsilon") - self.repo.update_job( - jid_dup, status="failed", error_message="Duplicate content detected" - ) + self.repo.update_job(jid_dup, status="failed", error_message="Duplicate content detected") self.repo.create_job(None, status="cancelled", source_title="zeta") # Test status filters assert len(self.repo.list_jobs(status="processing")) == 2 # started, processing assert len(self.repo.list_jobs(status="completed")) == 1 # finished - assert ( - len(self.repo.list_jobs(status="failed")) == 1 - ) # delta (epsilon is duplicate) - assert ( - len(self.repo.list_jobs(status="cancelled")) == 2 - ) # zeta, epsilon (duplicate) + assert len(self.repo.list_jobs(status="failed")) == 1 # delta (epsilon is duplicate) + assert len(self.repo.list_jobs(status="cancelled")) == 2 # zeta, epsilon (duplicate) assert len(self.repo.list_jobs(status="started")) == 1 # Test search @@ -199,9 +191,7 @@ def test_get_status_counts(self): self.repo.create_job(None, status="finished") # completed self.repo.create_job(None, status="failed") # failed jid = self.repo.create_job(None, status="failed") - self.repo.update_job( - jid, status="failed", error_message="Duplicate" - ) # cancelled + self.repo.update_job(jid, status="failed", error_message="Duplicate") # cancelled self.repo.create_job(None, status="cancelled") # cancelled counts = self.repo.get_status_counts() diff --git a/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py b/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py index 9523af64..dd220d04 100644 --- a/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py +++ b/tests/infrastructure/repositories/sql/test_knowledge_subject_repository.py @@ -16,9 +16,7 @@ class TestKnowledgeSubjectSQLRepository: def test_create_subject_success(self, sqlite_memory): repo = KnowledgeSubjectSQLRepository() - sid = repo.create_subject( - name="Test", external_ref="ref", description="desc", icon="icon" - ) + sid = repo.create_subject(name="Test", external_ref="ref", description="desc", icon="icon") assert sid is not None ks = repo.get_by_id(sid) diff --git a/tests/infrastructure/repositories/sql/test_repos_services.py b/tests/infrastructure/repositories/sql/test_repos_services.py index 465be75d..4d0d584b 100644 --- a/tests/infrastructure/repositories/sql/test_repos_services.py +++ b/tests/infrastructure/repositories/sql/test_repos_services.py @@ -30,9 +30,7 @@ def test_sql_repositories_and_services(): # KnowledgeSubject repository CRUD ks_repo = KnowledgeSubjectSQLRepository() - ks_id = ks_repo.create_subject( - name="KS1", external_ref="ext-ks1", description="desc1" - ) + ks_id = ks_repo.create_subject(name="KS1", external_ref="ext-ks1", description="desc1") assert isinstance(ks_id, UUID) ks_model = ks_repo.get_by_id(ks_id) diff --git a/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py b/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py index af2d92e1..fbdd43af 100644 --- a/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/chroma/test_chroma_chunk_repository.py @@ -140,9 +140,7 @@ def test_hybrid_search(self, mock_emb, monkeypatch): # Mock dependencies for BM25 mock_bm25 = MagicMock() monkeypatch.setitem(sys.modules, "rank_bm25", MagicMock()) - monkeypatch.setattr( - "rank_bm25.BM25Okapi", MagicMock(return_value=mock_bm25), raising=False - ) + monkeypatch.setattr("rank_bm25.BM25Okapi", MagicMock(return_value=mock_bm25), raising=False) mock_np = MagicMock() monkeypatch.setitem(sys.modules, "numpy", mock_np) mock_np.argsort.return_value = [0] diff --git a/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py b/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py index 8c66d678..2b09b659 100644 --- a/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py +++ b/tests/infrastructure/repositories/vector/faiss/test_chunk_repository_extended.py @@ -98,9 +98,7 @@ def test_hybrid_search_success(self, mock_emb, temp_index_path, monkeypatch): monkeypatch.setattr("rank_bm25.BM25Okapi", MagicMock(return_value=mock_bm25)) with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -130,9 +128,7 @@ def test_hybrid_search_success(self, mock_emb, temp_index_path, monkeypatch): def test_delete_simple_id(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -143,9 +139,7 @@ def test_delete_simple_id(self, mock_emb, temp_index_path): def test_list_chunks_with_limit(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) diff --git a/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py b/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py index 60589918..da642257 100644 --- a/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/faiss/test_faiss_chunk_repository.py @@ -52,9 +52,7 @@ def test_init_new(self, mock_emb, temp_index_path): def test_init_load_existing(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -72,9 +70,7 @@ def test_init_load_error(self, mock_emb, temp_index_path): def test_create_documents_new_store(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=False): - with patch( - "langchain_community.vectorstores.FAISS.from_texts" - ) as mock_from: + with patch("langchain_community.vectorstores.FAISS.from_texts") as mock_from: mock_store = MagicMock() mock_from.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -107,9 +103,7 @@ def test_retriever_not_init(self, mock_emb, temp_index_path): def test_retriever_semantic(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -133,24 +127,18 @@ def test_retriever_semantic(self, mock_emb, temp_index_path): def test_retriever_error(self, mock_emb, temp_index_path): with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) - mock_store.similarity_search_with_score.side_effect = Exception( - "Search error" - ) + mock_store.similarity_search_with_score.side_effect = Exception("Search error") with pytest.raises(Exception): repo.retriever("query") def test_bm25_search_empty(self, mock_emb, temp_index_path, monkeypatch): monkeypatch.setitem(sys.modules, "rank_bm25", MagicMock()) with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) @@ -160,9 +148,7 @@ def test_bm25_search_empty(self, mock_emb, temp_index_path, monkeypatch): def test_hybrid_search_empty(self, mock_emb, temp_index_path, monkeypatch): monkeypatch.setitem(sys.modules, "rank_bm25", MagicMock()) with patch("os.path.exists", return_value=True): - with patch( - "langchain_community.vectorstores.FAISS.load_local" - ) as mock_load: + with patch("langchain_community.vectorstores.FAISS.load_local") as mock_load: mock_store = MagicMock() mock_load.return_value = mock_store repo = ChunkFAISSRepository(mock_emb, temp_index_path) diff --git a/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py b/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py index 929cd4bb..d55a9b7d 100644 --- a/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/qdrant/test_chunk_repository.py @@ -38,9 +38,7 @@ def repo(self, mock_connector, mock_embedding_service): collection_name="test_collection", ) - def test_ensure_collection_exists_creates_if_not_present( - self, mock_connector, mock_embedding_service - ): + def test_ensure_collection_exists_creates_if_not_present(self, mock_connector, mock_embedding_service): mock_client = mock_connector.__enter__.return_value mock_client.collection_exists.return_value = False @@ -55,9 +53,7 @@ def test_ensure_collection_exists_creates_if_not_present( mock_client.create_collection.assert_called_once() mock_client.create_payload_index.assert_called_once() - def test_ensure_collection_exists_skips_if_present( - self, mock_connector, mock_embedding_service - ): + def test_ensure_collection_exists_skips_if_present(self, mock_connector, mock_embedding_service): mock_client = mock_connector.__enter__.return_value mock_client.collection_exists.return_value = True @@ -71,9 +67,7 @@ def test_ensure_collection_exists_skips_if_present( mock_client.create_collection.assert_not_called() - def test_create_documents_success( - self, repo, mock_connector, mock_embedding_service - ): + def test_create_documents_success(self, repo, mock_connector, mock_embedding_service): doc = ChunkModel( id=uuid4(), job_id=uuid4(), @@ -292,17 +286,9 @@ def test_bm25_search_with_existing_filters(self, repo, mock_connector): mock_client.query_points.return_value = MagicMock(points=[]) existing_filters = rest.Filter( - must=[ - rest.FieldCondition( - key="subject_id", match=rest.MatchValue(value="123") - ) - ], - should=[ - rest.FieldCondition(key="extra", match=rest.MatchValue(value="val")) - ], - must_not=[ - rest.FieldCondition(key="bad", match=rest.MatchValue(value="val")) - ], + must=[rest.FieldCondition(key="subject_id", match=rest.MatchValue(value="123"))], + should=[rest.FieldCondition(key="extra", match=rest.MatchValue(value="val"))], + must_not=[rest.FieldCondition(key="bad", match=rest.MatchValue(value="val"))], ) repo._bm25_search("query", 5, existing_filters) diff --git a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py index 1668851f..24831eac 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository.py @@ -25,9 +25,7 @@ def mock_embedding_service(self): @pytest.fixture def repo(self, mock_weaviate_client, mock_embedding_service): - with patch( - "src.infrastructure.repositories.vector.weaviate.weaviate_vector.WeaviateVector" - ): + with patch("src.infrastructure.repositories.vector.weaviate.weaviate_vector.WeaviateVector"): return ChunkWeaviateRepository( weaviate_client=mock_weaviate_client, embedding_service=mock_embedding_service, @@ -85,9 +83,7 @@ def test_create_documents_error(self, repo): subject_id=uuid4(), embedding_model="model", ) - repo.vector_store.__enter__.return_value.add_texts.side_effect = Exception( - "Weaviate error" - ) + repo.vector_store.__enter__.return_value.add_texts.side_effect = Exception("Weaviate error") with pytest.raises(Exception, match="Weaviate error"): repo.create_documents([doc]) diff --git a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py index b3c3d8d8..61942c32 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_chunk_repository_extended.py @@ -38,9 +38,7 @@ def mock_emb(self): @pytest.fixture def repo(self, mock_client, mock_emb): - with patch( - "src.infrastructure.repositories.vector.weaviate.weaviate_vector.WeaviateVector" - ): + with patch("src.infrastructure.repositories.vector.weaviate.weaviate_vector.WeaviateVector"): return ChunkWeaviateRepository(mock_client, mock_emb, "TestCollection") def create_mock_weaviate_obj(self, content="text", score=0.9): diff --git a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py index 616eac86..7846d5b7 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_chunk_repository.py @@ -47,12 +47,8 @@ def __enter__(self): class Collection: def __init__(self, response): self._response = response - self.query = SimpleNamespace( - fetch_objects=lambda **kwargs: self._response - ) - self.data = SimpleNamespace( - delete_many=lambda where: SimpleNamespace(matches=1) - ) + self.query = SimpleNamespace(fetch_objects=lambda **kwargs: self._response) + self.data = SimpleNamespace(delete_many=lambda where: SimpleNamespace(matches=1)) def get(self, _): return self diff --git a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_vector.py b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_vector.py index 1ca4dc7f..ac22fc08 100644 --- a/tests/infrastructure/repositories/vector/weaviate/test_weaviate_vector.py +++ b/tests/infrastructure/repositories/vector/weaviate/test_weaviate_vector.py @@ -37,9 +37,7 @@ def test_enter_returns_weaviate_vector_store(self, monkeypatch): captured = {} class FakeWeaviateVectorStore: - def __init__( - self, client, index_name, text_key, embedding, use_multi_tenancy - ): + def __init__(self, client, index_name, text_key, embedding, use_multi_tenancy): captured["client"] = client captured["index_name"] = index_name captured["text_key"] = text_key diff --git a/tests/infrastructure/services/test_chunk_index_service.py b/tests/infrastructure/services/test_chunk_index_service.py index b6fc4f81..522ae711 100644 --- a/tests/infrastructure/services/test_chunk_index_service.py +++ b/tests/infrastructure/services/test_chunk_index_service.py @@ -60,15 +60,11 @@ def test_create_chunks(self, service, mock_repo): def test_list_by_content_source(self, service, mock_repo): sid = uuid4() - mock_repo.list_by_content_source.return_value = [ - self.create_mock_model(content_source_id=sid) - ] + mock_repo.list_by_content_source.return_value = [self.create_mock_model(content_source_id=sid)] result = service.list_by_content_source(sid, limit=5, offset=0) assert len(result) == 1 assert result[0].content_source_id == sid - mock_repo.list_by_content_source.assert_called_once_with( - content_source_id=sid, limit=5, offset=0 - ) + mock_repo.list_by_content_source.assert_called_once_with(content_source_id=sid, limit=5, offset=0) def test_count_by_content_source(self, service, mock_repo): sid = uuid4() @@ -80,17 +76,13 @@ def test_delete_by_content_source(self, service, mock_repo): sid = uuid4() mock_repo.delete_by_content_source.return_value = 10 assert service.delete_by_content_source(sid) == 10 - mock_repo.delete_by_content_source.assert_called_once_with( - content_source_id=sid - ) + mock_repo.delete_by_content_source.assert_called_once_with(content_source_id=sid) def test_search(self, service, mock_repo): mock_repo.search.return_value = [self.create_mock_model()] result = service.search("query", top_k=3, filters={"a": "b"}) assert len(result) == 1 - mock_repo.search.assert_called_once_with( - query="query", top_k=3, filters={"a": "b"} - ) + mock_repo.search.assert_called_once_with(query="query", top_k=3, filters={"a": "b"}) def test_get_by_id(self, service, mock_repo): cid = uuid4() @@ -104,13 +96,9 @@ def test_get_by_id(self, service, mock_repo): def test_list_chunks(self, service, mock_repo): sid = uuid4() mock_repo.list_chunks.return_value = [self.create_mock_model()] - result = service.list_chunks( - limit=10, offset=5, source_id=sid, search_query="q" - ) + result = service.list_chunks(limit=10, offset=5, source_id=sid, search_query="q") assert len(result) == 1 - mock_repo.list_chunks.assert_called_once_with( - limit=10, offset=5, source_id=sid, search_query="q" - ) + mock_repo.list_chunks.assert_called_once_with(limit=10, offset=5, source_id=sid, search_query="q") def test_delete_chunk(self, service, mock_repo): cid = uuid4() diff --git a/tests/infrastructure/services/test_content_source_service.py b/tests/infrastructure/services/test_content_source_service.py index f6d8c3e8..8e12a524 100644 --- a/tests/infrastructure/services/test_content_source_service.py +++ b/tests/infrastructure/services/test_content_source_service.py @@ -42,9 +42,7 @@ def test_create_source(self, service, mock_repo): sid = uuid4() cid = uuid4() mock_repo.create.return_value = cid - mock_repo.get_by_id.return_value = self.create_mock_model( - id=cid, subject_id=sid - ) + mock_repo.get_by_id.return_value = self.create_mock_model(id=cid, subject_id=sid) res = service.create_source( subject_id=sid, diff --git a/tests/infrastructure/services/test_ingestion_job_service.py b/tests/infrastructure/services/test_ingestion_job_service.py index 57bc88fe..2d2f5b84 100644 --- a/tests/infrastructure/services/test_ingestion_job_service.py +++ b/tests/infrastructure/services/test_ingestion_job_service.py @@ -54,9 +54,7 @@ def test_create_job(self, service, mock_repo): assert result.id == jid mock_repo.create_job.assert_called_once_with( - content_source_id=mock_repo.create_job.call_args.kwargs.get( - "content_source_id" - ), + content_source_id=mock_repo.create_job.call_args.kwargs.get("content_source_id"), status="started", embedding_model="emb", pipeline_version="1.0", @@ -93,9 +91,7 @@ def test_link_job_to_source(self, service, mock_repo): jid = uuid4() sid = uuid4() service.link_job_to_source(jid, sid, "pdf") - mock_repo.link_job_to_source.assert_called_once_with( - job_id=jid, content_source_id=sid, ingestion_type="pdf" - ) + mock_repo.link_job_to_source.assert_called_once_with(job_id=jid, content_source_id=sid, ingestion_type="pdf") def test_get_by_id(self, service, mock_repo): jid = uuid4() @@ -124,6 +120,4 @@ def test_list_recent_jobs_by_subject(self, service, mock_repo): mock_repo.list_recent_jobs_by_subject.return_value = [self.create_mock_model()] result = service.list_recent_jobs_by_subject(sid, limit=5, offset=0) assert len(result) == 1 - mock_repo.list_recent_jobs_by_subject.assert_called_once_with( - sid, limit=5, offset=0 - ) + mock_repo.list_recent_jobs_by_subject.assert_called_once_with(sid, limit=5, offset=0) diff --git a/tests/infrastructure/services/test_knowledge_subject_service.py b/tests/infrastructure/services/test_knowledge_subject_service.py index dba0f8c8..8ad4b535 100644 --- a/tests/infrastructure/services/test_knowledge_subject_service.py +++ b/tests/infrastructure/services/test_knowledge_subject_service.py @@ -38,15 +38,11 @@ def test_create_subject(self, service, mock_repo): mock_repo.create_subject.return_value = sid mock_repo.get_by_id.return_value = self.create_mock_model(id=sid, name="New") - result = service.create_subject( - name="New", external_ref="ref", description="d", icon="i" - ) + result = service.create_subject(name="New", external_ref="ref", description="d", icon="i") assert result.id == sid assert result.name == "New" - mock_repo.create_subject.assert_called_once_with( - name="New", external_ref="ref", description="d", icon="i" - ) + mock_repo.create_subject.assert_called_once_with(name="New", external_ref="ref", description="d", icon="i") def test_get_by_name(self, service, mock_repo): mock_repo.get_by_name.return_value = self.create_mock_model(name="Target") @@ -62,17 +58,13 @@ def test_get_subject_by_id(self, service, mock_repo): mock_repo.get_by_id.assert_called_once_with(sid) def test_get_subject_by_external_ref(self, service, mock_repo): - mock_repo.get_by_external_ref.return_value = self.create_mock_model( - external_ref="ref123" - ) + mock_repo.get_by_external_ref.return_value = self.create_mock_model(external_ref="ref123") result = service.get_subject_by_external_ref("ref123") assert result.external_ref == "ref123" mock_repo.get_by_external_ref.assert_called_once_with("ref123") def test_get_or_create_by_external_ref_existing(self, service, mock_repo): - mock_repo.get_by_external_ref.return_value = self.create_mock_model( - external_ref="ext" - ) + mock_repo.get_by_external_ref.return_value = self.create_mock_model(external_ref="ext") result = service.get_or_create_by_external_ref("ext") assert result.external_ref == "ext" mock_repo.create_subject.assert_not_called() @@ -83,9 +75,7 @@ def test_get_or_create_by_external_ref_new(self, service, mock_repo): self.create_mock_model(external_ref="ext", name="Name"), ] mock_repo.create_subject.return_value = uuid4() - mock_repo.get_by_id.return_value = self.create_mock_model( - external_ref="ext", name="Name" - ) + mock_repo.get_by_id.return_value = self.create_mock_model(external_ref="ext", name="Name") result = service.get_or_create_by_external_ref("ext", name="Name") @@ -105,9 +95,7 @@ def test_list_subjects(self, service, mock_repo): def test_update_subject(self, service, mock_repo): sid = uuid4() service.update_subject(sid, name="Updated") - mock_repo.update.assert_called_once_with( - id=sid, name="Updated", description=None, external_ref=None, icon=None - ) + mock_repo.update.assert_called_once_with(id=sid, name="Updated", description=None, external_ref=None, icon=None) def test_delete_subject(self, service, mock_repo): sid = uuid4() diff --git a/tests/infrastructure/services/test_model_loader_service.py b/tests/infrastructure/services/test_model_loader_service.py index 2b959aff..cc76baf4 100644 --- a/tests/infrastructure/services/test_model_loader_service.py +++ b/tests/infrastructure/services/test_model_loader_service.py @@ -17,9 +17,7 @@ def reset_singleton(): @pytest.mark.Dependencies class TestModelLoaderService: def test_load_model_success(self): - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ) as mock_st: + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer") as mock_st: mock_model = MagicMock() mock_st.return_value = mock_model @@ -43,9 +41,7 @@ def test_load_model_failure(self): service.load_model() def test_dimensions_property(self): - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ) as mock_st: + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer") as mock_st: mock_model = MagicMock() mock_model.get_sentence_embedding_dimension.return_value = 384 mock_st.return_value = mock_model @@ -54,9 +50,7 @@ def test_dimensions_property(self): assert service.dimensions == 384 def test_dimensions_property_failure(self): - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ) as mock_st: + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer") as mock_st: mock_model = MagicMock() mock_model.get_sentence_embedding_dimension.return_value = None mock_st.return_value = mock_model @@ -66,9 +60,7 @@ def test_dimensions_property_failure(self): assert service.dimensions == 0 def test_max_seq_length_property(self): - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ) as mock_st: + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer") as mock_st: mock_model = MagicMock() mock_model.max_seq_length = 512 mock_st.return_value = mock_model @@ -77,9 +69,7 @@ def test_max_seq_length_property(self): assert service.max_seq_length == 512 def test_max_seq_length_property_default(self): - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ) as mock_st: + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer") as mock_st: mock_model = MagicMock() # My implementation uses getattr with default 0 if hasattr(mock_model, "max_seq_length"): diff --git a/tests/infrastructure/services/test_pyannote_voice_recognizer.py b/tests/infrastructure/services/test_pyannote_voice_recognizer.py index d90c1f75..68be9819 100644 --- a/tests/infrastructure/services/test_pyannote_voice_recognizer.py +++ b/tests/infrastructure/services/test_pyannote_voice_recognizer.py @@ -82,13 +82,9 @@ def test_identify_dir(self): assert res.results["s1"].best_match == "A" def test_get_inference_internal(self): - with patch( - "src.infrastructure.services.pyannote_voice_recognizer.model_loader" - ) as mock_loader: + with patch("src.infrastructure.services.pyannote_voice_recognizer.model_loader") as mock_loader: mock_loader.get_voice_inference.return_value = MagicMock() recognizer = VoiceRecognizer(MagicMock(), hf_token="f") inf = recognizer._get_inference() assert inf is not None - mock_loader.get_voice_inference.assert_called_once_with( - hf_token="f", device=recognizer._device - ) + mock_loader.get_voice_inference.assert_called_once_with(hf_token="f", device=recognizer._device) diff --git a/tests/infrastructure/services/test_re_rank_service.py b/tests/infrastructure/services/test_re_rank_service.py index 358ae3dc..ad5043d9 100644 --- a/tests/infrastructure/services/test_re_rank_service.py +++ b/tests/infrastructure/services/test_re_rank_service.py @@ -19,9 +19,7 @@ class TestReRankService: def test_init_success(self, monkeypatch): mock_ranker_class = MagicMock() - monkeypatch.setattr( - "src.infrastructure.services.re_rank_service.Ranker", mock_ranker_class - ) + monkeypatch.setattr("src.infrastructure.services.re_rank_service.Ranker", mock_ranker_class) service = ReRankService(model_name="test-model") @@ -33,9 +31,7 @@ def test_init_success(self, monkeypatch): def test_init_failure(self, monkeypatch): mock_ranker_class = MagicMock(side_effect=Exception("Failed to load")) - monkeypatch.setattr( - "src.infrastructure.services.re_rank_service.Ranker", mock_ranker_class - ) + monkeypatch.setattr("src.infrastructure.services.re_rank_service.Ranker", mock_ranker_class) service = ReRankService() diff --git a/tests/infrastructure/services/test_text_splitter_service.py b/tests/infrastructure/services/test_text_splitter_service.py index 9464cdf4..e6365a56 100644 --- a/tests/infrastructure/services/test_text_splitter_service.py +++ b/tests/infrastructure/services/test_text_splitter_service.py @@ -13,9 +13,7 @@ def mock_tokenizer(self): # For simple testing, each word is an ID tokenizer.encode.side_effect = lambda text, **kwargs: [ord(c) for c in text] # Mock decode: converts IDs back to characters - tokenizer.decode.side_effect = lambda ids, **kwargs: "".join( - chr(i) for i in ids - ) + tokenizer.decode.side_effect = lambda ids, **kwargs: "".join(chr(i) for i in ids) return tokenizer def test_split_text_success(self, mock_tokenizer): @@ -60,9 +58,7 @@ def test_split_text_invalid_params(self, mock_tokenizer): service = TextSplitterService(tokenizer=mock_tokenizer) with pytest.raises(ValueError) as excinfo: service.split_text(text="some text", tokens_per_chunk=10, tokens_overlap=10) - assert "tokens_per_chunk must be greater than tokens_overlap" in str( - excinfo.value - ) + assert "tokens_per_chunk must be greater than tokens_overlap" in str(excinfo.value) def test_split_text_tokenizer_type_error(self, mock_tokenizer): # Test the fallback in split_text if add_special_tokens=False is not supported diff --git a/tests/infrastructure/services/test_voice_profile_service.py b/tests/infrastructure/services/test_voice_profile_service.py index b606f724..4fb1c05c 100644 --- a/tests/infrastructure/services/test_voice_profile_service.py +++ b/tests/infrastructure/services/test_voice_profile_service.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from src.infrastructure.repositories.sql.models.voice_record import VoiceRecord from src.infrastructure.services.voice_profile_service import VoiceDB @@ -12,9 +13,7 @@ class TestVoiceDB: def mock_infra(self): # Stub StorageService and Os logic globally for this class to prevent ANY footprint with ( - patch( - "src.infrastructure.services.voice_profile_service.StorageService" - ) as mock_cls, + patch("src.infrastructure.services.voice_profile_service.StorageService") as mock_cls, patch("os.path.exists", return_value=True), patch("os.path.isdir", return_value=True), patch("os.makedirs"), @@ -30,9 +29,7 @@ def test_add_voice_local_file(self, sqlite_memory): "src.infrastructure.services.voice_profile_service.get_best_device", return_value="cpu", ): - with patch( - "src.infrastructure.services.voice_profile_service.VoiceDB._get_inference" - ) as mock_inf_getter: + with patch("src.infrastructure.services.voice_profile_service.VoiceDB._get_inference") as mock_inf_getter: mock_inf = MagicMock() mock_inf_getter.return_value = mock_inf mock_inf.return_value = MagicMock(tolist=lambda: [0.1, 0.2]) @@ -47,6 +44,12 @@ def test_add_voice_local_file(self, sqlite_memory): voice_id, _ = db_service.add("Test User", "local.wav") assert voice_id is not None + + # Verify status was updated to ready + record = sqlite_memory.get(VoiceRecord, voice_id) + assert record.status == "ready" + assert record.status_message is None + voices = db_service.voices assert "Test User" in voices @@ -62,9 +65,7 @@ def test_remove_voice(self, sqlite_memory): sqlite_memory.add(v) sqlite_memory.commit() - self.mock_storage.list_files.return_value = [ - {"key": "voices/1/reference_1.wav"} - ] + self.mock_storage.list_files.return_value = [{"key": "voices/1/reference_1.wav"}] db_service = VoiceDB(sqlite_memory, hf_token="fake") db_service.remove("Test") @@ -76,9 +77,7 @@ def test_add_voice_s3_source(self, sqlite_memory): "src.infrastructure.services.voice_profile_service.get_best_device", return_value="cpu", ): - with patch( - "src.infrastructure.services.voice_profile_service.VoiceDB._get_inference" - ) as mock_inf_getter: + with patch("src.infrastructure.services.voice_profile_service.VoiceDB._get_inference") as mock_inf_getter: mock_inf = MagicMock() mock_inf_getter.return_value = mock_inf mock_inf.return_value = MagicMock(tolist=lambda: [0.1, 0.2]) @@ -149,3 +148,69 @@ def test_remove_by_invalid_name(self, sqlite_memory): db_service = VoiceDB(sqlite_memory, hf_token="fake") with pytest.raises(KeyError, match="not found"): db_service.remove("NonExistent") + + def test_add_voice_s3_download_failure(self, sqlite_memory): + self.mock_storage.download_file.side_effect = Exception("S3 Error") + db_service = VoiceDB(sqlite_memory, hf_token="fake") + + with pytest.raises(ValueError, match="Failed to download from S3"): + db_service.add("S3 User", "s3://bucket/voice.wav") + + # After failure, it should NOT have created a successful record + # but let's check if it created a fixed "failed" record if we implement it that way. + # Currently, if it fails at download, it doesn't even create the record yet in the DB. + # Wait, the record is created AFTER the S3 check block. + # So no record should exist in DB yet. + assert sqlite_memory.query(VoiceRecord).count() == 0 + + def test_add_voice_embedding_extraction_failure(self, sqlite_memory): + db_service = VoiceDB(sqlite_memory, hf_token="fake") + + # Fail during embedding extraction + with patch.object(db_service, "_extract_embedding", side_effect=Exception("Model Error")): + with pytest.raises(Exception, match="Model Error"): + db_service.add("Failed User", "local.wav") + + # Verify it marked as failed + record = sqlite_memory.query(VoiceRecord).filter(VoiceRecord.name == "Failed User").first() + assert record is not None + assert record.status == "failed" + assert "Model Error" in record.status_message + + def test_list_audio_files(self, sqlite_memory): + v = VoiceRecord(id="v1", name="V1", embedding=[0.1], audios_path="voices/v1/", status="ready") + sqlite_memory.add(v) + sqlite_memory.commit() + + self.mock_storage.list_files.return_value = [{"key": "f1.wav"}, {"key": "f2.wav"}] + + db_service = VoiceDB(sqlite_memory, hf_token="fake") + files = db_service.list_audio_files("v1") + + assert len(files) == 2 + self.mock_storage.list_files.assert_called_once_with(prefix="voices/v1/", extension=".wav") + + def test_delete_audio_file(self, sqlite_memory): + db_service = VoiceDB(sqlite_memory, hf_token="fake") + db_service.delete_audio_file("some/key.wav") + self.mock_storage.delete_file.assert_called_once_with("some/key.wav") + + def test_list_voices_and_len(self, sqlite_memory): + v1 = VoiceRecord(id="v1", name="Ready", embedding=[0.1], status="ready") + v2 = VoiceRecord(id="v2", name="Processing", embedding=[], status="processing") + sqlite_memory.add_all([v1, v2]) + sqlite_memory.commit() + + db_service = VoiceDB(sqlite_memory, hf_token="fake") + + # list_voices should only show ready ones + voice_list = db_service.list_voices() + assert "Ready" in voice_list + assert "Processing" not in voice_list + + # len should only show ready ones + assert len(db_service) == 1 + + # .voices property should only show ready ones + assert "Ready" in db_service.voices + assert "Processing" not in db_service.voices diff --git a/tests/infrastructure/services/test_whisperx_audio_diarizer.py b/tests/infrastructure/services/test_whisperx_audio_diarizer.py index 26dad877..65a693ca 100644 --- a/tests/infrastructure/services/test_whisperx_audio_diarizer.py +++ b/tests/infrastructure/services/test_whisperx_audio_diarizer.py @@ -10,13 +10,9 @@ class TestAudioDiarizer: @pytest.fixture(autouse=True) def mock_deps(self): - with patch( - "src.infrastructure.services.whisperx_audio_diarizer.model_loader" - ) as ml: + with patch("src.infrastructure.services.whisperx_audio_diarizer.model_loader") as ml: self.mock_model_loader = ml - with patch( - "src.infrastructure.services.whisperx_audio_diarizer.whisperx" - ) as wx: + with patch("src.infrastructure.services.whisperx_audio_diarizer.whisperx") as wx: self.whisperx = wx yield diff --git a/tests/infrastructure/services/test_youtube_audio_downloader.py b/tests/infrastructure/services/test_youtube_audio_downloader.py index 72d97a4a..899b5a96 100644 --- a/tests/infrastructure/services/test_youtube_audio_downloader.py +++ b/tests/infrastructure/services/test_youtube_audio_downloader.py @@ -16,15 +16,10 @@ def test_download_success(self, mock_makedirs, mock_ytdl): mock_instance.prepare_filename.return_value = "temp_audio/Test Audio.webm" extractor = YoutubeExtractor() - result = extractor.download_audio( - "https://www.youtube.com/watch?v=dummy", output_dir="temp_audio" - ) + result = extractor.download_audio("https://www.youtube.com/watch?v=dummy", output_dir="temp_audio") # In the code, it changes extension to .mp3 using Path.with_suffix - assert ( - result == "temp_audio\\Test Audio.mp3" - or result == "temp_audio/Test Audio.mp3" - ) + assert result == "temp_audio\\Test Audio.mp3" or result == "temp_audio/Test Audio.mp3" assert mock_instance.extract_info.called assert mock_makedirs.called diff --git a/tests/infrastructure/services/test_youtube_data_service.py b/tests/infrastructure/services/test_youtube_data_service.py index 9ad57e6d..67975f7a 100644 --- a/tests/infrastructure/services/test_youtube_data_service.py +++ b/tests/infrastructure/services/test_youtube_data_service.py @@ -42,28 +42,18 @@ def dummy_transcript(): def mock_model_loader_service(): mock = MagicMock() mock.model.tokenizer = MagicMock() - mock.model.tokenizer.encode.side_effect = lambda txt, add_special_tokens=False: [ - ord(c) for c in txt - ] - mock.model.tokenizer.decode.side_effect = lambda ids, skip_special_tokens=True: ( - "".join(chr(i) for i in ids) - ) + mock.model.tokenizer.encode.side_effect = lambda txt, add_special_tokens=False: [ord(c) for c in txt] + mock.model.tokenizer.decode.side_effect = lambda ids, skip_special_tokens=True: "".join(chr(i) for i in ids) return mock @pytest.mark.YoutubeDataProcessService class TestYoutubeDataService: - def test_split_by_time( - self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor - ): + def test_split_by_time(self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript with patch.object(logger, "info"), patch.object(logger, "debug"): - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) - docs = splitter.split_transcript( - mode="time", time_window_size=30, time_overlap=5 - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) + docs = splitter.split_transcript(mode="time", time_window_size=30, time_overlap=5) assert isinstance(docs, list) assert all(isinstance(doc, Document) for doc in docs) assert len(docs) > 0 @@ -72,17 +62,11 @@ def test_split_by_time( assert "window_end" in doc.metadata assert doc.metadata["video_id"] == "dummy_video_id" - def test_split_by_tokens( - self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor - ): + def test_split_by_tokens(self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript with patch.object(logger, "info"), patch.object(logger, "debug"): - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) - docs = splitter.split_transcript( - mode="tokens", tokens_per_chunk=10, tokens_overlap=2 - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) + docs = splitter.split_transcript(mode="tokens", tokens_per_chunk=10, tokens_overlap=2) assert isinstance(docs, list) assert all(isinstance(doc, Document) for doc in docs) assert len(docs) > 0 @@ -94,47 +78,29 @@ def test_split_by_tokens( def test_empty_transcript(self, mock_model_loader_service, dummy_yt_extractor): with patch.object(logger, "info"), patch.object(logger, "debug"): - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) - docs = splitter.split_transcript( - mode="time", time_window_size=30, time_overlap=5 - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) + docs = splitter.split_transcript(mode="time", time_window_size=30, time_overlap=5) assert docs == [] - def test_invalid_overlap( - self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor - ): + def test_invalid_overlap(self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript with patch.object(logger, "error"): - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) with pytest.raises(ValueError): - splitter.split_transcript( - mode="time", time_window_size=30, time_overlap=40 - ) + splitter.split_transcript(mode="time", time_window_size=30, time_overlap=40) with pytest.raises(ValueError): - splitter.split_transcript( - mode="tokens", tokens_per_chunk=5, tokens_overlap=10 - ) + splitter.split_transcript(mode="tokens", tokens_per_chunk=5, tokens_overlap=10) def test_no_tokenizer(self, dummy_transcript, dummy_yt_extractor): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript with patch.object(logger, "error"): mock_model_loader_service = MagicMock() mock_model_loader_service.model.tokenizer = None - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) with pytest.raises(RuntimeError): - splitter.split_transcript( - mode="tokens", tokens_per_chunk=10, tokens_overlap=2 - ) + splitter.split_transcript(mode="tokens", tokens_per_chunk=10, tokens_overlap=2) - def test_encode_typeerror_fallback( - self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor - ): + def test_encode_typeerror_fallback(self, dummy_transcript, mock_model_loader_service, dummy_yt_extractor): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript def encode_side_effect(txt, add_special_tokens=None): @@ -142,15 +108,9 @@ def encode_side_effect(txt, add_special_tokens=None): raise TypeError("unexpected argument") return [ord(c) for c in txt] - mock_model_loader_service.model.tokenizer.encode.side_effect = ( - encode_side_effect - ) - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) - docs = splitter.split_transcript( - mode="tokens", tokens_per_chunk=10, tokens_overlap=2 - ) + mock_model_loader_service.model.tokenizer.encode.side_effect = encode_side_effect + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) + docs = splitter.split_transcript(mode="tokens", tokens_per_chunk=10, tokens_overlap=2) assert isinstance(docs, list) assert all(isinstance(doc, Document) for doc in docs) assert len(docs) > 0 @@ -160,32 +120,20 @@ def encode_side_effect(txt, add_special_tokens=None): assert doc.metadata["video_id"] == "dummy_video_id" assert "tokens_count" in doc.metadata - def test_tokenize_skips_empty_snippet( - self, mock_model_loader_service, dummy_yt_extractor - ): + def test_tokenize_skips_empty_snippet(self, mock_model_loader_service, dummy_yt_extractor): class EmptySnippet: text = "" start = 0 duration = 0 transcript = [EmptySnippet()] - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) with patch.object(logger, "debug") as mock_debug: - splitter._tokenize_transcript( - transcript, mock_model_loader_service.model.tokenizer, {} - ) - mock_debug.assert_any_call( - "Skipping empty snippet", context={"snippet_index": 0} - ) - - def test_decode_typeerror_and_attributeerror( - self, mock_model_loader_service, dummy_yt_extractor - ): - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter._tokenize_transcript(transcript, mock_model_loader_service.model.tokenizer, {}) + mock_debug.assert_any_call("Skipping empty snippet", context={"snippet_index": 0}) + + def test_decode_typeerror_and_attributeerror(self, mock_model_loader_service, dummy_yt_extractor): + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) chunk_ids = [65, 66] transcript = [DummySnippet("A", 0, 1)] @@ -208,9 +156,7 @@ def decode_typeerror(ids, *args, **kwargs): def decode_attributeerror(ids, *args, **kwargs): raise AttributeError("no decode") - mock_model_loader_service.model.tokenizer.decode.side_effect = ( - decode_attributeerror - ) + mock_model_loader_service.model.tokenizer.decode.side_effect = decode_attributeerror docs = splitter._create_token_chunks( chunk_ids, [{"start": 0, "end": 1, "snippet_index": 0}] * len(chunk_ids), @@ -221,31 +167,21 @@ def decode_attributeerror(ids, *args, **kwargs): ) assert docs[0].page_content == str(chunk_ids) - def test_create_token_chunks_empty_meta( - self, mock_model_loader_service, dummy_yt_extractor - ): + def test_create_token_chunks_empty_meta(self, mock_model_loader_service, dummy_yt_extractor): import math - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) token_ids = [65, 66] token_meta = [] transcript = [DummySnippet("A", 0, 1)] - docs = splitter._create_token_chunks( - token_ids, token_meta, len(token_ids), len(token_ids), transcript, {} - ) + docs = splitter._create_token_chunks(token_ids, token_meta, len(token_ids), len(token_ids), transcript, {}) assert math.isclose(docs[0].metadata["window_start"], 0.0) assert math.isclose(docs[0].metadata["window_end"], 0.0) - def test_unknown_mode_error( - self, mock_model_loader_service, dummy_yt_extractor, dummy_transcript - ): + def test_unknown_mode_error(self, mock_model_loader_service, dummy_yt_extractor, dummy_transcript): dummy_yt_extractor.extract_transcript.return_value = dummy_transcript with patch.object(logger, "error") as mock_error: - splitter = YoutubeDataProcessService( - mock_model_loader_service, dummy_yt_extractor - ) + splitter = YoutubeDataProcessService(mock_model_loader_service, dummy_yt_extractor) with pytest.raises(ValueError) as exc: splitter.split_transcript(mode="unknown_mode", tokens_per_chunk=10) # type: ignore assert "Unknown splitting mode" in str(exc.value) diff --git a/tests/presentation/api/routes/test_audio_diarization_router.py b/tests/presentation/api/routes/test_audio_diarization_router.py index 0b80bc40..e0d78925 100644 --- a/tests/presentation/api/routes/test_audio_diarization_router.py +++ b/tests/presentation/api/routes/test_audio_diarization_router.py @@ -59,9 +59,7 @@ def test_update_diarization_segments_success(self, mock_db, mock_task_queue): mock_cs_repo = mock_cs_repo_cls.return_value mock_cs_repo.get_by_diarization_id.return_value = mock_cs - response = client.patch( - f"/rest/audio/{record_mock.id}", json={"segments": [{"text": "hello"}]} - ) + response = client.patch(f"/rest/audio/{record_mock.id}", json={"segments": [{"text": "hello"}]}) assert response.status_code == 200 assert response.json()["status"] == "success" @@ -145,9 +143,7 @@ def test_start_audio_processing_pipeline_duplicate(self, mock_db, mock_task_queu app.dependency_overrides.clear() - def test_start_audio_processing_pipeline_with_subject( - self, mock_db, mock_task_queue - ): + def test_start_audio_processing_pipeline_with_subject(self, mock_db, mock_task_queue): app.dependency_overrides[get_db] = lambda: mock_db app.dependency_overrides[get_task_queue_service] = lambda: mock_task_queue record_mock = MagicMock(id="new-uuid") @@ -173,9 +169,7 @@ def test_start_audio_processing_pipeline_with_subject( app.dependency_overrides.clear() - def test_start_audio_processing_pipeline_failed_retry( - self, mock_db, mock_task_queue - ): + def test_start_audio_processing_pipeline_failed_retry(self, mock_db, mock_task_queue): app.dependency_overrides[get_db] = lambda: mock_db app.dependency_overrides[get_task_queue_service] = lambda: mock_task_queue diff --git a/tests/presentation/api/routes/test_auth_router.py b/tests/presentation/api/routes/test_auth_router.py index 32bd5d9d..326315b9 100644 --- a/tests/presentation/api/routes/test_auth_router.py +++ b/tests/presentation/api/routes/test_auth_router.py @@ -78,9 +78,7 @@ async def test_google_callback_success(self, mock_auth_use_case): } ) - response = client.get( - "/rest/auth/google/callback?code=testcode&state=state123&expected_state=state123" - ) + response = client.get("/rest/auth/google/callback?code=testcode&state=state123&expected_state=state123") assert response.status_code == 200 assert response.json()["access_token"] == "jwt" diff --git a/tests/presentation/api/routes/test_chunk_router.py b/tests/presentation/api/routes/test_chunk_router.py index e33fab3c..cac4484c 100644 --- a/tests/presentation/api/routes/test_chunk_router.py +++ b/tests/presentation/api/routes/test_chunk_router.py @@ -55,9 +55,7 @@ def test_update_chunk_success(mock_chunk_index_service, mock_chunk_vector_servic assert response.status_code == 200 assert response.json() is True mock_chunk_index_service.get_by_id.assert_called_once() - mock_chunk_index_service.update_chunk.assert_called_once_with( - chunk_id, "new content" - ) + mock_chunk_index_service.update_chunk.assert_called_once_with(chunk_id, "new content") mock_chunk_vector_service.delete_by_id.assert_called_once_with(chunk_id) mock_chunk_vector_service.index_documents.assert_called_once() diff --git a/tests/presentation/api/routes/test_ingest_router.py b/tests/presentation/api/routes/test_ingest_router.py index c9a9af72..0ecd6f71 100644 --- a/tests/presentation/api/routes/test_ingest_router.py +++ b/tests/presentation/api/routes/test_ingest_router.py @@ -37,9 +37,7 @@ def test_ingest_youtube_success(mock_use_case): video_results=[], ) - response = client.post( - "/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"} - ) + response = client.post("/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"}) assert response.status_code == 200 mock_use_case.execute.assert_called_once() @@ -51,9 +49,7 @@ def test_ingest_youtube_skipped(mock_use_case): mock_result.reason = "This content has already been ingested." mock_use_case.execute.return_value = mock_result - response = client.post( - "/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"} - ) + response = client.post("/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"}) assert response.status_code == 409 assert response.json()["detail"] == "This content has already been ingested." @@ -71,9 +67,7 @@ def test_ingest_youtube_value_error(mock_use_case): def test_ingest_youtube_exception(mock_use_case): mock_use_case.execute.side_effect = Exception("Internal error") - response = client.post( - "/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"} - ) + response = client.post("/rest/ingest/youtube", json={"video_url": "https://youtube.com/watch?v=123"}) assert response.status_code == 500 assert response.json()["detail"] == "Internal error" diff --git a/tests/presentation/api/routes/test_ingest_router_file.py b/tests/presentation/api/routes/test_ingest_router_file.py index 3edb031d..3dc83fd7 100644 --- a/tests/presentation/api/routes/test_ingest_router_file.py +++ b/tests/presentation/api/routes/test_ingest_router_file.py @@ -42,10 +42,7 @@ def test_ingest_file_success(self, client, mock_file_use_case): response = client.post("/rest/ingest/file", files=files, data=data) assert response.status_code == 200 - assert ( - response.json()["message"] - == "File upload successful, ingestion started in background." - ) + assert response.json()["message"] == "File upload successful, ingestion started in background." assert response.json()["file_name"] == "test.txt" def test_ingest_file_invalid_uuid(self, client, mock_file_use_case): @@ -58,9 +55,7 @@ def test_ingest_file_invalid_uuid(self, client, mock_file_use_case): def test_ingest_file_missing_file(self, client): response = client.post("/rest/ingest/file", data={"subject_name": "test"}) - assert ( - response.status_code == 422 - ) # FastAPI validation error for missing required File + assert response.status_code == 422 # FastAPI validation error for missing required File def test_ingest_file_url_success(self, client): data = { diff --git a/tests/presentation/api/routes/test_settings_router.py b/tests/presentation/api/routes/test_settings_router.py index ff679fed..15d8eb43 100644 --- a/tests/presentation/api/routes/test_settings_router.py +++ b/tests/presentation/api/routes/test_settings_router.py @@ -69,9 +69,7 @@ def test_check_health_api(): def test_check_health_sql(): - with patch( - "src.presentation.api.routes.settings_router.Connector" - ) as mock_connector: + with patch("src.presentation.api.routes.settings_router.Connector") as mock_connector: mock_session = MagicMock() mock_connector.return_value.__enter__.return_value = mock_session @@ -109,9 +107,7 @@ def test_check_health_unknown(): def test_check_health_exception(): - with patch( - "src.presentation.api.routes.settings_router.Connector" - ) as mock_connector: + with patch("src.presentation.api.routes.settings_router.Connector") as mock_connector: mock_connector.side_effect = Exception("Connection failed") response = client.get("/rest/settings/check/sql") assert response.status_code == 200 diff --git a/tests/presentation/api/routes/test_source_router.py b/tests/presentation/api/routes/test_source_router.py index 9e838dae..c07a7a7a 100644 --- a/tests/presentation/api/routes/test_source_router.py +++ b/tests/presentation/api/routes/test_source_router.py @@ -80,9 +80,7 @@ def test_get_model_info_success(mock_model_loader): def test_get_model_info_error(mock_model_loader): # To trigger the router's internal try-except block, we mock an attribute access error. - type(mock_model_loader).model_name = property( - lambda x: exec('raise(Exception("attr fail"))') - ) + type(mock_model_loader).model_name = property(lambda x: exec('raise(Exception("attr fail"))')) response = client.get("/rest/sources/model") assert response.status_code == 500 diff --git a/tests/presentation/api/routes/test_subject_router.py b/tests/presentation/api/routes/test_subject_router.py index 737c6d5e..35ece52d 100644 --- a/tests/presentation/api/routes/test_subject_router.py +++ b/tests/presentation/api/routes/test_subject_router.py @@ -35,9 +35,7 @@ def test_create_subject_success(self, mock_ks_service): icon="icon", ) - response = client.post( - "/rest/subjects", json={"name": "test", "description": "desc"} - ) + response = client.post("/rest/subjects", json={"name": "test", "description": "desc"}) assert response.status_code == 201 mock_ks_service.create_subject.assert_called_once() diff --git a/tests/presentation/api/routes/test_voice_profile_router.py b/tests/presentation/api/routes/test_voice_profile_router.py index 3ad871e6..93a70fc2 100644 --- a/tests/presentation/api/routes/test_voice_profile_router.py +++ b/tests/presentation/api/routes/test_voice_profile_router.py @@ -22,14 +22,10 @@ class TestVoiceProfileRouter: def test_register_voice_profile_success(self): app.dependency_overrides[get_db] = lambda: MagicMock() mock_use_case = MagicMock() - app.dependency_overrides[get_register_voice_profile_use_case] = lambda: ( - mock_use_case - ) + app.dependency_overrides[get_register_voice_profile_use_case] = lambda: mock_use_case mock_use_case.execute.return_value = "v-123" - response = client.post( - "/rest/voices", json={"name": "Alice", "audio_path": "s3://path"} - ) + response = client.post("/rest/voices", json={"name": "Alice", "audio_path": "s3://path"}) assert response.status_code == 200 assert response.json()["voice_id"] == "v-123" @@ -66,14 +62,52 @@ def test_train_from_speaker_success(self): app.dependency_overrides.clear() + def test_train_from_speaker_not_found(self): + # 1. Mock dependencies + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_repo = MagicMock() + app.dependency_overrides[get_diarization_repo] = lambda: mock_repo + mock_repo.get_by_id.return_value = None + + # 2. Execute request + payload = { + "diarization_id": "non-existent", + "speaker_label": "SPEAKER_00", + "name": "Bob", + } + response = client.post("/rest/voices/train-from-speaker", json=payload) + + # 3. Assert status and body + assert response.status_code == 404 + assert "Diarization not found" in response.json()["detail"] + + app.dependency_overrides.clear() + + def test_train_from_speaker_error(self): + # 1. Mock dependencies + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_repo = MagicMock() + app.dependency_overrides[get_diarization_repo] = lambda: mock_repo + mock_repo.get_by_id.side_effect = Exception("System failure") + + # 2. Execute request + payload = { + "diarization_id": "d-1", + "speaker_label": "SPEAKER_00", + "name": "Bob", + } + response = client.post("/rest/voices/train-from-speaker", json=payload) + + # 3. Assert status and body + assert response.status_code == 500 + assert "System failure" in response.json()["detail"] + app.dependency_overrides.clear() def test_list_voices(self): app.dependency_overrides[get_db] = lambda: MagicMock() mock_use_case = MagicMock() - app.dependency_overrides[get_list_voice_profiles_use_case] = lambda: ( - mock_use_case - ) + app.dependency_overrides[get_list_voice_profiles_use_case] = lambda: mock_use_case mock_use_case.execute.return_value = [{"name": "Alice"}] response = client.get("/rest/voices") @@ -85,9 +119,7 @@ def test_list_voices(self): def test_delete_voice_success(self): app.dependency_overrides[get_db] = lambda: MagicMock() mock_use_case = MagicMock() - app.dependency_overrides[get_delete_voice_profile_use_case] = lambda: ( - mock_use_case - ) + app.dependency_overrides[get_delete_voice_profile_use_case] = lambda: mock_use_case mock_use_case.execute.return_value = None response = client.delete("/rest/voices/Alice") @@ -99,9 +131,7 @@ def test_delete_voice_success(self): def test_delete_voice_not_found(self): app.dependency_overrides[get_db] = lambda: MagicMock() mock_use_case = MagicMock() - app.dependency_overrides[get_delete_voice_profile_use_case] = lambda: ( - mock_use_case - ) + app.dependency_overrides[get_delete_voice_profile_use_case] = lambda: mock_use_case mock_use_case.execute.side_effect = KeyError("not found") response = client.delete("/rest/voices/Unknown") @@ -112,9 +142,7 @@ def test_delete_voice_not_found(self): def test_upload_voice_profile_success(self): app.dependency_overrides[get_db] = lambda: MagicMock() mock_use_case = MagicMock() - app.dependency_overrides[get_register_voice_profile_use_case] = lambda: ( - mock_use_case - ) + app.dependency_overrides[get_register_voice_profile_use_case] = lambda: mock_use_case mock_use_case.execute.return_value = "v-123" @@ -139,11 +167,100 @@ def test_upload_voice_profile_success(self): app.dependency_overrides.clear() + def test_upload_voice_profile_no_filename(self): + app.dependency_overrides[get_db] = lambda: MagicMock() + + from io import BytesIO + + file_content = b"fake audio content" + # Create a file object with an empty filename + # Using None as filename can sometimes bypass FastAPI/HTTPX validation + # and result in an UploadFile with filename=None or "" + files = {"file": (" ", BytesIO(file_content), "audio/wav")} + data = {"name": "Alice"} + + response = client.post("/rest/voices/upload", data=data, files=files) + + assert response.status_code == 400 + assert "No filename provided" in response.json()["detail"] + + app.dependency_overrides.clear() + + def test_upload_voice_profile_error(self): + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_use_case = MagicMock() + app.dependency_overrides[get_register_voice_profile_use_case] = lambda: mock_use_case + + mock_use_case.execute.side_effect = Exception("Processing error") + + from io import BytesIO + + file_content = b"fake audio content" + files = {"file": ("test.wav", BytesIO(file_content), "audio/wav")} + data = {"name": "Alice"} + + with ( + patch("tempfile.mkdtemp", return_value="/tmp/test"), + patch("os.path.exists", return_value=True), + patch("os.remove"), + patch("os.rmdir"), + patch("anyio.open_file", side_effect=Exception("Async write error")), + ): + response = client.post("/rest/voices/upload", data=data, files=files) + + assert response.status_code == 400 + assert "Async write error" in response.json()["detail"] + + app.dependency_overrides.clear() + + def test_list_voice_audio_files(self): + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_use_case = MagicMock() + from src.presentation.api.dependencies import get_list_voice_audio_files_use_case + + app.dependency_overrides[get_list_voice_audio_files_use_case] = lambda: mock_use_case + mock_use_case.execute.return_value = ["audio1.wav", "audio2.wav"] + + response = client.get("/rest/voices/v-123/audios") + assert response.status_code == 200 + assert len(response.json()) == 2 + + app.dependency_overrides.clear() + + def test_delete_voice_audio_file_success(self): + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_use_case = MagicMock() + from src.presentation.api.dependencies import ( + get_delete_voice_audio_file_use_case, + ) + + app.dependency_overrides[get_delete_voice_audio_file_use_case] = lambda: mock_use_case + + response = client.delete("/rest/voices/audios/path/to/audio.wav") + assert response.status_code == 200 + assert response.json()["message"] == "Audio file deleted" + + app.dependency_overrides.clear() + + def test_delete_voice_audio_file_error(self): + app.dependency_overrides[get_db] = lambda: MagicMock() + mock_use_case = MagicMock() + from src.presentation.api.dependencies import ( + get_delete_voice_audio_file_use_case, + ) + + app.dependency_overrides[get_delete_voice_audio_file_use_case] = lambda: mock_use_case + mock_use_case.execute.side_effect = Exception("File not found in S3") + + response = client.delete("/rest/voices/audios/path/to/missing.wav") + assert response.status_code == 404 + assert "File not found in S3" in response.json()["detail"] + + app.dependency_overrides.clear() + def test_get_voice_audio_url_success(self): # Patch the StorageService class inside the router module - with patch( - "src.presentation.api.routes.voice_profile_management_router.StorageService" - ) as mock_storage_cls: + with patch("src.presentation.api.routes.voice_profile_management_router.StorageService") as mock_storage_cls: mock_storage = mock_storage_cls.return_value mock_storage.get_presigned_url.return_value = "http://presigned-url" @@ -151,3 +268,14 @@ def test_get_voice_audio_url_success(self): assert response.status_code == 200 assert response.json()["url"] == "http://presigned-url" + + def test_get_voice_audio_url_error(self): + # Patch the StorageService class inside the router module + with patch("src.presentation.api.routes.voice_profile_management_router.StorageService") as mock_storage_cls: + mock_storage = mock_storage_cls.return_value + mock_storage.get_presigned_url.side_effect = Exception("S3 access denied") + + response = client.get("/rest/voices/audios/path/to/voice.wav") + + assert response.status_code == 404 + assert "S3 access denied" in response.json()["detail"] diff --git a/tests/presentation/api/test_dependencies.py b/tests/presentation/api/test_dependencies.py index b92d9524..bfaa8606 100644 --- a/tests/presentation/api/test_dependencies.py +++ b/tests/presentation/api/test_dependencies.py @@ -40,9 +40,7 @@ def test_get_repositories(self): def test_get_model_loader(self): settings = MagicMock() settings.model_embedding.name = "test-model" - with patch( - "src.infrastructure.services.model_loader_service.SentenceTransformer" - ): + with patch("src.infrastructure.services.model_loader_service.SentenceTransformer"): loader = get_model_loader(settings) assert loader is not None @@ -53,9 +51,7 @@ def test_get_embedding_service(self): def test_get_weaviate_client(self): settings = MagicMock() - with patch( - "src.infrastructure.repositories.vector.weaviate.weaviate_client.WeaviateClient" - ) as _: + with patch("src.infrastructure.repositories.vector.weaviate.weaviate_client.WeaviateClient") as _: client = get_weaviate_client(settings) assert client is not None @@ -66,12 +62,8 @@ def test_get_vector_repository_weaviate(self): loader = MagicMock() with ( - patch( - "src.infrastructure.repositories.vector.weaviate.weaviate_client.WeaviateClient" - ), - patch( - "src.infrastructure.repositories.vector.weaviate.chunk_repository.ChunkWeaviateRepository" - ), + patch("src.infrastructure.repositories.vector.weaviate.weaviate_client.WeaviateClient"), + patch("src.infrastructure.repositories.vector.weaviate.chunk_repository.ChunkWeaviateRepository"), ): repo = get_vector_repository(settings, loader) assert repo is not None @@ -84,9 +76,7 @@ def test_get_vector_repository_chroma(self): settings.vector.collection_name_chunks = "test_collection" loader = MagicMock() - with patch( - "src.infrastructure.repositories.vector.chroma.chunk_repository.ChunkChromaRepository" - ): + with patch("src.infrastructure.repositories.vector.chroma.chunk_repository.ChunkChromaRepository"): repo = get_vector_repository(settings, loader) assert repo is not None @@ -96,9 +86,7 @@ def test_get_vector_repository_faiss(self): settings.vector.vector_index_path = "test_path" loader = MagicMock() - with patch( - "src.infrastructure.repositories.vector.faiss.chunk_repository.ChunkFAISSRepository" - ): + with patch("src.infrastructure.repositories.vector.faiss.chunk_repository.ChunkFAISSRepository"): repo = get_vector_repository(settings, loader) assert repo is not None