diff --git a/app/models/bertweet_model.py b/app/models/bertweet_model.py index 2342c7c..019fc5c 100644 --- a/app/models/bertweet_model.py +++ b/app/models/bertweet_model.py @@ -1,15 +1,21 @@ """ -This module defines the BertweetSentiment class, which is a PyTorch model for sentiment analysis using the Bertweet model. +This module defines the BertweetSentiment class, optimized with ONNX Runtime +for low-latency CPU sentiment analysis using the Bertweet model. """ import torch import torch.nn as nn +import logging -from transformers import AutoTokenizer, AutoModelForSequenceClassification +from transformers import AutoTokenizer +# Injecting Hugging Face Optimum for ONNX Runtime acceleration +from optimum.onnxruntime import ORTModelForSequenceClassification + +logger = logging.getLogger(__name__) class BertweetSentiment(nn.Module): - def __init__(self,config: dict)->None: + def __init__(self, config: dict) -> None: """ - Initialize the Bertweet model for sentiment analysis. + Initialize the ONNX-optimized Bertweet model for sentiment analysis. :param config: The configuration object containing model and device info. """ self.debug = config.get('debug') @@ -19,13 +25,19 @@ def __init__(self,config: dict)->None: self.device = self.config.get('device') super(BertweetSentiment, self).__init__() + + logger.info(f"Initializing ONNX-optimized sentiment model: {self.model_name} on {self.device}") + # Initialize the Tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - # Initialize the Model - self.model= AutoModelForSequenceClassification.from_pretrained(self.model_name) - self.model.to(self.device) - + # Initialize the Model dynamically into an ONNX graph using export=True + # This bypasses the heavy native PyTorch execution + self.model = ORTModelForSequenceClassification.from_pretrained( + self.model_name, + export=True + ) + # Load the model configuration to get class labels self.model_config = self.model.config @@ -35,9 +47,9 @@ def __init__(self,config: dict)->None: else: self.class_labels = None - def forward(self,text)->tuple: + def forward(self, text) -> tuple: """ - Perform sentiment analysis on the given text. + Perform sentiment analysis on the given text using ONNX runtime optimizations. Args: text (str): Input text for sentiment analysis. @@ -48,7 +60,7 @@ def forward(self,text)->tuple: # Tokenize the input text inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device) - # Forward pass + # Forward pass through the ONNX graph outputs = self.model(**inputs) # Convert logits to probabilities @@ -63,30 +75,28 @@ def forward(self,text)->tuple: return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item() -# if __name__ == "__main__": -# config = { -# 'debug': True, -# 'sentiment_analysis': { -# 'default_model': "bertweet", # Specify the default sentiment analysis model (e.g., bertweet, another_model) -# 'bertweet': { -# 'model_name': "finiteautomata/bertweet-base-sentiment-analysis", -# 'device': 'cpu' -# } -# } -# } -# print("config",config) -# model = BertweetSentiment(config) -# print("model",model) -# print("model.class_labels",model.class_labels) - -# text = "I love the new features of the app!" -# print(model(text)) - -# text = "I hate the new features of the app!" -# print(model(text)) - -# text = "Hi how are u?" -# print(model(text)) - -# # Run: -# # python -m app.models.bertweet_model \ No newline at end of file +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + config = { + 'debug': True, + 'sentiment_analysis': { + 'default_model': "bertweet", + 'bertweet': { + 'model_name': "finiteautomata/bertweet-base-sentiment-analysis", + 'device': 'cpu' + } + } + } + print("Testing ONNX Inference Implementation...") + model = BertweetSentiment(config) + print("Model initialized successfully.") + + texts_to_test = [ + "I love the new features of the app!", + "I hate the new features of the app!", + "Hi how are u?" + ] + + for t in texts_to_test: + _, _, label, conf = model(t) + print(f"Text: '{t}' | Label: {label} | Confidence: {conf:.4f}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 40f1e3f..7559930 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,8 @@ transformers==4.48.2 emoji==0.6.0 pydantic==2.10.6 - +optimum>=1.14.0 +onnxruntime>=1.16.0 ## Testing coverage==7.6.10