Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ logs/
# Ignore Testing Coverage Results
tests/coverage/.coverage

env/
env/venv/
venv/
46 changes: 35 additions & 11 deletions app/models/bertweet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, Any

class BertweetSentiment(nn.Module):
def __init__(self,config: dict)->None:
Expand All @@ -14,53 +15,76 @@ def __init__(self,config: dict)->None:
"""
self.debug = config.get('debug')

self.config = config.get('sentiment_analysis').get('bertweet')
# ✅ Add null check
sentiment_config = config.get('sentiment_analysis')
if not sentiment_config:
raise ValueError("'sentiment_analysis' not found in config")

self.config = sentiment_config.get('bertweet')
if not self.config:
raise ValueError("'bertweet' not found in sentiment_analysis config")


self.model_name = self.config.get('model_name')
self.device = self.config.get('device')

super(BertweetSentiment, self).__init__()
# Initialize the Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

# Initialize the Model
# Initializing the Model
self.model= AutoModelForSequenceClassification.from_pretrained(self.model_name)
self.model.to(self.device)

# Load the model configuration to get class labels
# Loading the model configuration to get class labels
self.model_config = self.model.config

# Get Labels
# Geting the Labels
if hasattr(self.model_config, 'id2label'):
self.class_labels = [self.model_config.id2label[i] for i in range(len(self.model_config.id2label))]
else:
self.class_labels = None

def forward(self,text)->tuple:
def forward(self,text)-> Dict[str, Any]:
"""
Perform sentiment analysis on the given text.

Args:
text (str): Input text for sentiment analysis.

Returns:
tuple: Model outputs, probabilities, predicted label, and confidence score.
Dict: Model outputs, probabilities, predicted label, and confidence score.
"""
# Tokenize the input text
# Tokenizing the input text
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device)

# Forward pass
outputs = self.model(**inputs)

# Convert logits to probabilities
# Converting logits to probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)

# Get the predicted sentiment
# to get the predicted sentiment
predicted_class = torch.argmax(probabilities, dim=1).item()

# Get the corresponding class label

# Converting it to the integer explicitly
predicted_class = int(torch.argmax(probabilities, dim=1).item())

# Adding a null check
if self.class_labels is None:
raise ValueError("Class labels not available")


# Geting the corresponding class label
predicted_label = self.class_labels[predicted_class]

return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item()
return {
"logits": outputs.logits.tolist(),
"probabilities": probabilities.tolist(),
"label": predicted_label,
"score": probabilities[0][predicted_class].item()
}


# if __name__ == "__main__":
Expand Down
34 changes: 30 additions & 4 deletions app/models/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn as nn

from transformers import pipeline
from typing import Dict, Any


class WhisperTranscript(nn.Module):
Expand All @@ -15,7 +16,14 @@ def __init__(self, config: dict) -> None:
"""
self.debug = config.get('debug')

self.config = config.get('transcription').get('whisper')
transcription_config = config.get('transcription')
if not transcription_config:
raise ValueError("'transcription' not found in config")

self.config = transcription_config.get('whisper')
if not self.config:
raise ValueError("'whisper' not found in transcription config")

self.model_size = self.config.get('model_size')
self.device = self.config.get('device')
self.chunk_length_s = self.config.get('chunk_length_s')
Expand All @@ -32,20 +40,38 @@ def __init__(self, config: dict) -> None:
)


def forward(self, audio_file: str) -> tuple:
def forward(self, audio_file: str) -> Dict[str, Any]:
"""
Perform transcription on the given audio file.

Args:
audio_file (str): Path to the audio file.

Returns:
tuple: Transcribed text and timestamped chunks.
Dict: Transcribed text and timestamped chunks.
"""
# Forward pass
out = self.pipeline(audio_file, return_timestamps=True)

# Initialize to avoid "possibly unbound" error
text = ""
chunks = []


return out["text"], out["chunks"]
# Extracting the text and chunks safely
if isinstance(out, dict):
text = out.get("text", "")
chunks = out.get("chunks", [])
else:
# For dict-like objects (not necessarily dict type)
text = getattr(out, "text", "")
chunks = getattr(out, "chunks", [])

return {
"text": text,
"chunks": chunks
}


# if __name__ == "__main__":
# config = {
Expand Down