diff --git a/.gitignore b/.gitignore index 564b8d3..6ca20e3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ logs/ # Ignore Testing Coverage Results tests/coverage/.coverage -env/ \ No newline at end of file +env/ +speaker-diarization-community-1/ diff --git a/app/data/speaker_diarization_data.py b/app/data/speaker_diarization_data.py new file mode 100644 index 0000000..1b36d74 --- /dev/null +++ b/app/data/speaker_diarization_data.py @@ -0,0 +1,71 @@ +""" +This module is responsible for handling the speaker diarization data layer. +""" +from app.utils.logger import logger + + +class SpeakerDiarizationDataLayer: + def __init__(self, config: dict): + """ + Initialize the Speaker Diarization Data Layer. + :param config: The configuration object containing model and device info. + """ + self.debug = config.get('debug') + + self.config = config.get('speaker_diarization', {}) + self.default_backend = self.config.get('backend', 'pyannote') + + self.root_config = config + self.model = None + + def _get_model(self): + if self.model is not None: + return self.model + + if self.default_backend == "pyannote": + from app.models.speaker_diarization_model import PyannoteSpeakerDiarization + + self.model = PyannoteSpeakerDiarization(self.root_config) + return self.model + + raise ValueError(f"Unsupported speaker diarization backend: {self.default_backend}") + + @staticmethod + def _normalize_segments(segments: list) -> list: + """ + Normalize backend speaker labels into stable speaker_1..N labels. + """ + normalized_segments = [] + label_mapping = {} + + for segment in sorted(segments, key=lambda item: (item['start'], item['end'], item['speaker'])): + speaker = segment['speaker'] + if speaker not in label_mapping: + label_mapping[speaker] = f"speaker_{len(label_mapping) + 1}" + + normalized_segments.append({ + 'speaker': label_mapping[speaker], + 'start': float(segment['start']), + 'end': float(segment['end']), + }) + + return normalized_segments + + def diarize(self, audio_file_path: str) -> dict: + """ + Process the audio file and return speaker diarization segments. + :param audio_file_path: Path to the audio file. + :return: Speaker diarization segments. + """ + try: + raw_segments = self._get_model()(audio_file_path) + return { + 'segments': self._normalize_segments(raw_segments) + } + + except Exception as e: + logger.warning( + f"[warning] [Data Layer] [SpeakerDiarizationDataLayer] [diarize] " + f"Speaker diarization unavailable: {str(e)}" + ) + return {'error': 'Speaker diarization is unavailable.'} diff --git a/app/models/speaker_diarization_model.py b/app/models/speaker_diarization_model.py new file mode 100644 index 0000000..857a05c --- /dev/null +++ b/app/models/speaker_diarization_model.py @@ -0,0 +1,171 @@ +""" +This module defines the PyannoteSpeakerDiarization model wrapper. +""" +import os +from contextlib import contextmanager +from pathlib import Path + +import numpy as np +import torch +from pydub import AudioSegment + + +class PyannoteSpeakerDiarization: + def __init__(self, config: dict) -> None: + """ + Initialize the pyannote speaker diarization pipeline lazily. + :param config: The configuration object containing speaker diarization settings. + """ + self.debug = config.get('debug') + + self.config = config.get('speaker_diarization', {}) + self.local_model_path = self.config.get('local_model_path') + self.model_name = self.config.get('model_name', 'pyannote/speaker-diarization-community-1') + self.token_env = self.config.get('token_env', 'HUGGINGFACE_TOKEN') + + self.pipeline = None + + @staticmethod + def _resolve_local_model_path(local_model_path: str | None) -> Path | None: + """ + Resolve a configured local model path relative to the repository root when needed. + """ + if not local_model_path: + return None + + path = Path(local_model_path) + if not path.is_absolute(): + repo_root = Path(__file__).resolve().parents[2] + path = repo_root / path + + if path.is_dir(): + config_path = path / "config.yaml" + if config_path.exists(): + path = config_path + + return path.resolve() + + def _load_pipeline(self): + """ + Lazy-load the pyannote pipeline so app startup does not fail when diarization is unavailable. + """ + if self.pipeline is not None: + return self.pipeline + + self._patch_torchaudio_for_speechbrain() + + try: + from pyannote.audio import Pipeline + except Exception as exc: + raise RuntimeError( + "pyannote.audio is not installed or could not be imported." + ) from exc + + local_model_path = self._resolve_local_model_path(self.local_model_path) + if local_model_path and local_model_path.exists(): + with self._legacy_torch_load(): + self.pipeline = Pipeline.from_pretrained(local_model_path) + return self.pipeline + + token = os.getenv(self.token_env) + if not token: + if local_model_path: + raise RuntimeError( + f"Local speaker diarization model not found at: {local_model_path}. " + f"Either restore that directory or set {self.token_env} for Hugging Face loading." + ) + raise RuntimeError( + f"Missing Hugging Face token for speaker diarization. " + f"Set the {self.token_env} environment variable." + ) + + try: + self.pipeline = Pipeline.from_pretrained(self.model_name, token=token) + except TypeError: + # Backward compatibility with older pyannote releases. + self.pipeline = Pipeline.from_pretrained(self.model_name, use_auth_token=token) + + return self.pipeline + + @staticmethod + @contextmanager + def _legacy_torch_load(): + """ + Newer torch defaults to weights_only=True, but the trusted local + pyannote checkpoints still require full checkpoint deserialization. + """ + original_torch_load = torch.load + + def patched_torch_load(*args, **kwargs): + if kwargs.get("weights_only") is None: + kwargs["weights_only"] = False + return original_torch_load(*args, **kwargs) + + torch.load = patched_torch_load + try: + yield + finally: + torch.load = original_torch_load + + @staticmethod + def _patch_torchaudio_for_speechbrain() -> None: + """ + SpeechBrain still expects torchaudio.list_audio_backends on import, but + newer torchaudio builds can omit it on Windows. Provide a minimal shim + so local pyannote loading can proceed. + """ + try: + import torchaudio + except Exception: + return + + if not hasattr(torchaudio, "list_audio_backends"): + torchaudio.list_audio_backends = lambda: ["ffmpeg"] + + @staticmethod + def _load_audio_input(audio_file_path: str) -> dict: + """ + Preload audio into memory so pyannote does not rely on torchaudio/torchcodec + decoding on this machine. + """ + audio = AudioSegment.from_file(audio_file_path).set_sample_width(2) + + samples = np.array(audio.get_array_of_samples(), dtype=np.float32) + if audio.channels > 1: + samples = samples.reshape((-1, audio.channels)).T + else: + samples = samples.reshape((1, -1)) + + waveform = torch.from_numpy(samples / 32768.0) + return { + "waveform": waveform, + "sample_rate": audio.frame_rate, + } + + def __call__(self, audio_file_path: str) -> list: + """ + Perform speaker diarization on the given audio file. + :param audio_file_path: Path to the audio file. + :return: Raw speaker diarization segments. + """ + pipeline = self._load_pipeline() + output = pipeline(self._load_audio_input(audio_file_path)) + + diarization = getattr(output, 'exclusive_speaker_diarization', None) + if diarization is None: + diarization = getattr(output, 'speaker_diarization', output) + + segments = [] + for turn, _, speaker in diarization.itertracks(yield_label=True): + start = float(turn.start) + end = float(turn.end) + if end <= start: + continue + + segments.append({ + 'speaker': str(speaker), + 'start': start, + 'end': end, + }) + + return segments diff --git a/app/services/local_audio_transcription_sentiment_pipeline.py b/app/services/local_audio_transcription_sentiment_pipeline.py new file mode 100644 index 0000000..11c2f3f --- /dev/null +++ b/app/services/local_audio_transcription_sentiment_pipeline.py @@ -0,0 +1,117 @@ +""" +Local-only audio -> speaker -> transcript -> sentiment pipeline used by test.py. +This keeps the original Flask pipeline untouched. +""" +import copy +import os + +from app.config import Config +from app.services.audio_service import AudioService +from app.services.sentiment_service import SentimentService +from app.services.speaker_diarization_service import SpeakerDiarizationService +from app.services.transcript_service import TranscriptService +from app.utils.logger import logger + + +class LocalAudioTranscriptionSentimentPipeline: + def __init__(self, local_model_path: str = "speaker-diarization-community-1"): + root_config = copy.deepcopy(Config().config) + speaker_config = root_config.setdefault("speaker_diarization", {}) + speaker_config.setdefault("enabled", True) + speaker_config.setdefault("backend", "pyannote") + speaker_config.setdefault("local_model_path", local_model_path) + speaker_config.setdefault("model_name", "pyannote/speaker-diarization-community-1") + speaker_config.setdefault("token_env", "HUGGINGFACE_TOKEN") + + self.debug = root_config.get("debug") + self.remove_audio = root_config.get("audio_transcription_sentiment_pipeline", {}).get("remove_audio", False) + + self.audio_service = AudioService() + self.speaker_diarization_service = SpeakerDiarizationService(root_config) + self.transcript_service = TranscriptService() + self.sentiment_service = SentimentService() + + @staticmethod + def _resolve_chunk_speaker(chunk_timestamp, speaker_segments: list) -> str: + if not chunk_timestamp or len(chunk_timestamp) != 2: + return "UNKNOWN" + + chunk_start = float(chunk_timestamp[0]) + chunk_end = float(chunk_timestamp[1]) + + best_speaker = "UNKNOWN" + best_overlap = 0.0 + best_start = float("inf") + + for segment in speaker_segments: + overlap = max( + 0.0, + min(chunk_end, float(segment["end"])) - max(chunk_start, float(segment["start"])) + ) + segment_start = float(segment["start"]) + + if overlap > best_overlap or (overlap == best_overlap and overlap > 0 and segment_start < best_start): + best_overlap = overlap + best_speaker = segment["speaker"] + best_start = segment_start + + return best_speaker if best_overlap > 0 else "UNKNOWN" + + def _assign_speakers_to_chunks(self, chunks: list, speaker_segments: list) -> list: + for chunk in chunks: + chunk["speaker"] = self._resolve_chunk_speaker(chunk.get("timestamp"), speaker_segments) + return chunks + + def process(self, url: str, start_time_ms: int, end_time_ms: int = None, user_id: str = None) -> dict: + try: + audio_result = self.audio_service.extract_audio(url, start_time_ms, end_time_ms, user_id) + if isinstance(audio_result, dict) and "error" in audio_result: + return {"error": audio_result["error"]} + + audio_path = audio_result["audio_path"] + start_time_ms = audio_result["start_time_ms"] + end_time_ms = audio_result["end_time_ms"] + + speaker_result = self.speaker_diarization_service.diarize(audio_path) + speaker_segments = [] + if isinstance(speaker_result, dict) and "error" in speaker_result: + logger.warning( + "[warning] [LocalAudioTranscriptionSentimentPipeline] [process] " + "Speaker diarization unavailable. Falling back to UNKNOWN speaker labels. " + f"Details: {speaker_result['error']}" + ) + else: + speaker_segments = speaker_result.get("segments", []) + + transcription_result = self.transcript_service.transcribe(audio_path) + if isinstance(transcription_result, dict) and "error" in transcription_result: + return {"error": transcription_result["error"]} + + transcription = transcription_result["transcription"] + chunks = self._assign_speakers_to_chunks(transcription_result["chunks"], speaker_segments) + + if self.remove_audio: + os.remove(audio_path) + + for chunk in chunks: + sentiment_result = self.sentiment_service.analyze(chunk["text"]) + if isinstance(sentiment_result, dict) and "error" in sentiment_result: + chunk["error"] = sentiment_result["error"] + continue + + chunk["label"] = sentiment_result["label"] + chunk["confidence"] = sentiment_result["confidence"] + + return { + "audio_path": audio_path, + "start_time_ms": start_time_ms, + "end_time_ms": end_time_ms, + "transcription": transcription, + "utterances_sentiment": chunks, + } + except Exception as exc: + logger.error( + "[error] [LocalAudioTranscriptionSentimentPipeline] [process] " + f"An error occurred during processing: {str(exc)}" + ) + return {"error": "An unexpected error occurred while processing the request."} diff --git a/app/services/speaker_diarization_service.py b/app/services/speaker_diarization_service.py new file mode 100644 index 0000000..e188d0f --- /dev/null +++ b/app/services/speaker_diarization_service.py @@ -0,0 +1,49 @@ +""" +This module contains the service layer for speaker diarization. +""" +import os + +from app.config import Config +from app.data.speaker_diarization_data import SpeakerDiarizationDataLayer +from app.utils.logger import logger + + +class SpeakerDiarizationService: + def __init__(self, config: dict | None = None): + self.root_config = config or Config().config + self.debug = self.root_config.get('debug') + + self.config = self.root_config.get('speaker_diarization', {}) + self.enabled = self.config.get('enabled', True) + + self.speaker_diarization_data_layer = SpeakerDiarizationDataLayer(self.root_config) + + def diarize(self, audio_file_path: str) -> dict: + """ + Perform speaker diarization on the given audio file. + :param audio_file_path: Path to the audio file. + :return: Speaker diarization segments or an error. + """ + try: + if not self.enabled: + return {'segments': []} + + if not os.path.exists(audio_file_path) or not os.path.isfile(audio_file_path): + return {'error': f'Audio file not found: {audio_file_path}'} + + result = self.speaker_diarization_data_layer.diarize(audio_file_path) + if isinstance(result, dict) and 'error' in result: + return { + 'error': result['error'] + } + + return { + 'segments': result['segments'] + } + + except Exception as e: + logger.warning( + f"[warning] [Service Layer] [SpeakerDiarizationService] [diarize] " + f"Speaker diarization unavailable: {str(e)}" + ) + return {'error': 'Speaker diarization is unavailable.'} diff --git a/test.py b/test.py new file mode 100644 index 0000000..bd402b6 --- /dev/null +++ b/test.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +import argparse +import io +import json +import logging +from pathlib import Path +from contextlib import redirect_stderr, redirect_stdout +import warnings + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run local audio transcript sentiment analysis with speaker labels." + ) + parser.add_argument( + "file", + nargs="?", + help="Local audio/video file path. Defaults to samples/sample_1.mp3.", + ) + parser.add_argument( + "--start-ms", + type=int, + default=0, + help="Start time in milliseconds. Default: 0", + ) + parser.add_argument( + "--end-ms", + type=int, + default=None, + help="End time in milliseconds. Default: end of file", + ) + parser.add_argument( + "--json", + action="store_true", + help="Print the full result as JSON instead of a readable summary.", + ) + return parser.parse_args() + + +def resolve_input_file(user_path: str | None) -> Path: + repo_root = Path(__file__).resolve().parent + default_file = repo_root / "samples" / "sample_1.mp3" + + if user_path: + candidate = Path(user_path) + if not candidate.is_absolute(): + candidate = repo_root / candidate + return candidate.resolve() + + return default_file.resolve() + + +def format_chunk_line(index: int, chunk: dict) -> str: + timestamp = chunk.get("timestamp", []) + if len(timestamp) == 2: + start_s = f"{float(timestamp[0]):.2f}s" + end_s = f"{float(timestamp[1]):.2f}s" + else: + start_s = "?" + end_s = "?" + + speaker = chunk.get("speaker", "UNKNOWN") + + if "error" in chunk: + sentiment = f"ERROR: {chunk['error']}" + else: + label = chunk.get("label", "UNKNOWN") + confidence = chunk.get("confidence") + if confidence is None: + sentiment = label + else: + sentiment = f"{label} ({float(confidence):.4f})" + + text = chunk.get("text", "").strip() + return f"[{index}] {start_s} -> {end_s} | {speaker} | {sentiment}\n {text}" + + +def print_summary(audio_file: Path, result: dict) -> None: + print(f"Input file: {audio_file}") + print(f"Audio path: {result.get('audio_path')}") + print(f"Range: {result.get('start_time_ms')}ms -> {result.get('end_time_ms')}ms") + print(f"Transcription: {result.get('transcription', '').strip()}") + print("") + print("Chunks:") + + chunks = result.get("utterances_sentiment", []) + for index, chunk in enumerate(chunks, start=1): + print(format_chunk_line(index, chunk)) + + +def configure_quiet_runtime() -> None: + warnings.filterwarnings( + "ignore", + message=r".*torchcodec is not installed correctly so built-in audio decoding will fail.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*std\(\): degrees of freedom is <= 0.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*The input name `inputs` is deprecated.*", + category=FutureWarning, + ) + warnings.filterwarnings( + "ignore", + message=r".*transcription using a multilingual Whisper will default to language detection.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*Passing a tuple of `past_key_values` is deprecated.*", + ) + warnings.filterwarnings( + "ignore", + message=r".*The attention mask is not set and cannot be inferred.*", + ) + + for logger_name in ("transformers", "pyannote", "speechbrain"): + logging.getLogger(logger_name).setLevel(logging.ERROR) + + try: + from transformers.utils import logging as transformers_logging + + transformers_logging.set_verbosity_error() + except Exception: + pass + + +def run_pipeline_quietly(pipeline, **kwargs) -> dict: + sink = io.StringIO() + with redirect_stdout(sink), redirect_stderr(sink): + return pipeline.process(**kwargs) + + +def main() -> int: + configure_quiet_runtime() + + from app.services.local_audio_transcription_sentiment_pipeline import ( + LocalAudioTranscriptionSentimentPipeline, + ) + + args = parse_args() + audio_file = resolve_input_file(args.file) + + if not audio_file.exists(): + print(f"Input file not found: {audio_file}") + return 1 + + if args.start_ms < 0: + print("--start-ms cannot be negative.") + return 1 + + if args.end_ms is not None and args.end_ms < 0: + print("--end-ms cannot be negative.") + return 1 + + if args.end_ms is not None and args.end_ms <= args.start_ms: + print("--end-ms must be greater than --start-ms.") + return 1 + + pipeline = LocalAudioTranscriptionSentimentPipeline() + result = run_pipeline_quietly( + pipeline, + url=str(audio_file), + start_time_ms=args.start_ms, + end_time_ms=args.end_ms, + ) + + if "error" in result: + print(f"Processing failed: {result['error']}") + return 1 + + if args.json: + print(json.dumps(result, indent=2, ensure_ascii=False)) + else: + print_summary(audio_file, result) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())