diff --git a/.gitignore b/.gitignore index 564b8d3..4717a04 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ logs/ # Ignore Testing Coverage Results tests/coverage/.coverage -env/ \ No newline at end of file +env/venv/ +venv/ diff --git a/app/models/bertweet_model.py b/app/models/bertweet_model.py index 2342c7c..5d57d73 100644 --- a/app/models/bertweet_model.py +++ b/app/models/bertweet_model.py @@ -5,6 +5,7 @@ import torch.nn as nn from transformers import AutoTokenizer, AutoModelForSequenceClassification +from typing import Dict, Any class BertweetSentiment(nn.Module): def __init__(self,config: dict)->None: @@ -14,7 +15,16 @@ def __init__(self,config: dict)->None: """ self.debug = config.get('debug') - self.config = config.get('sentiment_analysis').get('bertweet') + # ✅ Add null check + sentiment_config = config.get('sentiment_analysis') + if not sentiment_config: + raise ValueError("'sentiment_analysis' not found in config") + + self.config = sentiment_config.get('bertweet') + if not self.config: + raise ValueError("'bertweet' not found in sentiment_analysis config") + + self.model_name = self.config.get('model_name') self.device = self.config.get('device') @@ -22,20 +32,20 @@ def __init__(self,config: dict)->None: # Initialize the Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - # Initialize the Model + # Initializing the Model self.model= AutoModelForSequenceClassification.from_pretrained(self.model_name) self.model.to(self.device) - # Load the model configuration to get class labels + # Loading the model configuration to get class labels self.model_config = self.model.config - # Get Labels + # Geting the Labels if hasattr(self.model_config, 'id2label'): self.class_labels = [self.model_config.id2label[i] for i in range(len(self.model_config.id2label))] else: self.class_labels = None - def forward(self,text)->tuple: + def forward(self,text)-> Dict[str, Any]: """ Perform sentiment analysis on the given text. @@ -43,24 +53,38 @@ def forward(self,text)->tuple: text (str): Input text for sentiment analysis. Returns: - tuple: Model outputs, probabilities, predicted label, and confidence score. + Dict: Model outputs, probabilities, predicted label, and confidence score. """ - # Tokenize the input text + # Tokenizing the input text inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) # Forward pass outputs = self.model(**inputs) - # Convert logits to probabilities + # Converting logits to probabilities probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) - # Get the predicted sentiment + # to get the predicted sentiment predicted_class = torch.argmax(probabilities, dim=1).item() - # Get the corresponding class label + + # Converting it to the integer explicitly + predicted_class = int(torch.argmax(probabilities, dim=1).item()) + + # Adding a null check + if self.class_labels is None: + raise ValueError("Class labels not available") + + + # Geting the corresponding class label predicted_label = self.class_labels[predicted_class] - return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item() + return { + "logits": outputs.logits.tolist(), + "probabilities": probabilities.tolist(), + "label": predicted_label, + "score": probabilities[0][predicted_class].item() +} # if __name__ == "__main__": diff --git a/app/models/whisper_model.py b/app/models/whisper_model.py index 9217bf2..c2ea654 100644 --- a/app/models/whisper_model.py +++ b/app/models/whisper_model.py @@ -5,6 +5,7 @@ import torch.nn as nn from transformers import pipeline +from typing import Dict, Any class WhisperTranscript(nn.Module): @@ -15,7 +16,14 @@ def __init__(self, config: dict) -> None: """ self.debug = config.get('debug') - self.config = config.get('transcription').get('whisper') + transcription_config = config.get('transcription') + if not transcription_config: + raise ValueError("'transcription' not found in config") + + self.config = transcription_config.get('whisper') + if not self.config: + raise ValueError("'whisper' not found in transcription config") + self.model_size = self.config.get('model_size') self.device = self.config.get('device') self.chunk_length_s = self.config.get('chunk_length_s') @@ -32,7 +40,7 @@ def __init__(self, config: dict) -> None: ) - def forward(self, audio_file: str) -> tuple: + def forward(self, audio_file: str) -> Dict[str, Any]: """ Perform transcription on the given audio file. @@ -40,12 +48,30 @@ def forward(self, audio_file: str) -> tuple: audio_file (str): Path to the audio file. Returns: - tuple: Transcribed text and timestamped chunks. + Dict: Transcribed text and timestamped chunks. """ # Forward pass out = self.pipeline(audio_file, return_timestamps=True) + + # Initialize to avoid "possibly unbound" error + text = "" + chunks = [] + - return out["text"], out["chunks"] + # Extracting the text and chunks safely + if isinstance(out, dict): + text = out.get("text", "") + chunks = out.get("chunks", []) + else: + # For dict-like objects (not necessarily dict type) + text = getattr(out, "text", "") + chunks = getattr(out, "chunks", []) + + return { + "text": text, + "chunks": chunks +} + # if __name__ == "__main__": # config = {