Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ logs/
# Ignore Testing Coverage Results
tests/coverage/.coverage

env/
env/
speaker-diarization-community-1/
71 changes: 71 additions & 0 deletions app/data/speaker_diarization_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
This module is responsible for handling the speaker diarization data layer.
"""
from app.utils.logger import logger


class SpeakerDiarizationDataLayer:
def __init__(self, config: dict):
"""
Initialize the Speaker Diarization Data Layer.
:param config: The configuration object containing model and device info.
"""
self.debug = config.get('debug')

self.config = config.get('speaker_diarization', {})
self.default_backend = self.config.get('backend', 'pyannote')

self.root_config = config
self.model = None

def _get_model(self):
if self.model is not None:
return self.model

if self.default_backend == "pyannote":
from app.models.speaker_diarization_model import PyannoteSpeakerDiarization

self.model = PyannoteSpeakerDiarization(self.root_config)
return self.model

raise ValueError(f"Unsupported speaker diarization backend: {self.default_backend}")

@staticmethod
def _normalize_segments(segments: list) -> list:
"""
Normalize backend speaker labels into stable speaker_1..N labels.
"""
normalized_segments = []
label_mapping = {}

for segment in sorted(segments, key=lambda item: (item['start'], item['end'], item['speaker'])):
speaker = segment['speaker']
if speaker not in label_mapping:
label_mapping[speaker] = f"speaker_{len(label_mapping) + 1}"

normalized_segments.append({
'speaker': label_mapping[speaker],
'start': float(segment['start']),
'end': float(segment['end']),
})

return normalized_segments

def diarize(self, audio_file_path: str) -> dict:
"""
Process the audio file and return speaker diarization segments.
:param audio_file_path: Path to the audio file.
:return: Speaker diarization segments.
"""
try:
raw_segments = self._get_model()(audio_file_path)
return {
'segments': self._normalize_segments(raw_segments)
}

except Exception as e:
logger.warning(
f"[warning] [Data Layer] [SpeakerDiarizationDataLayer] [diarize] "
f"Speaker diarization unavailable: {str(e)}"
)
return {'error': 'Speaker diarization is unavailable.'}
171 changes: 171 additions & 0 deletions app/models/speaker_diarization_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
This module defines the PyannoteSpeakerDiarization model wrapper.
"""
import os
from contextlib import contextmanager
from pathlib import Path

import numpy as np
import torch
from pydub import AudioSegment


class PyannoteSpeakerDiarization:
def __init__(self, config: dict) -> None:
"""
Initialize the pyannote speaker diarization pipeline lazily.
:param config: The configuration object containing speaker diarization settings.
"""
self.debug = config.get('debug')

self.config = config.get('speaker_diarization', {})
self.local_model_path = self.config.get('local_model_path')
self.model_name = self.config.get('model_name', 'pyannote/speaker-diarization-community-1')
self.token_env = self.config.get('token_env', 'HUGGINGFACE_TOKEN')

self.pipeline = None

@staticmethod
def _resolve_local_model_path(local_model_path: str | None) -> Path | None:
"""
Resolve a configured local model path relative to the repository root when needed.
"""
if not local_model_path:
return None

path = Path(local_model_path)
if not path.is_absolute():
repo_root = Path(__file__).resolve().parents[2]
path = repo_root / path

if path.is_dir():
config_path = path / "config.yaml"
if config_path.exists():
path = config_path

return path.resolve()

def _load_pipeline(self):
"""
Lazy-load the pyannote pipeline so app startup does not fail when diarization is unavailable.
"""
if self.pipeline is not None:
return self.pipeline

self._patch_torchaudio_for_speechbrain()

try:
from pyannote.audio import Pipeline
except Exception as exc:
raise RuntimeError(
"pyannote.audio is not installed or could not be imported."
) from exc

local_model_path = self._resolve_local_model_path(self.local_model_path)
if local_model_path and local_model_path.exists():
with self._legacy_torch_load():
self.pipeline = Pipeline.from_pretrained(local_model_path)
return self.pipeline

token = os.getenv(self.token_env)
if not token:
if local_model_path:
raise RuntimeError(
f"Local speaker diarization model not found at: {local_model_path}. "
f"Either restore that directory or set {self.token_env} for Hugging Face loading."
)
raise RuntimeError(
f"Missing Hugging Face token for speaker diarization. "
f"Set the {self.token_env} environment variable."
)

try:
self.pipeline = Pipeline.from_pretrained(self.model_name, token=token)
except TypeError:
# Backward compatibility with older pyannote releases.
self.pipeline = Pipeline.from_pretrained(self.model_name, use_auth_token=token)

return self.pipeline

@staticmethod
@contextmanager
def _legacy_torch_load():
"""
Newer torch defaults to weights_only=True, but the trusted local
pyannote checkpoints still require full checkpoint deserialization.
"""
original_torch_load = torch.load

def patched_torch_load(*args, **kwargs):
if kwargs.get("weights_only") is None:
kwargs["weights_only"] = False
return original_torch_load(*args, **kwargs)

torch.load = patched_torch_load
try:
yield
finally:
torch.load = original_torch_load

@staticmethod
def _patch_torchaudio_for_speechbrain() -> None:
"""
SpeechBrain still expects torchaudio.list_audio_backends on import, but
newer torchaudio builds can omit it on Windows. Provide a minimal shim
so local pyannote loading can proceed.
"""
try:
import torchaudio
except Exception:
return

if not hasattr(torchaudio, "list_audio_backends"):
torchaudio.list_audio_backends = lambda: ["ffmpeg"]

@staticmethod
def _load_audio_input(audio_file_path: str) -> dict:
"""
Preload audio into memory so pyannote does not rely on torchaudio/torchcodec
decoding on this machine.
"""
audio = AudioSegment.from_file(audio_file_path).set_sample_width(2)

samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
if audio.channels > 1:
samples = samples.reshape((-1, audio.channels)).T
else:
samples = samples.reshape((1, -1))

waveform = torch.from_numpy(samples / 32768.0)
return {
"waveform": waveform,
"sample_rate": audio.frame_rate,
}

def __call__(self, audio_file_path: str) -> list:
"""
Perform speaker diarization on the given audio file.
:param audio_file_path: Path to the audio file.
:return: Raw speaker diarization segments.
"""
pipeline = self._load_pipeline()
output = pipeline(self._load_audio_input(audio_file_path))

diarization = getattr(output, 'exclusive_speaker_diarization', None)
if diarization is None:
diarization = getattr(output, 'speaker_diarization', output)

segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
start = float(turn.start)
end = float(turn.end)
if end <= start:
continue

segments.append({
'speaker': str(speaker),
'start': start,
'end': end,
})

return segments
117 changes: 117 additions & 0 deletions app/services/local_audio_transcription_sentiment_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Local-only audio -> speaker -> transcript -> sentiment pipeline used by test.py.
This keeps the original Flask pipeline untouched.
"""
import copy
import os

from app.config import Config
from app.services.audio_service import AudioService
from app.services.sentiment_service import SentimentService
from app.services.speaker_diarization_service import SpeakerDiarizationService
from app.services.transcript_service import TranscriptService
from app.utils.logger import logger


class LocalAudioTranscriptionSentimentPipeline:
def __init__(self, local_model_path: str = "speaker-diarization-community-1"):
root_config = copy.deepcopy(Config().config)
speaker_config = root_config.setdefault("speaker_diarization", {})
speaker_config.setdefault("enabled", True)
speaker_config.setdefault("backend", "pyannote")
speaker_config.setdefault("local_model_path", local_model_path)
speaker_config.setdefault("model_name", "pyannote/speaker-diarization-community-1")
speaker_config.setdefault("token_env", "HUGGINGFACE_TOKEN")

self.debug = root_config.get("debug")
self.remove_audio = root_config.get("audio_transcription_sentiment_pipeline", {}).get("remove_audio", False)

self.audio_service = AudioService()
self.speaker_diarization_service = SpeakerDiarizationService(root_config)
self.transcript_service = TranscriptService()
self.sentiment_service = SentimentService()

@staticmethod
def _resolve_chunk_speaker(chunk_timestamp, speaker_segments: list) -> str:
if not chunk_timestamp or len(chunk_timestamp) != 2:
return "UNKNOWN"

chunk_start = float(chunk_timestamp[0])
chunk_end = float(chunk_timestamp[1])

best_speaker = "UNKNOWN"
best_overlap = 0.0
best_start = float("inf")

for segment in speaker_segments:
overlap = max(
0.0,
min(chunk_end, float(segment["end"])) - max(chunk_start, float(segment["start"]))
)
segment_start = float(segment["start"])

if overlap > best_overlap or (overlap == best_overlap and overlap > 0 and segment_start < best_start):
best_overlap = overlap
best_speaker = segment["speaker"]
best_start = segment_start

return best_speaker if best_overlap > 0 else "UNKNOWN"

def _assign_speakers_to_chunks(self, chunks: list, speaker_segments: list) -> list:
for chunk in chunks:
chunk["speaker"] = self._resolve_chunk_speaker(chunk.get("timestamp"), speaker_segments)
return chunks

def process(self, url: str, start_time_ms: int, end_time_ms: int = None, user_id: str = None) -> dict:
try:
audio_result = self.audio_service.extract_audio(url, start_time_ms, end_time_ms, user_id)
if isinstance(audio_result, dict) and "error" in audio_result:
return {"error": audio_result["error"]}

audio_path = audio_result["audio_path"]
start_time_ms = audio_result["start_time_ms"]
end_time_ms = audio_result["end_time_ms"]

speaker_result = self.speaker_diarization_service.diarize(audio_path)
speaker_segments = []
if isinstance(speaker_result, dict) and "error" in speaker_result:
logger.warning(
"[warning] [LocalAudioTranscriptionSentimentPipeline] [process] "
"Speaker diarization unavailable. Falling back to UNKNOWN speaker labels. "
f"Details: {speaker_result['error']}"
)
else:
speaker_segments = speaker_result.get("segments", [])

transcription_result = self.transcript_service.transcribe(audio_path)
if isinstance(transcription_result, dict) and "error" in transcription_result:
return {"error": transcription_result["error"]}

transcription = transcription_result["transcription"]
chunks = self._assign_speakers_to_chunks(transcription_result["chunks"], speaker_segments)

if self.remove_audio:
os.remove(audio_path)

for chunk in chunks:
sentiment_result = self.sentiment_service.analyze(chunk["text"])
if isinstance(sentiment_result, dict) and "error" in sentiment_result:
chunk["error"] = sentiment_result["error"]
continue

chunk["label"] = sentiment_result["label"]
chunk["confidence"] = sentiment_result["confidence"]

return {
"audio_path": audio_path,
"start_time_ms": start_time_ms,
"end_time_ms": end_time_ms,
"transcription": transcription,
"utterances_sentiment": chunks,
}
except Exception as exc:
logger.error(
"[error] [LocalAudioTranscriptionSentimentPipeline] [process] "
f"An error occurred during processing: {str(exc)}"
)
return {"error": "An unexpected error occurred while processing the request."}
Loading