diff --git a/README.md b/README.md index 5826330..ae88780 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ sentiment-analysis-api/ │ │ └── transcript_data.py │ ├── models/ # Contains the models for sentiment analysis (Whisper, BERTweet) │ │ ├── bertweet_model.py +│ │ ├── roberta_model.py │ │ └── whisper_model.py │ ├── routes/ # Defines the routes for the API │ │ ├── __init__.py diff --git a/app/data/sentiment_data.py b/app/data/sentiment_data.py index 93d469f..7e57209 100644 --- a/app/data/sentiment_data.py +++ b/app/data/sentiment_data.py @@ -3,7 +3,7 @@ """ # Model Layer from app.models.bertweet_model import BertweetSentiment - +from app.models.roberta_model import RoBERTaSentiment from app.utils.logger import logger class SentimentDataLayer: @@ -16,10 +16,13 @@ def __init__(self, config: dict): self.config = config.get('sentiment_analysis') self.default_model = self.config.get('default_model') + print(f"default_model: '{self.default_model}'") # ← ضيف السطر ده هنا # Initialize the appropriate model based on the configuration if self.default_model == "bertweet": self.model = BertweetSentiment(config) + elif self.default_model=="roberta": + self.model = RoBERTaSentiment(config) # elif self.default_model == "another_model": # self.model = AnotherModel(config) # Replace with your other model class else: diff --git a/app/models/roberta_model.py b/app/models/roberta_model.py new file mode 100644 index 0000000..f3267e2 --- /dev/null +++ b/app/models/roberta_model.py @@ -0,0 +1,74 @@ +""" +This module defines the RoBERTaSentiment class, which is a PyTorch model for sentiment analysis using the RoBERTa model. +""" +import torch +import torch.nn as nn + +from transformers import AutoTokenizer, AutoModelForSequenceClassification + +# Mapping from RoBERTa labels to standard labels +LABEL_MAPPING = { + "positive": "POS", + "neutral": "NEU", + "negative": "NEG" +} + +class RoBERTaSentiment(nn.Module): + def __init__(self, config: dict) -> None: + """ + Initialize the RoBERTa model for sentiment analysis. + :param config: The configuration object containing model and device info. + """ + self.debug = config.get('debug') + + self.config = config.get('sentiment_analysis').get('roberta') + self.model_name = self.config.get('model_name') + self.device = self.config.get('device') + + super(RoBERTaSentiment, self).__init__() + + # Initialize the Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + # Initialize the Model + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) + self.model.to(self.device) + + # Load the model configuration to get class labels + self.model_config = self.model.config + + # Get Labels + if hasattr(self.model_config, 'id2label'): + self.class_labels = [self.model_config.id2label[i] for i in range(len(self.model_config.id2label))] + else: + self.class_labels = None + + def forward(self, text) -> tuple: + """ + Perform sentiment analysis on the given text. + + Args: + text (str): Input text for sentiment analysis. + + Returns: + tuple: Model outputs, probabilities, predicted label, and confidence score. + """ + # Tokenize the input text + inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) + + # Forward pass + outputs = self.model(**inputs) + + # Convert logits to probabilities + probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) + + # Get the predicted sentiment + predicted_class = torch.argmax(probabilities, dim=1).item() + + # Get the corresponding class label + raw_label = self.class_labels[predicted_class] + + # Map the label to standard format (POS, NEU, NEG) + predicted_label = LABEL_MAPPING.get(raw_label.lower(), raw_label) + + return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item() \ No newline at end of file diff --git a/config.yaml b/config.yaml index 930c68b..be34133 100644 --- a/config.yaml +++ b/config.yaml @@ -26,15 +26,19 @@ transcription: # Sentiment Analysis Configuration sentiment_analysis: - default_model: "bertweet" # Specify the default sentiment analysis model (e.g., bertweet, another_model) - bertweet: # Vader-specific configuration + default_model: "bertweet" # Specified default model + bertweet: # Bertweet-specific configuration model_name: "finiteautomata/bertweet-base-sentiment-analysis" - device: 'cpu' # `cpu` for CPU, or `cuda` GPU device - # device: 'cuda' # `cpu` for CPU, or `cuda` GPU device + device: 'cpu' + # device: 'cuda' + roberta: # RoBERTa-specific configuration + model_name: "cardiffnlp/twitter-roberta-base-sentiment-latest" + device: 'cpu' + # device: 'cuda' # another_model: # Placeholder for another sentiment analysis model's configuration # api_key: "your_api_key" # endpoint: "https://api.example.com/sentiment" -# AudioTranscriptionSentimentPipeline Configuration + # AudioTranscriptionSentimentPipeline Configuration audio_transcription_sentiment_pipeline: - remove_audio: false # Specify whether to remove audio files after processing \ No newline at end of file + remove_audio: false # Specify whether to remove audio files after processing