Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions app/models/bertweet_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""
This module defines the BertweetSentiment class, which is a PyTorch model for sentiment analysis using the Bertweet model.
This module defines the BertweetSentiment class, optimized with ONNX Runtime
for low-latency CPU sentiment analysis using the Bertweet model.
"""
import torch
import torch.nn as nn
import logging

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import AutoTokenizer
# Injecting Hugging Face Optimum for ONNX Runtime acceleration
from optimum.onnxruntime import ORTModelForSequenceClassification

logger = logging.getLogger(__name__)

class BertweetSentiment(nn.Module):
def __init__(self,config: dict)->None:
def __init__(self, config: dict) -> None:
"""
Initialize the Bertweet model for sentiment analysis.
Initialize the ONNX-optimized Bertweet model for sentiment analysis.
:param config: The configuration object containing model and device info.
"""
self.debug = config.get('debug')
Expand All @@ -19,13 +25,19 @@ def __init__(self,config: dict)->None:
self.device = self.config.get('device')

super(BertweetSentiment, self).__init__()

logger.info(f"Initializing ONNX-optimized sentiment model: {self.model_name} on {self.device}")

# Initialize the Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

# Initialize the Model
self.model= AutoModelForSequenceClassification.from_pretrained(self.model_name)
self.model.to(self.device)

# Initialize the Model dynamically into an ONNX graph using export=True
# This bypasses the heavy native PyTorch execution
self.model = ORTModelForSequenceClassification.from_pretrained(
self.model_name,
export=True
)

# Load the model configuration to get class labels
self.model_config = self.model.config

Expand All @@ -35,9 +47,9 @@ def __init__(self,config: dict)->None:
else:
self.class_labels = None

def forward(self,text)->tuple:
def forward(self, text) -> tuple:
"""
Perform sentiment analysis on the given text.
Perform sentiment analysis on the given text using ONNX runtime optimizations.

Args:
text (str): Input text for sentiment analysis.
Expand All @@ -48,7 +60,7 @@ def forward(self,text)->tuple:
# Tokenize the input text
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.device)

# Forward pass
# Forward pass through the ONNX graph
outputs = self.model(**inputs)

# Convert logits to probabilities
Expand All @@ -63,30 +75,28 @@ def forward(self,text)->tuple:
return outputs, probabilities, predicted_label, probabilities[0][predicted_class].item()


# if __name__ == "__main__":
# config = {
# 'debug': True,
# 'sentiment_analysis': {
# 'default_model': "bertweet", # Specify the default sentiment analysis model (e.g., bertweet, another_model)
# 'bertweet': {
# 'model_name': "finiteautomata/bertweet-base-sentiment-analysis",
# 'device': 'cpu'
# }
# }
# }
# print("config",config)
# model = BertweetSentiment(config)
# print("model",model)
# print("model.class_labels",model.class_labels)

# text = "I love the new features of the app!"
# print(model(text))

# text = "I hate the new features of the app!"
# print(model(text))

# text = "Hi how are u?"
# print(model(text))

# # Run:
# # python -m app.models.bertweet_model
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
config = {
'debug': True,
'sentiment_analysis': {
'default_model': "bertweet",
'bertweet': {
'model_name': "finiteautomata/bertweet-base-sentiment-analysis",
'device': 'cpu'
}
}
}
print("Testing ONNX Inference Implementation...")
model = BertweetSentiment(config)
print("Model initialized successfully.")

texts_to_test = [
"I love the new features of the app!",
"I hate the new features of the app!",
"Hi how are u?"
]

for t in texts_to_test:
_, _, label, conf = model(t)
print(f"Text: '{t}' | Label: {label} | Confidence: {conf:.4f}")
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ transformers==4.48.2
emoji==0.6.0

pydantic==2.10.6

optimum>=1.14.0
onnxruntime>=1.16.0

## Testing
coverage==7.6.10
Expand Down